mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
* chore: add gitignore entries for local eval repos, e2e fixtures, and env files * fix: restore clean bubble_sort_method.py test fixture The call-site ID commit re-contaminated this file with instrumentation decorators, causing tests to fail with missing CODEFLASH_LOOP_INDEX. * fix: resolve ruff and mypy errors in codeflash-python - Add import-not-found ignores for optional torch/jax imports - Extract magic column index to _STDOUT_COLUMN_INDEX constant - Fix unused variable in _instrument_sync.py - Cast cpu_time_ns to int for mypy arg-type * fix: add skip markers for optional deps and apply ruff formatting to tests Skip torch/jax/tensorflow tests when those packages are not installed. Move has_module helper to conftest.py for reuse across test files. Apply ruff format to all test files that drifted. * fix: resolve remaining ruff format and mypy errors - Add missing blank line in conftest.py (ruff format) - Remove unused import-untyped ignore on jax import (mypy unused-ignore) - Add type: ignore comments for object-typed SQLite row values * chore: bump codeflash-python to 0.1.1.dev0
340 lines
11 KiB
Python
340 lines
11 KiB
Python
"""Sync-specific instrumentation: AST transformers, decorators, and helpers.
|
|
|
|
Provides ``SyncCallInstrumenter`` for injecting ``_codeflash_call_site``
|
|
contextvar assignments before sync function calls, ``SyncDecoratorAdder``
|
|
for adding sync behavior decorators via libcst, and high-level functions
|
|
for instrumenting sync test and source files.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import logging
|
|
from typing import TYPE_CHECKING
|
|
|
|
import libcst as cst
|
|
|
|
from .._model import (
|
|
FunctionToOptimize,
|
|
TestingMode,
|
|
)
|
|
from ..analysis._formatter import sort_imports
|
|
from ._instrument_async import ASYNC_HELPER_FILENAME, write_async_helper_file
|
|
from ._instrument_core import (
|
|
FunctionImportedAsVisitor,
|
|
node_in_call_position,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from pathlib import Path
|
|
|
|
from ..test_discovery.models import CodePosition
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
_CODEFLASH_SYNC_DECORATORS = frozenset(
|
|
{
|
|
"codeflash_behavior_sync",
|
|
"codeflash_performance_sync",
|
|
}
|
|
)
|
|
|
|
|
|
class SyncCallInstrumenter(ast.NodeTransformer):
|
|
"""AST transformer that injects call-site tracking before sync calls."""
|
|
|
|
def __init__(
|
|
self,
|
|
function: FunctionToOptimize,
|
|
call_positions: list[CodePosition],
|
|
) -> None:
|
|
"""Initialize with the target sync function and call positions."""
|
|
self.function_object = function
|
|
self.call_positions = call_positions
|
|
self.did_instrument = False
|
|
|
|
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
|
"""Recurse into class bodies to find test methods."""
|
|
return self.generic_visit(node) # type: ignore[return-value]
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
|
"""Instrument sync test functions that call the target function."""
|
|
if not node.name.startswith("test_"):
|
|
return node
|
|
return self._process_test_function(node) # type: ignore[return-value]
|
|
|
|
def visit_AsyncFunctionDef(
|
|
self, node: ast.AsyncFunctionDef
|
|
) -> ast.AsyncFunctionDef:
|
|
"""Instrument async test functions that call the target sync function."""
|
|
if not node.name.startswith("test_"):
|
|
return node
|
|
return self._process_test_function(node) # type: ignore[return-value]
|
|
|
|
def _process_test_function(
|
|
self,
|
|
node: ast.FunctionDef | ast.AsyncFunctionDef,
|
|
) -> ast.FunctionDef | ast.AsyncFunctionDef:
|
|
"""Add _codeflash_call_site.set() calls before target function calls."""
|
|
new_body: list[ast.stmt] = []
|
|
|
|
for stmt in node.body:
|
|
_, has_target = self._find_target_call(stmt)
|
|
|
|
if has_target:
|
|
call_site_set = ast.Expr(
|
|
value=ast.Call(
|
|
func=ast.Attribute(
|
|
value=ast.Name(
|
|
id="_codeflash_call_site",
|
|
ctx=ast.Load(),
|
|
),
|
|
attr="set",
|
|
ctx=ast.Load(),
|
|
),
|
|
args=[
|
|
ast.Constant(
|
|
value=f"{stmt.lineno}",
|
|
),
|
|
],
|
|
keywords=[],
|
|
),
|
|
lineno=stmt.lineno,
|
|
)
|
|
new_body.append(call_site_set)
|
|
self.did_instrument = True
|
|
|
|
new_body.append(stmt)
|
|
|
|
node.body = new_body
|
|
return node
|
|
|
|
def _is_target_call(self, call_node: ast.Call) -> bool:
|
|
"""Check if this call node is calling our target sync function."""
|
|
if isinstance(call_node.func, ast.Name):
|
|
return call_node.func.id == self.function_object.function_name
|
|
if isinstance(call_node.func, ast.Attribute):
|
|
return call_node.func.attr == self.function_object.function_name
|
|
return False
|
|
|
|
def _find_target_call(
|
|
self,
|
|
stmt: ast.stmt,
|
|
) -> tuple[ast.stmt, bool]:
|
|
"""Search a statement for direct calls to the target function."""
|
|
stack: list[ast.AST] = [stmt]
|
|
while stack:
|
|
node = stack.pop()
|
|
if (
|
|
isinstance(node, ast.Call)
|
|
and self._is_target_call(node)
|
|
and node_in_call_position(node, self.call_positions)
|
|
):
|
|
return stmt, True
|
|
for fname in node._fields:
|
|
child = getattr(node, fname, None)
|
|
if isinstance(child, list):
|
|
stack.extend(child)
|
|
elif isinstance(child, ast.AST):
|
|
stack.append(child)
|
|
return stmt, False
|
|
|
|
|
|
class SyncDecoratorAdder(cst.CSTTransformer):
|
|
"""Transformer that adds a sync decorator to sync function definitions."""
|
|
|
|
def __init__(
|
|
self,
|
|
function: FunctionToOptimize,
|
|
mode: TestingMode = TestingMode.BEHAVIOR,
|
|
) -> None:
|
|
"""Initialize the transformer."""
|
|
super().__init__()
|
|
self.qualified_name_parts = function.qualified_name.split(".")
|
|
self.context_stack: list[str] = []
|
|
self.added_decorator = False
|
|
self.decorator_name = get_sync_decorator_name_for_mode(mode)
|
|
|
|
def visit_ClassDef( # noqa: N802
|
|
self,
|
|
node: cst.ClassDef,
|
|
) -> None:
|
|
"""Push class name onto the context stack."""
|
|
self.context_stack.append(node.name.value)
|
|
|
|
def leave_ClassDef( # noqa: N802
|
|
self,
|
|
original_node: cst.ClassDef,
|
|
updated_node: cst.ClassDef,
|
|
) -> cst.ClassDef:
|
|
"""Pop class name from the context stack."""
|
|
self.context_stack.pop()
|
|
return updated_node
|
|
|
|
def visit_FunctionDef( # noqa: N802
|
|
self,
|
|
node: cst.FunctionDef,
|
|
) -> None:
|
|
"""Push function name onto the context stack."""
|
|
self.context_stack.append(node.name.value)
|
|
|
|
def leave_FunctionDef( # noqa: N802
|
|
self,
|
|
original_node: cst.FunctionDef,
|
|
updated_node: cst.FunctionDef,
|
|
) -> cst.FunctionDef:
|
|
"""Add the sync decorator if the function matches the target."""
|
|
if (
|
|
original_node.asynchronous is None
|
|
and self.context_stack == self.qualified_name_parts
|
|
and not any(
|
|
self._is_codeflash_decorator(d.decorator)
|
|
for d in original_node.decorators
|
|
)
|
|
):
|
|
new_decorator = cst.Decorator(
|
|
decorator=cst.Name(value=self.decorator_name),
|
|
)
|
|
if self._has_descriptor_decorator(original_node):
|
|
updated_node = updated_node.with_changes(
|
|
decorators=(*updated_node.decorators, new_decorator),
|
|
)
|
|
else:
|
|
updated_node = updated_node.with_changes(
|
|
decorators=(new_decorator, *updated_node.decorators),
|
|
)
|
|
self.added_decorator = True
|
|
|
|
self.context_stack.pop()
|
|
return updated_node
|
|
|
|
@staticmethod
|
|
def _is_codeflash_decorator(
|
|
decorator_node: cst.BaseExpression,
|
|
) -> bool:
|
|
"""Check if a decorator is one of the codeflash sync decorators."""
|
|
if isinstance(decorator_node, cst.Name):
|
|
return decorator_node.value in _CODEFLASH_SYNC_DECORATORS
|
|
if isinstance(decorator_node, cst.Call) and isinstance(
|
|
decorator_node.func,
|
|
cst.Name,
|
|
):
|
|
return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS
|
|
return False
|
|
|
|
@staticmethod
|
|
def _has_descriptor_decorator(
|
|
node: cst.FunctionDef,
|
|
) -> bool:
|
|
"""Check if the function has @classmethod or @staticmethod."""
|
|
for d in node.decorators:
|
|
if isinstance(d.decorator, cst.Name) and d.decorator.value in (
|
|
"classmethod",
|
|
"staticmethod",
|
|
):
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_sync_decorator_name_for_mode(
|
|
mode: TestingMode,
|
|
) -> str:
|
|
"""Return the sync decorator function name for the given testing mode."""
|
|
if mode == TestingMode.PERFORMANCE:
|
|
return "codeflash_performance_sync"
|
|
return "codeflash_behavior_sync"
|
|
|
|
|
|
def inject_sync_profiling_into_existing_test(
|
|
test_path: Path,
|
|
call_positions: list[CodePosition],
|
|
function_to_optimize: FunctionToOptimize,
|
|
) -> tuple[bool, str | None]:
|
|
"""Inject call-site tracking for sync function calls in a test file."""
|
|
with test_path.open(encoding="utf8") as f:
|
|
test_code = f.read()
|
|
|
|
try:
|
|
tree = ast.parse(test_code)
|
|
except SyntaxError:
|
|
log.exception(
|
|
"Syntax error in code in file - %s",
|
|
test_path,
|
|
)
|
|
return False, None
|
|
|
|
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
|
import_visitor.visit(tree)
|
|
func = import_visitor.imported_as
|
|
|
|
sync_instrumenter = SyncCallInstrumenter(func, call_positions)
|
|
tree = sync_instrumenter.visit(tree)
|
|
|
|
if not sync_instrumenter.did_instrument:
|
|
return False, None
|
|
|
|
new_imports = [
|
|
ast.ImportFrom(
|
|
module=ASYNC_HELPER_FILENAME.removesuffix(".py"),
|
|
names=[ast.alias(name="_codeflash_call_site")],
|
|
level=0,
|
|
),
|
|
]
|
|
tree.body = [*new_imports, *tree.body]
|
|
return True, sort_imports(ast.unparse(tree), float_to_top=True)
|
|
|
|
|
|
def add_sync_decorator_to_function(
|
|
source_path: Path,
|
|
function: FunctionToOptimize,
|
|
mode: TestingMode = TestingMode.BEHAVIOR,
|
|
project_root: Path | None = None,
|
|
) -> tuple[bool, dict[Path, str]]:
|
|
"""Add a sync instrumentation decorator to *function*.
|
|
|
|
Writes the helper file and adds the appropriate import and decorator.
|
|
Returns ``(True, originals)`` if the decorator was added, where
|
|
*originals* maps each modified file to its content before modification.
|
|
"""
|
|
if function.is_async:
|
|
return False, {}
|
|
|
|
try:
|
|
with source_path.open(encoding="utf8") as f:
|
|
source_code = f.read()
|
|
|
|
module = cst.parse_module(source_code)
|
|
decorator_transformer = SyncDecoratorAdder(function, mode)
|
|
module = module.visit(decorator_transformer)
|
|
|
|
if decorator_transformer.added_decorator:
|
|
helper_dir = (
|
|
project_root
|
|
if project_root is not None
|
|
else source_path.parent
|
|
)
|
|
write_async_helper_file(helper_dir)
|
|
decorator_name = get_sync_decorator_name_for_mode(mode)
|
|
import_stmt = ASYNC_HELPER_FILENAME.removesuffix(".py")
|
|
import_node = cst.parse_statement(
|
|
f"from {import_stmt} import {decorator_name}"
|
|
)
|
|
module = module.with_changes(
|
|
body=[import_node, *list(module.body)]
|
|
)
|
|
|
|
modified_code = sort_imports(code=module.code, float_to_top=True)
|
|
except Exception:
|
|
log.exception(
|
|
"Error adding sync decorator to function %s",
|
|
function.qualified_name,
|
|
)
|
|
return False, {}
|
|
else:
|
|
if decorator_transformer.added_decorator:
|
|
originals: dict[Path, str] = {source_path: source_code}
|
|
with source_path.open("w", encoding="utf8") as f:
|
|
f.write(modified_code)
|
|
return True, originals
|
|
return False, {}
|