mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Fix pre-existing CI lint and test failures (#40)
* chore: add gitignore entries for local eval repos, e2e fixtures, and env files * fix: restore clean bubble_sort_method.py test fixture The call-site ID commit re-contaminated this file with instrumentation decorators, causing tests to fail with missing CODEFLASH_LOOP_INDEX. * fix: resolve ruff and mypy errors in codeflash-python - Add import-not-found ignores for optional torch/jax imports - Extract magic column index to _STDOUT_COLUMN_INDEX constant - Fix unused variable in _instrument_sync.py - Cast cpu_time_ns to int for mypy arg-type * fix: add skip markers for optional deps and apply ruff formatting to tests Skip torch/jax/tensorflow tests when those packages are not installed. Move has_module helper to conftest.py for reuse across test files. Apply ruff format to all test files that drifted. * fix: resolve remaining ruff format and mypy errors - Add missing blank line in conftest.py (ruff format) - Remove unused import-untyped ignore on jax import (mypy unused-ignore) - Add type: ignore comments for object-typed SQLite row values * chore: bump codeflash-python to 0.1.1.dev0
This commit is contained in:
parent
2c9f2ad8de
commit
919a673be2
18 changed files with 172 additions and 175 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -14,3 +14,7 @@ dist-*/
|
||||||
dist-v2/
|
dist-v2/
|
||||||
.playwright-mcp/
|
.playwright-mcp/
|
||||||
.tessl/session-data/
|
.tessl/session-data/
|
||||||
|
evals/repos/
|
||||||
|
packages/codeflash-api/e2e/
|
||||||
|
packages/github-app/.env
|
||||||
|
packages/github-app/codex-config/
|
||||||
|
|
|
||||||
6
packages/codeflash-python/changelogs/fix-ci-lint.md
Normal file
6
packages/codeflash-python/changelogs/fix-ci-lint.md
Normal file
|
|
@ -0,0 +1,6 @@
|
||||||
|
### Fixes
|
||||||
|
|
||||||
|
- Restore clean bubble_sort_method.py test fixture (remove accidental instrumentation decorators)
|
||||||
|
- Add mypy type-ignore comments for optional torch/jax imports and SQLite row types
|
||||||
|
- Add pytest skip markers for tests requiring optional deps (torch, jax, tensorflow)
|
||||||
|
- Fix ruff format and unused mypy ignore comment
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "codeflash-python"
|
name = "codeflash-python"
|
||||||
version = "0.1.0"
|
version = "0.1.1.dev0"
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"codeflash-core",
|
"codeflash-core",
|
||||||
|
|
|
||||||
|
|
@ -130,14 +130,12 @@ def close_all_connections() -> None:
|
||||||
atexit.register(close_all_connections)
|
atexit.register(close_all_connections)
|
||||||
|
|
||||||
|
|
||||||
def detect_device_sync() -> (
|
def detect_device_sync() -> tuple[
|
||||||
tuple[
|
Any | None, # torch_cuda_sync
|
||||||
Any | None, # torch_cuda_sync
|
Any | None, # torch_mps_sync
|
||||||
Any | None, # torch_mps_sync
|
Any | None, # tf_sync
|
||||||
Any | None, # tf_sync
|
bool, # jax_available
|
||||||
bool, # jax_available
|
]:
|
||||||
]
|
|
||||||
):
|
|
||||||
"""Detect available GPU frameworks and return sync callables.
|
"""Detect available GPU frameworks and return sync callables.
|
||||||
|
|
||||||
Called once at first decorator invocation; results are cached.
|
Called once at first decorator invocation; results are cached.
|
||||||
|
|
@ -151,7 +149,7 @@ def detect_device_sync() -> (
|
||||||
jax_available = False
|
jax_available = False
|
||||||
|
|
||||||
if find_spec("torch") is not None:
|
if find_spec("torch") is not None:
|
||||||
import torch # noqa: PLC0415
|
import torch # type: ignore[import-not-found] # noqa: PLC0415
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.cuda.is_initialized():
|
if torch.cuda.is_available() and torch.cuda.is_initialized():
|
||||||
torch_cuda_sync = torch.cuda.synchronize
|
torch_cuda_sync = torch.cuda.synchronize
|
||||||
|
|
@ -165,7 +163,7 @@ def detect_device_sync() -> (
|
||||||
torch_mps_sync = torch.mps.synchronize
|
torch_mps_sync = torch.mps.synchronize
|
||||||
|
|
||||||
if find_spec("jax") is not None:
|
if find_spec("jax") is not None:
|
||||||
import jax # type: ignore[import-untyped] # noqa: PLC0415
|
import jax # type: ignore[import-not-found] # noqa: PLC0415
|
||||||
|
|
||||||
jax_available = hasattr(jax, "block_until_ready")
|
jax_available = hasattr(jax, "block_until_ready")
|
||||||
|
|
||||||
|
|
@ -180,14 +178,12 @@ def detect_device_sync() -> (
|
||||||
return torch_cuda_sync, torch_mps_sync, tf_sync, jax_available
|
return torch_cuda_sync, torch_mps_sync, tf_sync, jax_available
|
||||||
|
|
||||||
|
|
||||||
device_sync_cache: (
|
device_sync_cache: tuple[Any | None, Any | None, Any | None, bool] | None = (
|
||||||
tuple[Any | None, Any | None, Any | None, bool] | None
|
None
|
||||||
) = None
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_device_sync() -> (
|
def get_device_sync() -> tuple[Any | None, Any | None, Any | None, bool]:
|
||||||
tuple[Any | None, Any | None, Any | None, bool]
|
|
||||||
):
|
|
||||||
"""Return cached device sync callables, detecting on first call."""
|
"""Return cached device sync callables, detecting on first call."""
|
||||||
global device_sync_cache # noqa: PLW0603
|
global device_sync_cache # noqa: PLW0603
|
||||||
if device_sync_cache is None:
|
if device_sync_cache is None:
|
||||||
|
|
@ -253,7 +249,9 @@ def codeflash_behavior_sync(func: F) -> F:
|
||||||
invocation_id = f"{call_site}_{call_index}"
|
invocation_id = f"{call_site}_{call_index}"
|
||||||
|
|
||||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
db_path = get_run_tmp_file(
|
||||||
|
Path(f"codeflash_results_{iteration}.sqlite")
|
||||||
|
)
|
||||||
conn, cur = get_async_db(db_path)
|
conn, cur = get_async_db(db_path)
|
||||||
|
|
||||||
exception = None
|
exception = None
|
||||||
|
|
@ -348,7 +346,9 @@ def codeflash_performance_sync(func: F) -> F:
|
||||||
invocation_id = f"{call_site}_{call_index}"
|
invocation_id = f"{call_site}_{call_index}"
|
||||||
|
|
||||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
db_path = get_run_tmp_file(
|
||||||
|
Path(f"codeflash_results_{iteration}.sqlite")
|
||||||
|
)
|
||||||
conn, cur = get_async_db(db_path)
|
conn, cur = get_async_db(db_path)
|
||||||
|
|
||||||
exception = None
|
exception = None
|
||||||
|
|
@ -430,7 +430,9 @@ def codeflash_behavior_async(func: F) -> F:
|
||||||
invocation_id = f"{call_site}_{call_index}"
|
invocation_id = f"{call_site}_{call_index}"
|
||||||
|
|
||||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
db_path = get_run_tmp_file(
|
||||||
|
Path(f"codeflash_results_{iteration}.sqlite")
|
||||||
|
)
|
||||||
conn, cur = get_async_db(db_path)
|
conn, cur = get_async_db(db_path)
|
||||||
|
|
||||||
exception = None
|
exception = None
|
||||||
|
|
@ -522,7 +524,9 @@ def codeflash_performance_async(func: F) -> F:
|
||||||
invocation_id = f"{call_site}_{call_index}"
|
invocation_id = f"{call_site}_{call_index}"
|
||||||
|
|
||||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
db_path = get_run_tmp_file(
|
||||||
|
Path(f"codeflash_results_{iteration}.sqlite")
|
||||||
|
)
|
||||||
conn, cur = get_async_db(db_path)
|
conn, cur = get_async_db(db_path)
|
||||||
|
|
||||||
exception = None
|
exception = None
|
||||||
|
|
@ -590,7 +594,9 @@ def codeflash_concurrency_async(func: F) -> F:
|
||||||
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "0"))
|
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "0"))
|
||||||
|
|
||||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
db_path = get_run_tmp_file(
|
||||||
|
Path(f"codeflash_results_{iteration}.sqlite")
|
||||||
|
)
|
||||||
conn, cur = get_async_db(db_path)
|
conn, cur = get_async_db(db_path)
|
||||||
|
|
||||||
gc.disable()
|
gc.disable()
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STDOUT_COLUMN_INDEX = 10
|
||||||
|
|
||||||
_BEHAVIOR_QUERY = (
|
_BEHAVIOR_QUERY = (
|
||||||
"SELECT test_module_path, test_class_name,"
|
"SELECT test_module_path, test_class_name,"
|
||||||
" test_function_name, function_getting_tested,"
|
" test_function_name, function_getting_tested,"
|
||||||
|
|
@ -110,7 +112,9 @@ def _process_behavior_row_inner(
|
||||||
wall_time_ns = val[6]
|
wall_time_ns = val[6]
|
||||||
cpu_time_ns = val[7]
|
cpu_time_ns = val[7]
|
||||||
verification_type = val[9]
|
verification_type = val[9]
|
||||||
stdout_text = val[10] if len(val) > 10 else None
|
stdout_text = (
|
||||||
|
val[_STDOUT_COLUMN_INDEX] if len(val) > _STDOUT_COLUMN_INDEX else None
|
||||||
|
)
|
||||||
|
|
||||||
test_file_path = file_path_from_module_name(
|
test_file_path = file_path_from_module_name(
|
||||||
test_module_path, # type: ignore[arg-type]
|
test_module_path, # type: ignore[arg-type]
|
||||||
|
|
@ -168,14 +172,14 @@ def _process_behavior_row_inner(
|
||||||
test_framework=test_config.test_framework,
|
test_framework=test_config.test_framework,
|
||||||
test_type=test_type,
|
test_type=test_type,
|
||||||
return_value=ret_val,
|
return_value=ret_val,
|
||||||
cpu_runtime=cpu_time_ns or 0,
|
cpu_runtime=int(cpu_time_ns or 0), # type: ignore[call-overload]
|
||||||
timed_out=False,
|
timed_out=False,
|
||||||
verification_type=(
|
verification_type=(
|
||||||
VerificationType(verification_type)
|
VerificationType(verification_type)
|
||||||
if verification_type
|
if verification_type
|
||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
stdout=stdout_text or None,
|
stdout=stdout_text or None, # type: ignore[arg-type]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class SyncCallInstrumenter(ast.NodeTransformer):
|
||||||
new_body: list[ast.stmt] = []
|
new_body: list[ast.stmt] = []
|
||||||
|
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
call_node, has_target = self._find_target_call(stmt)
|
_, has_target = self._find_target_call(stmt)
|
||||||
|
|
||||||
if has_target:
|
if has_target:
|
||||||
call_site_set = ast.Expr(
|
call_site_set = ast.Expr(
|
||||||
|
|
|
||||||
|
|
@ -1,48 +1,41 @@
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from codeflash_async_wrapper import codeflash_behavior_sync
|
|
||||||
|
|
||||||
from codeflash_python.runtime._codeflash_capture import codeflash_capture
|
|
||||||
|
|
||||||
|
|
||||||
class BubbleSorter:
|
class BubbleSorter:
|
||||||
|
|
||||||
@codeflash_capture(function_name='BubbleSorter.__init__', tmp_dir_path='/var/folders/mg/k_c0twcj37q_gph3cfy3zlt80000gn/T/codeflash_l3k89hc3/codeflash_results', tests_root='/Users/krrt7/Desktop/work/cf_org/codeflash-agent/.claude/worktrees/jaunty-sauteeing-dolphin/packages/codeflash-python/tests/code_to_optimize/tests/pytest', is_fto=True)
|
|
||||||
def __init__(self, x=0):
|
def __init__(self, x=0):
|
||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
@codeflash_behavior_sync
|
|
||||||
def sorter(self, arr):
|
def sorter(self, arr):
|
||||||
print('codeflash stdout : BubbleSorter.sorter() called')
|
print("codeflash stdout : BubbleSorter.sorter() called")
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print('stderr test', file=sys.stderr)
|
print("stderr test", file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sorter_classmethod(cls, arr):
|
def sorter_classmethod(cls, arr):
|
||||||
print('codeflash stdout : BubbleSorter.sorter_classmethod() called')
|
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print('stderr test classmethod', file=sys.stderr)
|
print("stderr test classmethod", file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sorter_staticmethod(arr):
|
def sorter_staticmethod(arr):
|
||||||
print('codeflash stdout : BubbleSorter.sorter_staticmethod() called')
|
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
|
||||||
for i in range(len(arr)):
|
for i in range(len(arr)):
|
||||||
for j in range(len(arr) - 1):
|
for j in range(len(arr) - 1):
|
||||||
if arr[j] > arr[j + 1]:
|
if arr[j] > arr[j + 1]:
|
||||||
temp = arr[j]
|
temp = arr[j]
|
||||||
arr[j] = arr[j + 1]
|
arr[j] = arr[j + 1]
|
||||||
arr[j + 1] = temp
|
arr[j + 1] = temp
|
||||||
print('stderr test staticmethod', file=sys.stderr)
|
print("stderr test staticmethod", file=sys.stderr)
|
||||||
return arr
|
return arr
|
||||||
|
|
|
||||||
|
|
@ -3,10 +3,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def has_module(name: str) -> bool:
|
||||||
|
"""Check whether an optional dependency is importable (for skipif markers)."""
|
||||||
|
return find_spec(name) is not None
|
||||||
|
|
||||||
|
|
||||||
# Make the code_to_optimize fixture package importable by tests that need it
|
# Make the code_to_optimize fixture package importable by tests that need it
|
||||||
# (e.g. test_comparator.py, test_trace_benchmarks.py).
|
# (e.g. test_comparator.py, test_trace_benchmarks.py).
|
||||||
_TESTS_DIR = str(Path(__file__).resolve().parent)
|
_TESTS_DIR = str(Path(__file__).resolve().parent)
|
||||||
|
|
|
||||||
|
|
@ -14,25 +14,25 @@ import dill as pickle
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import codeflash_python.runtime._codeflash_async_decorators as _deco_mod
|
import codeflash_python.runtime._codeflash_async_decorators as _deco_mod
|
||||||
|
|
||||||
from codeflash_python.runtime._codeflash_async_decorators import (
|
from codeflash_python.runtime._codeflash_async_decorators import (
|
||||||
VerificationType,
|
VerificationType,
|
||||||
close_all_connections,
|
|
||||||
_codeflash_call_site,
|
_codeflash_call_site,
|
||||||
connections,
|
close_all_connections,
|
||||||
detect_device_sync,
|
|
||||||
get_async_db,
|
|
||||||
get_device_sync,
|
|
||||||
sync_devices_after,
|
|
||||||
sync_devices_before,
|
|
||||||
codeflash_behavior_async,
|
codeflash_behavior_async,
|
||||||
codeflash_behavior_sync,
|
codeflash_behavior_sync,
|
||||||
codeflash_concurrency_async,
|
codeflash_concurrency_async,
|
||||||
codeflash_performance_async,
|
codeflash_performance_async,
|
||||||
codeflash_performance_sync,
|
codeflash_performance_sync,
|
||||||
|
connections,
|
||||||
|
detect_device_sync,
|
||||||
extract_test_context_from_env,
|
extract_test_context_from_env,
|
||||||
|
get_async_db,
|
||||||
|
get_device_sync,
|
||||||
get_run_tmp_file,
|
get_run_tmp_file,
|
||||||
|
sync_devices_after,
|
||||||
|
sync_devices_before,
|
||||||
)
|
)
|
||||||
|
from conftest import has_module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="env_setup")
|
@pytest.fixture(name="env_setup")
|
||||||
|
|
@ -210,7 +210,9 @@ class TestBehaviorAsync:
|
||||||
con.close()
|
con.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exception_handling(self, env_setup, results_db_path) -> None:
|
async def test_exception_handling(
|
||||||
|
self, env_setup, results_db_path
|
||||||
|
) -> None:
|
||||||
"""Re-raises exceptions and stores them pickled."""
|
"""Re-raises exceptions and stores them pickled."""
|
||||||
|
|
||||||
@codeflash_behavior_async
|
@codeflash_behavior_async
|
||||||
|
|
@ -354,9 +356,7 @@ class TestBehaviorSync:
|
||||||
assert "hello world\n" == row[0]
|
assert "hello world\n" == row[0]
|
||||||
con.close()
|
con.close()
|
||||||
|
|
||||||
def test_no_stdout_leak(
|
def test_no_stdout_leak(self, env_setup, results_db_path, capsys) -> None:
|
||||||
self, env_setup, results_db_path, capsys
|
|
||||||
) -> None:
|
|
||||||
"""Sync decorator does not leak stdout to outer scope."""
|
"""Sync decorator does not leak stdout to outer scope."""
|
||||||
|
|
||||||
@codeflash_behavior_sync
|
@codeflash_behavior_sync
|
||||||
|
|
@ -382,9 +382,7 @@ class TestBehaviorSync:
|
||||||
|
|
||||||
con = sqlite3.connect(results_db_path)
|
con = sqlite3.connect(results_db_path)
|
||||||
cur = con.cursor()
|
cur = con.cursor()
|
||||||
cur.execute(
|
cur.execute("SELECT cpu_time_ns FROM codeflash_results")
|
||||||
"SELECT cpu_time_ns FROM codeflash_results"
|
|
||||||
)
|
|
||||||
row = cur.fetchone()
|
row = cur.fetchone()
|
||||||
assert row[0] is not None
|
assert row[0] is not None
|
||||||
assert 0 <= row[0]
|
assert 0 <= row[0]
|
||||||
|
|
@ -646,7 +644,9 @@ class TestPerformanceAsyncEdgeCases:
|
||||||
"""Edge cases for codeflash_performance_async."""
|
"""Edge cases for codeflash_performance_async."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exception_handling(self, env_setup, results_db_path) -> None:
|
async def test_exception_handling(
|
||||||
|
self, env_setup, results_db_path
|
||||||
|
) -> None:
|
||||||
"""Re-raises exceptions from the wrapped function."""
|
"""Re-raises exceptions from the wrapped function."""
|
||||||
|
|
||||||
@codeflash_performance_async
|
@codeflash_performance_async
|
||||||
|
|
@ -762,6 +762,7 @@ class TestDetectDeviceSync:
|
||||||
result = detect_device_sync()
|
result = detect_device_sync()
|
||||||
assert 4 == len(result)
|
assert 4 == len(result)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_module("torch"), reason="torch not installed")
|
||||||
def test_detects_real_torch(self, reset_device_cache) -> None:
|
def test_detects_real_torch(self, reset_device_cache) -> None:
|
||||||
"""Detects torch and returns a sync callable for the active device."""
|
"""Detects torch and returns a sync callable for the active device."""
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -780,6 +781,7 @@ class TestDetectDeviceSync:
|
||||||
assert cuda_sync is None
|
assert cuda_sync is None
|
||||||
assert mps_sync is None
|
assert mps_sync is None
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||||
def test_detects_real_jax(self, reset_device_cache) -> None:
|
def test_detects_real_jax(self, reset_device_cache) -> None:
|
||||||
"""Detects JAX and sets jax_available based on block_until_ready."""
|
"""Detects JAX and sets jax_available based on block_until_ready."""
|
||||||
import jax
|
import jax
|
||||||
|
|
@ -787,14 +789,16 @@ class TestDetectDeviceSync:
|
||||||
_, _, _, jax_avail = detect_device_sync()
|
_, _, _, jax_avail = detect_device_sync()
|
||||||
assert jax_avail is hasattr(jax, "block_until_ready")
|
assert jax_avail is hasattr(jax, "block_until_ready")
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not has_module("tensorflow"), reason="tensorflow not installed"
|
||||||
|
)
|
||||||
def test_detects_real_tensorflow(self, reset_device_cache) -> None:
|
def test_detects_real_tensorflow(self, reset_device_cache) -> None:
|
||||||
"""Detects TensorFlow sync_devices when available."""
|
"""Detects TensorFlow sync_devices when available."""
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
_, _, tf_sync, _ = detect_device_sync()
|
_, _, tf_sync, _ = detect_device_sync()
|
||||||
if (
|
if hasattr(tf.test, "experimental") and hasattr(
|
||||||
hasattr(tf.test, "experimental")
|
tf.test.experimental, "sync_devices"
|
||||||
and hasattr(tf.test.experimental, "sync_devices")
|
|
||||||
):
|
):
|
||||||
assert tf_sync is tf.test.experimental.sync_devices
|
assert tf_sync is tf.test.experimental.sync_devices
|
||||||
else:
|
else:
|
||||||
|
|
@ -810,9 +814,7 @@ class TestGetDeviceSync:
|
||||||
second = get_device_sync()
|
second = get_device_sync()
|
||||||
assert first is second
|
assert first is second
|
||||||
|
|
||||||
def test_redetects_after_cache_clear(
|
def test_redetects_after_cache_clear(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""Re-runs detection after the cache is cleared."""
|
"""Re-runs detection after the cache is cleared."""
|
||||||
first = get_device_sync()
|
first = get_device_sync()
|
||||||
_deco_mod.device_sync_cache = None
|
_deco_mod.device_sync_cache = None
|
||||||
|
|
@ -828,9 +830,7 @@ class TestSyncDevicesBefore:
|
||||||
"""Calling with real frameworks installed does not raise."""
|
"""Calling with real frameworks installed does not raise."""
|
||||||
sync_devices_before()
|
sync_devices_before()
|
||||||
|
|
||||||
def test_cuda_takes_priority_over_mps(
|
def test_cuda_takes_priority_over_mps(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""CUDA sync is called instead of MPS when both are in the cache."""
|
"""CUDA sync is called instead of MPS when both are in the cache."""
|
||||||
calls = []
|
calls = []
|
||||||
_deco_mod.device_sync_cache = (
|
_deco_mod.device_sync_cache = (
|
||||||
|
|
@ -842,9 +842,7 @@ class TestSyncDevicesBefore:
|
||||||
sync_devices_before()
|
sync_devices_before()
|
||||||
assert ["cuda"] == calls
|
assert ["cuda"] == calls
|
||||||
|
|
||||||
def test_mps_called_when_no_cuda(
|
def test_mps_called_when_no_cuda(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""MPS sync fires when CUDA is absent in the cache."""
|
"""MPS sync fires when CUDA is absent in the cache."""
|
||||||
calls = []
|
calls = []
|
||||||
_deco_mod.device_sync_cache = (
|
_deco_mod.device_sync_cache = (
|
||||||
|
|
@ -856,9 +854,7 @@ class TestSyncDevicesBefore:
|
||||||
sync_devices_before()
|
sync_devices_before()
|
||||||
assert ["mps"] == calls
|
assert ["mps"] == calls
|
||||||
|
|
||||||
def test_tf_called_independently(
|
def test_tf_called_independently(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""TF sync fires independently of torch sync."""
|
"""TF sync fires independently of torch sync."""
|
||||||
calls = []
|
calls = []
|
||||||
_deco_mod.device_sync_cache = (
|
_deco_mod.device_sync_cache = (
|
||||||
|
|
@ -878,6 +874,7 @@ class TestSyncDevicesAfter:
|
||||||
"""Calling with real frameworks and a plain return value does not raise."""
|
"""Calling with real frameworks and a plain return value does not raise."""
|
||||||
sync_devices_after(42)
|
sync_devices_after(42)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||||
def test_jax_block_until_ready_on_real_array(
|
def test_jax_block_until_ready_on_real_array(
|
||||||
self, reset_device_cache
|
self, reset_device_cache
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -887,16 +884,12 @@ class TestSyncDevicesAfter:
|
||||||
arr = jnp.array([1, 2, 3])
|
arr = jnp.array([1, 2, 3])
|
||||||
sync_devices_after(arr)
|
sync_devices_after(arr)
|
||||||
|
|
||||||
def test_skips_jax_on_plain_value(
|
def test_skips_jax_on_plain_value(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""Does not fail when jax_available=True but return value is plain."""
|
"""Does not fail when jax_available=True but return value is plain."""
|
||||||
_deco_mod.device_sync_cache = (None, None, None, True)
|
_deco_mod.device_sync_cache = (None, None, None, True)
|
||||||
sync_devices_after(42)
|
sync_devices_after(42)
|
||||||
|
|
||||||
def test_cuda_priority_in_after(
|
def test_cuda_priority_in_after(self, reset_device_cache) -> None:
|
||||||
self, reset_device_cache
|
|
||||||
) -> None:
|
|
||||||
"""CUDA sync fires instead of MPS in the after path too."""
|
"""CUDA sync fires instead of MPS in the after path too."""
|
||||||
calls = []
|
calls = []
|
||||||
_deco_mod.device_sync_cache = (
|
_deco_mod.device_sync_cache = (
|
||||||
|
|
@ -908,9 +901,8 @@ class TestSyncDevicesAfter:
|
||||||
sync_devices_after(42)
|
sync_devices_after(42)
|
||||||
assert ["cuda"] == calls
|
assert ["cuda"] == calls
|
||||||
|
|
||||||
def test_all_syncs_fire_together(
|
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||||
self, reset_device_cache
|
def test_all_syncs_fire_together(self, reset_device_cache) -> None:
|
||||||
) -> None:
|
|
||||||
"""All applicable syncs fire: torch + JAX block_until_ready + TF."""
|
"""All applicable syncs fire: torch + JAX block_until_ready + TF."""
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ from codeflash_python.analysis._reference_graph import (
|
||||||
)
|
)
|
||||||
from codeflash_python.context.models import CodeStringsMarkdown
|
from codeflash_python.context.models import CodeStringsMarkdown
|
||||||
from codeflash_python.pipeline._orchestrator import cleanup_paths
|
from codeflash_python.pipeline._orchestrator import cleanup_paths
|
||||||
|
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||||
from codeflash_python.test_discovery.linking import module_name_from_file_path
|
from codeflash_python.test_discovery.linking import module_name_from_file_path
|
||||||
from codeflash_python.testing._concolic import clean_concolic_tests
|
from codeflash_python.testing._concolic import clean_concolic_tests
|
||||||
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
|
||||||
from codeflash_python.testing._path_resolution import (
|
from codeflash_python.testing._path_resolution import (
|
||||||
file_name_from_test_module_name,
|
file_name_from_test_module_name,
|
||||||
file_path_from_module_name,
|
file_path_from_module_name,
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,8 @@ from codeflash_python.analysis._discovery import FunctionToOptimize
|
||||||
from codeflash_python.pipeline._function_optimizer import (
|
from codeflash_python.pipeline._function_optimizer import (
|
||||||
write_code_and_helpers,
|
write_code_and_helpers,
|
||||||
)
|
)
|
||||||
from codeflash_python.test_discovery.models import TestType
|
|
||||||
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||||
|
from codeflash_python.test_discovery.models import TestType
|
||||||
from codeflash_python.testing._instrument_capture import (
|
from codeflash_python.testing._instrument_capture import (
|
||||||
instrument_codeflash_capture,
|
instrument_codeflash_capture,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -879,5 +879,3 @@ def test_concurrency_ratio_display_formatting() -> None:
|
||||||
f"\u2192 {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
f"\u2192 {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||||
)
|
)
|
||||||
assert display_string == "Concurrency ratio: 0.01x \u2192 0.03x (+200.0%)"
|
assert display_string == "Concurrency ratio: 0.01x \u2192 0.03x (+200.0%)"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,9 @@ from pathlib import Path
|
||||||
|
|
||||||
from codeflash_python._model import FunctionToOptimize, TestingMode
|
from codeflash_python._model import FunctionToOptimize, TestingMode
|
||||||
from codeflash_python.test_discovery.models import CodePosition
|
from codeflash_python.test_discovery.models import CodePosition
|
||||||
from codeflash_python.testing._instrument_core import detect_frameworks_from_code
|
from codeflash_python.testing._instrument_core import (
|
||||||
|
detect_frameworks_from_code,
|
||||||
|
)
|
||||||
from codeflash_python.testing._instrumentation import (
|
from codeflash_python.testing._instrumentation import (
|
||||||
inject_profiling_into_existing_test,
|
inject_profiling_into_existing_test,
|
||||||
)
|
)
|
||||||
|
|
@ -200,7 +202,10 @@ def _make_func() -> FunctionToOptimize:
|
||||||
|
|
||||||
def _assert_sync_call_site_output(instrumented_code: str) -> None:
|
def _assert_sync_call_site_output(instrumented_code: str) -> None:
|
||||||
"""Assert the sync path output has call-site tracking, not codeflash_wrap."""
|
"""Assert the sync path output has call-site tracking, not codeflash_wrap."""
|
||||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented_code
|
assert (
|
||||||
|
"from codeflash_async_wrapper import _codeflash_call_site"
|
||||||
|
in instrumented_code
|
||||||
|
)
|
||||||
assert "_codeflash_call_site.set(" in instrumented_code
|
assert "_codeflash_call_site.set(" in instrumented_code
|
||||||
assert "codeflash_wrap" not in instrumented_code
|
assert "codeflash_wrap" not in instrumented_code
|
||||||
assert "torch.cuda.synchronize" not in instrumented_code
|
assert "torch.cuda.synchronize" not in instrumented_code
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,9 @@ def test_sort():
|
||||||
assert test_results[0].did_pass
|
assert test_results[0].did_pass
|
||||||
# return_value is ((args, kwargs, return_value),) in the new path
|
# return_value is ((args, kwargs, return_value),) in the new path
|
||||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||||
out_str = "codeflash stdout: Sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
out_str = (
|
||||||
|
"codeflash stdout: Sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
||||||
|
)
|
||||||
assert test_results[0].stdout == out_str
|
assert test_results[0].stdout == out_str
|
||||||
|
|
||||||
assert test_results[1].id.function_getting_tested == "sorter"
|
assert test_results[1].id.function_getting_tested == "sorter"
|
||||||
|
|
@ -272,9 +274,7 @@ def test_sort():
|
||||||
assert test_results[1].did_pass
|
assert test_results[1].did_pass
|
||||||
assert test_results[1].return_value[0] == {"x": 0}
|
assert test_results[1].return_value[0] == {"x": 0}
|
||||||
|
|
||||||
assert (
|
assert test_results[2].id.function_getting_tested == "sorter"
|
||||||
test_results[2].id.function_getting_tested == "sorter"
|
|
||||||
)
|
|
||||||
assert test_results[2].id.test_class_name is None
|
assert test_results[2].id.test_class_name is None
|
||||||
assert test_results[2].id.test_function_name == "test_sort"
|
assert test_results[2].id.test_function_name == "test_sort"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -285,13 +285,14 @@ def test_sort():
|
||||||
assert test_results[2].did_pass
|
assert test_results[2].did_pass
|
||||||
# return_value is ((args, kwargs, return_value),) in the new path
|
# return_value is ((args, kwargs, return_value),) in the new path
|
||||||
assert test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
assert test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||||
assert test_results[2].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
assert (
|
||||||
|
test_results[2].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||||
|
)
|
||||||
match, _ = compare_test_results(test_results, test_results)
|
match, _ = compare_test_results(test_results, test_results)
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
assert (
|
assert test_results[3].id.function_getting_tested == "sorter"
|
||||||
test_results[3].id.function_getting_tested == "sorter"
|
|
||||||
)
|
|
||||||
assert test_results[3].id.test_class_name is None
|
assert test_results[3].id.test_class_name is None
|
||||||
assert test_results[3].id.test_function_name == "test_sort"
|
assert test_results[3].id.test_function_name == "test_sort"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -300,7 +301,10 @@ def test_sort():
|
||||||
)
|
)
|
||||||
assert test_results[3].runtime > 0
|
assert test_results[3].runtime > 0
|
||||||
assert test_results[3].did_pass
|
assert test_results[3].did_pass
|
||||||
assert test_results[3].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
assert (
|
||||||
|
test_results[3].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||||
|
)
|
||||||
|
|
||||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||||
|
|
||||||
|
|
@ -376,9 +380,7 @@ class BubbleSorter:
|
||||||
assert new_test_results[1].did_pass
|
assert new_test_results[1].did_pass
|
||||||
assert new_test_results[1].return_value[0] == {"x": 1}
|
assert new_test_results[1].return_value[0] == {"x": 1}
|
||||||
|
|
||||||
assert (
|
assert new_test_results[2].id.function_getting_tested == "sorter"
|
||||||
new_test_results[2].id.function_getting_tested == "sorter"
|
|
||||||
)
|
|
||||||
assert new_test_results[2].id.test_class_name is None
|
assert new_test_results[2].id.test_class_name is None
|
||||||
assert new_test_results[2].id.test_function_name == "test_sort"
|
assert new_test_results[2].id.test_function_name == "test_sort"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -389,9 +391,7 @@ class BubbleSorter:
|
||||||
assert new_test_results[2].did_pass
|
assert new_test_results[2].did_pass
|
||||||
assert new_test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
assert new_test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||||
|
|
||||||
assert (
|
assert new_test_results[3].id.function_getting_tested == "sorter"
|
||||||
new_test_results[3].id.function_getting_tested == "sorter"
|
|
||||||
)
|
|
||||||
assert new_test_results[3].id.test_class_name is None
|
assert new_test_results[3].id.test_class_name is None
|
||||||
assert new_test_results[3].id.test_function_name == "test_sort"
|
assert new_test_results[3].id.test_function_name == "test_sort"
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -491,8 +491,7 @@ def test_sort():
|
||||||
test_results = _run_and_parse(test_files, test_env, test_config)
|
test_results = _run_and_parse(test_files, test_env, test_config)
|
||||||
assert len(test_results) == 2
|
assert len(test_results) == 2
|
||||||
assert (
|
assert (
|
||||||
test_results[0].id.function_getting_tested
|
test_results[0].id.function_getting_tested == "sorter_classmethod"
|
||||||
== "sorter_classmethod"
|
|
||||||
)
|
)
|
||||||
assert test_results[0].id.test_class_name is None
|
assert test_results[0].id.test_class_name is None
|
||||||
assert test_results[0].id.test_function_name == "test_sort"
|
assert test_results[0].id.test_function_name == "test_sort"
|
||||||
|
|
@ -503,13 +502,15 @@ def test_sort():
|
||||||
assert test_results[0].runtime > 0
|
assert test_results[0].runtime > 0
|
||||||
assert test_results[0].did_pass
|
assert test_results[0].did_pass
|
||||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
assert (
|
||||||
|
test_results[0].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||||
|
)
|
||||||
match, _ = compare_test_results(test_results, test_results)
|
match, _ = compare_test_results(test_results, test_results)
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
test_results[1].id.function_getting_tested
|
test_results[1].id.function_getting_tested == "sorter_classmethod"
|
||||||
== "sorter_classmethod"
|
|
||||||
)
|
)
|
||||||
assert test_results[1].id.test_class_name is None
|
assert test_results[1].id.test_class_name is None
|
||||||
assert test_results[1].id.test_function_name == "test_sort"
|
assert test_results[1].id.test_function_name == "test_sort"
|
||||||
|
|
@ -519,7 +520,10 @@ def test_sort():
|
||||||
)
|
)
|
||||||
assert test_results[1].runtime > 0
|
assert test_results[1].runtime > 0
|
||||||
assert test_results[1].did_pass
|
assert test_results[1].did_pass
|
||||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
assert (
|
||||||
|
test_results[1].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||||
|
)
|
||||||
|
|
||||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||||
|
|
||||||
|
|
@ -614,8 +618,7 @@ def test_sort():
|
||||||
test_results = _run_and_parse(test_files, test_env, test_config)
|
test_results = _run_and_parse(test_files, test_env, test_config)
|
||||||
assert len(test_results) == 2
|
assert len(test_results) == 2
|
||||||
assert (
|
assert (
|
||||||
test_results[0].id.function_getting_tested
|
test_results[0].id.function_getting_tested == "sorter_staticmethod"
|
||||||
== "sorter_staticmethod"
|
|
||||||
)
|
)
|
||||||
assert test_results[0].id.test_class_name is None
|
assert test_results[0].id.test_class_name is None
|
||||||
assert test_results[0].id.test_function_name == "test_sort"
|
assert test_results[0].id.test_function_name == "test_sort"
|
||||||
|
|
@ -626,13 +629,15 @@ def test_sort():
|
||||||
assert test_results[0].runtime > 0
|
assert test_results[0].runtime > 0
|
||||||
assert test_results[0].did_pass
|
assert test_results[0].did_pass
|
||||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
assert (
|
||||||
|
test_results[0].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||||
|
)
|
||||||
match, _ = compare_test_results(test_results, test_results)
|
match, _ = compare_test_results(test_results, test_results)
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
test_results[1].id.function_getting_tested
|
test_results[1].id.function_getting_tested == "sorter_staticmethod"
|
||||||
== "sorter_staticmethod"
|
|
||||||
)
|
)
|
||||||
assert test_results[1].id.test_class_name is None
|
assert test_results[1].id.test_class_name is None
|
||||||
assert test_results[1].id.test_function_name == "test_sort"
|
assert test_results[1].id.test_function_name == "test_sort"
|
||||||
|
|
@ -642,7 +647,10 @@ def test_sort():
|
||||||
)
|
)
|
||||||
assert test_results[1].runtime > 0
|
assert test_results[1].runtime > 0
|
||||||
assert test_results[1].did_pass
|
assert test_results[1].did_pass
|
||||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
assert (
|
||||||
|
test_results[1].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||||
|
)
|
||||||
|
|
||||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -755,17 +755,7 @@ async def test_multiple_calls():
|
||||||
assert "_codeflash_call_site.set('14')" in instrumented_test_code
|
assert "_codeflash_call_site.set('14')" in instrumented_test_code
|
||||||
assert "_codeflash_call_site.set('15')" in instrumented_test_code
|
assert "_codeflash_call_site.set('15')" in instrumented_test_code
|
||||||
|
|
||||||
assert 1 == instrumented_test_code.count(
|
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('8')")
|
||||||
"_codeflash_call_site.set('8')"
|
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('13')")
|
||||||
)
|
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('14')")
|
||||||
assert 1 == instrumented_test_code.count(
|
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('15')")
|
||||||
"_codeflash_call_site.set('13')"
|
|
||||||
)
|
|
||||||
assert 1 == instrumented_test_code.count(
|
|
||||||
"_codeflash_call_site.set('14')"
|
|
||||||
)
|
|
||||||
assert 1 == instrumented_test_code.count(
|
|
||||||
"_codeflash_call_site.set('15')"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,9 @@ class TestGetSyncDecoratorNameForMode:
|
||||||
|
|
||||||
def test_performance_mode(self) -> None:
|
def test_performance_mode(self) -> None:
|
||||||
"""Returns codeflash_performance_sync for PERFORMANCE."""
|
"""Returns codeflash_performance_sync for PERFORMANCE."""
|
||||||
assert "codeflash_performance_sync" == get_sync_decorator_name_for_mode(
|
assert (
|
||||||
TestingMode.PERFORMANCE
|
"codeflash_performance_sync"
|
||||||
|
== get_sync_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,9 +142,7 @@ class TestSyncCallInstrumenter:
|
||||||
parents=[],
|
parents=[],
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
instrumenter = SyncCallInstrumenter(
|
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)])
|
||||||
func, [CodePosition(2, 13)]
|
|
||||||
)
|
|
||||||
tree = instrumenter.visit(tree)
|
tree = instrumenter.visit(tree)
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
|
|
@ -164,9 +163,7 @@ class TestSyncCallInstrumenter:
|
||||||
parents=[],
|
parents=[],
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
instrumenter = SyncCallInstrumenter(
|
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)])
|
||||||
func, [CodePosition(2, 13)]
|
|
||||||
)
|
|
||||||
tree = instrumenter.visit(tree)
|
tree = instrumenter.visit(tree)
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
|
|
@ -206,10 +203,7 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
def test_skips_non_test_functions(self) -> None:
|
def test_skips_non_test_functions(self) -> None:
|
||||||
"""Does not instrument functions that don't start with test_."""
|
"""Does not instrument functions that don't start with test_."""
|
||||||
test_code = (
|
test_code = "def helper():\n return my_func(1)\n"
|
||||||
"def helper():\n"
|
|
||||||
" return my_func(1)\n"
|
|
||||||
)
|
|
||||||
tree = ast.parse(test_code)
|
tree = ast.parse(test_code)
|
||||||
func = FunctionToOptimize(
|
func = FunctionToOptimize(
|
||||||
function_name="my_func",
|
function_name="my_func",
|
||||||
|
|
@ -217,9 +211,7 @@ class TestSyncCallInstrumenter:
|
||||||
parents=[],
|
parents=[],
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
instrumenter = SyncCallInstrumenter(
|
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 11)])
|
||||||
func, [CodePosition(2, 11)]
|
|
||||||
)
|
|
||||||
tree = instrumenter.visit(tree)
|
tree = instrumenter.visit(tree)
|
||||||
|
|
||||||
assert not instrumenter.did_instrument
|
assert not instrumenter.did_instrument
|
||||||
|
|
@ -238,9 +230,7 @@ class TestSyncCallInstrumenter:
|
||||||
parents=[],
|
parents=[],
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
instrumenter = SyncCallInstrumenter(
|
instrumenter = SyncCallInstrumenter(func, [CodePosition(3, 17)])
|
||||||
func, [CodePosition(3, 17)]
|
|
||||||
)
|
|
||||||
tree = instrumenter.visit(tree)
|
tree = instrumenter.visit(tree)
|
||||||
|
|
||||||
assert instrumenter.did_instrument
|
assert instrumenter.did_instrument
|
||||||
|
|
@ -249,10 +239,7 @@ class TestSyncCallInstrumenter:
|
||||||
|
|
||||||
def test_no_match_when_position_wrong(self) -> None:
|
def test_no_match_when_position_wrong(self) -> None:
|
||||||
"""Does not instrument if code position doesn't match."""
|
"""Does not instrument if code position doesn't match."""
|
||||||
test_code = (
|
test_code = "def test_example():\n result = my_func(1)\n"
|
||||||
"def test_example():\n"
|
|
||||||
" result = my_func(1)\n"
|
|
||||||
)
|
|
||||||
tree = ast.parse(test_code)
|
tree = ast.parse(test_code)
|
||||||
func = FunctionToOptimize(
|
func = FunctionToOptimize(
|
||||||
function_name="my_func",
|
function_name="my_func",
|
||||||
|
|
@ -260,9 +247,7 @@ class TestSyncCallInstrumenter:
|
||||||
parents=[],
|
parents=[],
|
||||||
is_async=False,
|
is_async=False,
|
||||||
)
|
)
|
||||||
instrumenter = SyncCallInstrumenter(
|
instrumenter = SyncCallInstrumenter(func, [CodePosition(99, 99)])
|
||||||
func, [CodePosition(99, 99)]
|
|
||||||
)
|
|
||||||
tree = instrumenter.visit(tree)
|
tree = instrumenter.visit(tree)
|
||||||
|
|
||||||
assert not instrumenter.did_instrument
|
assert not instrumenter.did_instrument
|
||||||
|
|
@ -368,9 +353,7 @@ class TestAddSyncDecoratorToFunction:
|
||||||
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
||||||
"""Adds codeflash decorator below @staticmethod/@classmethod."""
|
"""Adds codeflash decorator below @staticmethod/@classmethod."""
|
||||||
source_code = (
|
source_code = (
|
||||||
"@staticmethod\n"
|
"@staticmethod\ndef my_func(x: int) -> int:\n return x + 1\n"
|
||||||
"def my_func(x: int) -> int:\n"
|
|
||||||
" return x + 1\n"
|
|
||||||
)
|
)
|
||||||
source_file = temp_dir / "my_module.py"
|
source_file = temp_dir / "my_module.py"
|
||||||
source_file.write_text(source_code)
|
source_file.write_text(source_code)
|
||||||
|
|
@ -450,7 +433,10 @@ class TestInjectSyncProfilingIntoExistingTest:
|
||||||
assert success
|
assert success
|
||||||
assert instrumented is not None
|
assert instrumented is not None
|
||||||
assert "_codeflash_call_site.set('4')" in instrumented
|
assert "_codeflash_call_site.set('4')" in instrumented
|
||||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
|
assert (
|
||||||
|
"from codeflash_async_wrapper import _codeflash_call_site"
|
||||||
|
in instrumented
|
||||||
|
)
|
||||||
|
|
||||||
def test_multiple_calls_use_line_numbers(self, temp_dir) -> None:
|
def test_multiple_calls_use_line_numbers(self, temp_dir) -> None:
|
||||||
"""Multiple calls use their source line numbers as call-site IDs."""
|
"""Multiple calls use their source line numbers as call-site IDs."""
|
||||||
|
|
@ -512,10 +498,7 @@ class TestInjectSyncProfilingIntoExistingTest:
|
||||||
|
|
||||||
def test_returns_false_when_no_target_calls(self, temp_dir) -> None:
|
def test_returns_false_when_no_target_calls(self, temp_dir) -> None:
|
||||||
"""Returns (False, None) when no target function calls are found."""
|
"""Returns (False, None) when no target function calls are found."""
|
||||||
test_code = (
|
test_code = "def test_unrelated():\n assert 1 == 1\n"
|
||||||
"def test_unrelated():\n"
|
|
||||||
" assert 1 == 1\n"
|
|
||||||
)
|
|
||||||
test_file = temp_dir / "test_noop.py"
|
test_file = temp_dir / "test_noop.py"
|
||||||
test_file.write_text(test_code)
|
test_file.write_text(test_code)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,10 @@ def test_prepare_image_for_yolo():
|
||||||
assert new_test is not None
|
assert new_test is not None
|
||||||
_assert_sync_instrumentation_present(new_test)
|
_assert_sync_instrumentation_present(new_test)
|
||||||
_assert_old_instrumentation_absent(new_test)
|
_assert_old_instrumentation_absent(new_test)
|
||||||
assert "packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo" in new_test
|
assert (
|
||||||
|
"packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo"
|
||||||
|
in new_test
|
||||||
|
)
|
||||||
assert "def test_prepare_image_for_yolo" in new_test
|
assert "def test_prepare_image_for_yolo" in new_test
|
||||||
assert "_codeflash_call_site.set(" in new_test
|
assert "_codeflash_call_site.set(" in new_test
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,16 +49,13 @@ def test_single_element_list():
|
||||||
project_root
|
project_root
|
||||||
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
||||||
).resolve()
|
).resolve()
|
||||||
tests_root = (
|
tests_root = project_root / "code_to_optimize/tests/pytest/"
|
||||||
project_root / "code_to_optimize/tests/pytest/"
|
|
||||||
)
|
|
||||||
project_root_path = project_root
|
project_root_path = project_root
|
||||||
run_cwd = project_root
|
run_cwd = project_root
|
||||||
old_cwd = os.getcwd()
|
old_cwd = os.getcwd()
|
||||||
os.chdir(run_cwd)
|
os.chdir(run_cwd)
|
||||||
fto_path = (
|
fto_path = (
|
||||||
project_root
|
project_root / "code_to_optimize/bubble_sort_method.py"
|
||||||
/ "code_to_optimize/bubble_sort_method.py"
|
|
||||||
).resolve()
|
).resolve()
|
||||||
original_code = fto_path.read_text("utf-8")
|
original_code = fto_path.read_text("utf-8")
|
||||||
|
|
||||||
|
|
@ -127,7 +124,10 @@ def test_single_element_list():
|
||||||
run_result=run_result,
|
run_result=run_result,
|
||||||
)
|
)
|
||||||
assert test_results[0].id.function_getting_tested == "sorter"
|
assert test_results[0].id.function_getting_tested == "sorter"
|
||||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
assert (
|
||||||
|
test_results[0].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
test_results[0].id.test_function_name == "test_single_element_list"
|
test_results[0].id.test_function_name == "test_single_element_list"
|
||||||
)
|
)
|
||||||
|
|
@ -216,14 +216,11 @@ def test_single_element_list():
|
||||||
project_root
|
project_root
|
||||||
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
||||||
).resolve()
|
).resolve()
|
||||||
tests_root = (
|
tests_root = project_root / "code_to_optimize/tests/pytest/"
|
||||||
project_root / "code_to_optimize/tests/pytest/"
|
|
||||||
)
|
|
||||||
project_root_path = project_root
|
project_root_path = project_root
|
||||||
|
|
||||||
fto_path = (
|
fto_path = (
|
||||||
project_root
|
project_root / "code_to_optimize/bubble_sort_method.py"
|
||||||
/ "code_to_optimize/bubble_sort_method.py"
|
|
||||||
).resolve()
|
).resolve()
|
||||||
original_code = fto_path.read_text("utf-8")
|
original_code = fto_path.read_text("utf-8")
|
||||||
function_to_optimize = FunctionToOptimize(
|
function_to_optimize = FunctionToOptimize(
|
||||||
|
|
@ -316,7 +313,10 @@ def test_single_element_list():
|
||||||
assert test_results[1].did_pass
|
assert test_results[1].did_pass
|
||||||
# return_value is ((args, kwargs, return_value),) in the new path
|
# return_value is ((args, kwargs, return_value),) in the new path
|
||||||
assert test_results[1].return_value[0][2] == [1, 2, 3]
|
assert test_results[1].return_value[0][2] == [1, 2, 3]
|
||||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
assert (
|
||||||
|
test_results[1].stdout
|
||||||
|
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||||
|
)
|
||||||
|
|
||||||
# Replace with optimized code that mutated instance attribute
|
# Replace with optimized code that mutated instance attribute
|
||||||
optimized_code_mutated_attr = """
|
optimized_code_mutated_attr = """
|
||||||
|
|
@ -442,9 +442,7 @@ class BubbleSorter:
|
||||||
# In the new decorator-based path, args (including self) are captured.
|
# In the new decorator-based path, args (including self) are captured.
|
||||||
# Adding a new instance attribute changes self, so the comparison
|
# Adding a new instance attribute changes self, so the comparison
|
||||||
# detects a difference even though codeflash_capture considers it additive.
|
# detects a difference even though codeflash_capture considers it additive.
|
||||||
match, _ = compare_test_results(
|
match, _ = compare_test_results(test_results, test_results_new_attr)
|
||||||
test_results, test_results_new_attr
|
|
||||||
)
|
|
||||||
assert not match
|
assert not match
|
||||||
finally:
|
finally:
|
||||||
fto_path.write_text(original_code, "utf-8")
|
fto_path.write_text(original_code, "utf-8")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue