"""Sync-specific instrumentation: AST transformers, decorators, and helpers. Provides ``SyncCallInstrumenter`` for injecting ``_codeflash_call_site`` contextvar assignments before sync function calls, ``SyncDecoratorAdder`` for adding sync behavior decorators via libcst, and high-level functions for instrumenting sync test and source files. """ from __future__ import annotations import ast import logging from typing import TYPE_CHECKING import libcst as cst from .._model import ( FunctionToOptimize, TestingMode, ) from ..analysis._formatter import sort_imports from ._instrument_async import ASYNC_HELPER_FILENAME, write_async_helper_file from ._instrument_core import ( FunctionImportedAsVisitor, node_in_call_position, ) if TYPE_CHECKING: from pathlib import Path from ..test_discovery.models import CodePosition log = logging.getLogger(__name__) _CODEFLASH_SYNC_DECORATORS = frozenset( { "codeflash_behavior_sync", "codeflash_performance_sync", } ) class SyncCallInstrumenter(ast.NodeTransformer): """AST transformer that injects call-site tracking before sync calls.""" def __init__( self, function: FunctionToOptimize, call_positions: list[CodePosition], ) -> None: """Initialize with the target sync function and call positions.""" self.function_object = function self.call_positions = call_positions self.did_instrument = False def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: """Recurse into class bodies to find test methods.""" return self.generic_visit(node) # type: ignore[return-value] def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: """Instrument sync test functions that call the target function.""" if not node.name.startswith("test_"): return node return self._process_test_function(node) # type: ignore[return-value] def visit_AsyncFunctionDef( self, node: ast.AsyncFunctionDef ) -> ast.AsyncFunctionDef: """Instrument async test functions that call the target sync function.""" if not node.name.startswith("test_"): return node return self._process_test_function(node) # type: ignore[return-value] def _process_test_function( self, node: ast.FunctionDef | ast.AsyncFunctionDef, ) -> ast.FunctionDef | ast.AsyncFunctionDef: """Add _codeflash_call_site.set() calls before target function calls.""" new_body: list[ast.stmt] = [] for stmt in node.body: _, has_target = self._find_target_call(stmt) if has_target: call_site_set = ast.Expr( value=ast.Call( func=ast.Attribute( value=ast.Name( id="_codeflash_call_site", ctx=ast.Load(), ), attr="set", ctx=ast.Load(), ), args=[ ast.Constant( value=f"{stmt.lineno}", ), ], keywords=[], ), lineno=stmt.lineno, ) new_body.append(call_site_set) self.did_instrument = True new_body.append(stmt) node.body = new_body return node def _is_target_call(self, call_node: ast.Call) -> bool: """Check if this call node is calling our target sync function.""" if isinstance(call_node.func, ast.Name): return call_node.func.id == self.function_object.function_name if isinstance(call_node.func, ast.Attribute): return call_node.func.attr == self.function_object.function_name return False def _find_target_call( self, stmt: ast.stmt, ) -> tuple[ast.stmt, bool]: """Search a statement for direct calls to the target function.""" stack: list[ast.AST] = [stmt] while stack: node = stack.pop() if ( isinstance(node, ast.Call) and self._is_target_call(node) and node_in_call_position(node, self.call_positions) ): return stmt, True for fname in node._fields: child = getattr(node, fname, None) if isinstance(child, list): stack.extend(child) elif isinstance(child, ast.AST): stack.append(child) return stmt, False class SyncDecoratorAdder(cst.CSTTransformer): """Transformer that adds a sync decorator to sync function definitions.""" def __init__( self, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR, ) -> None: """Initialize the transformer.""" super().__init__() self.qualified_name_parts = function.qualified_name.split(".") self.context_stack: list[str] = [] self.added_decorator = False self.decorator_name = get_sync_decorator_name_for_mode(mode) def visit_ClassDef( # noqa: N802 self, node: cst.ClassDef, ) -> None: """Push class name onto the context stack.""" self.context_stack.append(node.name.value) def leave_ClassDef( # noqa: N802 self, original_node: cst.ClassDef, updated_node: cst.ClassDef, ) -> cst.ClassDef: """Pop class name from the context stack.""" self.context_stack.pop() return updated_node def visit_FunctionDef( # noqa: N802 self, node: cst.FunctionDef, ) -> None: """Push function name onto the context stack.""" self.context_stack.append(node.name.value) def leave_FunctionDef( # noqa: N802 self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef, ) -> cst.FunctionDef: """Add the sync decorator if the function matches the target.""" if ( original_node.asynchronous is None and self.context_stack == self.qualified_name_parts and not any( self._is_codeflash_decorator(d.decorator) for d in original_node.decorators ) ): new_decorator = cst.Decorator( decorator=cst.Name(value=self.decorator_name), ) if self._has_descriptor_decorator(original_node): updated_node = updated_node.with_changes( decorators=(*updated_node.decorators, new_decorator), ) else: updated_node = updated_node.with_changes( decorators=(new_decorator, *updated_node.decorators), ) self.added_decorator = True self.context_stack.pop() return updated_node @staticmethod def _is_codeflash_decorator( decorator_node: cst.BaseExpression, ) -> bool: """Check if a decorator is one of the codeflash sync decorators.""" if isinstance(decorator_node, cst.Name): return decorator_node.value in _CODEFLASH_SYNC_DECORATORS if isinstance(decorator_node, cst.Call) and isinstance( decorator_node.func, cst.Name, ): return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS return False @staticmethod def _has_descriptor_decorator( node: cst.FunctionDef, ) -> bool: """Check if the function has @classmethod or @staticmethod.""" for d in node.decorators: if isinstance(d.decorator, cst.Name) and d.decorator.value in ( "classmethod", "staticmethod", ): return True return False def get_sync_decorator_name_for_mode( mode: TestingMode, ) -> str: """Return the sync decorator function name for the given testing mode.""" if mode == TestingMode.PERFORMANCE: return "codeflash_performance_sync" return "codeflash_behavior_sync" def inject_sync_profiling_into_existing_test( test_path: Path, call_positions: list[CodePosition], function_to_optimize: FunctionToOptimize, ) -> tuple[bool, str | None]: """Inject call-site tracking for sync function calls in a test file.""" with test_path.open(encoding="utf8") as f: test_code = f.read() try: tree = ast.parse(test_code) except SyntaxError: log.exception( "Syntax error in code in file - %s", test_path, ) return False, None import_visitor = FunctionImportedAsVisitor(function_to_optimize) import_visitor.visit(tree) func = import_visitor.imported_as sync_instrumenter = SyncCallInstrumenter(func, call_positions) tree = sync_instrumenter.visit(tree) if not sync_instrumenter.did_instrument: return False, None new_imports = [ ast.ImportFrom( module=ASYNC_HELPER_FILENAME.removesuffix(".py"), names=[ast.alias(name="_codeflash_call_site")], level=0, ), ] tree.body = [*new_imports, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) def add_sync_decorator_to_function( source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR, project_root: Path | None = None, ) -> tuple[bool, dict[Path, str]]: """Add a sync instrumentation decorator to *function*. Writes the helper file and adds the appropriate import and decorator. Returns ``(True, originals)`` if the decorator was added, where *originals* maps each modified file to its content before modification. """ if function.is_async: return False, {} try: with source_path.open(encoding="utf8") as f: source_code = f.read() module = cst.parse_module(source_code) decorator_transformer = SyncDecoratorAdder(function, mode) module = module.visit(decorator_transformer) if decorator_transformer.added_decorator: helper_dir = ( project_root if project_root is not None else source_path.parent ) write_async_helper_file(helper_dir) decorator_name = get_sync_decorator_name_for_mode(mode) import_stmt = ASYNC_HELPER_FILENAME.removesuffix(".py") import_node = cst.parse_statement( f"from {import_stmt} import {decorator_name}" ) module = module.with_changes( body=[import_node, *list(module.body)] ) modified_code = sort_imports(code=module.code, float_to_top=True) except Exception: log.exception( "Error adding sync decorator to function %s", function.qualified_name, ) return False, {} else: if decorator_transformer.added_decorator: originals: dict[Path, str] = {source_path: source_code} with source_path.open("w", encoding="utf8") as f: f.write(modified_code) return True, originals return False, {}