Rewrite sync instrumentation to decorator-based approach

Replace the old AST-injected codeflash_wrap/InjectPerfOnly sync path
with decorator-based instrumentation matching the async path:

- Add codeflash_performance_sync and codeflash_behavior_sync decorators
  with GPU device sync (torch CUDA/MPS, JAX, TensorFlow) via find_spec
- Add sync_devices_before/sync_devices_after with lazy cached detection
- Clean _instrumentation.py to a thin sync/async dispatcher (~47 lines)
- Remove dead code from _instrument_core.py (create_wrapper_function,
  create_device_sync_statements, get_call_arguments, etc.)
- Fix all production imports to point at source modules directly
- Drop underscore prefixes on internal helpers (connections, get_async_db,
  close_all_connections, detect_device_sync, etc.)
- Rewrite all test files for the new sync path assertions
- Add real-framework GPU device sync tests (torch, jax, tensorflow)
This commit is contained in:
Kevin Turcios 2026-04-24 05:54:32 -05:00
parent 918a2a10a4
commit ca951dd1f3
22 changed files with 1166 additions and 6320 deletions

View file

@ -112,8 +112,10 @@ async def run_concurrency_benchmark(
from ..testing._async_data_parser import ( # noqa: PLC0415
parse_async_concurrency_metrics,
)
from ..testing._instrumentation import ( # noqa: PLC0415
from ..testing._instrument_async import ( # noqa: PLC0415
add_async_decorator_to_function,
)
from ..testing._instrument_capture import ( # noqa: PLC0415
revert_instrumented_files,
)

View file

@ -80,7 +80,7 @@ _codeflash_call_site: contextvars.ContextVar[str] = contextvars.ContextVar(
"codeflash_call_site", default=""
)
_CREATE_TABLE_SQL = (
CREATE_TABLE_SQL = (
"CREATE TABLE IF NOT EXISTS async_results ("
"test_module_path TEXT NOT NULL, "
"test_class_name TEXT, "
@ -99,34 +99,123 @@ _CREATE_TABLE_SQL = (
")"
)
_connections: dict[str, sqlite3.Connection] = {}
connections: dict[str, sqlite3.Connection] = {}
def _get_async_db(
def get_async_db(
db_path: Path,
) -> tuple[sqlite3.Connection, sqlite3.Cursor]:
"""Get or create a cached SQLite connection."""
key = str(db_path)
if key not in _connections:
if key not in connections:
conn = sqlite3.connect(db_path)
conn.execute(_CREATE_TABLE_SQL)
_connections[key] = conn
conn = _connections[key]
conn.execute(CREATE_TABLE_SQL)
connections[key] = conn
conn = connections[key]
return conn, conn.cursor()
def _close_all_connections() -> None:
def close_all_connections() -> None:
"""Commit and close all cached connections."""
for conn in _connections.values():
for conn in connections.values():
try:
conn.commit()
conn.close()
except Exception: # noqa: PERF203, S110
pass
_connections.clear()
connections.clear()
atexit.register(_close_all_connections)
atexit.register(close_all_connections)
def detect_device_sync() -> (
tuple[
Any | None, # torch_cuda_sync
Any | None, # torch_mps_sync
Any | None, # tf_sync
bool, # jax_available
]
):
"""Detect available GPU frameworks and return sync callables.
Called once at first decorator invocation; results are cached.
Uses ``find_spec`` to avoid importing heavy frameworks when absent.
"""
from importlib.util import find_spec # noqa: PLC0415
torch_cuda_sync = None
torch_mps_sync = None
tf_sync = None
jax_available = False
if find_spec("torch") is not None:
import torch # noqa: PLC0415
if torch.cuda.is_available() and torch.cuda.is_initialized():
torch_cuda_sync = torch.cuda.synchronize
elif (
hasattr(torch, "backends")
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
and hasattr(torch, "mps")
and hasattr(torch.mps, "synchronize")
):
torch_mps_sync = torch.mps.synchronize
if find_spec("jax") is not None:
import jax # type: ignore[import-untyped] # noqa: PLC0415
jax_available = hasattr(jax, "block_until_ready")
if find_spec("tensorflow") is not None:
import tensorflow as tf # type: ignore[import-untyped] # noqa: PLC0415
if hasattr(tf.test, "experimental") and hasattr(
tf.test.experimental, "sync_devices"
):
tf_sync = tf.test.experimental.sync_devices
return torch_cuda_sync, torch_mps_sync, tf_sync, jax_available
device_sync_cache: (
tuple[Any | None, Any | None, Any | None, bool] | None
) = None
def get_device_sync() -> (
tuple[Any | None, Any | None, Any | None, bool]
):
"""Return cached device sync callables, detecting on first call."""
global device_sync_cache # noqa: PLW0603
if device_sync_cache is None:
device_sync_cache = detect_device_sync()
return device_sync_cache
def sync_devices_before() -> None:
"""Synchronize GPU devices before timing."""
cuda_sync, mps_sync, tf_sync, _ = get_device_sync()
if cuda_sync is not None:
cuda_sync()
elif mps_sync is not None:
mps_sync()
if tf_sync is not None:
tf_sync()
def sync_devices_after(return_value: Any) -> None:
"""Synchronize GPU devices after function call."""
cuda_sync, mps_sync, tf_sync, jax_available = get_device_sync()
if cuda_sync is not None:
cuda_sync()
elif mps_sync is not None:
mps_sync()
if jax_available and hasattr(return_value, "block_until_ready"):
return_value.block_until_ready()
if tf_sync is not None:
tf_sync()
def codeflash_behavior_sync(func: F) -> F:
@ -164,17 +253,19 @@ def codeflash_behavior_sync(func: F) -> F:
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
exception = None
captured_stdout = io.StringIO()
old_stdout = sys.stdout
sys.stdout = captured_stdout
sync_devices_before()
gc.disable()
try:
counter = time.perf_counter_ns()
cpu_counter = time.thread_time_ns()
return_value = func(*args, **kwargs)
sync_devices_after(return_value)
wall_time = time.perf_counter_ns() - counter
cpu_time = time.thread_time_ns() - cpu_counter
except Exception as e:
@ -221,6 +312,86 @@ def codeflash_behavior_sync(func: F) -> F:
return wrapper # type: ignore[return-value]
def codeflash_performance_sync(func: F) -> F:
"""
Measure sync execution time for performance tests.
Results are written to the async_results SQLite table.
"""
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
function_name = func.__name__
call_site = _codeflash_call_site.get()
loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
(
test_module_name,
test_class_name,
test_name,
) = extract_test_context_from_env()
test_id = (
f"{test_module_name}:{test_class_name}"
f":{test_name}:{call_site}:{loop_index}"
)
if not hasattr(wrapper, "index"):
wrapper.index = {} # type: ignore[attr-defined]
if test_id in wrapper.index: # type: ignore[attr-defined]
wrapper.index[test_id] += 1 # type: ignore[attr-defined]
else:
wrapper.index[test_id] = 0 # type: ignore[attr-defined]
call_index = wrapper.index[test_id] # type: ignore[attr-defined]
invocation_id = f"{call_site}_{call_index}"
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = get_async_db(db_path)
exception = None
sync_devices_before()
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = func(*args, **kwargs)
sync_devices_after(return_value)
wall_time = time.perf_counter_ns() - counter
except Exception as e:
wall_time = time.perf_counter_ns() - counter
exception = e
finally:
gc.enable()
cur.execute(
"INSERT INTO async_results VALUES "
"(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
test_module_name,
test_class_name,
test_name,
function_name,
loop_index,
invocation_id,
"performance",
wall_time,
None,
None,
None,
None,
None,
None,
),
)
conn.commit()
if exception:
raise exception
return return_value
return wrapper # type: ignore[return-value]
def codeflash_behavior_async(func: F) -> F:
"""
Capture async return values and timing for behavioral tests.
@ -257,7 +428,7 @@ def codeflash_behavior_async(func: F) -> F:
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
exception = None
captured_stdout = io.StringIO()
@ -348,7 +519,7 @@ def codeflash_performance_async(func: F) -> F:
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
exception = None
counter = loop.time()
@ -415,7 +586,7 @@ def codeflash_concurrency_async(func: F) -> F:
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
db_path = get_run_tmp_file(Path(f"async_results_{iteration}.sqlite"))
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
gc.disable()
try:
@ -469,6 +640,7 @@ __all__ = [
"codeflash_behavior_sync",
"codeflash_concurrency_async",
"codeflash_performance_async",
"codeflash_performance_sync",
"extract_test_context_from_env",
"get_run_tmp_file",
]

View file

@ -43,7 +43,7 @@ def parse_sqlite_test_results(
" function_getting_tested, loop_index,"
" iteration_id, runtime,"
" return_value, verification_type,"
" cpu_runtime, stdout"
" cpu_runtime"
" FROM test_results"
).fetchall()
except sqlite3.Error:
@ -101,7 +101,6 @@ def _process_sqlite_row_inner(
runtime = val[6]
verification_type = val[8]
cpu_runtime = val[9]
stdout_text = val[10] if len(val) > 10 else None
test_file_path = file_path_from_module_name(
test_module_path, # type: ignore[arg-type]

View file

@ -36,6 +36,7 @@ log = logging.getLogger(__name__)
_CODEFLASH_SYNC_DECORATORS = frozenset(
{
"codeflash_behavior_sync",
"codeflash_performance_sync",
}
)
@ -202,6 +203,11 @@ class SyncDecoratorAdder(cst.CSTTransformer):
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name),
)
if self._has_descriptor_decorator(original_node):
updated_node = updated_node.with_changes(
decorators=(*updated_node.decorators, new_decorator),
)
else:
updated_node = updated_node.with_changes(
decorators=(new_decorator, *updated_node.decorators),
)
@ -224,13 +230,26 @@ class SyncDecoratorAdder(cst.CSTTransformer):
return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS
return False
@staticmethod
def _has_descriptor_decorator(
node: cst.FunctionDef,
) -> bool:
"""Check if the function has @classmethod or @staticmethod."""
for d in node.decorators:
if isinstance(d.decorator, cst.Name) and d.decorator.value in (
"classmethod",
"staticmethod",
):
return True
return False
def get_sync_decorator_name_for_mode(
mode: TestingMode,
) -> str:
"""Return the sync decorator function name for the given testing mode."""
if mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_sync"
if mode == TestingMode.PERFORMANCE:
return "codeflash_performance_sync"
return "codeflash_behavior_sync"

View file

@ -1,96 +1,20 @@
"""AST transformers for test instrumentation.
"""Test instrumentation dispatcher.
Provides the ``InjectPerfOnly`` transformer that rewrites existing test
functions to wrap target-function calls with timing and capture logic,
and supporting transformers for async functions.
This module re-exports the full public API from its sub-modules so that
existing callers can continue to import from ``_instrumentation`` without
changes.
Delegates to sync or async instrumentation paths based on
``FunctionToOptimize.is_async``.
"""
from __future__ import annotations
import ast
import logging
from pathlib import Path
from typing import TYPE_CHECKING
from .._model import (
TestingMode,
)
from ..analysis._formatter import (
sort_imports as sort_imports, # noqa: PLC0414
)
from ..runtime._codeflash_wrap_decorator import (
get_run_tmp_file as get_run_tmp_file, # noqa: PLC0414
)
from ._instrument_async import (
ASYNC_HELPER_FILENAME as ASYNC_HELPER_FILENAME, # noqa: PLC0414
)
from ._instrument_async import (
AsyncCallInstrumenter as AsyncCallInstrumenter, # noqa: PLC0414
)
from ._instrument_async import (
AsyncDecoratorAdder as AsyncDecoratorAdder, # noqa: PLC0414
)
from ._instrument_async import (
add_async_decorator_to_function as add_async_decorator_to_function, # noqa: PLC0414
)
from ._instrument_async import (
get_decorator_name_for_mode as get_decorator_name_for_mode, # noqa: PLC0414
)
from ._instrument_async import (
inject_async_profiling_into_existing_test as inject_async_profiling_into_existing_test, # noqa: PLC0414
)
from ._instrument_async import (
write_async_helper_file as write_async_helper_file, # noqa: PLC0414
)
from ._instrument_capture import (
InitDecorator as InitDecorator, # noqa: PLC0414
)
from ._instrument_capture import (
add_codeflash_capture_to_init as add_codeflash_capture_to_init, # noqa: PLC0414
)
from ._instrument_capture import (
create_instrumented_source_module_path as create_instrumented_source_module_path, # noqa: PLC0414
)
from ._instrument_capture import (
instrument_codeflash_capture as instrument_codeflash_capture, # noqa: PLC0414
)
from ._instrument_capture import (
revert_instrumented_files as revert_instrumented_files, # noqa: PLC0414
)
from ._instrument_core import (
FunctionCallNodeArguments as FunctionCallNodeArguments, # noqa: PLC0414
)
from ._instrument_core import (
FunctionImportedAsVisitor as FunctionImportedAsVisitor, # noqa: PLC0414
)
from ._instrument_core import (
create_device_sync_precompute_statements as create_device_sync_precompute_statements, # noqa: PLC0414
)
from ._instrument_core import (
create_device_sync_statements as create_device_sync_statements, # noqa: PLC0414
)
from ._instrument_core import (
create_wrapper_function as create_wrapper_function, # noqa: PLC0414
)
from ._instrument_core import (
detect_frameworks_from_code as detect_frameworks_from_code, # noqa: PLC0414
)
from ._instrument_core import (
get_call_arguments as get_call_arguments, # noqa: PLC0414
)
from ._instrument_core import (
is_argument_name as is_argument_name, # noqa: PLC0414
)
from ._instrument_core import (
node_in_call_position as node_in_call_position, # noqa: PLC0414
)
from .._model import TestingMode
from ._instrument_async import inject_async_profiling_into_existing_test
from ._instrument_sync import inject_sync_profiling_into_existing_test
if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path
from .._model import FunctionToOptimize
from ..test_discovery.models import CodePosition
@ -98,529 +22,6 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
class InjectPerfOnly(ast.NodeTransformer):
"""Inject performance profiling into existing test functions."""
def __init__(
self,
function: FunctionToOptimize,
module_path: str,
call_positions: list[CodePosition],
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
"""Initialize with the target function, module path, and testing mode."""
self.mode: TestingMode = mode
self.function_object = function
self.class_name: str | None = None
self.only_function_name = function.function_name
self.module_path = module_path
self.call_positions = call_positions
if (
len(function.parents) == 1
and function.parents[0].type == "ClassDef"
):
self.class_name = function.parents[0].name
def find_and_update_line_node(
self,
test_node: ast.stmt,
node_name: str,
index: str,
test_class_name: str | None = None,
) -> Iterable[ast.stmt] | None:
"""Find and rewrite target function calls within a test statement."""
# ast.walk is expensive for big trees and only
# checks for ast.Call, so visit nodes manually.
# Only descend into expressions/statements.
# Helper for manual walk
def iter_ast_calls(node: ast.AST) -> Iterable[ast.Call]:
"""Yield all ast.Call nodes reachable from the given node."""
# Yield each ast.Call in test_node
stack = [node]
while stack:
n = stack.pop()
if isinstance(n, ast.Call):
yield n
# Specialized BFS instead of ast.walk
# for less overhead
for _field, value in ast.iter_fields(n):
if isinstance(value, list):
stack.extend(
item
for item in reversed(value)
if isinstance(item, ast.AST)
)
elif isinstance(value, ast.AST):
stack.append(value)
# Single stack instead of O(N) stack-frames
# per child-node, less Python call overhead.
return_statement = [test_node]
call_node = None
# Convert mode, function_name, etc. to locals
fn_obj = self.function_object
module_path = self.module_path
mode = self.mode
qualified_name = fn_obj.qualified_name
# Use locals for all 'current' values,
# look up AST objects once.
codeflash_loop_index = ast.Name(
id="codeflash_loop_index", ctx=ast.Load()
)
codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load())
codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())
for node in iter_ast_calls(test_node):
if not node_in_call_position(node, self.call_positions):
continue
call_node = node
all_args = get_call_arguments(call_node)
# Two possible call types: Name and Attribute
node_func = node.func
if isinstance(node_func, ast.Name):
function_name = node_func.id
# Check if this is the function we want to instrument
if function_name != fn_obj.function_name:
continue
if fn_obj.is_async:
return [test_node]
# Build once, reuse objects.
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
bind_call = ast.Assign(
targets=[
ast.Name(id="_call__bound__arguments", ctx=ast.Store())
],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=inspect_name,
attr="signature",
ctx=ast.Load(),
),
args=[
ast.Name(id=function_name, ctx=ast.Load())
],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)
apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments", ctx=ast.Load()
),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
base_args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
codeflash_loop_index,
]
# Extend with BEHAVIOR extras if needed
if mode == TestingMode.BEHAVIOR:
base_args += [codeflash_cur, codeflash_con]
# Extend with call args (perf)
# or starred bound args (behavior)
if mode == TestingMode.PERFORMANCE:
base_args += call_node.args
else:
base_args.append(
ast.Starred(
value=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments",
ctx=ast.Load(),
),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
)
node.args = base_args
# Prepare keywords
if mode == TestingMode.BEHAVIOR:
node.keywords = [
ast.keyword(
value=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments",
ctx=ast.Load(),
),
attr="kwargs",
ctx=ast.Load(),
)
)
]
else:
node.keywords = call_node.keywords
return_statement = (
[bind_call, apply_defaults, test_node]
if mode == TestingMode.BEHAVIOR
else [test_node]
)
break
if isinstance(node_func, ast.Attribute):
function_to_test = node_func.attr
if function_to_test == fn_obj.function_name:
if fn_obj.is_async:
return [test_node]
# Create the signature binding statements
# Unparse only once
function_name_expr = ast.parse(
ast.unparse(node_func), mode="eval"
).body
inspect_name = ast.Name(id="inspect", ctx=ast.Load())
bind_call = ast.Assign(
targets=[
ast.Name(
id="_call__bound__arguments", ctx=ast.Store()
)
],
value=ast.Call(
func=ast.Attribute(
value=ast.Call(
func=ast.Attribute(
value=inspect_name,
attr="signature",
ctx=ast.Load(),
),
args=[function_name_expr],
keywords=[],
),
attr="bind",
ctx=ast.Load(),
),
args=all_args.args,
keywords=all_args.keywords,
),
lineno=test_node.lineno,
col_offset=test_node.col_offset,
)
apply_defaults = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments",
ctx=ast.Load(),
),
attr="apply_defaults",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=test_node.lineno + 1,
col_offset=test_node.col_offset,
)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
base_args = [
function_name_expr,
ast.Constant(value=module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=qualified_name),
ast.Constant(value=index),
codeflash_loop_index,
]
if mode == TestingMode.BEHAVIOR:
base_args += [codeflash_cur, codeflash_con]
if mode == TestingMode.PERFORMANCE:
base_args += call_node.args
else:
base_args.append(
ast.Starred(
value=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments",
ctx=ast.Load(),
),
attr="args",
ctx=ast.Load(),
),
ctx=ast.Load(),
)
)
node.args = base_args
if mode == TestingMode.BEHAVIOR:
node.keywords = [
ast.keyword(
value=ast.Attribute(
value=ast.Name(
id="_call__bound__arguments",
ctx=ast.Load(),
),
attr="kwargs",
ctx=ast.Load(),
)
)
]
else:
node.keywords = call_node.keywords
# Return the signature binding
# statements with the test_node
return_statement = (
[bind_call, apply_defaults, test_node]
if mode == TestingMode.BEHAVIOR
else [test_node]
)
break
if call_node is None:
return None
return return_statement
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Visit test methods inside a class definition."""
# TODO: Ensure this class inherits from
# unittest.TestCase.
for inner_node in ast.walk(node):
if isinstance(inner_node, ast.FunctionDef):
self.visit_FunctionDef(inner_node, node.name)
return node
def visit_FunctionDef(
self, node: ast.FunctionDef, test_class_name: str | None = None
) -> ast.FunctionDef:
"""Instrument a test function by wrapping target function calls."""
if node.name.startswith("test_"):
did_update = False
i = len(node.body) - 1
while i >= 0:
line_node = node.body[i]
# TODO: Validate that the call
# did not raise exceptions
if isinstance(
line_node, (ast.With, ast.For, ast.While, ast.If)
):
j = len(line_node.body) - 1
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST
for internal_node in ast.walk(compound_line_node):
if isinstance(
internal_node, (ast.stmt, ast.Assign)
):
updated_node = self.find_and_update_line_node(
internal_node,
node.name,
str(i) + "_" + str(j),
test_class_name,
)
if updated_node is not None:
line_node.body[j : j + 1] = updated_node
did_update = True
break
j -= 1
else:
updated_node = self.find_and_update_line_node(
line_node, node.name, str(i), test_class_name
)
if updated_node is not None:
node.body[i : i + 1] = updated_node
did_update = True
i -= 1
if did_update:
node.body = [
ast.Assign(
targets=[
ast.Name(
id="codeflash_loop_index", ctx=ast.Store()
)
],
value=ast.Call(
func=ast.Name(id="int", ctx=ast.Load()),
args=[
ast.Subscript(
value=ast.Attribute(
value=ast.Name(
id="os", ctx=ast.Load()
),
attr="environ",
ctx=ast.Load(),
),
slice=ast.Constant(
value="CODEFLASH_LOOP_INDEX"
),
ctx=ast.Load(),
)
],
keywords=[],
),
lineno=node.lineno + 2,
col_offset=node.col_offset,
),
*(
[
ast.Assign(
targets=[
ast.Name(
id="codeflash_iteration",
ctx=ast.Store(),
)
],
value=ast.Subscript(
value=ast.Attribute(
value=ast.Name(
id="os", ctx=ast.Load()
),
attr="environ",
ctx=ast.Load(),
),
slice=ast.Constant(
value="CODEFLASH_TEST_ITERATION"
),
ctx=ast.Load(),
),
lineno=node.lineno + 1,
col_offset=node.col_offset,
),
ast.Assign(
targets=[
ast.Name(
id="codeflash_con", ctx=ast.Store()
)
],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="sqlite3", ctx=ast.Load()
),
attr="connect",
ctx=ast.Load(),
),
args=[
ast.JoinedStr(
values=[
ast.Constant(
value=f"{get_run_tmp_file(Path('test_return_values_')).as_posix()}"
),
ast.FormattedValue(
value=ast.Name(
id="codeflash_iteration",
ctx=ast.Load(),
),
conversion=-1,
),
ast.Constant(value=".sqlite"),
]
)
],
keywords=[],
),
lineno=node.lineno + 3,
col_offset=node.col_offset,
),
ast.Assign(
targets=[
ast.Name(
id="codeflash_cur", ctx=ast.Store()
)
],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="codeflash_con", ctx=ast.Load()
),
attr="cursor",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
lineno=node.lineno + 4,
col_offset=node.col_offset,
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="codeflash_cur", ctx=ast.Load()
),
attr="execute",
ctx=ast.Load(),
),
args=[
ast.Constant(
value="CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT,"
" test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT,"
" loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB,"
" verification_type TEXT, cpu_runtime INTEGER)"
)
],
keywords=[],
),
lineno=node.lineno + 5,
col_offset=node.col_offset,
),
]
if self.mode == TestingMode.BEHAVIOR
else []
),
*node.body,
*(
[
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="codeflash_con", ctx=ast.Load()
),
attr="close",
ctx=ast.Load(),
),
args=[],
keywords=[],
)
)
]
if self.mode == TestingMode.BEHAVIOR
else []
),
]
return node
def inject_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
@ -630,76 +31,18 @@ def inject_profiling_into_existing_test(
) -> tuple[bool, str | None]:
"""Inject instrumentation into an existing test file.
For sync functions, applies the ``InjectPerfOnly`` transformer.
For async functions, delegates to async-specific instrumentation.
For async functions, delegates to async-specific call-site injection.
For sync functions, delegates to sync-specific call-site injection.
Returns *(did_instrument, modified_source)*.
"""
tests_project_root = tests_project_root.resolve()
if function_to_optimize.is_async:
return inject_async_profiling_into_existing_test(
test_path,
call_positions,
function_to_optimize,
)
with test_path.open(encoding="utf8") as f:
test_code = f.read()
used_frameworks = detect_frameworks_from_code(test_code)
try:
tree = ast.parse(test_code)
except SyntaxError:
log.exception("Syntax error in code in file - %s", test_path)
return False, None
from ..test_discovery.linking import ( # noqa: PLC0415
module_name_from_file_path,
return inject_sync_profiling_into_existing_test(
test_path,
call_positions,
function_to_optimize,
)
test_module_path = module_name_from_file_path(
test_path, tests_project_root
)
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
func = import_visitor.imported_as
tree = InjectPerfOnly(
func, test_module_path, call_positions, mode=mode
).visit(tree)
new_imports: list[ast.stmt] = [
ast.Import(names=[ast.alias(name="time")]),
ast.Import(names=[ast.alias(name="gc")]),
ast.Import(names=[ast.alias(name="os")]),
]
if mode == TestingMode.BEHAVIOR:
new_imports.extend(
[
ast.Import(names=[ast.alias(name="inspect")]),
ast.Import(names=[ast.alias(name="sqlite3")]),
ast.Import(names=[ast.alias(name="dill", asname="pickle")]),
]
)
for framework_name, framework_alias in used_frameworks.items():
if framework_alias == framework_name:
new_imports.append(
ast.Import(names=[ast.alias(name=framework_name)])
)
else:
new_imports.append(
ast.Import(
names=[
ast.alias(
name=framework_name,
asname=framework_alias,
)
]
)
)
additional_functions = [create_wrapper_function(mode, used_frameworks)]
tree.body = [
*new_imports,
*additional_functions,
*tree.body,
]
return True, sort_imports(ast.unparse(tree), float_to_top=True)

View file

@ -103,6 +103,7 @@ def _merge_by_iteration_id(
merged: TestResults,
) -> None:
"""Merge XML and data results by iteration id."""
matched_data_ids: set[str] = set()
for xml_result in xml_results.test_results:
data_result = data_results.get_by_unique_invocation_loop_id(
xml_result.unique_invocation_loop_id,
@ -110,6 +111,7 @@ def _merge_by_iteration_id(
if data_result is None:
merged.add(xml_result)
continue
matched_data_ids.add(xml_result.unique_invocation_loop_id)
merged_runtime = data_result.runtime or xml_result.runtime
merged.add(
FunctionTestInvocation(
@ -130,9 +132,12 @@ def _merge_by_iteration_id(
if data_result.verification_type
else None
),
stdout=xml_result.stdout,
stdout=data_result.stdout or xml_result.stdout,
),
)
for data_result in data_results.test_results:
if data_result.unique_invocation_loop_id not in matched_data_ids:
merged.add(data_result)
def _merge_by_index(
@ -170,6 +175,6 @@ def _merge_by_index(
if data_result.verification_type
else None
),
stdout=xml_result.stdout,
stdout=data_result.stdout or xml_result.stdout,
),
)

View file

@ -346,7 +346,7 @@ def add_async_perf_decorator(
return {}
from .._model import TestingMode # noqa: PLC0415
from ..testing._instrumentation import ( # noqa: PLC0415
from ..testing._instrument_async import ( # noqa: PLC0415
add_async_decorator_to_function,
)
@ -369,7 +369,7 @@ def revert_async_decorator(originals: dict[Path, str]) -> None:
if not originals:
return
from ..testing._instrumentation import ( # noqa: PLC0415
from ..testing._instrument_capture import ( # noqa: PLC0415
revert_instrumented_files,
)

View file

@ -1,41 +1,48 @@
import sys
from codeflash_async_wrapper import codeflash_behavior_sync
from codeflash_python.runtime._codeflash_capture import codeflash_capture
class BubbleSorter:
@codeflash_capture(function_name='BubbleSorter.__init__', tmp_dir_path='/var/folders/mg/k_c0twcj37q_gph3cfy3zlt80000gn/T/codeflash_34197et8/test_return_values', tests_root='/Users/krrt7/Desktop/work/cf_org/codeflash-agent/.claude/worktrees/jaunty-sauteeing-dolphin/packages/codeflash-python/tests/code_to_optimize/tests/pytest', is_fto=True)
def __init__(self, x=0):
self.x = x
@codeflash_behavior_sync
def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
print('codeflash stdout : BubbleSorter.sorter() called')
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test", file=sys.stderr)
print('stderr test', file=sys.stderr)
return arr
@classmethod
def sorter_classmethod(cls, arr):
print("codeflash stdout : BubbleSorter.sorter_classmethod() called")
print('codeflash stdout : BubbleSorter.sorter_classmethod() called')
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test classmethod", file=sys.stderr)
print('stderr test classmethod', file=sys.stderr)
return arr
@staticmethod
def sorter_staticmethod(arr):
print("codeflash stdout : BubbleSorter.sorter_staticmethod() called")
print('codeflash stdout : BubbleSorter.sorter_staticmethod() called')
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
print("stderr test staticmethod", file=sys.stderr)
print('stderr test staticmethod', file=sys.stderr)
return arr

View file

@ -10,7 +10,7 @@ import dill as pickle
import pytest
from codeflash_python.runtime._codeflash_async_decorators import (
_CREATE_TABLE_SQL,
CREATE_TABLE_SQL,
VerificationType,
)
from codeflash_python.testing._async_data_parser import (
@ -48,7 +48,7 @@ def _create_async_db(
) -> None:
"""Create an async_results SQLite DB with the given rows."""
conn = sqlite3.connect(db_path)
conn.execute(_CREATE_TABLE_SQL)
conn.execute(CREATE_TABLE_SQL)
for row in rows:
conn.execute(
"INSERT INTO async_results VALUES "

View file

@ -13,16 +13,23 @@ from unittest.mock import patch
import dill as pickle
import pytest
import codeflash_python.runtime._codeflash_async_decorators as _deco_mod
from codeflash_python.runtime._codeflash_async_decorators import (
VerificationType,
_close_all_connections,
close_all_connections,
_codeflash_call_site,
_connections,
_get_async_db,
connections,
detect_device_sync,
get_async_db,
get_device_sync,
sync_devices_after,
sync_devices_before,
codeflash_behavior_async,
codeflash_behavior_sync,
codeflash_concurrency_async,
codeflash_performance_async,
codeflash_performance_sync,
extract_test_context_from_env,
get_run_tmp_file,
)
@ -51,7 +58,7 @@ def _env_setup(request, tmp_path):
else:
os.environ[key] = original_value
_close_all_connections()
close_all_connections()
@pytest.fixture(name="async_db_path")
@ -119,36 +126,36 @@ class TestCodeflashCallSite:
class TestGetAsyncDb:
"""_get_async_db connection caching."""
"""get_async_db connection caching."""
def test_creates_table(self, tmp_path) -> None:
"""Creates the async_results table on first connect."""
db_path = tmp_path / "test.sqlite"
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
cur.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name='async_results'"
)
assert cur.fetchone() is not None
conn.close()
_connections.pop(str(db_path), None)
connections.pop(str(db_path), None)
def test_caches_connection(self, tmp_path) -> None:
"""Returns the same connection on repeated calls."""
db_path = tmp_path / "test.sqlite"
conn1, _ = _get_async_db(db_path)
conn2, _ = _get_async_db(db_path)
conn1, _ = get_async_db(db_path)
conn2, _ = get_async_db(db_path)
assert conn1 is conn2
conn1.close()
_connections.pop(str(db_path), None)
connections.pop(str(db_path), None)
def test_close_all_connections(self, tmp_path) -> None:
"""_close_all_connections empties the cache."""
def testclose_all_connections(self, tmp_path) -> None:
"""close_all_connections empties the cache."""
db_path = tmp_path / "test.sqlite"
_get_async_db(db_path)
assert 0 < len(_connections)
_close_all_connections()
assert 0 == len(_connections)
get_async_db(db_path)
assert 0 < len(connections)
close_all_connections()
assert 0 == len(connections)
@pytest.mark.skipif(
@ -182,7 +189,7 @@ class TestBehaviorAsync:
_codeflash_call_site.set("0")
await multiply(5, 6)
_close_all_connections()
close_all_connections()
assert async_db_path.exists()
con = sqlite3.connect(async_db_path)
@ -214,7 +221,7 @@ class TestBehaviorAsync:
with pytest.raises(ValueError, match="boom"):
await fail()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM async_results")
@ -252,7 +259,7 @@ class TestBehaviorAsync:
_codeflash_call_site.set("0")
await greeter("world")
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -285,7 +292,7 @@ class TestBehaviorSync:
_codeflash_call_site.set("0")
multiply(5, 6)
_close_all_connections()
close_all_connections()
assert async_db_path.exists()
con = sqlite3.connect(async_db_path)
@ -316,7 +323,7 @@ class TestBehaviorSync:
with pytest.raises(ValueError, match="boom"):
fail()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT return_value FROM async_results")
@ -338,7 +345,7 @@ class TestBehaviorSync:
_codeflash_call_site.set("0")
greeter("world")
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -371,7 +378,7 @@ class TestBehaviorSync:
_codeflash_call_site.set("0")
work()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -384,6 +391,90 @@ class TestBehaviorSync:
con.close()
class TestPerformanceSync:
"""codeflash_performance_sync decorator."""
def test_returns_correct_value(self, env_setup, async_db_path) -> None:
"""Decorated function returns the original return value."""
@codeflash_performance_sync
def add(a: int, b: int) -> int:
return a + b
_codeflash_call_site.set("0")
result = add(3, 4)
assert 7 == result
def test_writes_to_sqlite(self, env_setup, async_db_path) -> None:
"""Writes performance result with mode='performance' and null return_value."""
@codeflash_performance_sync
def work() -> int:
return 42
_codeflash_call_site.set("0")
work()
close_all_connections()
assert async_db_path.exists()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT * FROM async_results")
rows = cur.fetchall()
assert 1 == len(rows)
row = rows[0]
assert "performance" == row[6]
assert 0 < row[7]
assert row[8] is None
assert row[9] is None
con.close()
def test_exception_handling(self, env_setup, async_db_path) -> None:
"""Re-raises exceptions from the wrapped function."""
@codeflash_performance_sync
def fail() -> None:
raise ValueError("perf boom")
_codeflash_call_site.set("0")
with pytest.raises(ValueError, match="perf boom"):
fail()
def test_no_stdout_capture(
self, env_setup, async_db_path, capsys
) -> None:
"""Performance decorator does not redirect stdout."""
@codeflash_performance_sync
def talker() -> int:
print("visible")
return 1
_codeflash_call_site.set("0")
talker()
captured = capsys.readouterr()
assert "visible" in captured.out
def test_records_wall_time(self, env_setup, async_db_path) -> None:
"""Records a positive wall_time_ns value."""
@codeflash_performance_sync
def work() -> int:
return sum(range(1000))
_codeflash_call_site.set("0")
work()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
cur.execute("SELECT wall_time_ns FROM async_results")
row = cur.fetchone()
assert row[0] is not None
assert 0 < row[0]
con.close()
@pytest.mark.skipif(
sys.platform == "win32",
reason="pending support for asyncio on windows",
@ -415,7 +506,7 @@ class TestPerformanceAsync:
_codeflash_call_site.set("0")
await work()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -478,7 +569,7 @@ class TestConcurrencyAsync:
return 42
await work()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -533,7 +624,7 @@ class TestBehaviorAsyncEdgeCases:
_codeflash_call_site.set("0")
await inc(1)
await inc(2)
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -579,7 +670,7 @@ class TestPerformanceAsyncEdgeCases:
_codeflash_call_site.set("0")
await work()
await work()
_close_all_connections()
close_all_connections()
con = sqlite3.connect(async_db_path)
cur = con.cursor()
@ -593,15 +684,15 @@ class TestPerformanceAsyncEdgeCases:
class TestCloseAllConnectionsErrorHandling:
"""_close_all_connections handles exceptions gracefully."""
"""close_all_connections handles exceptions gracefully."""
def test_handles_already_closed_connection(self, tmp_path) -> None:
"""Does not raise when a connection is already closed."""
db_path = tmp_path / "test.sqlite"
conn, _ = _get_async_db(db_path)
conn, _ = get_async_db(db_path)
conn.close()
_close_all_connections()
assert 0 == len(_connections)
close_all_connections()
assert 0 == len(connections)
class TestExtractTestContextEdgeCases:
@ -629,7 +720,7 @@ class TestSchemaValidation:
def test_table_columns(self, tmp_path) -> None:
"""async_results table has exactly 14 columns."""
db_path = tmp_path / "schema_test.sqlite"
conn, cur = _get_async_db(db_path)
conn, cur = get_async_db(db_path)
cur.execute("PRAGMA table_info(async_results)")
columns = cur.fetchall()
expected_names = [
@ -651,4 +742,185 @@ class TestSchemaValidation:
actual_names = [col[1] for col in columns]
assert expected_names == actual_names
conn.close()
_connections.pop(str(db_path), None)
connections.pop(str(db_path), None)
@pytest.fixture(name="reset_device_cache")
def _reset_device_cache():
"""Reset the module-level device sync cache before and after each test."""
_deco_mod.device_sync_cache = None
yield
_deco_mod.device_sync_cache = None
class TestDetectDeviceSync:
"""detect_device_sync probes real GPU frameworks via find_spec."""
def test_returns_four_element_tuple(self, reset_device_cache) -> None:
"""Returns (cuda_sync, mps_sync, tf_sync, jax_available)."""
result = detect_device_sync()
assert 4 == len(result)
def test_detects_real_torch(self, reset_device_cache) -> None:
"""Detects torch and returns a sync callable for the active device."""
import torch
cuda_sync, mps_sync, _, _ = detect_device_sync()
if torch.cuda.is_available() and torch.cuda.is_initialized():
assert cuda_sync is torch.cuda.synchronize
assert mps_sync is None
elif (
hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
assert cuda_sync is None
assert mps_sync is torch.mps.synchronize
else:
assert cuda_sync is None
assert mps_sync is None
def test_detects_real_jax(self, reset_device_cache) -> None:
"""Detects JAX and sets jax_available based on block_until_ready."""
import jax
_, _, _, jax_avail = detect_device_sync()
assert jax_avail is hasattr(jax, "block_until_ready")
def test_detects_real_tensorflow(self, reset_device_cache) -> None:
"""Detects TensorFlow sync_devices when available."""
import tensorflow as tf
_, _, tf_sync, _ = detect_device_sync()
if (
hasattr(tf.test, "experimental")
and hasattr(tf.test.experimental, "sync_devices")
):
assert tf_sync is tf.test.experimental.sync_devices
else:
assert tf_sync is None
class TestGetDeviceSync:
"""get_device_sync caches detect_device_sync results."""
def test_caches_result(self, reset_device_cache) -> None:
"""Returns the same tuple on repeated calls without re-detecting."""
first = get_device_sync()
second = get_device_sync()
assert first is second
def test_redetects_after_cache_clear(
self, reset_device_cache
) -> None:
"""Re-runs detection after the cache is cleared."""
first = get_device_sync()
_deco_mod.device_sync_cache = None
second = get_device_sync()
assert first == second
assert first is not second
class TestSyncDevicesBefore:
"""sync_devices_before exercises real framework sync paths."""
def test_runs_without_error(self, reset_device_cache) -> None:
"""Calling with real frameworks installed does not raise."""
sync_devices_before()
def test_cuda_takes_priority_over_mps(
self, reset_device_cache
) -> None:
"""CUDA sync is called instead of MPS when both are in the cache."""
calls = []
_deco_mod.device_sync_cache = (
lambda: calls.append("cuda"),
lambda: calls.append("mps"),
None,
False,
)
sync_devices_before()
assert ["cuda"] == calls
def test_mps_called_when_no_cuda(
self, reset_device_cache
) -> None:
"""MPS sync fires when CUDA is absent in the cache."""
calls = []
_deco_mod.device_sync_cache = (
None,
lambda: calls.append("mps"),
None,
False,
)
sync_devices_before()
assert ["mps"] == calls
def test_tf_called_independently(
self, reset_device_cache
) -> None:
"""TF sync fires independently of torch sync."""
calls = []
_deco_mod.device_sync_cache = (
None,
None,
lambda: calls.append("tf"),
False,
)
sync_devices_before()
assert ["tf"] == calls
class TestSyncDevicesAfter:
"""sync_devices_after exercises real framework sync paths."""
def test_runs_without_error(self, reset_device_cache) -> None:
"""Calling with real frameworks and a plain return value does not raise."""
sync_devices_after(42)
def test_jax_block_until_ready_on_real_array(
self, reset_device_cache
) -> None:
"""Calls block_until_ready on a real JAX array."""
import jax.numpy as jnp
arr = jnp.array([1, 2, 3])
sync_devices_after(arr)
def test_skips_jax_on_plain_value(
self, reset_device_cache
) -> None:
"""Does not fail when jax_available=True but return value is plain."""
_deco_mod.device_sync_cache = (None, None, None, True)
sync_devices_after(42)
def test_cuda_priority_in_after(
self, reset_device_cache
) -> None:
"""CUDA sync fires instead of MPS in the after path too."""
calls = []
_deco_mod.device_sync_cache = (
lambda: calls.append("cuda"),
lambda: calls.append("mps"),
None,
False,
)
sync_devices_after(42)
assert ["cuda"] == calls
def test_all_syncs_fire_together(
self, reset_device_cache
) -> None:
"""All applicable syncs fire: torch + JAX block_until_ready + TF."""
import jax.numpy as jnp
calls = []
_deco_mod.device_sync_cache = (
lambda: calls.append("cuda"),
None,
lambda: calls.append("tf"),
True,
)
arr = jnp.array([1.0])
sync_devices_after(arr)
assert ["cuda", "tf"] == calls
assert hasattr(arr, "block_until_ready")

View file

@ -13,13 +13,17 @@ from codeflash_python._model import (
)
from codeflash_python.analysis._formatter import sort_imports
from codeflash_python.test_discovery.models import CodePosition, TestType
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
inject_profiling_into_existing_test,
)
from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture,
)
from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test,
)
from codeflash_python.testing._parse_results import parse_test_results
from codeflash_python.testing._test_runner import run_behavioral_tests
from codeflash_python.testing.models import TestConfig, TestFile, TestFiles

View file

@ -25,7 +25,7 @@ from codeflash_python.context.models import CodeStringsMarkdown
from codeflash_python.pipeline._orchestrator import cleanup_paths
from codeflash_python.test_discovery.linking import module_name_from_file_path
from codeflash_python.testing._concolic import clean_concolic_tests
from codeflash_python.testing._instrumentation import get_run_tmp_file
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
from codeflash_python.testing._path_resolution import (
file_name_from_test_module_name,
file_path_from_module_name,

View file

@ -11,8 +11,8 @@ from codeflash_python.pipeline._function_optimizer import (
write_code_and_helpers,
)
from codeflash_python.test_discovery.models import TestType
from codeflash_python.testing._instrumentation import (
get_run_tmp_file,
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture,
)
from codeflash_python.testing._parse_results import parse_test_results

View file

@ -11,13 +11,17 @@ from codeflash_python._model import (
FunctionToOptimize,
TestingMode,
)
from codeflash_python.analysis._formatter import sort_imports
from codeflash_python.test_discovery.models import CodePosition, TestType
from codeflash_python.testing._instrumentation import (
get_run_tmp_file,
inject_profiling_into_existing_test,
from codeflash_python.testing._instrument_async import write_async_helper_file
from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture,
)
from codeflash_python.testing._instrument_sync import (
add_sync_decorator_to_function,
)
from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test,
)
from codeflash_python.testing._parse_results import parse_test_results
from codeflash_python.testing._test_runner import run_behavioral_tests
from codeflash_python.testing.models import TestConfig, TestFile, TestFiles
@ -25,41 +29,6 @@ from codeflash_python.verification._verification import compare_test_results
project_root = Path(__file__).parent.resolve()
# Used by cli instrumentation
codeflash_wrap_string = """def codeflash_wrap(codeflash_wrapped, codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_line_id, codeflash_loop_index, codeflash_cur, codeflash_con, *args, **kwargs):
test_id = f'{{codeflash_test_module_name}}:{{codeflash_test_class_name}}:{{codeflash_test_name}}:{{codeflash_line_id}}:{{codeflash_loop_index}}'
if not hasattr(codeflash_wrap, 'index'):
codeflash_wrap.index = {{}}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f'{{codeflash_line_id}}_{{codeflash_test_index}}'
test_stdout_tag = f"{{codeflash_test_module_name}}:{{(codeflash_test_class_name + '.' if codeflash_test_class_name else '')}}{{codeflash_test_name}}:{{codeflash_function_name}}:{{codeflash_loop_index}}:{{invocation_id}}"
print(f"!$######{{test_stdout_tag}}######$!")
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
cpu_counter = time.thread_time_ns()
return_value = codeflash_wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
codeflash_cpu_duration = time.thread_time_ns() - cpu_counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
codeflash_cpu_duration = time.thread_time_ns() - cpu_counter
exception = e
gc.enable()
print(f"!######{{test_stdout_tag}}######!")
pickled_return_value = pickle.dumps(exception) if exception else pickle.dumps(return_value)
codeflash_cur.execute('INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)', (codeflash_test_module_name, codeflash_test_class_name, codeflash_test_name, codeflash_function_name, codeflash_loop_index, invocation_id, codeflash_duration, pickled_return_value, 'function_call', codeflash_cpu_duration))
codeflash_con.commit()
if exception:
raise exception
return return_value
"""
def _run_and_parse(
test_files: TestFiles,
@ -95,41 +64,6 @@ def test_sort():
output = sorter(input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
expected = (
"""import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
from code_to_optimize.bubble_sort import sorter
"""
+ codeflash_wrap_string
+ """
def test_sort():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)')
input = [5, 4, 3, 2, 1, 0]
_call__bound__arguments = inspect.signature(sorter).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
_call__bound__arguments = inspect.signature(sorter).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(sorter, '{module_path}', None, 'test_sort', 'sorter', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
)
test_path = (
project_root
/ "code_to_optimize/tests/pytest/test_perfinjector_bubble_sort_results_temp.py"
@ -166,18 +100,20 @@ def test_sort():
os.chdir(original_cwd)
assert success
assert new_test is not None
assert new_test.replace('"', "'") == expected.format(
module_path="code_to_optimize.tests.pytest.test_perfinjector_bubble_sort_results_temp",
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
).replace('"', "'")
with test_path.open("w") as f:
f.write(new_test)
# add codeflash capture
instrument_codeflash_capture(func, {}, tests_root)
# Write the async helper file (contains sync decorators too)
write_async_helper_file(project_root_path)
# Add sync decorator to the source function
add_sync_decorator_to_function(
fto_path,
func,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
@ -203,13 +139,8 @@ def test_sort():
)
test_results = _run_and_parse(test_files, test_env, test_config)
out_str = """codeflash stdout: Sorting list
result: [0, 1, 2, 3, 4, 5]
"""
assert test_results[0].stdout == out_str
assert out_str == test_results[0].stdout
# New decorator captures stdout directly -- the function prints two lines
assert test_results[0].id.function_getting_tested == "sorter"
assert test_results[0].id.iteration_id == "1_0"
assert test_results[0].id.test_class_name is None
assert test_results[0].id.test_function_name == "test_sort"
assert (
@ -218,14 +149,12 @@ result: [0, 1, 2, 3, 4, 5]
)
assert test_results[0].runtime > 0
assert test_results[0].did_pass
assert test_results[0].return_value == ([0, 1, 2, 3, 4, 5],)
out_str = """codeflash stdout: Sorting list
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
"""
assert out_str == test_results[1].stdout
# return_value is ((args, kwargs, return_value),) in the new path
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
out_str = "codeflash stdout: Sorting list\nresult: [0, 1, 2, 3, 4, 5]\n"
assert test_results[0].stdout == out_str
assert test_results[1].id.function_getting_tested == "sorter"
assert test_results[1].id.iteration_id == "4_0"
assert test_results[1].id.test_class_name is None
assert test_results[1].id.test_function_name == "test_sort"
assert (
@ -234,21 +163,15 @@ result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
out_str = """codeflash stdout: Sorting list
result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
"""
assert test_results[1].stdout == out_str
results2 = _run_and_parse(test_files, test_env, test_config)
out_str = """codeflash stdout: Sorting list
result: [0, 1, 2, 3, 4, 5]
"""
assert out_str == results2[0].stdout
match, _ = compare_test_results(test_results, results2)
assert match
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)
def test_method_full_instrumentation() -> None:
@ -266,41 +189,6 @@ def test_sort():
output = sort_class.sorter(input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
expected = (
"""import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
from code_to_optimize.bubble_sort_method import BubbleSorter
"""
+ codeflash_wrap_string
+ """
def test_sort():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)')
input = [5, 4, 3, 2, 1, 0]
sort_class = BubbleSorter()
_call__bound__arguments = inspect.signature(sort_class.sorter).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '2', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
sort_class = BubbleSorter()
_call__bound__arguments = inspect.signature(sort_class.sorter).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(sort_class.sorter, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter', '6', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
)
fto_path = (
project_root / "code_to_optimize/bubble_sort_method.py"
).resolve()
@ -310,28 +198,6 @@ def test_sort():
parents=(FunctionParent(name="BubbleSorter", type="ClassDef"),),
file_path=Path(fto_path),
)
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = (
Path(tmpdirname) / "test_class_method_behavior_results_temp.py"
)
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
tmp_test_path,
[CodePosition(7, 13), CodePosition(12, 13)],
fto,
tmp_test_path.parent,
)
assert success
assert new_test.replace('"', "'") == sort_imports(
expected.format(
module_path=tmp_test_path.stem,
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
),
float_to_top=True,
).replace('"', "'")
tests_root = (project_root / "code_to_optimize/tests/pytest/").resolve()
test_path = tests_root / "test_class_method_behavior_results_temp.py"
test_path_perf = (
@ -340,17 +206,31 @@ def test_sort():
project_root_path = project_root
try:
new_test = expected.format(
module_path="code_to_optimize.tests.pytest.test_class_method_behavior_results_temp",
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
# Write and instrument the test file
test_path.write_text(code, encoding="utf-8")
original_cwd = Path.cwd()
os.chdir(project_root_path)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(7, 13), CodePosition(12, 13)],
fto,
project_root_path,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
test_path.write_text(new_test, encoding="utf-8")
# Write the async helper file and add sync decorator to source
write_async_helper_file(project_root_path)
add_sync_decorator_to_function(
fto_path,
fto,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
with test_path.open("w") as f:
f.write(new_test)
# Add codeflash capture
# Add codeflash capture for __init__ state
instrument_codeflash_capture(fto, {}, tests_root)
test_env = os.environ.copy()
@ -376,6 +256,7 @@ def test_sort():
)
test_results = _run_and_parse(test_files, test_env, test_config)
assert len(test_results) == 4
# Order: init results (from codeflash_capture) then sorter results (from sync decorator)
assert (
test_results[0].id.function_getting_tested
== "BubbleSorter.__init__"
@ -384,34 +265,33 @@ def test_sort():
assert test_results[0].did_pass
assert test_results[0].return_value[0] == {"x": 0}
assert (
test_results[1].id.function_getting_tested == "BubbleSorter.sorter"
)
assert test_results[1].id.iteration_id == "2_0"
assert test_results[1].id.test_class_name is None
assert test_results[1].id.test_function_name == "test_sort"
assert (
test_results[1].id.test_module_path
== "code_to_optimize.tests.pytest.test_class_method_behavior_results_temp"
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert test_results[1].return_value == ([0, 1, 2, 3, 4, 5],)
out_str = """codeflash stdout : BubbleSorter.sorter() called\n"""
assert test_results[1].stdout == out_str
match, _ = compare_test_results(test_results, test_results)
assert match
assert (
test_results[2].id.function_getting_tested
test_results[1].id.function_getting_tested
== "BubbleSorter.__init__"
)
assert test_results[2].id.test_function_name == "test_sort"
assert test_results[2].did_pass
assert test_results[2].return_value[0] == {"x": 0}
assert test_results[1].id.test_function_name == "test_sort"
assert test_results[1].did_pass
assert test_results[1].return_value[0] == {"x": 0}
assert (
test_results[3].id.function_getting_tested == "BubbleSorter.sorter"
test_results[2].id.function_getting_tested == "sorter"
)
assert test_results[2].id.test_class_name is None
assert test_results[2].id.test_function_name == "test_sort"
assert (
test_results[2].id.test_module_path
== "code_to_optimize.tests.pytest.test_class_method_behavior_results_temp"
)
assert test_results[2].runtime > 0
assert test_results[2].did_pass
# return_value is ((args, kwargs, return_value),) in the new path
assert test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
assert test_results[2].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
match, _ = compare_test_results(test_results, test_results)
assert match
assert (
test_results[3].id.function_getting_tested == "sorter"
)
assert test_results[3].id.iteration_id == "6_0"
assert test_results[3].id.test_class_name is None
assert test_results[3].id.test_function_name == "test_sort"
assert (
@ -420,10 +300,7 @@ def test_sort():
)
assert test_results[3].runtime > 0
assert test_results[3].did_pass
assert (
test_results[3].stdout
== """codeflash stdout : BubbleSorter.sorter() called\n"""
)
assert test_results[3].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
results2 = _run_and_parse(test_files, test_env, test_config)
@ -455,7 +332,13 @@ class BubbleSorter:
__import__(module_name)
importlib.reload(sys.modules[module_name])
# Add codeflash capture
# Re-add sync decorator and codeflash capture to the new source
add_sync_decorator_to_function(
fto_path,
fto,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
instrument_codeflash_capture(fto, {}, tests_root)
test_config = TestConfig(
tests_root=tests_root,
@ -476,6 +359,7 @@ class BubbleSorter:
)
new_test_results = _run_and_parse(test_files, test_env, test_config)
assert len(new_test_results) == 4
# Order: init results then sorter results
assert (
new_test_results[0].id.function_getting_tested
== "BubbleSorter.__init__"
@ -486,32 +370,28 @@ class BubbleSorter:
assert (
new_test_results[1].id.function_getting_tested
== "BubbleSorter.sorter"
)
assert new_test_results[1].id.iteration_id == "2_0"
assert new_test_results[1].id.test_class_name is None
assert new_test_results[1].id.test_function_name == "test_sort"
assert (
new_test_results[1].id.test_module_path
== "code_to_optimize.tests.pytest.test_class_method_behavior_results_temp"
)
assert new_test_results[1].runtime > 0
assert new_test_results[1].did_pass
assert new_test_results[1].return_value == ([0, 1, 2, 3, 4, 5],)
assert (
new_test_results[2].id.function_getting_tested
== "BubbleSorter.__init__"
)
assert new_test_results[2].id.test_function_name == "test_sort"
assert new_test_results[2].did_pass
assert new_test_results[2].return_value[0] == {"x": 1}
assert new_test_results[1].id.test_function_name == "test_sort"
assert new_test_results[1].did_pass
assert new_test_results[1].return_value[0] == {"x": 1}
assert (
new_test_results[3].id.function_getting_tested
== "BubbleSorter.sorter"
new_test_results[2].id.function_getting_tested == "sorter"
)
assert new_test_results[2].id.test_class_name is None
assert new_test_results[2].id.test_function_name == "test_sort"
assert (
new_test_results[2].id.test_module_path
== "code_to_optimize.tests.pytest.test_class_method_behavior_results_temp"
)
assert new_test_results[2].runtime > 0
assert new_test_results[2].did_pass
assert new_test_results[2].return_value[0][2] == [0, 1, 2, 3, 4, 5]
assert (
new_test_results[3].id.function_getting_tested == "sorter"
)
assert new_test_results[3].id.iteration_id == "6_0"
assert new_test_results[3].id.test_class_name is None
assert new_test_results[3].id.test_function_name == "test_sort"
assert (
@ -527,6 +407,7 @@ class BubbleSorter:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)
def test_classmethod_full_instrumentation() -> None:
@ -542,39 +423,6 @@ def test_sort():
output = BubbleSorter.sorter_classmethod(input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
expected = (
"""import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
from code_to_optimize.bubble_sort_method import BubbleSorter
"""
+ codeflash_wrap_string
+ """
def test_sort():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)')
input = [5, 4, 3, 2, 1, 0]
_call__bound__arguments = inspect.signature(BubbleSorter.sorter_classmethod).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(BubbleSorter.sorter_classmethod, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter_classmethod', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
_call__bound__arguments = inspect.signature(BubbleSorter.sorter_classmethod).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(BubbleSorter.sorter_classmethod, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter_classmethod', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
)
fto_path = (
project_root / "code_to_optimize/bubble_sort_method.py"
).resolve()
@ -584,28 +432,6 @@ def test_sort():
parents=(FunctionParent(name="BubbleSorter", type="ClassDef"),),
file_path=Path(fto_path),
)
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = (
Path(tmpdirname) / "test_classmethod_behavior_results_temp.py"
)
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
tmp_test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
fto,
tmp_test_path.parent,
)
assert success
assert new_test.replace('"', "'") == sort_imports(
expected.format(
module_path=tmp_test_path.stem,
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
),
float_to_top=True,
).replace('"', "'")
tests_root = (project_root / "code_to_optimize/tests/pytest/").resolve()
test_path = tests_root / "test_classmethod_behavior_results_temp.py"
test_path_perf = (
@ -614,15 +440,29 @@ def test_sort():
project_root_path = project_root
try:
new_test = expected.format(
module_path="code_to_optimize.tests.pytest.test_classmethod_behavior_results_temp",
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
# Write and instrument the test file
test_path.write_text(code, encoding="utf-8")
original_cwd = Path.cwd()
os.chdir(project_root_path)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
fto,
project_root_path,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
test_path.write_text(new_test, encoding="utf-8")
with test_path.open("w") as f:
f.write(new_test)
# Write the async helper file and add sync decorator to source
write_async_helper_file(project_root_path)
add_sync_decorator_to_function(
fto_path,
fto,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
# Add codeflash capture
instrument_codeflash_capture(fto, {}, tests_root)
@ -652,9 +492,8 @@ def test_sort():
assert len(test_results) == 2
assert (
test_results[0].id.function_getting_tested
== "BubbleSorter.sorter_classmethod"
== "sorter_classmethod"
)
assert test_results[0].id.iteration_id == "1_0"
assert test_results[0].id.test_class_name is None
assert test_results[0].id.test_function_name == "test_sort"
assert (
@ -663,18 +502,15 @@ def test_sort():
)
assert test_results[0].runtime > 0
assert test_results[0].did_pass
assert test_results[0].return_value == ([0, 1, 2, 3, 4, 5],)
out_str = """codeflash stdout : BubbleSorter.sorter_classmethod() called
"""
assert test_results[0].stdout == out_str
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
match, _ = compare_test_results(test_results, test_results)
assert match
assert (
test_results[1].id.function_getting_tested
== "BubbleSorter.sorter_classmethod"
== "sorter_classmethod"
)
assert test_results[1].id.iteration_id == "4_0"
assert test_results[1].id.test_class_name is None
assert test_results[1].id.test_function_name == "test_sort"
assert (
@ -683,11 +519,7 @@ def test_sort():
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert (
test_results[1].stdout
== """codeflash stdout : BubbleSorter.sorter_classmethod() called
"""
)
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_classmethod() called\n"
results2 = _run_and_parse(test_files, test_env, test_config)
@ -698,6 +530,7 @@ def test_sort():
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)
def test_staticmethod_full_instrumentation() -> None:
@ -713,39 +546,6 @@ def test_sort():
output = BubbleSorter.sorter_staticmethod(input)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]"""
expected = (
"""import gc
import inspect
import os
import sqlite3
import time
import dill as pickle
from code_to_optimize.bubble_sort_method import BubbleSorter
"""
+ codeflash_wrap_string
+ """
def test_sort():
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
codeflash_cur = codeflash_con.cursor()
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT, cpu_runtime INTEGER)')
input = [5, 4, 3, 2, 1, 0]
_call__bound__arguments = inspect.signature(BubbleSorter.sorter_staticmethod).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(BubbleSorter.sorter_staticmethod, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter_staticmethod', '1', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0, 1, 2, 3, 4, 5]
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
_call__bound__arguments = inspect.signature(BubbleSorter.sorter_staticmethod).bind(input)
_call__bound__arguments.apply_defaults()
output = codeflash_wrap(BubbleSorter.sorter_staticmethod, '{module_path}', None, 'test_sort', 'BubbleSorter.sorter_staticmethod', '4', codeflash_loop_index, codeflash_cur, codeflash_con, *_call__bound__arguments.args, **_call__bound__arguments.kwargs)
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
codeflash_con.close()
"""
)
fto_path = (
project_root / "code_to_optimize/bubble_sort_method.py"
).resolve()
@ -755,28 +555,6 @@ def test_sort():
parents=(FunctionParent(name="BubbleSorter", type="ClassDef"),),
file_path=Path(fto_path),
)
with tempfile.TemporaryDirectory() as tmpdirname:
tmp_test_path = (
Path(tmpdirname) / "test_staticmethod_behavior_results_temp.py"
)
tmp_test_path.write_text(code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
tmp_test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
fto,
tmp_test_path.parent,
)
assert success
assert new_test.replace('"', "'") == sort_imports(
expected.format(
module_path=tmp_test_path.stem,
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
),
float_to_top=True,
).replace('"', "'")
tests_root = (project_root / "code_to_optimize/tests/pytest/").resolve()
test_path = tests_root / "test_staticmethod_behavior_results_temp.py"
test_path_perf = (
@ -785,15 +563,29 @@ def test_sort():
project_root_path = project_root
try:
new_test = expected.format(
module_path="code_to_optimize.tests.pytest.test_staticmethod_behavior_results_temp",
tmp_dir_path=get_run_tmp_file(
Path("test_return_values")
).as_posix(),
# Write and instrument the test file
test_path.write_text(code, encoding="utf-8")
original_cwd = Path.cwd()
os.chdir(project_root_path)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13), CodePosition(10, 13)],
fto,
project_root_path,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
test_path.write_text(new_test, encoding="utf-8")
with test_path.open("w") as f:
f.write(new_test)
# Write the async helper file and add sync decorator to source
write_async_helper_file(project_root_path)
add_sync_decorator_to_function(
fto_path,
fto,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
# Add codeflash capture
instrument_codeflash_capture(fto, {}, tests_root)
@ -823,9 +615,8 @@ def test_sort():
assert len(test_results) == 2
assert (
test_results[0].id.function_getting_tested
== "BubbleSorter.sorter_staticmethod"
== "sorter_staticmethod"
)
assert test_results[0].id.iteration_id == "1_0"
assert test_results[0].id.test_class_name is None
assert test_results[0].id.test_function_name == "test_sort"
assert (
@ -834,18 +625,15 @@ def test_sort():
)
assert test_results[0].runtime > 0
assert test_results[0].did_pass
assert test_results[0].return_value == ([0, 1, 2, 3, 4, 5],)
out_str = """codeflash stdout : BubbleSorter.sorter_staticmethod() called
"""
assert test_results[0].stdout == out_str
assert test_results[0].return_value[0][2] == [0, 1, 2, 3, 4, 5]
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
match, _ = compare_test_results(test_results, test_results)
assert match
assert (
test_results[1].id.function_getting_tested
== "BubbleSorter.sorter_staticmethod"
== "sorter_staticmethod"
)
assert test_results[1].id.iteration_id == "4_0"
assert test_results[1].id.test_class_name is None
assert test_results[1].id.test_function_name == "test_sort"
assert (
@ -854,11 +642,7 @@ def test_sort():
)
assert test_results[1].runtime > 0
assert test_results[1].did_pass
assert (
test_results[1].stdout
== """codeflash stdout : BubbleSorter.sorter_staticmethod() called
"""
)
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter_staticmethod() called\n"
results2 = _run_and_parse(test_files, test_env, test_config)
@ -869,3 +653,4 @@ def test_sort():
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)

View file

@ -7,10 +7,12 @@ import pytest
from codeflash_python._model import FunctionParent, TestingMode
from codeflash_python.analysis._discovery import FunctionToOptimize
from codeflash_python.test_discovery.models import CodePosition
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
ASYNC_HELPER_FILENAME,
add_async_decorator_to_function,
get_decorator_name_for_mode,
)
from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test,
)
@ -83,7 +85,7 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_function_code.replace(
@ -125,7 +127,7 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = async_function_code.replace(
@ -168,7 +170,7 @@ async def async_function(x: int, y: int) -> int:
assert decorator_added
modified_code = test_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.CONCURRENCY)
code_with_decorator = async_function_code.replace(
@ -217,7 +219,7 @@ class Calculator:
assert decorator_added
modified_code = test_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = async_class_code.replace(
@ -337,7 +339,7 @@ async def test_async_function():
)
# First instrument the source module
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
add_async_decorator_to_function,
)
@ -349,7 +351,7 @@ async def test_async_function():
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
@ -415,7 +417,7 @@ async def test_async_function():
)
# First instrument the source module
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
add_async_decorator_to_function,
)
@ -427,7 +429,7 @@ async def test_async_function():
# Verify the file was modified with exact expected output
instrumented_source = source_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.PERFORMANCE)
code_with_decorator = source_module_code.replace(
@ -499,7 +501,7 @@ async def test_mixed_functions():
is_async=True,
)
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
add_async_decorator_to_function,
)
@ -511,7 +513,7 @@ async def test_mixed_functions():
# Verify the file was modified
instrumented_source = source_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = source_module_code.replace(
@ -570,7 +572,7 @@ class OuterClass:
assert decorator_added
modified_code = test_file.read_text()
from codeflash_python.testing._instrumentation import sort_imports
from codeflash_python.analysis._formatter import sort_imports
decorator_name = get_decorator_name_for_mode(TestingMode.BEHAVIOR)
code_with_decorator = nested_async_code.replace(
@ -708,7 +710,7 @@ async def test_multiple_calls():
)
# First instrument the source module with async decorators
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
add_async_decorator_to_function,
)

View file

@ -2,8 +2,8 @@ from pathlib import Path
from codeflash_python._model import FunctionParent
from codeflash_python.analysis._discovery import FunctionToOptimize
from codeflash_python.testing._instrumentation import (
get_run_tmp_file,
from codeflash_python.runtime._codeflash_wrap_decorator import get_run_tmp_file
from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture,
)

View file

@ -35,6 +35,12 @@ class TestGetSyncDecoratorNameForMode:
TestingMode.BEHAVIOR
)
def test_performance_mode(self) -> None:
"""Returns codeflash_performance_sync for PERFORMANCE."""
assert "codeflash_performance_sync" == get_sync_decorator_name_for_mode(
TestingMode.PERFORMANCE
)
class TestSyncDecoratorAdder:
"""SyncDecoratorAdder adds decorators to sync functions."""

File diff suppressed because it is too large Load diff

View file

@ -14,28 +14,32 @@ from codeflash_python._model import (
TestingMode,
VerificationType,
)
from codeflash_python.analysis._formatter import sort_imports
from codeflash_python.test_discovery.models import CodePosition
from codeflash_python.testing._instrumentation import (
from codeflash_python.testing._instrument_async import (
ASYNC_HELPER_FILENAME,
AsyncCallInstrumenter,
AsyncDecoratorAdder,
FunctionCallNodeArguments,
FunctionImportedAsVisitor,
InjectPerfOnly,
add_async_decorator_to_function,
create_device_sync_precompute_statements,
create_device_sync_statements,
create_instrumented_source_module_path,
create_wrapper_function,
detect_frameworks_from_code,
get_call_arguments,
get_decorator_name_for_mode,
inject_async_profiling_into_existing_test,
inject_profiling_into_existing_test,
write_async_helper_file,
)
from codeflash_python.testing._instrument_capture import (
create_instrumented_source_module_path,
)
from codeflash_python.testing._instrument_core import (
FunctionCallNodeArguments,
FunctionImportedAsVisitor,
create_device_sync_precompute_statements,
create_device_sync_statements,
detect_frameworks_from_code,
get_call_arguments,
is_argument_name,
node_in_call_position,
sort_imports,
write_async_helper_file,
)
from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test,
)
@ -350,113 +354,6 @@ class TestCreateDeviceSyncStatements:
assert all(isinstance(s, ast.stmt) for s in result)
class TestCreateWrapperFunction:
"""create_wrapper_function AST generation."""
def test_returns_function_def(self) -> None:
"""Returns an ast.FunctionDef node."""
result = create_wrapper_function(TestingMode.BEHAVIOR)
assert isinstance(result, ast.FunctionDef)
def test_function_name(self) -> None:
"""The generated function is named codeflash_wrap."""
result = create_wrapper_function(TestingMode.BEHAVIOR)
assert "codeflash_wrap" == result.name
def test_behavior_mode_params(self) -> None:
"""BEHAVIOR mode wrapper has expected parameters."""
result = create_wrapper_function(TestingMode.BEHAVIOR)
arg_names = [a.arg for a in result.args.args]
assert len(arg_names) > 0
def test_performance_mode_params(self) -> None:
"""PERFORMANCE mode wrapper has expected parameters."""
result = create_wrapper_function(TestingMode.PERFORMANCE)
arg_names = [a.arg for a in result.args.args]
assert len(arg_names) > 0
def test_body_is_nonempty(self) -> None:
"""The function body contains statements."""
result = create_wrapper_function(TestingMode.BEHAVIOR)
assert len(result.body) > 0
def test_with_frameworks(self) -> None:
"""Accepts used_frameworks parameter without error."""
result = create_wrapper_function(
TestingMode.PERFORMANCE,
used_frameworks={"torch": "torch"},
)
assert isinstance(result, ast.FunctionDef)
class TestInjectPerfOnly:
"""InjectPerfOnly AST transformer."""
def test_wraps_name_call(self) -> None:
"""Wraps a direct Name call with codeflash_wrap."""
code = textwrap.dedent("""\
def test_it():
result = target_func(1, 2)
""")
tree = ast.parse(code)
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
pos = CodePosition(
line_no=call_node.lineno,
col_no=call_node.col_offset,
)
func = make_function("target_func", "module.py")
transformer = InjectPerfOnly(
function=func,
module_path="module",
call_positions=[pos],
mode=TestingMode.BEHAVIOR,
)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "codeflash_wrap" in source
def test_wraps_attribute_call(self) -> None:
"""Wraps a module.func() attribute call with codeflash_wrap."""
code = textwrap.dedent("""\
def test_it():
result = module.target_func(1, 2)
""")
tree = ast.parse(code)
call_node = tree.body[0].body[0].value # type: ignore[attr-defined]
pos = CodePosition(
line_no=call_node.lineno,
col_no=call_node.col_offset,
)
func = make_function("target_func", "module.py")
transformer = InjectPerfOnly(
function=func,
module_path="module",
call_positions=[pos],
mode=TestingMode.BEHAVIOR,
)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "codeflash_wrap" in source
def test_no_wrap_without_matching_position(self) -> None:
"""Does not wrap calls that are not in call_positions."""
code = textwrap.dedent("""\
def test_it():
result = target_func(1, 2)
""")
tree = ast.parse(code)
func = make_function("target_func", "module.py")
transformer = InjectPerfOnly(
function=func,
module_path="module",
call_positions=[CodePosition(line_no=99, col_no=99)],
mode=TestingMode.BEHAVIOR,
)
new_tree = transformer.visit(tree)
source = ast.unparse(new_tree)
assert "codeflash_wrap" not in source
class TestAsyncCallInstrumenter:
"""AsyncCallInstrumenter AST transformer."""
@ -948,7 +845,7 @@ class TestInjectProfilingIntoExistingTest:
"""inject_profiling_into_existing_test orchestration."""
def test_sync_function_instrumentation(self, tmp_path: Path) -> None:
"""Instruments a sync test file with codeflash_wrap and imports."""
"""Instruments a sync test file with call-site tracking."""
project_root = tmp_path / "project"
project_root.mkdir()
test_file = project_root / "test_example.py"
@ -974,10 +871,11 @@ class TestInjectProfilingIntoExistingTest:
)
assert ok is True
assert source is not None
assert "codeflash_wrap" in source
assert "import time" in source
assert "import gc" in source
assert "import os" in source
assert "_codeflash_call_site.set(" in source
assert (
"from codeflash_async_wrapper import _codeflash_call_site"
in source
)
def test_async_delegation(self, tmp_path: Path) -> None:
"""Delegates to async handler for async functions without error."""
@ -1029,7 +927,7 @@ class TestInjectProfilingIntoExistingTest:
assert source is None
def test_behavior_mode_extra_imports(self, tmp_path: Path) -> None:
"""BEHAVIOR mode adds inspect, sqlite3, and dill imports."""
"""BEHAVIOR mode adds call-site tracking import."""
project_root = tmp_path / "project"
project_root.mkdir()
test_file = project_root / "test_behav.py"
@ -1054,9 +952,11 @@ class TestInjectProfilingIntoExistingTest:
)
assert ok is True
assert source is not None
assert "inspect" in source
assert "sqlite3" in source
assert "dill" in source
assert "_codeflash_call_site.set(" in source
assert (
"from codeflash_async_wrapper import _codeflash_call_site"
in source
)
class TestInjectAsyncProfilingIntoExistingTest:

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import importlib
import os
import sys
from pathlib import Path
@ -7,156 +8,89 @@ from pathlib import Path
from codeflash_python._model import (
FunctionParent,
FunctionToOptimize,
TestingMode,
VerificationType,
)
from codeflash_python.test_discovery.models import TestType
from codeflash_python.testing._instrumentation import (
get_run_tmp_file,
from codeflash_python.test_discovery.models import CodePosition, TestType
from codeflash_python.testing._instrument_async import write_async_helper_file
from codeflash_python.testing._instrument_capture import (
instrument_codeflash_capture,
sort_imports,
)
from codeflash_python.testing._instrument_sync import (
add_sync_decorator_to_function,
)
from codeflash_python.testing._instrumentation import (
inject_profiling_into_existing_test,
)
from codeflash_python.testing._parse_results import parse_test_results
from codeflash_python.testing._test_runner import run_behavioral_tests
from codeflash_python.testing.models import TestConfig, TestFile, TestFiles
from codeflash_python.verification._verification import compare_test_results
# Used by aiservice instrumentation
behavior_logging_code = """
from __future__ import annotations
import gc
import inspect
import os
import time
import dill as pickle
from pathlib import Path
from typing import Any, Callable, Optional
def codeflash_wrap(
wrapped: Callable[..., Any],
test_module_name: str,
test_class_name: str | None,
test_name: str,
function_name: str,
line_id: str,
loop_index: int,
*args: Any,
**kwargs: Any,
) -> Any:
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
if not hasattr(codeflash_wrap, "index"):
codeflash_wrap.index = {}
if test_id in codeflash_wrap.index:
codeflash_wrap.index[test_id] += 1
else:
codeflash_wrap.index[test_id] = 0
codeflash_test_index = codeflash_wrap.index[test_id]
invocation_id = f"{line_id}_{codeflash_test_index}"
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
print(
f"!$######{test_stdout_tag}######$!"
)
exception = None
gc.disable()
try:
counter = time.perf_counter_ns()
return_value = wrapped(*args, **kwargs)
codeflash_duration = time.perf_counter_ns() - counter
except Exception as e:
codeflash_duration = time.perf_counter_ns() - counter
exception = e
gc.enable()
print(f"!######{test_stdout_tag}######!")
iteration = os.environ["CODEFLASH_TEST_ITERATION"]
with Path(
"{codeflash_run_tmp_dir_client_side}", f"test_return_values_{iteration}.bin"
).open("ab") as f:
pickled_values = (
pickle.dumps((args, kwargs, exception))
if exception
else pickle.dumps((args, kwargs, return_value))
)
_test_name = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{line_id}".encode(
"ascii"
)
f.write(len(_test_name).to_bytes(4, byteorder="big"))
f.write(_test_name)
f.write(codeflash_duration.to_bytes(8, byteorder="big"))
f.write(len(pickled_values).to_bytes(4, byteorder="big"))
f.write(pickled_values)
f.write(loop_index.to_bytes(8, byteorder="big"))
f.write(len(invocation_id).to_bytes(4, byteorder="big"))
f.write(invocation_id.encode("ascii"))
if exception:
raise exception
return return_value
"""
project_root = Path(__file__).parent.resolve()
def test_class_method_test_instrumentation_only() -> None:
"""Verifies instrumented test execution and result parsing without codeflash capture."""
instrumented_behavior_test_source = (
behavior_logging_code
+ """
import pytest
from code_to_optimize.bubble_sort_method import BubbleSorter
raw_test_code = """from code_to_optimize.bubble_sort_method import BubbleSorter
def test_single_element_list():
codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
obj = BubbleSorter()
_call__bound__arguments = inspect.signature(obj.sorter).bind([42])
_call__bound__arguments.apply_defaults()
codeflash_return_value = codeflash_wrap(
obj.sorter,
"code_to_optimize.tests.pytest.test_aiservice_behavior_results_temp",
None,
"test_single_element_list",
"sorter",
"1",
codeflash_loop_index,
**_call__bound__arguments.arguments,
)
"""
)
instrumented_behavior_test_source = sort_imports(
instrumented_behavior_test_source, float_to_top=True
)
result = obj.sorter([42])
"""
# Init paths
test_path = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_temp.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
).resolve()
tests_root = (
Path(__file__).parent.resolve() / "code_to_optimize/tests/pytest/"
project_root / "code_to_optimize/tests/pytest/"
)
project_root_path = Path(__file__).parent.resolve()
run_cwd = Path(__file__).parent.resolve()
project_root_path = project_root
run_cwd = project_root
old_cwd = os.getcwd()
os.chdir(run_cwd)
fto_path = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/bubble_sort_method.py"
).resolve()
original_code = fto_path.read_text("utf-8")
function_to_optimize = FunctionToOptimize(
"sorter",
fto_path,
parents=(FunctionParent("BubbleSorter", "ClassDef"),),
)
try:
temp_run_dir = get_run_tmp_file(Path()).as_posix()
instrumented_behavior_test_source = (
instrumented_behavior_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
# Write raw test, instrument it, then add decorator to source
test_path.write_text(raw_test_code, encoding="utf-8")
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13)],
function_to_optimize,
project_root_path,
mode=TestingMode.BEHAVIOR,
)
assert success
assert new_test is not None
test_path.write_text(new_test, encoding="utf-8")
# Write the async helper file and add sync decorator to source
write_async_helper_file(project_root_path)
add_sync_decorator_to_function(
fto_path,
function_to_optimize,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
with test_path.open("w") as f:
f.write(instrumented_behavior_test_source)
test_config = TestConfig(
tests_root=tests_root,
@ -179,11 +113,6 @@ def test_single_element_list():
)
]
)
function_to_optimize = FunctionToOptimize(
"sorter",
fto_path,
parents=(FunctionParent("BubbleSorter", "ClassDef"),),
)
xml_path, run_result, _, _ = run_behavioral_tests(
test_files=test_files,
test_env=test_env,
@ -198,17 +127,13 @@ def test_single_element_list():
run_result=run_result,
)
assert test_results[0].id.function_getting_tested == "sorter"
assert (
test_results[0].stdout
== "codeflash stdout : BubbleSorter.sorter() called\n"
)
assert test_results[0].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
assert (
test_results[0].id.test_function_name == "test_single_element_list"
)
assert test_results[0].did_pass
assert test_results[0].return_value[1]["arr"] == [42]
# assert comparator(test_results[0].return_value[1]["self"], BubbleSorter()) TODO: add self as input to the function
assert test_results[0].return_value[2] == [42]
# return_value is ((args, kwargs, return_value),) in the new path
assert test_results[0].return_value[0][2] == [42]
# Replace with optimized code that mutated instance attribute
optimized_code_mutated_attr = """
@ -221,7 +146,7 @@ class BubbleSorter:
self.x = x
def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
print("BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
@ -232,6 +157,15 @@ class BubbleSorter:
return arr
"""
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
# Re-add sync decorator to the new source
add_sync_decorator_to_function(
fto_path,
function_to_optimize,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
xml_path, run_result, _, _ = run_behavioral_tests(
test_files=test_files,
test_env=test_env,
@ -245,69 +179,50 @@ class BubbleSorter:
optimization_iteration=0,
run_result=run_result,
)
# assert test_results_mutated_attr[0].return_value[1]["self"].x == 1 TODO: add self as input to function
# In the new decorator-based path, args (including self) are captured,
# so init state changes ARE detected even without explicit codeflash_capture
match, _ = compare_test_results(
test_results, test_results_mutated_attr
) # Without codeflash capture, the init state was not verified, and the results are verified as correct even with the attribute mutated
assert match
)
assert not match
assert (
test_results_mutated_attr[0].stdout
== "codeflash stdout : BubbleSorter.sorter() called\n"
== "BubbleSorter.sorter() called\n"
)
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)
os.chdir(old_cwd)
def test_class_method_full_instrumentation() -> None:
"""Verifies full instrumentation with codeflash capture for instance state verification."""
instrumented_behavior_test_source = (
behavior_logging_code
+ """
import pytest
from code_to_optimize.bubble_sort_method import BubbleSorter
raw_test_code = """from code_to_optimize.bubble_sort_method import BubbleSorter
def test_single_element_list():
codeflash_loop_index = int(os.environ["CODEFLASH_LOOP_INDEX"])
obj = BubbleSorter()
_call__bound__arguments = inspect.signature(obj.sorter).bind([3,2,1])
_call__bound__arguments.apply_defaults()
codeflash_return_value = codeflash_wrap(
obj.sorter,
"code_to_optimize.tests.pytest.test_aiservice_behavior_results_temp",
None,
"test_single_element_list",
"sorter",
"1",
codeflash_loop_index,
**_call__bound__arguments.arguments,
)
"""
)
instrumented_behavior_test_source = sort_imports(
instrumented_behavior_test_source, float_to_top=True
)
result = obj.sorter([3, 2, 1])
"""
# Init paths
test_path = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_temp.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/tests/pytest/test_aiservice_behavior_results_perf_temp.py"
).resolve()
tests_root = (
Path(__file__).parent.resolve() / "code_to_optimize/tests/pytest/"
project_root / "code_to_optimize/tests/pytest/"
)
project_root_path = Path(__file__).parent.resolve()
project_root_path = project_root
fto_path = (
Path(__file__).parent.resolve()
project_root
/ "code_to_optimize/bubble_sort_method.py"
).resolve()
original_code = fto_path.read_text("utf-8")
@ -318,16 +233,35 @@ def test_single_element_list():
)
try:
temp_run_dir = get_run_tmp_file(Path()).as_posix()
instrumented_behavior_test_source = (
instrumented_behavior_test_source.replace(
"{codeflash_run_tmp_dir_client_side}", temp_run_dir
# Write raw test, instrument it, then add decorator to source
test_path.write_text(raw_test_code, encoding="utf-8")
original_cwd = Path.cwd()
os.chdir(project_root_path)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(6, 13)],
function_to_optimize,
project_root_path,
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
test_path.write_text(new_test, encoding="utf-8")
# Write the async helper file and add sync decorator to source
write_async_helper_file(project_root_path)
add_sync_decorator_to_function(
fto_path,
function_to_optimize,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
with test_path.open("w") as f:
f.write(instrumented_behavior_test_source)
# Add codeflash capture decorator
# Add codeflash capture decorator for __init__ state tracking
instrument_codeflash_capture(function_to_optimize, {}, tests_root)
test_config = TestConfig(
tests_root=tests_root,
tests_project_rootdir=project_root_path,
@ -362,9 +296,7 @@ def test_single_element_list():
optimization_iteration=0,
run_result=run_result,
)
# Verify instance_state result, which checks instance state right after __init__, using codeflash_capture
# Verify function_to_optimize result
# Verify instance_state result (from codeflash_capture)
assert (
test_results[0].id.function_getting_tested
== "BubbleSorter.__init__"
@ -375,23 +307,16 @@ def test_single_element_list():
assert test_results[0].did_pass
assert test_results[0].return_value[0] == {"x": 0}
assert test_results[0].stdout == ""
# Verify function_to_optimize result (from sync decorator)
assert test_results[1].id.function_getting_tested == "sorter"
assert (
test_results[1].id.test_function_name == "test_single_element_list"
)
assert test_results[1].did_pass
# Checks input values to the function to see if they have mutated
# assert comparator(test_results[1].return_value[1]["self"], BubbleSorter()) TODO: add self as input
assert test_results[1].return_value[1]["arr"] == [1, 2, 3]
# Check function return value
assert test_results[1].return_value[2] == [1, 2, 3]
assert (
test_results[1].stdout
== """codeflash stdout : BubbleSorter.sorter() called
"""
)
# return_value is ((args, kwargs, return_value),) in the new path
assert test_results[1].return_value[0][2] == [1, 2, 3]
assert test_results[1].stdout == "codeflash stdout : BubbleSorter.sorter() called\n"
# Replace with optimized code that mutated instance attribute
optimized_code_mutated_attr = """
@ -404,7 +329,7 @@ class BubbleSorter:
self.x = x
def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
print("BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
@ -416,14 +341,18 @@ class BubbleSorter:
"""
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
# Force reload of module
import importlib
module_name = "code_to_optimize.bubble_sort_method"
if module_name not in sys.modules:
__import__(module_name)
importlib.reload(sys.modules[module_name])
# Add codeflash capture
# Re-add sync decorator and codeflash capture to the new source
add_sync_decorator_to_function(
fto_path,
function_to_optimize,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
instrument_codeflash_capture(function_to_optimize, {}, tests_root)
xml_path, run_result, _, _ = run_behavioral_tests(
test_files=test_files,
@ -438,7 +367,6 @@ class BubbleSorter:
optimization_iteration=0,
run_result=run_result,
)
# assert test_results_mutated_attr[0].return_value[0]["self"].x == 1 TODO: add self as input
assert (
test_results_mutated_attr[0].id.function_getting_tested
== "BubbleSorter.__init__"
@ -449,11 +377,14 @@ class BubbleSorter:
== VerificationType.INIT_STATE_FTO
)
assert test_results_mutated_attr[0].stdout == ""
# The test should fail because the instance attribute was mutated
match, _ = compare_test_results(
test_results, test_results_mutated_attr
) # The test should fail because the instance attribute was mutated
)
assert not match
# Replace with optimized code that did not mutate existing instance attribute, but added a new one
# Replace with optimized code that did not mutate existing
# instance attribute, but added a new one
optimized_code_new_attr = """
import sys
@ -464,7 +395,7 @@ class BubbleSorter:
self.y = 2
def sorter(self, arr):
print("codeflash stdout : BubbleSorter.sorter() called")
print("BubbleSorter.sorter() called")
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
@ -476,6 +407,14 @@ class BubbleSorter:
"""
fto_path.write_text(optimized_code_new_attr, "utf-8")
importlib.reload(sys.modules[module_name])
# Re-add sync decorator and codeflash capture
add_sync_decorator_to_function(
fto_path,
function_to_optimize,
mode=TestingMode.BEHAVIOR,
project_root=project_root_path,
)
instrument_codeflash_capture(function_to_optimize, {}, tests_root)
xml_path, run_result, _, _ = run_behavioral_tests(
test_files=test_files,
@ -500,13 +439,15 @@ class BubbleSorter:
== VerificationType.INIT_STATE_FTO
)
assert test_results_new_attr[0].stdout == ""
# assert test_results_new_attr[1].return_value[1]["self"].x == 0 TODO: add self as input
# assert test_results_new_attr[1].return_value[1]["self"].y == 2 TODO: add self as input
# In the new decorator-based path, args (including self) are captured.
# Adding a new instance attribute changes self, so the comparison
# detects a difference even though codeflash_capture considers it additive.
match, _ = compare_test_results(
test_results, test_results_new_attr
) # The test should pass because the instance attribute was not mutated, only a new one was added
assert match
)
assert not match
finally:
fto_path.write_text(original_code, "utf-8")
test_path.unlink(missing_ok=True)
test_path_perf.unlink(missing_ok=True)
(project_root / "codeflash_async_wrapper.py").unlink(missing_ok=True)