mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Fix pre-existing CI lint and test failures (#40)
* chore: add gitignore entries for local eval repos, e2e fixtures, and env files * fix: restore clean bubble_sort_method.py test fixture The call-site ID commit re-contaminated this file with instrumentation decorators, causing tests to fail with missing CODEFLASH_LOOP_INDEX. * fix: resolve ruff and mypy errors in codeflash-python - Add import-not-found ignores for optional torch/jax imports - Extract magic column index to _STDOUT_COLUMN_INDEX constant - Fix unused variable in _instrument_sync.py - Cast cpu_time_ns to int for mypy arg-type * fix: add skip markers for optional deps and apply ruff formatting to tests Skip torch/jax/tensorflow tests when those packages are not installed. Move has_module helper to conftest.py for reuse across test files. Apply ruff format to all test files that drifted. * fix: resolve remaining ruff format and mypy errors - Add missing blank line in conftest.py (ruff format) - Remove unused import-untyped ignore on jax import (mypy unused-ignore) - Add type: ignore comments for object-typed SQLite row values * chore: bump codeflash-python to 0.1.1.dev0
This commit is contained in:
parent
2c9f2ad8de
commit
919a673be2
18 changed files with 172 additions and 175 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -14,3 +14,7 @@ dist-*/
|
|||
dist-v2/
|
||||
.playwright-mcp/
|
||||
.tessl/session-data/
|
||||
evals/repos/
|
||||
packages/codeflash-api/e2e/
|
||||
packages/github-app/.env
|
||||
packages/github-app/codex-config/
|
||||
|
|
|
|||
6
packages/codeflash-python/changelogs/fix-ci-lint.md
Normal file
6
packages/codeflash-python/changelogs/fix-ci-lint.md
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
### Fixes
|
||||
|
||||
- Restore clean bubble_sort_method.py test fixture (remove accidental instrumentation decorators)
|
||||
- Add mypy type-ignore comments for optional torch/jax imports and SQLite row types
|
||||
- Add pytest skip markers for tests requiring optional deps (torch, jax, tensorflow)
|
||||
- Fix ruff format and unused mypy ignore comment
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "codeflash-python"
|
||||
version = "0.1.0"
|
||||
version = "0.1.1.dev0"
|
||||
requires-python = ">=3.9"
|
||||
dependencies = [
|
||||
"codeflash-core",
|
||||
|
|
|
|||
|
|
@ -130,14 +130,12 @@ def close_all_connections() -> None:
|
|||
atexit.register(close_all_connections)
|
||||
|
||||
|
||||
def detect_device_sync() -> (
|
||||
tuple[
|
||||
Any | None, # torch_cuda_sync
|
||||
Any | None, # torch_mps_sync
|
||||
Any | None, # tf_sync
|
||||
bool, # jax_available
|
||||
]
|
||||
):
|
||||
def detect_device_sync() -> tuple[
|
||||
Any | None, # torch_cuda_sync
|
||||
Any | None, # torch_mps_sync
|
||||
Any | None, # tf_sync
|
||||
bool, # jax_available
|
||||
]:
|
||||
"""Detect available GPU frameworks and return sync callables.
|
||||
|
||||
Called once at first decorator invocation; results are cached.
|
||||
|
|
@ -151,7 +149,7 @@ def detect_device_sync() -> (
|
|||
jax_available = False
|
||||
|
||||
if find_spec("torch") is not None:
|
||||
import torch # noqa: PLC0415
|
||||
import torch # type: ignore[import-not-found] # noqa: PLC0415
|
||||
|
||||
if torch.cuda.is_available() and torch.cuda.is_initialized():
|
||||
torch_cuda_sync = torch.cuda.synchronize
|
||||
|
|
@ -165,7 +163,7 @@ def detect_device_sync() -> (
|
|||
torch_mps_sync = torch.mps.synchronize
|
||||
|
||||
if find_spec("jax") is not None:
|
||||
import jax # type: ignore[import-untyped] # noqa: PLC0415
|
||||
import jax # type: ignore[import-not-found] # noqa: PLC0415
|
||||
|
||||
jax_available = hasattr(jax, "block_until_ready")
|
||||
|
||||
|
|
@ -180,14 +178,12 @@ def detect_device_sync() -> (
|
|||
return torch_cuda_sync, torch_mps_sync, tf_sync, jax_available
|
||||
|
||||
|
||||
device_sync_cache: (
|
||||
tuple[Any | None, Any | None, Any | None, bool] | None
|
||||
) = None
|
||||
device_sync_cache: tuple[Any | None, Any | None, Any | None, bool] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
|
||||
def get_device_sync() -> (
|
||||
tuple[Any | None, Any | None, Any | None, bool]
|
||||
):
|
||||
def get_device_sync() -> tuple[Any | None, Any | None, Any | None, bool]:
|
||||
"""Return cached device sync callables, detecting on first call."""
|
||||
global device_sync_cache # noqa: PLW0603
|
||||
if device_sync_cache is None:
|
||||
|
|
@ -253,7 +249,9 @@ def codeflash_behavior_sync(func: F) -> F:
|
|||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"codeflash_results_{iteration}.sqlite")
|
||||
)
|
||||
conn, cur = get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
|
|
@ -348,7 +346,9 @@ def codeflash_performance_sync(func: F) -> F:
|
|||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"codeflash_results_{iteration}.sqlite")
|
||||
)
|
||||
conn, cur = get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
|
|
@ -430,7 +430,9 @@ def codeflash_behavior_async(func: F) -> F:
|
|||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"codeflash_results_{iteration}.sqlite")
|
||||
)
|
||||
conn, cur = get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
|
|
@ -522,7 +524,9 @@ def codeflash_performance_async(func: F) -> F:
|
|||
invocation_id = f"{call_site}_{call_index}"
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"codeflash_results_{iteration}.sqlite")
|
||||
)
|
||||
conn, cur = get_async_db(db_path)
|
||||
|
||||
exception = None
|
||||
|
|
@ -590,7 +594,9 @@ def codeflash_concurrency_async(func: F) -> F:
|
|||
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "0"))
|
||||
|
||||
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
|
||||
db_path = get_run_tmp_file(Path(f"codeflash_results_{iteration}.sqlite"))
|
||||
db_path = get_run_tmp_file(
|
||||
Path(f"codeflash_results_{iteration}.sqlite")
|
||||
)
|
||||
conn, cur = get_async_db(db_path)
|
||||
|
||||
gc.disable()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ if TYPE_CHECKING:
|
|||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_STDOUT_COLUMN_INDEX = 10
|
||||
|
||||
_BEHAVIOR_QUERY = (
|
||||
"SELECT test_module_path, test_class_name,"
|
||||
" test_function_name, function_getting_tested,"
|
||||
|
|
@ -110,7 +112,9 @@ def _process_behavior_row_inner(
|
|||
wall_time_ns = val[6]
|
||||
cpu_time_ns = val[7]
|
||||
verification_type = val[9]
|
||||
stdout_text = val[10] if len(val) > 10 else None
|
||||
stdout_text = (
|
||||
val[_STDOUT_COLUMN_INDEX] if len(val) > _STDOUT_COLUMN_INDEX else None
|
||||
)
|
||||
|
||||
test_file_path = file_path_from_module_name(
|
||||
test_module_path, # type: ignore[arg-type]
|
||||
|
|
@ -168,14 +172,14 @@ def _process_behavior_row_inner(
|
|||
test_framework=test_config.test_framework,
|
||||
test_type=test_type,
|
||||
return_value=ret_val,
|
||||
cpu_runtime=cpu_time_ns or 0,
|
||||
cpu_runtime=int(cpu_time_ns or 0), # type: ignore[call-overload]
|
||||
timed_out=False,
|
||||
verification_type=(
|
||||
VerificationType(verification_type)
|
||||
if verification_type
|
||||
else None
|
||||
),
|
||||
stdout=stdout_text or None,
|
||||
stdout=stdout_text or None, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class SyncCallInstrumenter(ast.NodeTransformer):
|
|||
new_body: list[ast.stmt] = []
|
||||
|
||||
for stmt in node.body:
|
||||
call_node, has_target = self._find_target_call(stmt)
|
||||
_, has_target = self._find_target_call(stmt)
|
||||
|
||||
if has_target:
|
||||
call_site_set = ast.Expr(
|
||||
|
|
|
|||
|
|
@ -1,48 +1,41 @@
|
|||
import sys
|
||||
|
||||
from codeflash_async_wrapper import codeflash_behavior_sync
|
||||
|
||||
from codeflash_python.runtime._codeflash_capture import codeflash_capture
|
||||
|
||||
|
||||
class BubbleSorter:
|
||||
|
||||
@codeflash_capture(function_name='BubbleSorter.__init__', tmp_dir_path='/var/folders/mg/k_c0twcj37q_gph3cfy3zlt80000gn/T/codeflash_l3k89hc3/codeflash_results', tests_root='/Users/krrt7/Desktop/work/cf_org/codeflash-agent/.claude/worktrees/jaunty-sauteeing-dolphin/packages/codeflash-python/tests/code_to_optimize/tests/pytest', is_fto=True)
|
||||
def __init__(self, x=0):
|
||||
self.x = x
|
||||
|
||||
@codeflash_behavior_sync
|
||||
def sorter(self, arr):
|
||||
print('codeflash stdout : BubbleSorter.sorter() called')
|
||||
print("codeflash stdout : BubbleSorter.sorter() called")
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print('stderr test', file=sys.stderr)
|
||||
print("stderr test", file=sys.stderr)
|
||||
return arr
|
||||
|
||||
@classmethod
|
||||
def sorter_classmethod(cls, arr):
|
||||
print('codeflash stdout : BubbleSorter.sorter_classmethod() called')
|
||||
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print('stderr test classmethod', file=sys.stderr)
|
||||
print("stderr test classmethod", file=sys.stderr)
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sorter_staticmethod(arr):
|
||||
print('codeflash stdout : BubbleSorter.sorter_staticmethod() called')
|
||||
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print('stderr test staticmethod', file=sys.stderr)
|
||||
print("stderr test staticmethod", file=sys.stderr)
|
||||
return arr
|
||||
|
|
|
|||
|
|
@ -3,10 +3,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def has_module(name: str) -> bool:
|
||||
"""Check whether an optional dependency is importable (for skipif markers)."""
|
||||
return find_spec(name) is not None
|
||||
|
||||
|
||||
# Make the code_to_optimize fixture package importable by tests that need it
|
||||
# (e.g. test_comparator.py, test_trace_benchmarks.py).
|
||||
_TESTS_DIR = str(Path(__file__).resolve().parent)
|
||||
|
|
|
|||
|
|
@ -14,25 +14,25 @@ import dill as pickle
|
|||
import pytest
|
||||
|
||||
import codeflash_python.runtime._codeflash_async_decorators as _deco_mod
|
||||
|
||||
from codeflash_python.runtime._codeflash_async_decorators import (
|
||||
VerificationType,
|
||||
close_all_connections,
|
||||
_codeflash_call_site,
|
||||
connections,
|
||||
detect_device_sync,
|
||||
get_async_db,
|
||||
get_device_sync,
|
||||
sync_devices_after,
|
||||
sync_devices_before,
|
||||
close_all_connections,
|
||||
codeflash_behavior_async,
|
||||
codeflash_behavior_sync,
|
||||
codeflash_concurrency_async,
|
||||
codeflash_performance_async,
|
||||
codeflash_performance_sync,
|
||||
connections,
|
||||
detect_device_sync,
|
||||
extract_test_context_from_env,
|
||||
get_async_db,
|
||||
get_device_sync,
|
||||
get_run_tmp_file,
|
||||
sync_devices_after,
|
||||
sync_devices_before,
|
||||
)
|
||||
from conftest import has_module
|
||||
|
||||
|
||||
@pytest.fixture(name="env_setup")
|
||||
|
|
@ -210,7 +210,9 @@ class TestBehaviorAsync:
|
|||
con.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self, env_setup, results_db_path) -> None:
|
||||
async def test_exception_handling(
|
||||
self, env_setup, results_db_path
|
||||
) -> None:
|
||||
"""Re-raises exceptions and stores them pickled."""
|
||||
|
||||
@codeflash_behavior_async
|
||||
|
|
@ -354,9 +356,7 @@ class TestBehaviorSync:
|
|||
assert "hello world\n" == row[0]
|
||||
con.close()
|
||||
|
||||
def test_no_stdout_leak(
|
||||
self, env_setup, results_db_path, capsys
|
||||
) -> None:
|
||||
def test_no_stdout_leak(self, env_setup, results_db_path, capsys) -> None:
|
||||
"""Sync decorator does not leak stdout to outer scope."""
|
||||
|
||||
@codeflash_behavior_sync
|
||||
|
|
@ -382,9 +382,7 @@ class TestBehaviorSync:
|
|||
|
||||
con = sqlite3.connect(results_db_path)
|
||||
cur = con.cursor()
|
||||
cur.execute(
|
||||
"SELECT cpu_time_ns FROM codeflash_results"
|
||||
)
|
||||
cur.execute("SELECT cpu_time_ns FROM codeflash_results")
|
||||
row = cur.fetchone()
|
||||
assert row[0] is not None
|
||||
assert 0 <= row[0]
|
||||
|
|
@ -646,7 +644,9 @@ class TestPerformanceAsyncEdgeCases:
|
|||
"""Edge cases for codeflash_performance_async."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_handling(self, env_setup, results_db_path) -> None:
|
||||
async def test_exception_handling(
|
||||
self, env_setup, results_db_path
|
||||
) -> None:
|
||||
"""Re-raises exceptions from the wrapped function."""
|
||||
|
||||
@codeflash_performance_async
|
||||
|
|
@ -762,6 +762,7 @@ class TestDetectDeviceSync:
|
|||
result = detect_device_sync()
|
||||
assert 4 == len(result)
|
||||
|
||||
@pytest.mark.skipif(not has_module("torch"), reason="torch not installed")
|
||||
def test_detects_real_torch(self, reset_device_cache) -> None:
|
||||
"""Detects torch and returns a sync callable for the active device."""
|
||||
import torch
|
||||
|
|
@ -780,6 +781,7 @@ class TestDetectDeviceSync:
|
|||
assert cuda_sync is None
|
||||
assert mps_sync is None
|
||||
|
||||
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||
def test_detects_real_jax(self, reset_device_cache) -> None:
|
||||
"""Detects JAX and sets jax_available based on block_until_ready."""
|
||||
import jax
|
||||
|
|
@ -787,14 +789,16 @@ class TestDetectDeviceSync:
|
|||
_, _, _, jax_avail = detect_device_sync()
|
||||
assert jax_avail is hasattr(jax, "block_until_ready")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not has_module("tensorflow"), reason="tensorflow not installed"
|
||||
)
|
||||
def test_detects_real_tensorflow(self, reset_device_cache) -> None:
|
||||
"""Detects TensorFlow sync_devices when available."""
|
||||
import tensorflow as tf
|
||||
|
||||
_, _, tf_sync, _ = detect_device_sync()
|
||||
if (
|
||||
hasattr(tf.test, "experimental")
|
||||
and hasattr(tf.test.experimental, "sync_devices")
|
||||
if hasattr(tf.test, "experimental") and hasattr(
|
||||
tf.test.experimental, "sync_devices"
|
||||
):
|
||||
assert tf_sync is tf.test.experimental.sync_devices
|
||||
else:
|
||||
|
|
@ -810,9 +814,7 @@ class TestGetDeviceSync:
|
|||
second = get_device_sync()
|
||||
assert first is second
|
||||
|
||||
def test_redetects_after_cache_clear(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_redetects_after_cache_clear(self, reset_device_cache) -> None:
|
||||
"""Re-runs detection after the cache is cleared."""
|
||||
first = get_device_sync()
|
||||
_deco_mod.device_sync_cache = None
|
||||
|
|
@ -828,9 +830,7 @@ class TestSyncDevicesBefore:
|
|||
"""Calling with real frameworks installed does not raise."""
|
||||
sync_devices_before()
|
||||
|
||||
def test_cuda_takes_priority_over_mps(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_cuda_takes_priority_over_mps(self, reset_device_cache) -> None:
|
||||
"""CUDA sync is called instead of MPS when both are in the cache."""
|
||||
calls = []
|
||||
_deco_mod.device_sync_cache = (
|
||||
|
|
@ -842,9 +842,7 @@ class TestSyncDevicesBefore:
|
|||
sync_devices_before()
|
||||
assert ["cuda"] == calls
|
||||
|
||||
def test_mps_called_when_no_cuda(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_mps_called_when_no_cuda(self, reset_device_cache) -> None:
|
||||
"""MPS sync fires when CUDA is absent in the cache."""
|
||||
calls = []
|
||||
_deco_mod.device_sync_cache = (
|
||||
|
|
@ -856,9 +854,7 @@ class TestSyncDevicesBefore:
|
|||
sync_devices_before()
|
||||
assert ["mps"] == calls
|
||||
|
||||
def test_tf_called_independently(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_tf_called_independently(self, reset_device_cache) -> None:
|
||||
"""TF sync fires independently of torch sync."""
|
||||
calls = []
|
||||
_deco_mod.device_sync_cache = (
|
||||
|
|
@ -878,6 +874,7 @@ class TestSyncDevicesAfter:
|
|||
"""Calling with real frameworks and a plain return value does not raise."""
|
||||
sync_devices_after(42)
|
||||
|
||||
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||
def test_jax_block_until_ready_on_real_array(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
|
|
@ -887,16 +884,12 @@ class TestSyncDevicesAfter:
|
|||
arr = jnp.array([1, 2, 3])
|
||||
sync_devices_after(arr)
|
||||
|
||||
def test_skips_jax_on_plain_value(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_skips_jax_on_plain_value(self, reset_device_cache) -> None:
|
||||
"""Does not fail when jax_available=True but return value is plain."""
|
||||
_deco_mod.device_sync_cache = (None, None, None, True)
|
||||
sync_devices_after(42)
|
||||
|
||||
def test_cuda_priority_in_after(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
def test_cuda_priority_in_after(self, reset_device_cache) -> None:
|
||||
"""CUDA sync fires instead of MPS in the after path too."""
|
||||
calls = []
|
||||
_deco_mod.device_sync_cache = (
|
||||
|
|
@ -908,9 +901,8 @@ class TestSyncDevicesAfter:
|
|||
sync_devices_after(42)
|
||||
assert ["cuda"] == calls
|
||||
|
||||
def test_all_syncs_fire_together(
|
||||
self, reset_device_cache
|
||||
) -> None:
|
||||
@pytest.mark.skipif(not has_module("jax"), reason="jax not installed")
|
||||
def test_all_syncs_fire_together(self, reset_device_cache) -> None:
|
||||
"""All applicable syncs fire: torch + JAX block_until_ready + TF."""
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ from codeflash_python.analysis._reference_graph import (
|
|||
)
|
||||
from codeflash_python.context.models import CodeStringsMarkdown
|
||||
from codeflash_python.pipeline._orchestrator import cleanup_paths
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||
from codeflash_python.test_discovery.linking import module_name_from_file_path
|
||||
from codeflash_python.testing._concolic import clean_concolic_tests
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||
from codeflash_python.testing._path_resolution import (
|
||||
file_name_from_test_module_name,
|
||||
file_path_from_module_name,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ from codeflash_python.analysis._discovery import FunctionToOptimize
|
|||
from codeflash_python.pipeline._function_optimizer import (
|
||||
write_code_and_helpers,
|
||||
)
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
|
||||
from codeflash_python.test_discovery.models import TestType
|
||||
from codeflash_python.testing._instrument_capture import (
|
||||
instrument_codeflash_capture,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -879,5 +879,3 @@ def test_concurrency_ratio_display_formatting() -> None:
|
|||
f"\u2192 {cand_ratio:.2f}x ({conc_gain:+.1f}%)"
|
||||
)
|
||||
assert display_string == "Concurrency ratio: 0.01x \u2192 0.03x (+200.0%)"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ from pathlib import Path
|
|||
|
||||
from codeflash_python._model import FunctionToOptimize, TestingMode
|
||||
from codeflash_python.test_discovery.models import CodePosition
|
||||
from codeflash_python.testing._instrument_core import detect_frameworks_from_code
|
||||
from codeflash_python.testing._instrument_core import (
|
||||
detect_frameworks_from_code,
|
||||
)
|
||||
from codeflash_python.testing._instrumentation import (
|
||||
inject_profiling_into_existing_test,
|
||||
)
|
||||
|
|
@ -200,7 +202,10 @@ def _make_func() -> FunctionToOptimize:
|
|||
|
||||
def _assert_sync_call_site_output(instrumented_code: str) -> None:
|
||||
"""Assert the sync path output has call-site tracking, not codeflash_wrap."""
|
||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented_code
|
||||
assert (
|
||||
"from codeflash_async_wrapper import _codeflash_call_site"
|
||||
in instrumented_code
|
||||
)
|
||||
assert "_codeflash_call_site.set(" in instrumented_code
|
||||
assert "codeflash_wrap" not in instrumented_code
|
||||
assert "torch.cuda.synchronize" not in instrumented_code
|
||||
|
|
|
|||
|
|
@ -151,7 +151,9 @@ def test_sort():
|
|||
assert test_results[0].did_pass
|
||||
# return_value is ((args, kwargs, return_value),) in the new path
|
||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||
out_str = "codeflash stdout: Sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
||||
out_str = (
|
||||
"codeflash stdout: Sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
|
||||
)
|
||||
assert test_results[0].stdout == out_str
|
||||
|
||||
assert test_results[1].id.function_getting_tested == "sorter"
|
||||
|
|
@ -272,9 +274,7 @@ def test_sort():
|
|||
assert test_results[1].did_pass
|
||||
assert test_results[1].return_value[0] == {"x": 0}
|
||||
|
||||
assert (
|
||||
test_results[2].id.function_getting_tested == "sorter"
|
||||
)
|
||||
assert test_results[2].id.function_getting_tested == "sorter"
|
||||
assert test_results[2].id.test_class_name is None
|
||||
assert test_results[2].id.test_function_name == "test_sort"
|
||||
assert (
|
||||
|
|
@ -285,13 +285,14 @@ def test_sort():
|
|||
assert test_results[2].did_pass
|
||||
# return_value is ((args, kwargs, return_value),) in the new path
|
||||
assert test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||
assert test_results[2].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
assert (
|
||||
test_results[2].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
)
|
||||
match, _ = compare_test_results(test_results, test_results)
|
||||
assert match
|
||||
|
||||
assert (
|
||||
test_results[3].id.function_getting_tested == "sorter"
|
||||
)
|
||||
assert test_results[3].id.function_getting_tested == "sorter"
|
||||
assert test_results[3].id.test_class_name is None
|
||||
assert test_results[3].id.test_function_name == "test_sort"
|
||||
assert (
|
||||
|
|
@ -300,7 +301,10 @@ def test_sort():
|
|||
)
|
||||
assert test_results[3].runtime > 0
|
||||
assert test_results[3].did_pass
|
||||
assert test_results[3].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
assert (
|
||||
test_results[3].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
)
|
||||
|
||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||
|
||||
|
|
@ -376,9 +380,7 @@ class BubbleSorter:
|
|||
assert new_test_results[1].did_pass
|
||||
assert new_test_results[1].return_value[0] == {"x": 1}
|
||||
|
||||
assert (
|
||||
new_test_results[2].id.function_getting_tested == "sorter"
|
||||
)
|
||||
assert new_test_results[2].id.function_getting_tested == "sorter"
|
||||
assert new_test_results[2].id.test_class_name is None
|
||||
assert new_test_results[2].id.test_function_name == "test_sort"
|
||||
assert (
|
||||
|
|
@ -389,9 +391,7 @@ class BubbleSorter:
|
|||
assert new_test_results[2].did_pass
|
||||
assert new_test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||
|
||||
assert (
|
||||
new_test_results[3].id.function_getting_tested == "sorter"
|
||||
)
|
||||
assert new_test_results[3].id.function_getting_tested == "sorter"
|
||||
assert new_test_results[3].id.test_class_name is None
|
||||
assert new_test_results[3].id.test_function_name == "test_sort"
|
||||
assert (
|
||||
|
|
@ -491,8 +491,7 @@ def test_sort():
|
|||
test_results = _run_and_parse(test_files, test_env, test_config)
|
||||
assert len(test_results) == 2
|
||||
assert (
|
||||
test_results[0].id.function_getting_tested
|
||||
== "sorter_classmethod"
|
||||
test_results[0].id.function_getting_tested == "sorter_classmethod"
|
||||
)
|
||||
assert test_results[0].id.test_class_name is None
|
||||
assert test_results[0].id.test_function_name == "test_sort"
|
||||
|
|
@ -503,13 +502,15 @@ def test_sort():
|
|||
assert test_results[0].runtime > 0
|
||||
assert test_results[0].did_pass
|
||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||
assert (
|
||||
test_results[0].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||
)
|
||||
match, _ = compare_test_results(test_results, test_results)
|
||||
assert match
|
||||
|
||||
assert (
|
||||
test_results[1].id.function_getting_tested
|
||||
== "sorter_classmethod"
|
||||
test_results[1].id.function_getting_tested == "sorter_classmethod"
|
||||
)
|
||||
assert test_results[1].id.test_class_name is None
|
||||
assert test_results[1].id.test_function_name == "test_sort"
|
||||
|
|
@ -519,7 +520,10 @@ def test_sort():
|
|||
)
|
||||
assert test_results[1].runtime > 0
|
||||
assert test_results[1].did_pass
|
||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||
assert (
|
||||
test_results[1].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
|
||||
)
|
||||
|
||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||
|
||||
|
|
@ -614,8 +618,7 @@ def test_sort():
|
|||
test_results = _run_and_parse(test_files, test_env, test_config)
|
||||
assert len(test_results) == 2
|
||||
assert (
|
||||
test_results[0].id.function_getting_tested
|
||||
== "sorter_staticmethod"
|
||||
test_results[0].id.function_getting_tested == "sorter_staticmethod"
|
||||
)
|
||||
assert test_results[0].id.test_class_name is None
|
||||
assert test_results[0].id.test_function_name == "test_sort"
|
||||
|
|
@ -626,13 +629,15 @@ def test_sort():
|
|||
assert test_results[0].runtime > 0
|
||||
assert test_results[0].did_pass
|
||||
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
|
||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||
assert (
|
||||
test_results[0].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||
)
|
||||
match, _ = compare_test_results(test_results, test_results)
|
||||
assert match
|
||||
|
||||
assert (
|
||||
test_results[1].id.function_getting_tested
|
||||
== "sorter_staticmethod"
|
||||
test_results[1].id.function_getting_tested == "sorter_staticmethod"
|
||||
)
|
||||
assert test_results[1].id.test_class_name is None
|
||||
assert test_results[1].id.test_function_name == "test_sort"
|
||||
|
|
@ -642,7 +647,10 @@ def test_sort():
|
|||
)
|
||||
assert test_results[1].runtime > 0
|
||||
assert test_results[1].did_pass
|
||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||
assert (
|
||||
test_results[1].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
|
||||
)
|
||||
|
||||
results2 = _run_and_parse(test_files, test_env, test_config)
|
||||
|
||||
|
|
|
|||
|
|
@ -755,17 +755,7 @@ async def test_multiple_calls():
|
|||
assert "_codeflash_call_site.set('14')" in instrumented_test_code
|
||||
assert "_codeflash_call_site.set('15')" in instrumented_test_code
|
||||
|
||||
assert 1 == instrumented_test_code.count(
|
||||
"_codeflash_call_site.set('8')"
|
||||
)
|
||||
assert 1 == instrumented_test_code.count(
|
||||
"_codeflash_call_site.set('13')"
|
||||
)
|
||||
assert 1 == instrumented_test_code.count(
|
||||
"_codeflash_call_site.set('14')"
|
||||
)
|
||||
assert 1 == instrumented_test_code.count(
|
||||
"_codeflash_call_site.set('15')"
|
||||
)
|
||||
|
||||
|
||||
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('8')")
|
||||
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('13')")
|
||||
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('14')")
|
||||
assert 1 == instrumented_test_code.count("_codeflash_call_site.set('15')")
|
||||
|
|
|
|||
|
|
@ -37,8 +37,9 @@ class TestGetSyncDecoratorNameForMode:
|
|||
|
||||
def test_performance_mode(self) -> None:
|
||||
"""Returns codeflash_performance_sync for PERFORMANCE."""
|
||||
assert "codeflash_performance_sync" == get_sync_decorator_name_for_mode(
|
||||
TestingMode.PERFORMANCE
|
||||
assert (
|
||||
"codeflash_performance_sync"
|
||||
== get_sync_decorator_name_for_mode(TestingMode.PERFORMANCE)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -141,9 +142,7 @@ class TestSyncCallInstrumenter:
|
|||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 13)]
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)])
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
|
|
@ -164,9 +163,7 @@ class TestSyncCallInstrumenter:
|
|||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 13)]
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)])
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
|
|
@ -206,10 +203,7 @@ class TestSyncCallInstrumenter:
|
|||
|
||||
def test_skips_non_test_functions(self) -> None:
|
||||
"""Does not instrument functions that don't start with test_."""
|
||||
test_code = (
|
||||
"def helper():\n"
|
||||
" return my_func(1)\n"
|
||||
)
|
||||
test_code = "def helper():\n return my_func(1)\n"
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
|
|
@ -217,9 +211,7 @@ class TestSyncCallInstrumenter:
|
|||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 11)]
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 11)])
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert not instrumenter.did_instrument
|
||||
|
|
@ -238,9 +230,7 @@ class TestSyncCallInstrumenter:
|
|||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(3, 17)]
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(func, [CodePosition(3, 17)])
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
|
|
@ -249,10 +239,7 @@ class TestSyncCallInstrumenter:
|
|||
|
||||
def test_no_match_when_position_wrong(self) -> None:
|
||||
"""Does not instrument if code position doesn't match."""
|
||||
test_code = (
|
||||
"def test_example():\n"
|
||||
" result = my_func(1)\n"
|
||||
)
|
||||
test_code = "def test_example():\n result = my_func(1)\n"
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
|
|
@ -260,9 +247,7 @@ class TestSyncCallInstrumenter:
|
|||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(99, 99)]
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(func, [CodePosition(99, 99)])
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert not instrumenter.did_instrument
|
||||
|
|
@ -368,9 +353,7 @@ class TestAddSyncDecoratorToFunction:
|
|||
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
||||
"""Adds codeflash decorator below @staticmethod/@classmethod."""
|
||||
source_code = (
|
||||
"@staticmethod\n"
|
||||
"def my_func(x: int) -> int:\n"
|
||||
" return x + 1\n"
|
||||
"@staticmethod\ndef my_func(x: int) -> int:\n return x + 1\n"
|
||||
)
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
|
@ -450,7 +433,10 @@ class TestInjectSyncProfilingIntoExistingTest:
|
|||
assert success
|
||||
assert instrumented is not None
|
||||
assert "_codeflash_call_site.set('4')" in instrumented
|
||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
|
||||
assert (
|
||||
"from codeflash_async_wrapper import _codeflash_call_site"
|
||||
in instrumented
|
||||
)
|
||||
|
||||
def test_multiple_calls_use_line_numbers(self, temp_dir) -> None:
|
||||
"""Multiple calls use their source line numbers as call-site IDs."""
|
||||
|
|
@ -512,10 +498,7 @@ class TestInjectSyncProfilingIntoExistingTest:
|
|||
|
||||
def test_returns_false_when_no_target_calls(self, temp_dir) -> None:
|
||||
"""Returns (False, None) when no target function calls are found."""
|
||||
test_code = (
|
||||
"def test_unrelated():\n"
|
||||
" assert 1 == 1\n"
|
||||
)
|
||||
test_code = "def test_unrelated():\n assert 1 == 1\n"
|
||||
test_file = temp_dir / "test_noop.py"
|
||||
test_file.write_text(test_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,10 @@ def test_prepare_image_for_yolo():
|
|||
assert new_test is not None
|
||||
_assert_sync_instrumentation_present(new_test)
|
||||
_assert_old_instrumentation_absent(new_test)
|
||||
assert "packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo" in new_test
|
||||
assert (
|
||||
"packagename_ml_yolo_image_reshaping_utils_prepare_image_for_yolo"
|
||||
in new_test
|
||||
)
|
||||
assert "def test_prepare_image_for_yolo" in new_test
|
||||
assert "_codeflash_call_site.set(" in new_test
|
||||
|
||||
|
|
|
|||
|
|
@ -49,16 +49,13 @@ def test_single_element_list():
|
|||
project_root
|
||||
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
||||
).resolve()
|
||||
tests_root = (
|
||||
project_root / "code_to_optimize/tests/pytest/"
|
||||
)
|
||||
tests_root = project_root / "code_to_optimize/tests/pytest/"
|
||||
project_root_path = project_root
|
||||
run_cwd = project_root
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(run_cwd)
|
||||
fto_path = (
|
||||
project_root
|
||||
/ "code_to_optimize/bubble_sort_method.py"
|
||||
project_root / "code_to_optimize/bubble_sort_method.py"
|
||||
).resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
|
||||
|
|
@ -127,7 +124,10 @@ def test_single_element_list():
|
|||
run_result=run_result,
|
||||
)
|
||||
assert test_results[0].id.function_getting_tested == "sorter"
|
||||
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
assert (
|
||||
test_results[0].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
)
|
||||
assert (
|
||||
test_results[0].id.test_function_name == "test_single_element_list"
|
||||
)
|
||||
|
|
@ -216,14 +216,11 @@ def test_single_element_list():
|
|||
project_root
|
||||
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
|
||||
).resolve()
|
||||
tests_root = (
|
||||
project_root / "code_to_optimize/tests/pytest/"
|
||||
)
|
||||
tests_root = project_root / "code_to_optimize/tests/pytest/"
|
||||
project_root_path = project_root
|
||||
|
||||
fto_path = (
|
||||
project_root
|
||||
/ "code_to_optimize/bubble_sort_method.py"
|
||||
project_root / "code_to_optimize/bubble_sort_method.py"
|
||||
).resolve()
|
||||
original_code = fto_path.read_text("utf-8")
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
|
|
@ -316,7 +313,10 @@ def test_single_element_list():
|
|||
assert test_results[1].did_pass
|
||||
# return_value is ((args, kwargs, return_value),) in the new path
|
||||
assert test_results[1].return_value[0][2] == [1, 2, 3]
|
||||
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
assert (
|
||||
test_results[1].stdout
|
||||
== "codeflash stdout : BubbleSorter.sorter() called\n"
|
||||
)
|
||||
|
||||
# Replace with optimized code that mutated instance attribute
|
||||
optimized_code_mutated_attr = """
|
||||
|
|
@ -442,9 +442,7 @@ class BubbleSorter:
|
|||
# In the new decorator-based path, args (including self) are captured.
|
||||
# Adding a new instance attribute changes self, so the comparison
|
||||
# detects a difference even though codeflash_capture considers it additive.
|
||||
match, _ = compare_test_results(
|
||||
test_results, test_results_new_attr
|
||||
)
|
||||
match, _ = compare_test_results(test_results, test_results_new_attr)
|
||||
assert not match
|
||||
finally:
|
||||
fto_path.write_text(original_code, "utf-8")
|
||||
|
|
|
|||
Loading…
Reference in a new issue