codeflash-agent/packages/codeflash-python/src/codeflash_python/testing/_instrument_sync.py
Kevin Turcios 919a673be2
Fix pre-existing CI lint and test failures (#40)
* chore: add gitignore entries for local eval repos, e2e fixtures, and env files

* fix: restore clean bubble_sort_method.py test fixture

The call-site ID commit re-contaminated this file with instrumentation
decorators, causing tests to fail with missing CODEFLASH_LOOP_INDEX.

* fix: resolve ruff and mypy errors in codeflash-python

- Add import-not-found ignores for optional torch/jax imports
- Extract magic column index to _STDOUT_COLUMN_INDEX constant
- Fix unused variable in _instrument_sync.py
- Cast cpu_time_ns to int for mypy arg-type

* fix: add skip markers for optional deps and apply ruff formatting to tests

Skip torch/jax/tensorflow tests when those packages are not installed.
Move has_module helper to conftest.py for reuse across test files.
Apply ruff format to all test files that drifted.

* fix: resolve remaining ruff format and mypy errors

- Add missing blank line in conftest.py (ruff format)
- Remove unused import-untyped ignore on jax import (mypy unused-ignore)
- Add type: ignore comments for object-typed SQLite row values

* chore: bump codeflash-python to 0.1.1.dev0
2026-04-28 18:39:46 -05:00

340 lines
11 KiB
Python

"""Sync-specific instrumentation: AST transformers, decorators, and helpers.
Provides ``SyncCallInstrumenter`` for injecting ``_codeflash_call_site``
contextvar assignments before sync function calls, ``SyncDecoratorAdder``
for adding sync behavior decorators via libcst, and high-level functions
for instrumenting sync test and source files.
"""
from __future__ import annotations
import ast
import logging
from typing import TYPE_CHECKING
import libcst as cst
from .._model import (
FunctionToOptimize,
TestingMode,
)
from ..analysis._formatter import sort_imports
from ._instrument_async import ASYNC_HELPER_FILENAME, write_async_helper_file
from ._instrument_core import (
FunctionImportedAsVisitor,
node_in_call_position,
)
if TYPE_CHECKING:
from pathlib import Path
from ..test_discovery.models import CodePosition
log = logging.getLogger(__name__)
_CODEFLASH_SYNC_DECORATORS = frozenset(
{
"codeflash_behavior_sync",
"codeflash_performance_sync",
}
)
class SyncCallInstrumenter(ast.NodeTransformer):
"""AST transformer that injects call-site tracking before sync calls."""
def __init__(
self,
function: FunctionToOptimize,
call_positions: list[CodePosition],
) -> None:
"""Initialize with the target sync function and call positions."""
self.function_object = function
self.call_positions = call_positions
self.did_instrument = False
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Recurse into class bodies to find test methods."""
return self.generic_visit(node) # type: ignore[return-value]
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Instrument sync test functions that call the target function."""
if not node.name.startswith("test_"):
return node
return self._process_test_function(node) # type: ignore[return-value]
def visit_AsyncFunctionDef(
self, node: ast.AsyncFunctionDef
) -> ast.AsyncFunctionDef:
"""Instrument async test functions that call the target sync function."""
if not node.name.startswith("test_"):
return node
return self._process_test_function(node) # type: ignore[return-value]
def _process_test_function(
self,
node: ast.FunctionDef | ast.AsyncFunctionDef,
) -> ast.FunctionDef | ast.AsyncFunctionDef:
"""Add _codeflash_call_site.set() calls before target function calls."""
new_body: list[ast.stmt] = []
for stmt in node.body:
_, has_target = self._find_target_call(stmt)
if has_target:
call_site_set = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="_codeflash_call_site",
ctx=ast.Load(),
),
attr="set",
ctx=ast.Load(),
),
args=[
ast.Constant(
value=f"{stmt.lineno}",
),
],
keywords=[],
),
lineno=stmt.lineno,
)
new_body.append(call_site_set)
self.did_instrument = True
new_body.append(stmt)
node.body = new_body
return node
def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target sync function."""
if isinstance(call_node.func, ast.Name):
return call_node.func.id == self.function_object.function_name
if isinstance(call_node.func, ast.Attribute):
return call_node.func.attr == self.function_object.function_name
return False
def _find_target_call(
self,
stmt: ast.stmt,
) -> tuple[ast.stmt, bool]:
"""Search a statement for direct calls to the target function."""
stack: list[ast.AST] = [stmt]
while stack:
node = stack.pop()
if (
isinstance(node, ast.Call)
and self._is_target_call(node)
and node_in_call_position(node, self.call_positions)
):
return stmt, True
for fname in node._fields:
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, False
class SyncDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds a sync decorator to sync function definitions."""
def __init__(
self,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
"""Initialize the transformer."""
super().__init__()
self.qualified_name_parts = function.qualified_name.split(".")
self.context_stack: list[str] = []
self.added_decorator = False
self.decorator_name = get_sync_decorator_name_for_mode(mode)
def visit_ClassDef( # noqa: N802
self,
node: cst.ClassDef,
) -> None:
"""Push class name onto the context stack."""
self.context_stack.append(node.name.value)
def leave_ClassDef( # noqa: N802
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
"""Pop class name from the context stack."""
self.context_stack.pop()
return updated_node
def visit_FunctionDef( # noqa: N802
self,
node: cst.FunctionDef,
) -> None:
"""Push function name onto the context stack."""
self.context_stack.append(node.name.value)
def leave_FunctionDef( # noqa: N802
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
"""Add the sync decorator if the function matches the target."""
if (
original_node.asynchronous is None
and self.context_stack == self.qualified_name_parts
and not any(
self._is_codeflash_decorator(d.decorator)
for d in original_node.decorators
)
):
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name),
)
if self._has_descriptor_decorator(original_node):
updated_node = updated_node.with_changes(
decorators=(*updated_node.decorators, new_decorator),
)
else:
updated_node = updated_node.with_changes(
decorators=(new_decorator, *updated_node.decorators),
)
self.added_decorator = True
self.context_stack.pop()
return updated_node
@staticmethod
def _is_codeflash_decorator(
decorator_node: cst.BaseExpression,
) -> bool:
"""Check if a decorator is one of the codeflash sync decorators."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value in _CODEFLASH_SYNC_DECORATORS
if isinstance(decorator_node, cst.Call) and isinstance(
decorator_node.func,
cst.Name,
):
return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS
return False
@staticmethod
def _has_descriptor_decorator(
node: cst.FunctionDef,
) -> bool:
"""Check if the function has @classmethod or @staticmethod."""
for d in node.decorators:
if isinstance(d.decorator, cst.Name) and d.decorator.value in (
"classmethod",
"staticmethod",
):
return True
return False
def get_sync_decorator_name_for_mode(
mode: TestingMode,
) -> str:
"""Return the sync decorator function name for the given testing mode."""
if mode == TestingMode.PERFORMANCE:
return "codeflash_performance_sync"
return "codeflash_behavior_sync"
def inject_sync_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
) -> tuple[bool, str | None]:
"""Inject call-site tracking for sync function calls in a test file."""
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
except SyntaxError:
log.exception(
"Syntax error in code in file - %s",
test_path,
)
return False, None
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
func = import_visitor.imported_as
sync_instrumenter = SyncCallInstrumenter(func, call_positions)
tree = sync_instrumenter.visit(tree)
if not sync_instrumenter.did_instrument:
return False, None
new_imports = [
ast.ImportFrom(
module=ASYNC_HELPER_FILENAME.removesuffix(".py"),
names=[ast.alias(name="_codeflash_call_site")],
level=0,
),
]
tree.body = [*new_imports, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)
def add_sync_decorator_to_function(
source_path: Path,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
project_root: Path | None = None,
) -> tuple[bool, dict[Path, str]]:
"""Add a sync instrumentation decorator to *function*.
Writes the helper file and adds the appropriate import and decorator.
Returns ``(True, originals)`` if the decorator was added, where
*originals* maps each modified file to its content before modification.
"""
if function.is_async:
return False, {}
try:
with source_path.open(encoding="utf8") as f:
source_code = f.read()
module = cst.parse_module(source_code)
decorator_transformer = SyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
if decorator_transformer.added_decorator:
helper_dir = (
project_root
if project_root is not None
else source_path.parent
)
write_async_helper_file(helper_dir)
decorator_name = get_sync_decorator_name_for_mode(mode)
import_stmt = ASYNC_HELPER_FILENAME.removesuffix(".py")
import_node = cst.parse_statement(
f"from {import_stmt} import {decorator_name}"
)
module = module.with_changes(
body=[import_node, *list(module.body)]
)
modified_code = sort_imports(code=module.code, float_to_top=True)
except Exception:
log.exception(
"Error adding sync decorator to function %s",
function.qualified_name,
)
return False, {}
else:
if decorator_transformer.added_decorator:
originals: dict[Path, str] = {source_path: source_code}
with source_path.open("w", encoding="utf8") as f:
f.write(modified_code)
return True, originals
return False, {}