from __future__ import annotations import tempfile from pathlib import Path from codeflash.benchmarking.instrument_codeflash_trace import ( add_codeflash_decorator_to_code, instrument_codeflash_trace_decorator, ) from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize def test_add_decorator_to_normal_function() -> None: """Test adding decorator to a normal function.""" code = """ def normal_function(): return "Hello, World!" """ fto = FunctionToOptimize(function_name="normal_function", file_path=Path("dummy_path.py"), parents=[]) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace @codeflash_trace def normal_function(): return "Hello, World!" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_normal_method() -> None: """Test adding decorator to a normal method.""" code = """ class TestClass: def normal_method(self): return "Hello from method" """ fto = FunctionToOptimize( function_name="normal_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @codeflash_trace def normal_method(self): return "Hello from method" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_classmethod() -> None: """Test adding decorator to a classmethod.""" code = """ class TestClass: @classmethod def class_method(cls): return "Hello from classmethod" """ fto = FunctionToOptimize( function_name="class_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @classmethod @codeflash_trace def class_method(cls): return "Hello from classmethod" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_staticmethod() -> None: """Test adding decorator to a staticmethod.""" code = """ class TestClass: @staticmethod def static_method(): return "Hello from staticmethod" """ fto = FunctionToOptimize( function_name="static_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @staticmethod @codeflash_trace def static_method(): return "Hello from staticmethod" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_init_function() -> None: """Test adding decorator to an __init__ function.""" code = """ class TestClass: def __init__(self, value): self.value = value """ fto = FunctionToOptimize( function_name="__init__", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @codeflash_trace def __init__(self, value): self.value = value """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_with_multiple_decorators() -> None: """Test adding decorator to a function with multiple existing decorators.""" code = """ class TestClass: @property @other_decorator def property_method(self): return self._value """ fto = FunctionToOptimize( function_name="property_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @property @other_decorator @codeflash_trace def property_method(self): return self._value """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_function_in_multiple_classes() -> None: """Test that only the right class's method gets the decorator.""" code = """ class TestClass: def test_method(self): return "This should get decorated" class OtherClass: def test_method(self): return "This should NOT get decorated" """ fto = FunctionToOptimize( function_name="test_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class TestClass: @codeflash_trace def test_method(self): return "This should get decorated" class OtherClass: def test_method(self): return "This should NOT get decorated" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_nonexistent_function() -> None: """Test that code remains unchanged when function doesn't exist.""" code = """ def existing_function(): return "This exists" """ fto = FunctionToOptimize(function_name="nonexistent_function", file_path=Path("dummy_path.py"), parents=[]) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) # Code should remain unchanged assert modified_code.strip() == code.strip() def test_add_decorator_to_multiple_functions() -> None: """Test adding decorator to multiple functions.""" code = """ def function_one(): return "First function" class TestClass: def method_one(self): return "First method" def method_two(self): return "Second method" def function_two(): return "Second function" """ functions_to_optimize = [ FunctionToOptimize(function_name="function_one", file_path=Path("dummy_path.py"), parents=[]), FunctionToOptimize( function_name="method_two", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="TestClass", type="ClassDef")], ), FunctionToOptimize(function_name="function_two", file_path=Path("dummy_path.py"), parents=[]), ] modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=functions_to_optimize) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace @codeflash_trace def function_one(): return "First function" class TestClass: def method_one(self): return "First method" @codeflash_trace def method_two(self): return "Second method" @codeflash_trace def function_two(): return "Second function" """ assert modified_code.strip() == expected_code.strip() def test_instrument_codeflash_trace_decorator_single_file() -> None: """Test instrumenting codeflash trace decorator on a single file.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a test Python file test_file_path = Path(temp_dir) / "test_module.py" test_file_content = """ def function_one(): return "First function" class TestClass: def method_one(self): return "First method" def method_two(self): return "Second method" def function_two(): return "Second function" """ test_file_path.write_text(test_file_content, encoding="utf-8") # Define functions to optimize functions_to_optimize = [ FunctionToOptimize(function_name="function_one", file_path=test_file_path, parents=[]), FunctionToOptimize( function_name="method_two", file_path=test_file_path, parents=[FunctionParent(name="TestClass", type="ClassDef")], ), ] # Execute the function being tested instrument_codeflash_trace_decorator({test_file_path: functions_to_optimize}) # Read the modified file modified_content = test_file_path.read_text(encoding="utf-8") # Define expected content (with isort applied) expected_content = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace @codeflash_trace def function_one(): return "First function" class TestClass: def method_one(self): return "First method" @codeflash_trace def method_two(self): return "Second method" def function_two(): return "Second function" """ # Compare the modified content with expected content assert modified_content.strip() == expected_content.strip() def test_instrument_codeflash_trace_decorator_multiple_files() -> None: """Test instrumenting codeflash trace decorator on multiple files.""" with tempfile.TemporaryDirectory() as temp_dir: # Create first test Python file test_file_1_path = Path(temp_dir) / "module_a.py" test_file_1_content = """ def function_a(): return "Function in module A" class ClassA: def method_a(self): return "Method in ClassA" """ test_file_1_path.write_text(test_file_1_content, encoding="utf-8") # Create second test Python file test_file_2_path = Path(temp_dir) / "module_b.py" test_file_2_content = """ def function_b(): return "Function in module B" class ClassB: @staticmethod def static_method_b(): return "Static method in ClassB" """ test_file_2_path.write_text(test_file_2_content, encoding="utf-8") # Define functions to optimize file_to_funcs_to_optimize = { test_file_1_path: [FunctionToOptimize(function_name="function_a", file_path=test_file_1_path, parents=[])], test_file_2_path: [ FunctionToOptimize( function_name="static_method_b", file_path=test_file_2_path, parents=[FunctionParent(name="ClassB", type="ClassDef")], ) ], } # Execute the function being tested instrument_codeflash_trace_decorator(file_to_funcs_to_optimize) # Read the modified files modified_content_1 = test_file_1_path.read_text(encoding="utf-8") modified_content_2 = test_file_2_path.read_text(encoding="utf-8") # Define expected content for first file (with isort applied) expected_content_1 = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace @codeflash_trace def function_a(): return "Function in module A" class ClassA: def method_a(self): return "Method in ClassA" """ # Define expected content for second file (with isort applied) expected_content_2 = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace def function_b(): return "Function in module B" class ClassB: @staticmethod @codeflash_trace def static_method_b(): return "Static method in ClassB" """ # Compare the modified content with expected content assert modified_content_1.strip() == expected_content_1.strip() assert modified_content_2.strip() == expected_content_2.strip() def test_add_decorator_to_method_after_nested_class() -> None: """Test adding decorator to a method that appears after a nested class definition.""" code = """ class OuterClass: class NestedClass: def nested_method(self): return "Hello from nested class method" def target_method(self): return "Hello from target method after nested class" """ fto = FunctionToOptimize( function_name="target_method", file_path=Path("dummy_path.py"), parents=[FunctionParent(name="OuterClass", type="ClassDef")], ) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace class OuterClass: class NestedClass: def nested_method(self): return "Hello from nested class method" @codeflash_trace def target_method(self): return "Hello from target method after nested class" """ assert modified_code.strip() == expected_code.strip() def test_add_decorator_to_function_after_nested_function() -> None: """Test adding decorator to a function that appears after a function with a nested function.""" code = """ def function_with_nested(): def inner_function(): return "Hello from inner function" return inner_function() def target_function(): return "Hello from target function after nested function" """ fto = FunctionToOptimize(function_name="target_function", file_path=Path("dummy_path.py"), parents=[]) modified_code = add_codeflash_decorator_to_code(code=code, functions_to_optimize=[fto]) expected_code = """ from codeflash.benchmarking.codeflash_trace import codeflash_trace def function_with_nested(): def inner_function(): return "Hello from inner function" return inner_function() @codeflash_trace def target_function(): return "Hello from target function after nested function" """ assert modified_code.strip() == expected_code.strip() def test_instrument_codeflash_trace_skips_benchmarking_module() -> None: """Test that files in codeflash/benchmarking/ are skipped to avoid circular imports.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a directory structure that mimics codeflash/benchmarking/ benchmarking_dir = Path(temp_dir) / "codeflash" / "benchmarking" benchmarking_dir.mkdir(parents=True) test_file_path = benchmarking_dir / "some_module.py" original_content = """ def some_function(): return "This should not be modified" """ test_file_path.write_text(original_content, encoding="utf-8") fto = FunctionToOptimize(function_name="some_function", file_path=test_file_path, parents=[]) instrument_codeflash_trace_decorator({test_file_path: [fto]}) # File should remain unchanged assert test_file_path.read_text(encoding="utf-8") == original_content def test_instrument_codeflash_trace_skips_picklepatch_module() -> None: """Test that files in codeflash/picklepatch/ are skipped to avoid circular imports.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a directory structure that mimics codeflash/picklepatch/ picklepatch_dir = Path(temp_dir) / "codeflash" / "picklepatch" picklepatch_dir.mkdir(parents=True) test_file_path = picklepatch_dir / "patcher.py" original_content = """ def patch_function(): return "This should not be modified" """ test_file_path.write_text(original_content, encoding="utf-8") fto = FunctionToOptimize(function_name="patch_function", file_path=test_file_path, parents=[]) instrument_codeflash_trace_decorator({test_file_path: [fto]}) # File should remain unchanged assert test_file_path.read_text(encoding="utf-8") == original_content def test_instrument_codeflash_trace_nested_codeflash_path_skips_benchmarking() -> None: """Test that nested codeflash paths like /project/codeflash/codeflash/benchmarking/ are skipped. The rpartition logic should find the LAST 'codeflash' in the path. """ with tempfile.TemporaryDirectory() as temp_dir: # Create nested structure: project_codeflash/codeflash/benchmarking/ nested_dir = Path(temp_dir) / "project_codeflash" / "codeflash" / "benchmarking" nested_dir.mkdir(parents=True) test_file_path = nested_dir / "trace_module.py" original_content = """ def trace_func(): return "Should not be modified" """ test_file_path.write_text(original_content, encoding="utf-8") fto = FunctionToOptimize(function_name="trace_func", file_path=test_file_path, parents=[]) instrument_codeflash_trace_decorator({test_file_path: [fto]}) # File should remain unchanged because last /codeflash/ is followed by benchmarking assert test_file_path.read_text(encoding="utf-8") == original_content def test_instrument_codeflash_trace_nested_codeflash_path_instruments_other_modules() -> None: """Test that nested codeflash paths with non-skipped modules ARE instrumented. The rpartition logic should allow instrumentation when the submodule is not benchmarking/picklepatch. """ with tempfile.TemporaryDirectory() as temp_dir: # Create nested structure: project_codeflash/codeflash/other_module/ nested_dir = Path(temp_dir) / "project_codeflash" / "codeflash" / "other_module" nested_dir.mkdir(parents=True) test_file_path = nested_dir / "utils.py" original_content = """ def util_func(): return "Should be modified" """ test_file_path.write_text(original_content, encoding="utf-8") fto = FunctionToOptimize(function_name="util_func", file_path=test_file_path, parents=[]) instrument_codeflash_trace_decorator({test_file_path: [fto]}) # File SHOULD be modified because other_module is not in skip list modified_content = test_file_path.read_text(encoding="utf-8") assert "codeflash_trace" in modified_content assert "@codeflash_trace" in modified_content def test_instrument_codeflash_trace_no_codeflash_in_path() -> None: """Test that paths without 'codeflash' directory are instrumented normally.""" with tempfile.TemporaryDirectory() as temp_dir: # Create a path with no 'codeflash' directory project_dir = Path(temp_dir) / "myproject" / "src" project_dir.mkdir(parents=True) test_file_path = project_dir / "main.py" original_content = """ def main_func(): return "Should be modified" """ test_file_path.write_text(original_content, encoding="utf-8") fto = FunctionToOptimize(function_name="main_func", file_path=test_file_path, parents=[]) instrument_codeflash_trace_decorator({test_file_path: [fto]}) # File SHOULD be modified modified_content = test_file_path.read_text(encoding="utf-8") assert "codeflash_trace" in modified_content assert "@codeflash_trace" in modified_content