mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
feat: add sync instrumentation module with decorator-based approach
New _instrument_sync.py mirrors the async instrumentation pattern: - SyncCallInstrumenter injects _codeflash_call_site.set() before sync calls - SyncDecoratorAdder applies @codeflash_behavior_sync via libcst - add_sync_decorator_to_function() decorates source files - inject_sync_profiling_into_existing_test() instruments test files Reuses the same helper file (codeflash_async_wrapper.py) since both sync and async decorators live in _codeflash_async_decorators.py.
This commit is contained in:
parent
8c218038e9
commit
918a2a10a4
2 changed files with 856 additions and 0 deletions
|
|
@ -0,0 +1,328 @@
|
|||
"""Sync-specific instrumentation: AST transformers, decorators, and helpers.
|
||||
|
||||
Provides ``SyncCallInstrumenter`` for injecting ``_codeflash_call_site``
|
||||
contextvar assignments before sync function calls, ``SyncDecoratorAdder``
|
||||
for adding sync behavior decorators via libcst, and high-level functions
|
||||
for instrumenting sync test and source files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import libcst as cst
|
||||
|
||||
from .._model import (
|
||||
FunctionToOptimize,
|
||||
TestingMode,
|
||||
)
|
||||
from ..analysis._formatter import sort_imports
|
||||
from ._instrument_async import ASYNC_HELPER_FILENAME, write_async_helper_file
|
||||
from ._instrument_core import (
|
||||
FunctionImportedAsVisitor,
|
||||
node_in_call_position,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from ..test_discovery.models import CodePosition
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_CODEFLASH_SYNC_DECORATORS = frozenset(
|
||||
{
|
||||
"codeflash_behavior_sync",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SyncCallInstrumenter(ast.NodeTransformer):
|
||||
"""AST transformer that injects call-site tracking before sync calls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
call_positions: list[CodePosition],
|
||||
) -> None:
|
||||
"""Initialize with the target sync function and call positions."""
|
||||
self.function_object = function
|
||||
self.call_positions = call_positions
|
||||
self.did_instrument = False
|
||||
self.call_counter: dict[str, int] = {}
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
|
||||
"""Recurse into class bodies to find test methods."""
|
||||
return self.generic_visit(node) # type: ignore[return-value]
|
||||
|
||||
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
|
||||
"""Instrument sync test functions that call the target function."""
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
return self._process_test_function(node) # type: ignore[return-value]
|
||||
|
||||
def visit_AsyncFunctionDef(
|
||||
self, node: ast.AsyncFunctionDef
|
||||
) -> ast.AsyncFunctionDef:
|
||||
"""Instrument async test functions that call the target sync function."""
|
||||
if not node.name.startswith("test_"):
|
||||
return node
|
||||
return self._process_test_function(node) # type: ignore[return-value]
|
||||
|
||||
def _process_test_function(
|
||||
self,
|
||||
node: ast.FunctionDef | ast.AsyncFunctionDef,
|
||||
) -> ast.FunctionDef | ast.AsyncFunctionDef:
|
||||
"""Add _codeflash_call_site.set() calls before target function calls."""
|
||||
if node.name not in self.call_counter:
|
||||
self.call_counter[node.name] = 0
|
||||
|
||||
new_body: list[ast.stmt] = []
|
||||
|
||||
for stmt in node.body:
|
||||
_, has_target = self._find_target_call(stmt)
|
||||
|
||||
if has_target:
|
||||
current_call_index = self.call_counter[node.name]
|
||||
self.call_counter[node.name] += 1
|
||||
|
||||
call_site_set = ast.Expr(
|
||||
value=ast.Call(
|
||||
func=ast.Attribute(
|
||||
value=ast.Name(
|
||||
id="_codeflash_call_site",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
attr="set",
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=[
|
||||
ast.Constant(
|
||||
value=f"{current_call_index}",
|
||||
),
|
||||
],
|
||||
keywords=[],
|
||||
),
|
||||
lineno=stmt.lineno,
|
||||
)
|
||||
new_body.append(call_site_set)
|
||||
self.did_instrument = True
|
||||
|
||||
new_body.append(stmt)
|
||||
|
||||
node.body = new_body
|
||||
return node
|
||||
|
||||
def _is_target_call(self, call_node: ast.Call) -> bool:
|
||||
"""Check if this call node is calling our target sync function."""
|
||||
if isinstance(call_node.func, ast.Name):
|
||||
return call_node.func.id == self.function_object.function_name
|
||||
if isinstance(call_node.func, ast.Attribute):
|
||||
return call_node.func.attr == self.function_object.function_name
|
||||
return False
|
||||
|
||||
def _find_target_call(
|
||||
self,
|
||||
stmt: ast.stmt,
|
||||
) -> tuple[ast.stmt, bool]:
|
||||
"""Search a statement for direct calls to the target function."""
|
||||
stack: list[ast.AST] = [stmt]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if (
|
||||
isinstance(node, ast.Call)
|
||||
and self._is_target_call(node)
|
||||
and node_in_call_position(node, self.call_positions)
|
||||
):
|
||||
return stmt, True
|
||||
for fname in node._fields:
|
||||
child = getattr(node, fname, None)
|
||||
if isinstance(child, list):
|
||||
stack.extend(child)
|
||||
elif isinstance(child, ast.AST):
|
||||
stack.append(child)
|
||||
return stmt, False
|
||||
|
||||
|
||||
class SyncDecoratorAdder(cst.CSTTransformer):
|
||||
"""Transformer that adds a sync decorator to sync function definitions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
function: FunctionToOptimize,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
) -> None:
|
||||
"""Initialize the transformer."""
|
||||
super().__init__()
|
||||
self.qualified_name_parts = function.qualified_name.split(".")
|
||||
self.context_stack: list[str] = []
|
||||
self.added_decorator = False
|
||||
self.decorator_name = get_sync_decorator_name_for_mode(mode)
|
||||
|
||||
def visit_ClassDef( # noqa: N802
|
||||
self,
|
||||
node: cst.ClassDef,
|
||||
) -> None:
|
||||
"""Push class name onto the context stack."""
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_ClassDef( # noqa: N802
|
||||
self,
|
||||
original_node: cst.ClassDef,
|
||||
updated_node: cst.ClassDef,
|
||||
) -> cst.ClassDef:
|
||||
"""Pop class name from the context stack."""
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
def visit_FunctionDef( # noqa: N802
|
||||
self,
|
||||
node: cst.FunctionDef,
|
||||
) -> None:
|
||||
"""Push function name onto the context stack."""
|
||||
self.context_stack.append(node.name.value)
|
||||
|
||||
def leave_FunctionDef( # noqa: N802
|
||||
self,
|
||||
original_node: cst.FunctionDef,
|
||||
updated_node: cst.FunctionDef,
|
||||
) -> cst.FunctionDef:
|
||||
"""Add the sync decorator if the function matches the target."""
|
||||
if (
|
||||
original_node.asynchronous is None
|
||||
and self.context_stack == self.qualified_name_parts
|
||||
and not any(
|
||||
self._is_codeflash_decorator(d.decorator)
|
||||
for d in original_node.decorators
|
||||
)
|
||||
):
|
||||
new_decorator = cst.Decorator(
|
||||
decorator=cst.Name(value=self.decorator_name),
|
||||
)
|
||||
updated_node = updated_node.with_changes(
|
||||
decorators=(new_decorator, *updated_node.decorators),
|
||||
)
|
||||
self.added_decorator = True
|
||||
|
||||
self.context_stack.pop()
|
||||
return updated_node
|
||||
|
||||
@staticmethod
|
||||
def _is_codeflash_decorator(
|
||||
decorator_node: cst.BaseExpression,
|
||||
) -> bool:
|
||||
"""Check if a decorator is one of the codeflash sync decorators."""
|
||||
if isinstance(decorator_node, cst.Name):
|
||||
return decorator_node.value in _CODEFLASH_SYNC_DECORATORS
|
||||
if isinstance(decorator_node, cst.Call) and isinstance(
|
||||
decorator_node.func,
|
||||
cst.Name,
|
||||
):
|
||||
return decorator_node.func.value in _CODEFLASH_SYNC_DECORATORS
|
||||
return False
|
||||
|
||||
|
||||
def get_sync_decorator_name_for_mode(
|
||||
mode: TestingMode,
|
||||
) -> str:
|
||||
"""Return the sync decorator function name for the given testing mode."""
|
||||
if mode == TestingMode.BEHAVIOR:
|
||||
return "codeflash_behavior_sync"
|
||||
return "codeflash_behavior_sync"
|
||||
|
||||
|
||||
def inject_sync_profiling_into_existing_test(
|
||||
test_path: Path,
|
||||
call_positions: list[CodePosition],
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Inject call-site tracking for sync function calls in a test file."""
|
||||
with test_path.open(encoding="utf8") as f:
|
||||
test_code = f.read()
|
||||
|
||||
try:
|
||||
tree = ast.parse(test_code)
|
||||
except SyntaxError:
|
||||
log.exception(
|
||||
"Syntax error in code in file - %s",
|
||||
test_path,
|
||||
)
|
||||
return False, None
|
||||
|
||||
import_visitor = FunctionImportedAsVisitor(function_to_optimize)
|
||||
import_visitor.visit(tree)
|
||||
func = import_visitor.imported_as
|
||||
|
||||
sync_instrumenter = SyncCallInstrumenter(func, call_positions)
|
||||
tree = sync_instrumenter.visit(tree)
|
||||
|
||||
if not sync_instrumenter.did_instrument:
|
||||
return False, None
|
||||
|
||||
new_imports = [
|
||||
ast.ImportFrom(
|
||||
module=ASYNC_HELPER_FILENAME.removesuffix(".py"),
|
||||
names=[ast.alias(name="_codeflash_call_site")],
|
||||
level=0,
|
||||
),
|
||||
]
|
||||
tree.body = [*new_imports, *tree.body]
|
||||
return True, sort_imports(ast.unparse(tree), float_to_top=True)
|
||||
|
||||
|
||||
def add_sync_decorator_to_function(
|
||||
source_path: Path,
|
||||
function: FunctionToOptimize,
|
||||
mode: TestingMode = TestingMode.BEHAVIOR,
|
||||
project_root: Path | None = None,
|
||||
) -> tuple[bool, dict[Path, str]]:
|
||||
"""Add a sync instrumentation decorator to *function*.
|
||||
|
||||
Writes the helper file and adds the appropriate import and decorator.
|
||||
Returns ``(True, originals)`` if the decorator was added, where
|
||||
*originals* maps each modified file to its content before modification.
|
||||
"""
|
||||
if function.is_async:
|
||||
return False, {}
|
||||
|
||||
try:
|
||||
with source_path.open(encoding="utf8") as f:
|
||||
source_code = f.read()
|
||||
|
||||
module = cst.parse_module(source_code)
|
||||
decorator_transformer = SyncDecoratorAdder(function, mode)
|
||||
module = module.visit(decorator_transformer)
|
||||
|
||||
if decorator_transformer.added_decorator:
|
||||
helper_dir = (
|
||||
project_root
|
||||
if project_root is not None
|
||||
else source_path.parent
|
||||
)
|
||||
write_async_helper_file(helper_dir)
|
||||
decorator_name = get_sync_decorator_name_for_mode(mode)
|
||||
import_stmt = ASYNC_HELPER_FILENAME.removesuffix(".py")
|
||||
import_node = cst.parse_statement(
|
||||
f"from {import_stmt} import {decorator_name}"
|
||||
)
|
||||
module = module.with_changes(
|
||||
body=[import_node, *list(module.body)]
|
||||
)
|
||||
|
||||
modified_code = sort_imports(code=module.code, float_to_top=True)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Error adding sync decorator to function %s",
|
||||
function.qualified_name,
|
||||
)
|
||||
return False, {}
|
||||
else:
|
||||
if decorator_transformer.added_decorator:
|
||||
originals: dict[Path, str] = {source_path: source_code}
|
||||
with source_path.open("w", encoding="utf8") as f:
|
||||
f.write(modified_code)
|
||||
return True, originals
|
||||
return False, {}
|
||||
528
packages/codeflash-python/tests/test_instrument_sync_tests.py
Normal file
528
packages/codeflash-python/tests/test_instrument_sync_tests.py
Normal file
|
|
@ -0,0 +1,528 @@
|
|||
"""Tests for the sync-specific instrumentation module."""
|
||||
|
||||
import ast
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash_python._model import FunctionParent, TestingMode
|
||||
from codeflash_python.analysis._discovery import FunctionToOptimize
|
||||
from codeflash_python.test_discovery.models import CodePosition
|
||||
from codeflash_python.testing._instrument_async import ASYNC_HELPER_FILENAME
|
||||
from codeflash_python.testing._instrument_sync import (
|
||||
SyncCallInstrumenter,
|
||||
SyncDecoratorAdder,
|
||||
add_sync_decorator_to_function,
|
||||
get_sync_decorator_name_for_mode,
|
||||
inject_sync_profiling_into_existing_test,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="temp_dir")
|
||||
def _temp_dir():
|
||||
"""Create a temporary directory for test files."""
|
||||
with tempfile.TemporaryDirectory() as temp:
|
||||
yield Path(temp)
|
||||
|
||||
|
||||
class TestGetSyncDecoratorNameForMode:
|
||||
"""get_sync_decorator_name_for_mode returns the correct name."""
|
||||
|
||||
def test_behavior_mode(self) -> None:
|
||||
"""Returns codeflash_behavior_sync for BEHAVIOR."""
|
||||
assert "codeflash_behavior_sync" == get_sync_decorator_name_for_mode(
|
||||
TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
|
||||
class TestSyncDecoratorAdder:
|
||||
"""SyncDecoratorAdder adds decorators to sync functions."""
|
||||
|
||||
def test_adds_decorator_to_sync_function(self, temp_dir) -> None:
|
||||
"""Adds the behavior decorator to a sync function."""
|
||||
import libcst as cst
|
||||
|
||||
source = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("test.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
module = cst.parse_module(source)
|
||||
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
|
||||
module = module.visit(adder)
|
||||
|
||||
assert adder.added_decorator
|
||||
assert "@codeflash_behavior_sync" in module.code
|
||||
|
||||
def test_skips_async_function(self) -> None:
|
||||
"""Does not add a sync decorator to an async function."""
|
||||
import libcst as cst
|
||||
|
||||
source = "async def my_func(x: int) -> int:\n return x + 1\n"
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("test.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
module = cst.parse_module(source)
|
||||
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
|
||||
module = module.visit(adder)
|
||||
|
||||
assert not adder.added_decorator
|
||||
|
||||
def test_adds_decorator_to_class_method(self) -> None:
|
||||
"""Adds decorator to a sync method inside a class."""
|
||||
import libcst as cst
|
||||
|
||||
source = (
|
||||
"class MyClass:\n"
|
||||
" def my_method(self, x: int) -> int:\n"
|
||||
" return x + 1\n"
|
||||
)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_method",
|
||||
file_path=Path("test.py"),
|
||||
parents=[FunctionParent(name="MyClass", type="ClassDef")],
|
||||
is_async=False,
|
||||
)
|
||||
module = cst.parse_module(source)
|
||||
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
|
||||
module = module.visit(adder)
|
||||
|
||||
assert adder.added_decorator
|
||||
assert "@codeflash_behavior_sync" in module.code
|
||||
|
||||
def test_no_duplicate_decorator(self) -> None:
|
||||
"""Does not add decorator if one already exists."""
|
||||
import libcst as cst
|
||||
|
||||
source = (
|
||||
"@codeflash_behavior_sync\n"
|
||||
"def my_func(x: int) -> int:\n"
|
||||
" return x + 1\n"
|
||||
)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("test.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
module = cst.parse_module(source)
|
||||
adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR)
|
||||
module = module.visit(adder)
|
||||
|
||||
assert not adder.added_decorator
|
||||
|
||||
|
||||
class TestSyncCallInstrumenter:
|
||||
"""SyncCallInstrumenter injects _codeflash_call_site.set() calls."""
|
||||
|
||||
def test_instruments_direct_call(self) -> None:
|
||||
"""Injects call-site set before a direct function call."""
|
||||
test_code = (
|
||||
"def test_example():\n"
|
||||
" result = my_func(1)\n"
|
||||
" assert result == 2\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 13)]
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
source = ast.unparse(tree)
|
||||
assert "_codeflash_call_site.set('0')" in source
|
||||
|
||||
def test_instruments_method_call(self) -> None:
|
||||
"""Injects call-site set before obj.method() style calls."""
|
||||
test_code = (
|
||||
"def test_example():\n"
|
||||
" result = obj.my_func(1)\n"
|
||||
" assert result == 2\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 13)]
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
source = ast.unparse(tree)
|
||||
assert "_codeflash_call_site.set('0')" in source
|
||||
|
||||
def test_multiple_calls_get_incremented_indices(self) -> None:
|
||||
"""Multiple calls in the same test get sequential indices."""
|
||||
test_code = (
|
||||
"def test_example():\n"
|
||||
" a = my_func(1)\n"
|
||||
" b = my_func(2)\n"
|
||||
" c = my_func(3)\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func,
|
||||
[
|
||||
CodePosition(2, 8),
|
||||
CodePosition(3, 8),
|
||||
CodePosition(4, 8),
|
||||
],
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
source = ast.unparse(tree)
|
||||
assert "_codeflash_call_site.set('0')" in source
|
||||
assert "_codeflash_call_site.set('1')" in source
|
||||
assert "_codeflash_call_site.set('2')" in source
|
||||
|
||||
def test_skips_non_test_functions(self) -> None:
|
||||
"""Does not instrument functions that don't start with test_."""
|
||||
test_code = (
|
||||
"def helper():\n"
|
||||
" return my_func(1)\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(2, 11)]
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert not instrumenter.did_instrument
|
||||
|
||||
def test_instruments_inside_class(self) -> None:
|
||||
"""Instruments test methods inside a test class."""
|
||||
test_code = (
|
||||
"class TestFoo:\n"
|
||||
" def test_bar(self):\n"
|
||||
" result = my_func(1)\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(3, 17)]
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert instrumenter.did_instrument
|
||||
source = ast.unparse(tree)
|
||||
assert "_codeflash_call_site.set('0')" in source
|
||||
|
||||
def test_no_match_when_position_wrong(self) -> None:
|
||||
"""Does not instrument if code position doesn't match."""
|
||||
test_code = (
|
||||
"def test_example():\n"
|
||||
" result = my_func(1)\n"
|
||||
)
|
||||
tree = ast.parse(test_code)
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
instrumenter = SyncCallInstrumenter(
|
||||
func, [CodePosition(99, 99)]
|
||||
)
|
||||
tree = instrumenter.visit(tree)
|
||||
|
||||
assert not instrumenter.did_instrument
|
||||
|
||||
|
||||
class TestAddSyncDecoratorToFunction:
|
||||
"""add_sync_decorator_to_function decorates sync source files."""
|
||||
|
||||
def test_decorates_sync_function(self, temp_dir) -> None:
|
||||
"""Adds decorator and import to a sync function source file."""
|
||||
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=source_file,
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, originals = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
assert source_file in originals
|
||||
assert originals[source_file] == source_code
|
||||
|
||||
modified = source_file.read_text()
|
||||
assert "@codeflash_behavior_sync" in modified
|
||||
assert "from codeflash_async_wrapper import" in modified
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
def test_skips_async_function(self, temp_dir) -> None:
|
||||
"""Returns False for async functions."""
|
||||
source_code = "async def my_func(x: int) -> int:\n return x + 1\n"
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=source_file,
|
||||
parents=[],
|
||||
is_async=True,
|
||||
)
|
||||
|
||||
success, originals = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert not success
|
||||
assert {} == originals
|
||||
|
||||
def test_decorates_class_method(self, temp_dir) -> None:
|
||||
"""Adds decorator to a sync method inside a class."""
|
||||
source_code = (
|
||||
"class Calc:\n"
|
||||
" def add(self, a: int, b: int) -> int:\n"
|
||||
" return a + b\n"
|
||||
)
|
||||
source_file = temp_dir / "calc.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=source_file,
|
||||
parents=[FunctionParent(name="Calc", type="ClassDef")],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, _ = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
modified = source_file.read_text()
|
||||
assert "@codeflash_behavior_sync" in modified
|
||||
|
||||
def test_uses_project_root_for_helper(self, temp_dir) -> None:
|
||||
"""Copies helper file to project_root when specified."""
|
||||
subdir = temp_dir / "subdir"
|
||||
subdir.mkdir()
|
||||
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||
source_file = subdir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=source_file,
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, _ = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR, project_root=temp_dir
|
||||
)
|
||||
|
||||
assert success
|
||||
assert (temp_dir / ASYNC_HELPER_FILENAME).exists()
|
||||
assert not (subdir / ASYNC_HELPER_FILENAME).exists()
|
||||
|
||||
def test_preserves_existing_decorators(self, temp_dir) -> None:
|
||||
"""Adds codeflash decorator above existing decorators."""
|
||||
source_code = (
|
||||
"@staticmethod\n"
|
||||
"def my_func(x: int) -> int:\n"
|
||||
" return x + 1\n"
|
||||
)
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=source_file,
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, _ = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert success
|
||||
modified = source_file.read_text()
|
||||
cf_pos = modified.find("@codeflash_behavior_sync")
|
||||
sm_pos = modified.find("@staticmethod")
|
||||
assert cf_pos < sm_pos
|
||||
|
||||
def test_no_duplicate_decorator(self, temp_dir) -> None:
|
||||
"""Does not add decorator if already present."""
|
||||
source_code = (
|
||||
"@codeflash_behavior_sync\n"
|
||||
"def my_func(x: int) -> int:\n"
|
||||
" return x + 1\n"
|
||||
)
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=source_file,
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, _ = add_sync_decorator_to_function(
|
||||
source_file, func, TestingMode.BEHAVIOR
|
||||
)
|
||||
|
||||
assert not success
|
||||
|
||||
|
||||
class TestInjectSyncProfilingIntoExistingTest:
|
||||
"""inject_sync_profiling_into_existing_test instruments test files."""
|
||||
|
||||
def test_injects_call_site_tracking(self, temp_dir) -> None:
|
||||
"""Injects _codeflash_call_site.set() and import into a test file."""
|
||||
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
test_code = (
|
||||
"from my_module import my_func\n\n"
|
||||
"def test_my_func():\n"
|
||||
" result = my_func(1)\n"
|
||||
" assert result == 2\n"
|
||||
)
|
||||
test_file = temp_dir / "test_my_module.py"
|
||||
test_file.write_text(test_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("my_module.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, instrumented = inject_sync_profiling_into_existing_test(
|
||||
test_file,
|
||||
[CodePosition(4, 13)],
|
||||
func,
|
||||
)
|
||||
|
||||
assert success
|
||||
assert instrumented is not None
|
||||
assert "_codeflash_call_site.set('0')" in instrumented
|
||||
assert "from codeflash_async_wrapper import _codeflash_call_site" in instrumented
|
||||
|
||||
def test_multiple_calls_numbered_sequentially(self, temp_dir) -> None:
|
||||
"""Multiple calls get sequential call-site indices."""
|
||||
source_code = "def my_func(x: int) -> int:\n return x + 1\n"
|
||||
source_file = temp_dir / "my_module.py"
|
||||
source_file.write_text(source_code)
|
||||
|
||||
test_code = (
|
||||
"from my_module import my_func\n\n"
|
||||
"def test_my_func():\n"
|
||||
" a = my_func(1)\n"
|
||||
" b = my_func(2)\n"
|
||||
" c = my_func(3)\n"
|
||||
)
|
||||
test_file = temp_dir / "test_my_module.py"
|
||||
test_file.write_text(test_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("my_module.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, instrumented = inject_sync_profiling_into_existing_test(
|
||||
test_file,
|
||||
[
|
||||
CodePosition(4, 8),
|
||||
CodePosition(5, 8),
|
||||
CodePosition(6, 8),
|
||||
],
|
||||
func,
|
||||
)
|
||||
|
||||
assert success
|
||||
assert instrumented is not None
|
||||
assert "_codeflash_call_site.set('0')" in instrumented
|
||||
assert "_codeflash_call_site.set('1')" in instrumented
|
||||
assert "_codeflash_call_site.set('2')" in instrumented
|
||||
|
||||
def test_returns_false_for_syntax_error(self, temp_dir) -> None:
|
||||
"""Returns (False, None) when the test file has a syntax error."""
|
||||
test_file = temp_dir / "test_bad.py"
|
||||
test_file.write_text("def test_foo(\n")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, result = inject_sync_profiling_into_existing_test(
|
||||
test_file, [CodePosition(1, 0)], func
|
||||
)
|
||||
|
||||
assert not success
|
||||
assert result is None
|
||||
|
||||
def test_returns_false_when_no_target_calls(self, temp_dir) -> None:
|
||||
"""Returns (False, None) when no target function calls are found."""
|
||||
test_code = (
|
||||
"def test_unrelated():\n"
|
||||
" assert 1 == 1\n"
|
||||
)
|
||||
test_file = temp_dir / "test_noop.py"
|
||||
test_file.write_text(test_code)
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="my_func",
|
||||
file_path=Path("mod.py"),
|
||||
parents=[],
|
||||
is_async=False,
|
||||
)
|
||||
|
||||
success, result = inject_sync_profiling_into_existing_test(
|
||||
test_file, [CodePosition(2, 0)], func
|
||||
)
|
||||
|
||||
assert not success
|
||||
assert result is None
|
||||
Loading…
Reference in a new issue