feat: add sync instrumentation module with decorator-based approach

New _instrument_sync.py mirrors the async instrumentation pattern:
- SyncCallInstrumenter injects _codeflash_call_site.set() before sync calls
- SyncDecoratorAdder applies @codeflash_behavior_sync via libcst
- add_sync_decorator_to_function() decorates source files
- inject_sync_profiling_into_existing_test() instruments test files

Reuses the same helper file (codeflash_async_wrapper.py) since both
sync and async decorators live in _codeflash_async_decorators.py.
This commit is contained in:
Kevin Turcios 2026-04-24 04:54:45 -05:00
parent 8c218038e9
commit 918a2a10a4
2 changed files with 856 additions and 0 deletions

View file

@ -0,0 +1,328 @@
"""Sync-specific instrumentation: AST transformers, decorators, and helpers.
Provides ``SyncCallInstrumenter`` for injecting ``_codeflash_call_site``
contextvar assignments before sync function calls, ``SyncDecoratorAdder``
for adding sync behavior decorators via libcst, and high-level functions
for instrumenting sync test and source files.
"""
from __future__ import annotations
import ast
import logging
from typing import TYPE_CHECKING
import libcst as cst
from .._model import (
FunctionToOptimize,
TestingMode,
)
from ..analysis._formatter import sort_imports
from ._instrument_async import ASYNC_HELPER_FILENAME, write_async_helper_file
from ._instrument_core import (
FunctionImportedAsVisitor,
node_in_call_position,
)
if TYPE_CHECKING:
from pathlib import Path
from ..test_discovery.models import CodePosition
log = logging.getLogger(__name__)
_CODEFLASH_SYNC_DECORATORS = frozenset(
{
"codeflash_behavior_sync",
}
)
class SyncCallInstrumenter(ast.NodeTransformer):
"""AST transformer that injects call-site tracking before sync calls."""
def __init__(
self,
function: FunctionToOptimize,
call_positions: list[CodePosition],
) -> None:
"""Initialize with the target sync function and call positions."""
self.function_object = function
self.call_positions = call_positions
self.did_instrument = False
self.call_counter: dict[str, int] = {}
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
"""Recurse into class bodies to find test methods."""
return self.generic_visit(node) # type: ignore[return-value]
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Instrument sync test functions that call the target function."""
if not node.name.startswith("test_"):
return node
return self._process_test_function(node) # type: ignore[return-value]
def visit_AsyncFunctionDef(
self, node: ast.AsyncFunctionDef
) -> ast.AsyncFunctionDef:
"""Instrument async test functions that call the target sync function."""
if not node.name.startswith("test_"):
return node
return self._process_test_function(node) # type: ignore[return-value]
def _process_test_function(
self,
node: ast.FunctionDef | ast.AsyncFunctionDef,
) -> ast.FunctionDef | ast.AsyncFunctionDef:
"""Add _codeflash_call_site.set() calls before target function calls."""
if node.name not in self.call_counter:
self.call_counter[node.name] = 0
new_body: list[ast.stmt] = []
for stmt in node.body:
_, has_target = self._find_target_call(stmt)
if has_target:
current_call_index = self.call_counter[node.name]
self.call_counter[node.name] += 1
call_site_set = ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(
id="_codeflash_call_site",
ctx=ast.Load(),
),
attr="set",
ctx=ast.Load(),
),
args=[
ast.Constant(
value=f"{current_call_index}",
),
],
keywords=[],
),
lineno=stmt.lineno,
)
new_body.append(call_site_set)
self.did_instrument = True
new_body.append(stmt)
node.body = new_body
return node
def _is_target_call(self, call_node: ast.Call) -> bool:
"""Check if this call node is calling our target sync function."""
if isinstance(call_node.func, ast.Name):
return call_node.func.id == self.function_object.function_name
if isinstance(call_node.func, ast.Attribute):
return call_node.func.attr == self.function_object.function_name
return False
def _find_target_call(
self,
stmt: ast.stmt,
) -> tuple[ast.stmt, bool]:
"""Search a statement for direct calls to the target function."""
stack: list[ast.AST] = [stmt]
while stack:
node = stack.pop()
if (
isinstance(node, ast.Call)
and self._is_target_call(node)
and node_in_call_position(node, self.call_positions)
):
return stmt, True
for fname in node._fields:
child = getattr(node, fname, None)
if isinstance(child, list):
stack.extend(child)
elif isinstance(child, ast.AST):
stack.append(child)
return stmt, False
class SyncDecoratorAdder(cst.CSTTransformer):
"""Transformer that adds a sync decorator to sync function definitions."""
def __init__(
self,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
) -> None:
"""Initialize the transformer."""
super().__init__()
self.qualified_name_parts = function.qualified_name.split(".")
self.context_stack: list[str] = []
self.added_decorator = False
self.decorator_name = get_sync_decorator_name_for_mode(mode)
def visit_ClassDef( # noqa: N802
self,
node: cst.ClassDef,
) -> None:
"""Push class name onto the context stack."""
self.context_stack.append(node.name.value)
def leave_ClassDef( # noqa: N802
self,
original_node: cst.ClassDef,
updated_node: cst.ClassDef,
) -> cst.ClassDef:
"""Pop class name from the context stack."""
self.context_stack.pop()
return updated_node
def visit_FunctionDef( # noqa: N802
self,
node: cst.FunctionDef,
) -> None:
"""Push function name onto the context stack."""
self.context_stack.append(node.name.value)
def leave_FunctionDef( # noqa: N802
self,
original_node: cst.FunctionDef,
updated_node: cst.FunctionDef,
) -> cst.FunctionDef:
"""Add the sync decorator if the function matches the target."""
if (
original_node.asynchronous is None
and self.context_stack == self.qualified_name_parts
and not any(
self._is_codeflash_decorator(d.decorator)
for d in original_node.decorators
)
):
new_decorator = cst.Decorator(
decorator=cst.Name(value=self.decorator_name),
)
updated_node = updated_node.with_changes(
decorators=(new_decorator, *updated_node.decorators),
)
self.added_decorator = True
self.context_stack.pop()
return updated_node
@staticmethod
def _is_codeflash_decorator(
decorator_node: cst.BaseExpression,
) -> bool:
"""Check if a decorator is one of the codeflash sync decorators."""
if isinstance(decorator_node, cst.Name):
return decorator_node.value in _CODEFLASH_SYNC_DECORATORS
if isinstance(decorator_node, cst.Call) and isinstance(
decorator_node.func,
cst.Name,
):
return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS
return False
def get_sync_decorator_name_for_mode(
mode: TestingMode,
) -> str:
"""Return the sync decorator function name for the given testing mode."""
if mode == TestingMode.BEHAVIOR:
return "codeflash_behavior_sync"
return "codeflash_behavior_sync"
def inject_sync_profiling_into_existing_test(
test_path: Path,
call_positions: list[CodePosition],
function_to_optimize: FunctionToOptimize,
) -> tuple[bool, str | None]:
"""Inject call-site tracking for sync function calls in a test file."""
with test_path.open(encoding="utf8") as f:
test_code = f.read()
try:
tree = ast.parse(test_code)
except SyntaxError:
log.exception(
"Syntax error in code in file - %s",
test_path,
)
return False, None
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
import_visitor.visit(tree)
func = import_visitor.imported_as
sync_instrumenter = SyncCallInstrumenter(func, call_positions)
tree = sync_instrumenter.visit(tree)
if not sync_instrumenter.did_instrument:
return False, None
new_imports = [
ast.ImportFrom(
module=ASYNC_HELPER_FILENAME.removesuffix(".py"),
names=[ast.alias(name="_codeflash_call_site")],
level=0,
),
]
tree.body = [*new_imports, *tree.body]
return True, sort_imports(ast.unparse(tree), float_to_top=True)
def add_sync_decorator_to_function(
source_path: Path,
function: FunctionToOptimize,
mode: TestingMode = TestingMode.BEHAVIOR,
project_root: Path | None = None,
) -> tuple[bool, dict[Path, str]]:
"""Add a sync instrumentation decorator to *function*.
Writes the helper file and adds the appropriate import and decorator.
Returns ``(True, originals)`` if the decorator was added, where
*originals* maps each modified file to its content before modification.
"""
if function.is_async:
return False, {}
try:
with source_path.open(encoding="utf8") as f:
source_code = f.read()
module = cst.parse_module(source_code)
decorator_transformer = SyncDecoratorAdder(function, mode)
module = module.visit(decorator_transformer)
if decorator_transformer.added_decorator:
helper_dir = (
project_root
if project_root is not None
else source_path.parent
)
write_async_helper_file(helper_dir)
decorator_name = get_sync_decorator_name_for_mode(mode)
import_stmt = ASYNC_HELPER_FILENAME.removesuffix(".py")
import_node = cst.parse_statement(
f"from {import_stmt} import {decorator_name}"
)
module = module.with_changes(
body=[import_node, *list(module.body)]
)
modified_code = sort_imports(code=module.code, float_to_top=True)
except Exception:
log.exception(
"Error adding sync decorator to function %s",
function.qualified_name,
)
return False, {}
else:
if decorator_transformer.added_decorator:
originals: dict[Path, str] = {source_path: source_code}
with source_path.open("w", encoding="utf8") as f:
f.write(modified_code)
return True, originals
return False, {}

View file

@ -0,0 +1,528 @@
"""Tests for the sync-specific instrumentation module."""
import ast
import tempfile
from pathlib import Path
import pytest
from codeflash_python._model import FunctionParent, TestingMode
from codeflash_python.analysis._discovery import FunctionToOptimize
from codeflash_python.test_discovery.models import CodePosition
from codeflash_python.testing._instrument_async import ASYNC_HELPER_FILENAME
from codeflash_python.testing._instrument_sync import (
SyncCallInstrumenter,
SyncDecoratorAdder,
add_sync_decorator_to_function,
get_sync_decorator_name_for_mode,
inject_sync_profiling_into_existing_test,
)
@pytest.fixture(name="temp_dir")
def _temp_dir():
"""Create a temporary directory for test files."""
with tempfile.TemporaryDirectory() as temp:
yield Path(temp)
class TestGetSyncDecoratorNameForMode:
"""get_sync_decorator_name_for_mode returns the correct name."""
def test_behavior_mode(self) -> None:
"""Returns codeflash_behavior_sync for BEHAVIOR."""
assert "codeflash_behavior_sync" == get_sync_decorator_name_for_mode(
TestingMode.BEHAVIOR
)
class TestSyncDecoratorAdder:
"""SyncDecoratorAdder adds decorators to sync functions."""
def test_adds_decorator_to_sync_function(self, temp_dir) -> None:
"""Adds the behavior decorator to a sync function."""
import libcst as cst
source = "def my_func(x: int) -> int:\n return x + 1\n"
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("test.py"),
parents=[],
is_async=False,
)
module = cst.parse_module(source)
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
module = module.visit(adder)
assert adder.added_decorator
assert "@codeflash_behavior_sync" in module.code
def test_skips_async_function(self) -> None:
"""Does not add a sync decorator to an async function."""
import libcst as cst
source = "async def my_func(x: int) -> int:\n return x + 1\n"
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("test.py"),
parents=[],
is_async=False,
)
module = cst.parse_module(source)
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
module = module.visit(adder)
assert not adder.added_decorator
def test_adds_decorator_to_class_method(self) -> None:
"""Adds decorator to a sync method inside a class."""
import libcst as cst
source = (
"class MyClass:\n"
" def my_method(self, x: int) -> int:\n"
" return x + 1\n"
)
func = FunctionToOptimize(
function_name="my_method",
file_path=Path("test.py"),
parents=[FunctionParent(name="MyClass", type="ClassDef")],
is_async=False,
)
module = cst.parse_module(source)
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
module = module.visit(adder)
assert adder.added_decorator
assert "@codeflash_behavior_sync" in module.code
def test_no_duplicate_decorator(self) -> None:
"""Does not add decorator if one already exists."""
import libcst as cst
source = (
"@codeflash_behavior_sync\n"
"def my_func(x: int) -> int:\n"
" return x + 1\n"
)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("test.py"),
parents=[],
is_async=False,
)
module = cst.parse_module(source)
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
module = module.visit(adder)
assert not adder.added_decorator
class TestSyncCallInstrumenter:
"""SyncCallInstrumenter injects _codeflash_call_site.set() calls."""
def test_instruments_direct_call(self) -> None:
"""Injects call-site set before a direct function call."""
test_code = (
"def test_example():\n"
" result = my_func(1)\n"
" assert result == 2\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func, [CodePosition(2, 13)]
)
tree = instrumenter.visit(tree)
assert instrumenter.did_instrument
source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source
def test_instruments_method_call(self) -> None:
"""Injects call-site set before obj.method() style calls."""
test_code = (
"def test_example():\n"
" result = obj.my_func(1)\n"
" assert result == 2\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func, [CodePosition(2, 13)]
)
tree = instrumenter.visit(tree)
assert instrumenter.did_instrument
source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source
def test_multiple_calls_get_incremented_indices(self) -> None:
"""Multiple calls in the same test get sequential indices."""
test_code = (
"def test_example():\n"
" a = my_func(1)\n"
" b = my_func(2)\n"
" c = my_func(3)\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func,
[
CodePosition(2, 8),
CodePosition(3, 8),
CodePosition(4, 8),
],
)
tree = instrumenter.visit(tree)
assert instrumenter.did_instrument
source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source
assert "_codeflash_call_site.set('1')" in source
assert "_codeflash_call_site.set('2')" in source
def test_skips_non_test_functions(self) -> None:
"""Does not instrument functions that don't start with test_."""
test_code = (
"def helper():\n"
" return my_func(1)\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func, [CodePosition(2, 11)]
)
tree = instrumenter.visit(tree)
assert not instrumenter.did_instrument
def test_instruments_inside_class(self) -> None:
"""Instruments test methods inside a test class."""
test_code = (
"class TestFoo:\n"
" def test_bar(self):\n"
" result = my_func(1)\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func, [CodePosition(3, 17)]
)
tree = instrumenter.visit(tree)
assert instrumenter.did_instrument
source = ast.unparse(tree)
assert "_codeflash_call_site.set('0')" in source
def test_no_match_when_position_wrong(self) -> None:
"""Does not instrument if code position doesn't match."""
test_code = (
"def test_example():\n"
" result = my_func(1)\n"
)
tree = ast.parse(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
instrumenter = SyncCallInstrumenter(
func, [CodePosition(99, 99)]
)
tree = instrumenter.visit(tree)
assert not instrumenter.did_instrument
class TestAddSyncDecoratorToFunction:
"""add_sync_decorator_to_function decorates sync source files."""
def test_decorates_sync_function(self, temp_dir) -> None:
"""Adds decorator and import to a sync function source file."""
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=source_file,
parents=[],
is_async=False,
)
success, originals = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
assert success
assert source_file in originals
assert originals[source_file] == source_code
modified = source_file.read_text()
assert "@codeflash_behavior_sync" in modified
assert "from codeflash_async_wrapper import" in modified
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
def test_skips_async_function(self, temp_dir) -> None:
"""Returns False for async functions."""
source_code = "async def my_func(x: int) -> int:\n return x + 1\n"
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=source_file,
parents=[],
is_async=True,
)
success, originals = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
assert not success
assert {} == originals
def test_decorates_class_method(self, temp_dir) -> None:
"""Adds decorator to a sync method inside a class."""
source_code = (
"class Calc:\n"
" def add(self, a: int, b: int) -> int:\n"
" return a + b\n"
)
source_file = temp_dir / "calc.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="add",
file_path=source_file,
parents=[FunctionParent(name="Calc", type="ClassDef")],
is_async=False,
)
success, _ = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
assert success
modified = source_file.read_text()
assert "@codeflash_behavior_sync" in modified
def test_uses_project_root_for_helper(self, temp_dir) -> None:
"""Copies helper file to project_root when specified."""
subdir = temp_dir / "subdir"
subdir.mkdir()
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
source_file = subdir / "my_module.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=source_file,
parents=[],
is_async=False,
)
success, _ = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR, project_root=temp_dir
)
assert success
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
assert not (subdir / ASYNC_HELPER_FILENAME).exists()
def test_preserves_existing_decorators(self, temp_dir) -> None:
"""Adds codeflash decorator above existing decorators."""
source_code = (
"@staticmethod\n"
"def my_func(x: int) -> int:\n"
" return x + 1\n"
)
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=source_file,
parents=[],
is_async=False,
)
success, _ = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
assert success
modified = source_file.read_text()
cf_pos = modified.find("@codeflash_behavior_sync")
sm_pos = modified.find("@staticmethod")
assert cf_pos < sm_pos
def test_no_duplicate_decorator(self, temp_dir) -> None:
"""Does not add decorator if already present."""
source_code = (
"@codeflash_behavior_sync\n"
"def my_func(x: int) -> int:\n"
" return x + 1\n"
)
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=source_file,
parents=[],
is_async=False,
)
success, _ = add_sync_decorator_to_function(
source_file, func, TestingMode.BEHAVIOR
)
assert not success
class TestInjectSyncProfilingIntoExistingTest:
"""inject_sync_profiling_into_existing_test instruments test files."""
def test_injects_call_site_tracking(self, temp_dir) -> None:
"""Injects _codeflash_call_site.set() and import into a test file."""
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
test_code = (
"from my_module import my_func\n\n"
"def test_my_func():\n"
" result = my_func(1)\n"
" assert result == 2\n"
)
test_file = temp_dir / "test_my_module.py"
test_file.write_text(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("my_module.py"),
parents=[],
is_async=False,
)
success, instrumented = inject_sync_profiling_into_existing_test(
test_file,
[CodePosition(4, 13)],
func,
)
assert success
assert instrumented is not None
assert "_codeflash_call_site.set('0')" in instrumented
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
def test_multiple_calls_numbered_sequentially(self, temp_dir) -> None:
"""Multiple calls get sequential call-site indices."""
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
source_file = temp_dir / "my_module.py"
source_file.write_text(source_code)
test_code = (
"from my_module import my_func\n\n"
"def test_my_func():\n"
" a = my_func(1)\n"
" b = my_func(2)\n"
" c = my_func(3)\n"
)
test_file = temp_dir / "test_my_module.py"
test_file.write_text(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("my_module.py"),
parents=[],
is_async=False,
)
success, instrumented = inject_sync_profiling_into_existing_test(
test_file,
[
CodePosition(4, 8),
CodePosition(5, 8),
CodePosition(6, 8),
],
func,
)
assert success
assert instrumented is not None
assert "_codeflash_call_site.set('0')" in instrumented
assert "_codeflash_call_site.set('1')" in instrumented
assert "_codeflash_call_site.set('2')" in instrumented
def test_returns_false_for_syntax_error(self, temp_dir) -> None:
"""Returns (False, None) when the test file has a syntax error."""
test_file = temp_dir / "test_bad.py"
test_file.write_text("def test_foo(\n")
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
success, result = inject_sync_profiling_into_existing_test(
test_file, [CodePosition(1, 0)], func
)
assert not success
assert result is None
def test_returns_false_when_no_target_calls(self, temp_dir) -> None:
"""Returns (False, None) when no target function calls are found."""
test_code = (
"def test_unrelated():\n"
" assert 1 == 1\n"
)
test_file = temp_dir / "test_noop.py"
test_file.write_text(test_code)
func = FunctionToOptimize(
function_name="my_func",
file_path=Path("mod.py"),
parents=[],
is_async=False,
)
success, result = inject_sync_profiling_into_existing_test(
test_file, [CodePosition(2, 0)], func
)
assert not success
assert result is None