"""Tests for the sync-specific instrumentation module.""" import ast import tempfile from pathlib import Path import pytest from codeflash_python._model import FunctionParent, TestingMode from codeflash_python.analysis._discovery import FunctionToOptimize from codeflash_python.test_discovery.models import CodePosition from codeflash_python.testing._instrument_async import ASYNC_HELPER_FILENAME from codeflash_python.testing._instrument_sync import ( SyncCallInstrumenter, SyncDecoratorAdder, add_sync_decorator_to_function, get_sync_decorator_name_for_mode, inject_sync_profiling_into_existing_test, ) @pytest.fixture(name="temp_dir") def _temp_dir(): """Create a temporary directory for test files.""" with tempfile.TemporaryDirectory() as temp: yield Path(temp) class TestGetSyncDecoratorNameForMode: """get_sync_decorator_name_for_mode returns the correct name.""" def test_behavior_mode(self) -> None: """Returns codeflash_behavior_sync for BEHAVIOR.""" assert "codeflash_behavior_sync" == get_sync_decorator_name_for_mode( TestingMode.BEHAVIOR ) def test_performance_mode(self) -> None: """Returns codeflash_performance_sync for PERFORMANCE.""" assert ( "codeflash_performance_sync" == get_sync_decorator_name_for_mode(TestingMode.PERFORMANCE) ) class TestSyncDecoratorAdder: """SyncDecoratorAdder adds decorators to sync functions.""" def test_adds_decorator_to_sync_function(self, temp_dir) -> None: """Adds the behavior decorator to a sync function.""" import libcst as cst source = "def my_func(x: int) -> int:\n return x + 1\n" func = FunctionToOptimize( function_name="my_func", file_path=Path("test.py"), parents=[], is_async=False, ) module = cst.parse_module(source) adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR) module = module.visit(adder) assert adder.added_decorator assert "@codeflash_behavior_sync" in module.code def test_skips_async_function(self) -> None: """Does not add a sync decorator to an async function.""" import libcst as cst source = "async def my_func(x: int) -> int:\n return x + 1\n" func = FunctionToOptimize( function_name="my_func", file_path=Path("test.py"), parents=[], is_async=False, ) module = cst.parse_module(source) adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR) module = module.visit(adder) assert not adder.added_decorator def test_adds_decorator_to_class_method(self) -> None: """Adds decorator to a sync method inside a class.""" import libcst as cst source = ( "class MyClass:\n" " def my_method(self, x: int) -> int:\n" " return x + 1\n" ) func = FunctionToOptimize( function_name="my_method", file_path=Path("test.py"), parents=[FunctionParent(name="MyClass", type="ClassDef")], is_async=False, ) module = cst.parse_module(source) adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR) module = module.visit(adder) assert adder.added_decorator assert "@codeflash_behavior_sync" in module.code def test_no_duplicate_decorator(self) -> None: """Does not add decorator if one already exists.""" import libcst as cst source = ( "@codeflash_behavior_sync\n" "def my_func(x: int) -> int:\n" " return x + 1\n" ) func = FunctionToOptimize( function_name="my_func", file_path=Path("test.py"), parents=[], is_async=False, ) module = cst.parse_module(source) adder = SyncDecoratorAdder(func, TestingMode.BEHAVIOR) module = module.visit(adder) assert not adder.added_decorator class TestSyncCallInstrumenter: """SyncCallInstrumenter injects _codeflash_call_site.set() calls.""" def test_instruments_direct_call(self) -> None: """Injects call-site set before a direct function call.""" test_code = ( "def test_example():\n" " result = my_func(1)\n" " assert result == 2\n" ) tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)]) tree = instrumenter.visit(tree) assert instrumenter.did_instrument source = ast.unparse(tree) assert "_codeflash_call_site.set('2')" in source def test_instruments_method_call(self) -> None: """Injects call-site set before obj.method() style calls.""" test_code = ( "def test_example():\n" " result = obj.my_func(1)\n" " assert result == 2\n" ) tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 13)]) tree = instrumenter.visit(tree) assert instrumenter.did_instrument source = ast.unparse(tree) assert "_codeflash_call_site.set('2')" in source def test_multiple_calls_get_line_number_indices(self) -> None: """Multiple calls use their source line numbers as call-site IDs.""" test_code = ( "def test_example():\n" " a = my_func(1)\n" " b = my_func(2)\n" " c = my_func(3)\n" ) tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter( func, [ CodePosition(2, 8), CodePosition(3, 8), CodePosition(4, 8), ], ) tree = instrumenter.visit(tree) assert instrumenter.did_instrument source = ast.unparse(tree) assert "_codeflash_call_site.set('2')" in source assert "_codeflash_call_site.set('3')" in source assert "_codeflash_call_site.set('4')" in source def test_skips_non_test_functions(self) -> None: """Does not instrument functions that don't start with test_.""" test_code = "def helper():\n return my_func(1)\n" tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter(func, [CodePosition(2, 11)]) tree = instrumenter.visit(tree) assert not instrumenter.did_instrument def test_instruments_inside_class(self) -> None: """Instruments test methods inside a test class.""" test_code = ( "class TestFoo:\n" " def test_bar(self):\n" " result = my_func(1)\n" ) tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter(func, [CodePosition(3, 17)]) tree = instrumenter.visit(tree) assert instrumenter.did_instrument source = ast.unparse(tree) assert "_codeflash_call_site.set('3')" in source def test_no_match_when_position_wrong(self) -> None: """Does not instrument if code position doesn't match.""" test_code = "def test_example():\n result = my_func(1)\n" tree = ast.parse(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) instrumenter = SyncCallInstrumenter(func, [CodePosition(99, 99)]) tree = instrumenter.visit(tree) assert not instrumenter.did_instrument class TestAddSyncDecoratorToFunction: """add_sync_decorator_to_function decorates sync source files.""" def test_decorates_sync_function(self, temp_dir) -> None: """Adds decorator and import to a sync function source file.""" source_code = "def my_func(x: int) -> int:\n return x + 1\n" source_file = temp_dir / "my_module.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="my_func", file_path=source_file, parents=[], is_async=False, ) success, originals = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR ) assert success assert source_file in originals assert originals[source_file] == source_code modified = source_file.read_text() assert "@codeflash_behavior_sync" in modified assert "from codeflash_async_wrapper import" in modified assert (temp_dir / ASYNC_HELPER_FILENAME).exists() def test_skips_async_function(self, temp_dir) -> None: """Returns False for async functions.""" source_code = "async def my_func(x: int) -> int:\n return x + 1\n" source_file = temp_dir / "my_module.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="my_func", file_path=source_file, parents=[], is_async=True, ) success, originals = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR ) assert not success assert {} == originals def test_decorates_class_method(self, temp_dir) -> None: """Adds decorator to a sync method inside a class.""" source_code = ( "class Calc:\n" " def add(self, a: int, b: int) -> int:\n" " return a + b\n" ) source_file = temp_dir / "calc.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="add", file_path=source_file, parents=[FunctionParent(name="Calc", type="ClassDef")], is_async=False, ) success, _ = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR ) assert success modified = source_file.read_text() assert "@codeflash_behavior_sync" in modified def test_uses_project_root_for_helper(self, temp_dir) -> None: """Copies helper file to project_root when specified.""" subdir = temp_dir / "subdir" subdir.mkdir() source_code = "def my_func(x: int) -> int:\n return x + 1\n" source_file = subdir / "my_module.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="my_func", file_path=source_file, parents=[], is_async=False, ) success, _ = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR, project_root=temp_dir ) assert success assert (temp_dir / ASYNC_HELPER_FILENAME).exists() assert not (subdir / ASYNC_HELPER_FILENAME).exists() def test_preserves_existing_decorators(self, temp_dir) -> None: """Adds codeflash decorator below @staticmethod/@classmethod.""" source_code = ( "@staticmethod\ndef my_func(x: int) -> int:\n return x + 1\n" ) source_file = temp_dir / "my_module.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="my_func", file_path=source_file, parents=[], is_async=False, ) success, _ = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR ) assert success modified = source_file.read_text() cf_pos = modified.find("@codeflash_behavior_sync") sm_pos = modified.find("@staticmethod") assert sm_pos < cf_pos def test_no_duplicate_decorator(self, temp_dir) -> None: """Does not add decorator if already present.""" source_code = ( "@codeflash_behavior_sync\n" "def my_func(x: int) -> int:\n" " return x + 1\n" ) source_file = temp_dir / "my_module.py" source_file.write_text(source_code) func = FunctionToOptimize( function_name="my_func", file_path=source_file, parents=[], is_async=False, ) success, _ = add_sync_decorator_to_function( source_file, func, TestingMode.BEHAVIOR ) assert not success class TestInjectSyncProfilingIntoExistingTest: """inject_sync_profiling_into_existing_test instruments test files.""" def test_injects_call_site_tracking(self, temp_dir) -> None: """Injects _codeflash_call_site.set() and import into a test file.""" source_code = "def my_func(x: int) -> int:\n return x + 1\n" source_file = temp_dir / "my_module.py" source_file.write_text(source_code) test_code = ( "from my_module import my_func\n\n" "def test_my_func():\n" " result = my_func(1)\n" " assert result == 2\n" ) test_file = temp_dir / "test_my_module.py" test_file.write_text(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("my_module.py"), parents=[], is_async=False, ) success, instrumented = inject_sync_profiling_into_existing_test( test_file, [CodePosition(4, 13)], func, ) assert success assert instrumented is not None assert "_codeflash_call_site.set('4')" in instrumented assert ( "from codeflash_async_wrapper import _codeflash_call_site" in instrumented ) def test_multiple_calls_use_line_numbers(self, temp_dir) -> None: """Multiple calls use their source line numbers as call-site IDs.""" source_code = "def my_func(x: int) -> int:\n return x + 1\n" source_file = temp_dir / "my_module.py" source_file.write_text(source_code) test_code = ( "from my_module import my_func\n\n" "def test_my_func():\n" " a = my_func(1)\n" " b = my_func(2)\n" " c = my_func(3)\n" ) test_file = temp_dir / "test_my_module.py" test_file.write_text(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("my_module.py"), parents=[], is_async=False, ) success, instrumented = inject_sync_profiling_into_existing_test( test_file, [ CodePosition(4, 8), CodePosition(5, 8), CodePosition(6, 8), ], func, ) assert success assert instrumented is not None assert "_codeflash_call_site.set('4')" in instrumented assert "_codeflash_call_site.set('5')" in instrumented assert "_codeflash_call_site.set('6')" in instrumented def test_returns_false_for_syntax_error(self, temp_dir) -> None: """Returns (False, None) when the test file has a syntax error.""" test_file = temp_dir / "test_bad.py" test_file.write_text("def test_foo(\n") func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) success, result = inject_sync_profiling_into_existing_test( test_file, [CodePosition(1, 0)], func ) assert not success assert result is None def test_returns_false_when_no_target_calls(self, temp_dir) -> None: """Returns (False, None) when no target function calls are found.""" test_code = "def test_unrelated():\n assert 1 == 1\n" test_file = temp_dir / "test_noop.py" test_file.write_text(test_code) func = FunctionToOptimize( function_name="my_func", file_path=Path("mod.py"), parents=[], is_async=False, ) success, result = inject_sync_profiling_into_existing_test( test_file, [CodePosition(2, 0)], func ) assert not success assert result is None