599 lines
No EOL
16 KiB
Python
599 lines
No EOL
16 KiB
Python
from __future__ import annotations
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \
|
|
instrument_codeflash_trace_decorator
|
|
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
|
|
|
|
|
def test_add_decorator_to_normal_function() -> None:
|
|
"""Test adding decorator to a normal function."""
|
|
code = """
|
|
def normal_function():
|
|
return "Hello, World!"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="normal_function",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
@codeflash_trace
|
|
def normal_function():
|
|
return "Hello, World!"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_normal_method() -> None:
|
|
"""Test adding decorator to a normal method."""
|
|
code = """
|
|
class TestClass:
|
|
def normal_method(self):
|
|
return "Hello from method"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="normal_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@codeflash_trace
|
|
def normal_method(self):
|
|
return "Hello from method"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_classmethod() -> None:
|
|
"""Test adding decorator to a classmethod."""
|
|
code = """
|
|
class TestClass:
|
|
@classmethod
|
|
def class_method(cls):
|
|
return "Hello from classmethod"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="class_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@classmethod
|
|
@codeflash_trace
|
|
def class_method(cls):
|
|
return "Hello from classmethod"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_staticmethod() -> None:
|
|
"""Test adding decorator to a staticmethod."""
|
|
code = """
|
|
class TestClass:
|
|
@staticmethod
|
|
def static_method():
|
|
return "Hello from staticmethod"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="static_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@staticmethod
|
|
@codeflash_trace
|
|
def static_method():
|
|
return "Hello from staticmethod"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_init_function() -> None:
|
|
"""Test adding decorator to an __init__ function."""
|
|
code = """
|
|
class TestClass:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="__init__",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@codeflash_trace
|
|
def __init__(self, value):
|
|
self.value = value
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_with_multiple_decorators() -> None:
|
|
"""Test adding decorator to a function with multiple existing decorators."""
|
|
code = """
|
|
class TestClass:
|
|
@property
|
|
@other_decorator
|
|
def property_method(self):
|
|
return self._value
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="property_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@property
|
|
@other_decorator
|
|
@codeflash_trace
|
|
def property_method(self):
|
|
return self._value
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_function_in_multiple_classes() -> None:
|
|
"""Test that only the right class's method gets the decorator."""
|
|
code = """
|
|
class TestClass:
|
|
def test_method(self):
|
|
return "This should get decorated"
|
|
|
|
class OtherClass:
|
|
def test_method(self):
|
|
return "This should NOT get decorated"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="test_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class TestClass:
|
|
@codeflash_trace
|
|
def test_method(self):
|
|
return "This should get decorated"
|
|
|
|
class OtherClass:
|
|
def test_method(self):
|
|
return "This should NOT get decorated"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_nonexistent_function() -> None:
|
|
"""Test that code remains unchanged when function doesn't exist."""
|
|
code = """
|
|
def existing_function():
|
|
return "This exists"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="nonexistent_function",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
# Code should remain unchanged
|
|
assert modified_code.strip() == code.strip()
|
|
|
|
|
|
def test_add_decorator_to_multiple_functions() -> None:
|
|
"""Test adding decorator to multiple functions."""
|
|
code = """
|
|
def function_one():
|
|
return "First function"
|
|
|
|
class TestClass:
|
|
def method_one(self):
|
|
return "First method"
|
|
|
|
def method_two(self):
|
|
return "Second method"
|
|
|
|
def function_two():
|
|
return "Second function"
|
|
"""
|
|
|
|
functions_to_optimize = [
|
|
FunctionToOptimize(
|
|
function_name="function_one",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[]
|
|
),
|
|
FunctionToOptimize(
|
|
function_name="method_two",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
),
|
|
FunctionToOptimize(
|
|
function_name="function_two",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[]
|
|
)
|
|
]
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=functions_to_optimize
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
@codeflash_trace
|
|
def function_one():
|
|
return "First function"
|
|
|
|
class TestClass:
|
|
def method_one(self):
|
|
return "First method"
|
|
|
|
@codeflash_trace
|
|
def method_two(self):
|
|
return "Second method"
|
|
|
|
@codeflash_trace
|
|
def function_two():
|
|
return "Second function"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_instrument_codeflash_trace_decorator_single_file() -> None:
|
|
"""Test instrumenting codeflash trace decorator on a single file."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create a test Python file
|
|
test_file_path = Path(temp_dir) / "test_module.py"
|
|
test_file_content = """
|
|
def function_one():
|
|
return "First function"
|
|
|
|
class TestClass:
|
|
def method_one(self):
|
|
return "First method"
|
|
|
|
def method_two(self):
|
|
return "Second method"
|
|
|
|
def function_two():
|
|
return "Second function"
|
|
"""
|
|
test_file_path.write_text(test_file_content, encoding="utf-8")
|
|
|
|
# Define functions to optimize
|
|
functions_to_optimize = [
|
|
FunctionToOptimize(
|
|
function_name="function_one",
|
|
file_path=test_file_path,
|
|
parents=[]
|
|
),
|
|
FunctionToOptimize(
|
|
function_name="method_two",
|
|
file_path=test_file_path,
|
|
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
|
)
|
|
]
|
|
|
|
# Execute the function being tested
|
|
instrument_codeflash_trace_decorator({test_file_path: functions_to_optimize})
|
|
|
|
# Read the modified file
|
|
modified_content = test_file_path.read_text(encoding="utf-8")
|
|
|
|
# Define expected content (with isort applied)
|
|
expected_content = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
|
|
|
|
@codeflash_trace
|
|
def function_one():
|
|
return "First function"
|
|
|
|
class TestClass:
|
|
def method_one(self):
|
|
return "First method"
|
|
|
|
@codeflash_trace
|
|
def method_two(self):
|
|
return "Second method"
|
|
|
|
def function_two():
|
|
return "Second function"
|
|
"""
|
|
|
|
# Compare the modified content with expected content
|
|
assert modified_content.strip() == expected_content.strip()
|
|
|
|
|
|
def test_instrument_codeflash_trace_decorator_multiple_files() -> None:
|
|
"""Test instrumenting codeflash trace decorator on multiple files."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create first test Python file
|
|
test_file_1_path = Path(temp_dir) / "module_a.py"
|
|
test_file_1_content = """
|
|
def function_a():
|
|
return "Function in module A"
|
|
|
|
class ClassA:
|
|
def method_a(self):
|
|
return "Method in ClassA"
|
|
"""
|
|
test_file_1_path.write_text(test_file_1_content, encoding="utf-8")
|
|
|
|
# Create second test Python file
|
|
test_file_2_path = Path(temp_dir) / "module_b.py"
|
|
test_file_2_content ="""
|
|
def function_b():
|
|
return "Function in module B"
|
|
|
|
class ClassB:
|
|
@staticmethod
|
|
def static_method_b():
|
|
return "Static method in ClassB"
|
|
"""
|
|
test_file_2_path.write_text(test_file_2_content, encoding="utf-8")
|
|
|
|
# Define functions to optimize
|
|
file_to_funcs_to_optimize = {
|
|
test_file_1_path: [
|
|
FunctionToOptimize(
|
|
function_name="function_a",
|
|
file_path=test_file_1_path,
|
|
parents=[]
|
|
)
|
|
],
|
|
test_file_2_path: [
|
|
FunctionToOptimize(
|
|
function_name="static_method_b",
|
|
file_path=test_file_2_path,
|
|
parents=[FunctionParent(name="ClassB", type="ClassDef")]
|
|
)
|
|
]
|
|
}
|
|
|
|
# Execute the function being tested
|
|
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
|
|
|
|
# Read the modified files
|
|
modified_content_1 = test_file_1_path.read_text(encoding="utf-8")
|
|
modified_content_2 = test_file_2_path.read_text(encoding="utf-8")
|
|
|
|
# Define expected content for first file (with isort applied)
|
|
expected_content_1 = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
|
|
|
|
@codeflash_trace
|
|
def function_a():
|
|
return "Function in module A"
|
|
|
|
class ClassA:
|
|
def method_a(self):
|
|
return "Method in ClassA"
|
|
"""
|
|
|
|
# Define expected content for second file (with isort applied)
|
|
expected_content_2 = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
|
|
|
|
def function_b():
|
|
return "Function in module B"
|
|
|
|
class ClassB:
|
|
@staticmethod
|
|
@codeflash_trace
|
|
def static_method_b():
|
|
return "Static method in ClassB"
|
|
"""
|
|
|
|
# Compare the modified content with expected content
|
|
assert modified_content_1.strip() == expected_content_1.strip()
|
|
assert modified_content_2.strip() == expected_content_2.strip()
|
|
|
|
|
|
def test_add_decorator_to_method_after_nested_class() -> None:
|
|
"""Test adding decorator to a method that appears after a nested class definition."""
|
|
code = """
|
|
class OuterClass:
|
|
class NestedClass:
|
|
def nested_method(self):
|
|
return "Hello from nested class method"
|
|
|
|
def target_method(self):
|
|
return "Hello from target method after nested class"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="target_method",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
class OuterClass:
|
|
class NestedClass:
|
|
def nested_method(self):
|
|
return "Hello from nested class method"
|
|
|
|
@codeflash_trace
|
|
def target_method(self):
|
|
return "Hello from target method after nested class"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_add_decorator_to_function_after_nested_function() -> None:
|
|
"""Test adding decorator to a function that appears after a function with a nested function."""
|
|
code = """
|
|
def function_with_nested():
|
|
def inner_function():
|
|
return "Hello from inner function"
|
|
|
|
return inner_function()
|
|
|
|
def target_function():
|
|
return "Hello from target function after nested function"
|
|
"""
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="target_function",
|
|
file_path=Path("dummy_path.py"),
|
|
parents=[]
|
|
)
|
|
|
|
modified_code = add_codeflash_decorator_to_code(
|
|
code=code,
|
|
functions_to_optimize=[fto]
|
|
)
|
|
|
|
expected_code = """
|
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
|
def function_with_nested():
|
|
def inner_function():
|
|
return "Hello from inner function"
|
|
|
|
return inner_function()
|
|
|
|
@codeflash_trace
|
|
def target_function():
|
|
return "Hello from target function after nested function"
|
|
"""
|
|
|
|
assert modified_code.strip() == expected_code.strip()
|
|
|
|
|
|
def test_instrument_codeflash_trace_skips_benchmarking_module() -> None:
|
|
"""Test that files in codeflash/benchmarking/ are skipped to avoid circular imports."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create a directory structure that mimics codeflash/benchmarking/
|
|
benchmarking_dir = Path(temp_dir) / "codeflash" / "benchmarking"
|
|
benchmarking_dir.mkdir(parents=True)
|
|
|
|
test_file_path = benchmarking_dir / "some_module.py"
|
|
original_content = """
|
|
def some_function():
|
|
return "This should not be modified"
|
|
"""
|
|
test_file_path.write_text(original_content, encoding="utf-8")
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="some_function",
|
|
file_path=test_file_path,
|
|
parents=[]
|
|
)
|
|
|
|
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
|
|
|
# File should remain unchanged
|
|
assert test_file_path.read_text(encoding="utf-8") == original_content
|
|
|
|
|
|
def test_instrument_codeflash_trace_skips_picklepatch_module() -> None:
|
|
"""Test that files in codeflash/picklepatch/ are skipped to avoid circular imports."""
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Create a directory structure that mimics codeflash/picklepatch/
|
|
picklepatch_dir = Path(temp_dir) / "codeflash" / "picklepatch"
|
|
picklepatch_dir.mkdir(parents=True)
|
|
|
|
test_file_path = picklepatch_dir / "patcher.py"
|
|
original_content = """
|
|
def patch_function():
|
|
return "This should not be modified"
|
|
"""
|
|
test_file_path.write_text(original_content, encoding="utf-8")
|
|
|
|
fto = FunctionToOptimize(
|
|
function_name="patch_function",
|
|
file_path=test_file_path,
|
|
parents=[]
|
|
)
|
|
|
|
instrument_codeflash_trace_decorator({test_file_path: [fto]})
|
|
|
|
# File should remain unchanged
|
|
assert test_file_path.read_text(encoding="utf-8") == original_content |