mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
few missing things
This commit is contained in:
parent
6bd39f7791
commit
c6c9d9559f
5 changed files with 671 additions and 171 deletions
|
|
@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
|
|||
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
|
||||
|
||||
|
||||
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
|
||||
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
|
||||
|
||||
|
||||
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
|
||||
"""Calculate function throughput from TestResults by extracting performance stdout.
|
||||
|
||||
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
|
||||
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
|
||||
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
|
||||
"""
|
||||
start_matches = start_pattern.findall(test_results.perf_stdout or "")
|
||||
end_matches = end_pattern.findall(test_results.perf_stdout or "")
|
||||
|
||||
end_matches_truncated = [end_match[:5] for end_match in end_matches]
|
||||
end_matches_set = set(end_matches_truncated)
|
||||
|
||||
function_throughput = 0
|
||||
for start_match in start_matches:
|
||||
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
|
||||
function_throughput += 1
|
||||
return function_throughput
|
||||
|
||||
|
||||
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
|
||||
test_results = TestResults()
|
||||
if not file_location.exists():
|
||||
|
|
|
|||
|
|
@ -1902,4 +1902,210 @@ def test_bubble_sort(input, expected_output):
|
|||
|
||||
# Check that comments were added
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert modified_source == expected
|
||||
assert modified_source == expected
|
||||
|
||||
def test_async_basic_runtime_comment_addition(self, test_config):
|
||||
"""Test basic functionality of adding runtime comments to async test functions."""
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_bubble_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
"""
|
||||
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py",
|
||||
)
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
|
||||
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
|
||||
|
||||
original_test_results.add(original_invocation)
|
||||
optimized_test_results.add(optimized_invocation)
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
|
||||
|
||||
def test_async_multiple_test_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_async_bubble_sort():
|
||||
codeflash_output = await async_quick_sort([3, 1, 2])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_quick_sort():
|
||||
codeflash_output = await async_quick_sort([5, 2, 8])
|
||||
assert codeflash_output == [2, 5, 8]
|
||||
|
||||
def helper_function():
|
||||
return "not a test"
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 500μs -> 300μs" in modified_source
|
||||
assert "# 800μs -> 600μs" in modified_source
|
||||
assert (
|
||||
"helper_function():" in modified_source
|
||||
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
|
||||
)
|
||||
|
||||
def test_async_class_method(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function()
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
invocation_id = InvocationId(
|
||||
test_module_path="tests.test_module__unit_test_0",
|
||||
test_class_name="TestAsyncClass",
|
||||
test_function_name="test_async_function",
|
||||
function_getting_tested="some_async_function",
|
||||
iteration_id="0",
|
||||
)
|
||||
|
||||
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
|
||||
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
expected_source = '''class TestAsyncClass:
|
||||
async def test_async_function(self):
|
||||
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
|
||||
assert codeflash_output == expected
|
||||
'''
|
||||
|
||||
assert len(result.generated_tests) == 1
|
||||
assert result.generated_tests[0].generated_original_test_source == expected_source
|
||||
|
||||
def test_async_mixed_sync_and_async_functions(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """def test_sync_function():
|
||||
codeflash_output = sync_function([1, 2, 3])
|
||||
assert codeflash_output == [1, 2, 3]
|
||||
|
||||
async def test_async_function():
|
||||
codeflash_output = await async_function([4, 5, 6])
|
||||
assert codeflash_output == [4, 5, 6]
|
||||
|
||||
def test_another_sync():
|
||||
result = another_sync_func()
|
||||
assert result is True
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
# Add test invocations for all test functions
|
||||
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
|
||||
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
|
||||
|
||||
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
|
||||
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
|
||||
assert "# 400μs -> 200μs" in modified_source
|
||||
assert "# 600μs -> 300μs" in modified_source
|
||||
assert "# 200μs -> 100μs" in modified_source
|
||||
|
||||
assert "async def test_async_function():" in modified_source
|
||||
assert "await async_function([4, 5, 6])" in modified_source
|
||||
|
||||
def test_async_complex_await_patterns(self, test_config):
|
||||
os.chdir(test_config.project_root_path)
|
||||
test_source = """async def test_complex_async():
|
||||
# Multiple await calls
|
||||
result1 = await async_func1()
|
||||
codeflash_output = await async_func2(result1)
|
||||
result3 = await async_func3(codeflash_output)
|
||||
assert result3 == expected
|
||||
|
||||
# Await in context manager
|
||||
async with async_context() as ctx:
|
||||
final_result = await ctx.process()
|
||||
assert final_result is not None
|
||||
"""
|
||||
generated_test = GeneratedTests(
|
||||
generated_original_test_source=test_source,
|
||||
instrumented_behavior_test_source="",
|
||||
instrumented_perf_test_source="",
|
||||
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
|
||||
perf_file_path=test_config.tests_root / "test_perf.py"
|
||||
)
|
||||
|
||||
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
|
||||
|
||||
original_test_results = TestResults()
|
||||
optimized_test_results = TestResults()
|
||||
|
||||
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
|
||||
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
|
||||
|
||||
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
|
||||
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
|
||||
|
||||
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
|
||||
|
||||
modified_source = result.generated_tests[0].generated_original_test_source
|
||||
assert "# 750μs -> 450μs" in modified_source
|
||||
|
|
@ -12,7 +12,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
|||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments
|
||||
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
|
||||
|
||||
|
||||
class HelperClass:
|
||||
|
|
@ -1800,9 +1800,10 @@ def get_system_details():
|
|||
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
project_root = package_dir.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=package_dir.resolve(),
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
|
|
@ -1826,8 +1827,10 @@ def get_system_details():
|
|||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
hashing_context = code_ctx.hashing_code_context
|
||||
# The expected contexts
|
||||
# Resolve both paths to handle symlink issues on macOS
|
||||
relative_path = file_path.relative_to(project_root)
|
||||
expected_read_write_context = f"""
|
||||
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
|
||||
```python:{relative_path}
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
|
|
@ -2045,9 +2048,10 @@ def get_system_details():
|
|||
|
||||
# Set up the optimizer
|
||||
file_path = main_file_path.resolve()
|
||||
project_root = package_dir.resolve()
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=package_dir.resolve(),
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root="tests",
|
||||
test_framework="pytest",
|
||||
|
|
@ -2070,6 +2074,7 @@ def get_system_details():
|
|||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||
# The expected contexts
|
||||
relative_path = file_path.relative_to(project_root)
|
||||
expected_read_write_context = f"""
|
||||
```python:utility_module.py
|
||||
# Function that will be used in the main code
|
||||
|
|
@ -2096,7 +2101,7 @@ def select_precision(precision, fallback_precision):
|
|||
else:
|
||||
return DEFAULT_PRECISION
|
||||
```
|
||||
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
|
||||
```python:{relative_path}
|
||||
import utility_module
|
||||
|
||||
class Calculator:
|
||||
|
|
@ -2477,3 +2482,148 @@ def test_circular_deps():
|
|||
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
||||
|
||||
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
||||
def test_global_assignment_collector_with_async_function():
|
||||
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignment
|
||||
GLOBAL_VAR = "global_value"
|
||||
OTHER_GLOBAL = 42
|
||||
|
||||
async def async_function():
|
||||
# This should not be collected (inside async function)
|
||||
local_var = "local_value"
|
||||
INNER_ASSIGNMENT = "should_not_be_global"
|
||||
return local_var
|
||||
|
||||
# Another global assignment
|
||||
ANOTHER_GLOBAL = "another_global"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect global assignments but not the ones inside async function
|
||||
assert len(collector.assignments) == 3
|
||||
assert "GLOBAL_VAR" in collector.assignments
|
||||
assert "OTHER_GLOBAL" in collector.assignments
|
||||
assert "ANOTHER_GLOBAL" in collector.assignments
|
||||
|
||||
# Should not collect assignments from inside async function
|
||||
assert "local_var" not in collector.assignments
|
||||
assert "INNER_ASSIGNMENT" not in collector.assignments
|
||||
|
||||
# Verify assignment order
|
||||
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
|
||||
assert collector.assignment_order == expected_order
|
||||
|
||||
|
||||
def test_global_assignment_collector_nested_async_functions():
|
||||
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignment
|
||||
CONFIG = {"key": "value"}
|
||||
|
||||
def sync_function():
|
||||
# Inside sync function - should not be collected
|
||||
sync_local = "sync"
|
||||
|
||||
async def nested_async():
|
||||
# Inside nested async function - should not be collected
|
||||
nested_var = "nested"
|
||||
return nested_var
|
||||
|
||||
return sync_local
|
||||
|
||||
async def async_function():
|
||||
# Inside async function - should not be collected
|
||||
async_local = "async"
|
||||
|
||||
def nested_sync():
|
||||
# Inside nested function - should not be collected
|
||||
deeply_nested = "deep"
|
||||
return deeply_nested
|
||||
|
||||
return async_local
|
||||
|
||||
# Another global assignment
|
||||
FINAL_GLOBAL = "final"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should only collect global-level assignments
|
||||
assert len(collector.assignments) == 2
|
||||
assert "CONFIG" in collector.assignments
|
||||
assert "FINAL_GLOBAL" in collector.assignments
|
||||
|
||||
# Should not collect any assignments from inside functions
|
||||
assert "sync_local" not in collector.assignments
|
||||
assert "nested_var" not in collector.assignments
|
||||
assert "async_local" not in collector.assignments
|
||||
assert "deeply_nested" not in collector.assignments
|
||||
|
||||
|
||||
def test_global_assignment_collector_mixed_async_sync_with_classes():
|
||||
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global assignments
|
||||
GLOBAL_CONSTANT = "constant"
|
||||
|
||||
class TestClass:
|
||||
# Class-level assignment - should not be collected
|
||||
class_var = "class_value"
|
||||
|
||||
def sync_method(self):
|
||||
# Method assignment - should not be collected
|
||||
method_var = "method"
|
||||
return method_var
|
||||
|
||||
async def async_method(self):
|
||||
# Async method assignment - should not be collected
|
||||
async_method_var = "async_method"
|
||||
return async_method_var
|
||||
|
||||
def sync_function():
|
||||
# Function assignment - should not be collected
|
||||
func_var = "function"
|
||||
return func_var
|
||||
|
||||
async def async_function():
|
||||
# Async function assignment - should not be collected
|
||||
async_func_var = "async_function"
|
||||
return async_func_var
|
||||
|
||||
# More global assignments
|
||||
ANOTHER_CONSTANT = 100
|
||||
FINAL_ASSIGNMENT = {"data": "value"}
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalAssignmentCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should only collect global-level assignments
|
||||
assert len(collector.assignments) == 3
|
||||
assert "GLOBAL_CONSTANT" in collector.assignments
|
||||
assert "ANOTHER_CONSTANT" in collector.assignments
|
||||
assert "FINAL_ASSIGNMENT" in collector.assignments
|
||||
|
||||
# Should not collect assignments from inside any scoped blocks
|
||||
assert "class_var" not in collector.assignments
|
||||
assert "method_var" not in collector.assignments
|
||||
assert "async_method_var" not in collector.assignments
|
||||
assert "func_var" not in collector.assignments
|
||||
assert "async_func_var" not in collector.assignments
|
||||
|
||||
# Verify correct order
|
||||
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
|
||||
assert collector.assignment_order == expected_order
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
|
|||
is_zero_diff,
|
||||
replace_functions_and_add_imports,
|
||||
replace_functions_in_file,
|
||||
OptimFunctionCollector,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
|
||||
|
|
@ -1707,6 +1708,7 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
a=2
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
|
|
@ -1782,6 +1784,7 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
a=np.zeros(10)
|
||||
|
|
@ -1860,6 +1863,7 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
a=3
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
|
|
@ -1937,6 +1941,7 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
a=2
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
|
|
@ -2015,6 +2020,7 @@ print("Hello world")
|
|||
"""
|
||||
expected_code = """import numpy as np
|
||||
|
||||
print("Hello world")
|
||||
a=3
|
||||
print("Hello world")
|
||||
def some_fn():
|
||||
|
|
@ -2101,6 +2107,7 @@ print("Hello world")
|
|||
|
||||
a = 6
|
||||
|
||||
print("Hello world")
|
||||
if 2<3:
|
||||
a=4
|
||||
else:
|
||||
|
|
@ -3448,156 +3455,173 @@ def hydrate_input_text_actions_with_field_names(
|
|||
|
||||
assert new_code == expected
|
||||
|
||||
def test_duplicate_global_assignments_when_reverting_helpers():
|
||||
root_dir = Path(__file__).parent.parent.resolve()
|
||||
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
|
||||
|
||||
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# -- all detectors need to be called to update state and avoid double counting
|
||||
# -- boundaries that happen to coincide, like Table and new section on same element.
|
||||
# -- Using `any()` would short-circuit on first True.
|
||||
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
|
||||
return any(semantic_boundaries)
|
||||
'''
|
||||
main_file.write_text(original_code, encoding="utf-8")
|
||||
optim_code = f'''```python:{main_file.relative_to(root_dir)}
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
from __future__ import annotations
|
||||
from typing import Iterable
|
||||
from unstructured.documents.elements import Element
|
||||
from unstructured.utils import lazyproperty
|
||||
class PreChunker:
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
```
|
||||
'''
|
||||
# OptimFunctionCollector async function tests
|
||||
def test_optim_function_collector_with_async_functions():
|
||||
"""Test OptimFunctionCollector correctly collects async functions."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def sync_function():
|
||||
return "sync"
|
||||
|
||||
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
|
||||
test_config = TestConfig(
|
||||
tests_root=root_dir / "tests/pytest",
|
||||
tests_project_rootdir=root_dir,
|
||||
project_root_path=root_dir,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
async def async_function():
|
||||
return "async"
|
||||
|
||||
class TestClass:
|
||||
def sync_method(self):
|
||||
return "sync_method"
|
||||
|
||||
async def async_method(self):
|
||||
return "async_method"
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
|
||||
preexisting_objects=None
|
||||
)
|
||||
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
|
||||
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect both sync and async functions
|
||||
assert len(collector.modified_functions) == 4
|
||||
assert (None, "sync_function") in collector.modified_functions
|
||||
assert (None, "async_function") in collector.modified_functions
|
||||
assert ("TestClass", "sync_method") in collector.modified_functions
|
||||
assert ("TestClass", "async_method") in collector.modified_functions
|
||||
|
||||
original_helper_code: dict[Path, str] = {}
|
||||
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
|
||||
for helper_function_path in helper_function_paths:
|
||||
with helper_function_path.open(encoding="utf8") as f:
|
||||
helper_code = f.read()
|
||||
original_helper_code[helper_function_path] = helper_code
|
||||
|
||||
func_optimizer.args = Args()
|
||||
func_optimizer.replace_function_and_helpers_with_optimized_code(
|
||||
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
|
||||
def test_optim_function_collector_new_async_functions():
|
||||
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
def existing_function():
|
||||
return "existing"
|
||||
|
||||
async def new_async_function():
|
||||
return "new_async"
|
||||
|
||||
def new_sync_function():
|
||||
return "new_sync"
|
||||
|
||||
class ExistingClass:
|
||||
async def new_class_async_method(self):
|
||||
return "new_class_async"
|
||||
"""
|
||||
|
||||
# Only existing_function is in preexisting objects
|
||||
preexisting_objects = {("existing_function", ())}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=set(), # Not looking for specific functions
|
||||
preexisting_objects=preexisting_objects
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should identify new functions (both sync and async)
|
||||
assert len(collector.new_functions) == 2
|
||||
function_names = [func.name.value for func in collector.new_functions]
|
||||
assert "new_async_function" in function_names
|
||||
assert "new_sync_function" in function_names
|
||||
|
||||
# Should identify new class methods
|
||||
assert "ExistingClass" in collector.new_class_functions
|
||||
assert len(collector.new_class_functions["ExistingClass"]) == 1
|
||||
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
|
||||
|
||||
|
||||
new_code = main_file.read_text(encoding="utf-8")
|
||||
main_file.unlink(missing_ok=True)
|
||||
def test_optim_function_collector_mixed_scenarios():
|
||||
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
|
||||
import libcst as cst
|
||||
|
||||
source_code = """
|
||||
# Global functions
|
||||
def global_sync():
|
||||
pass
|
||||
|
||||
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
|
||||
from __future__ import annotations
|
||||
import collections
|
||||
import copy
|
||||
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
|
||||
import regex
|
||||
from typing_extensions import Self, TypeAlias
|
||||
from unstructured.utils import lazyproperty
|
||||
from unstructured.documents.elements import Element
|
||||
# ================================================================================================
|
||||
# MODEL
|
||||
# ================================================================================================
|
||||
CHUNK_MAX_CHARS_DEFAULT: int = 500
|
||||
# ================================================================================================
|
||||
# PRE-CHUNKER
|
||||
# ================================================================================================
|
||||
class PreChunker:
|
||||
"""Gathers sequential elements into pre-chunks as length constraints allow.
|
||||
The pre-chunker's responsibilities are:
|
||||
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
|
||||
either side of those boundaries into different sections. In this case, the primary indicator
|
||||
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
|
||||
semantic boundary when `multipage_sections` is `False`.
|
||||
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
|
||||
into sections as big as possible without exceeding the chunk window size.
|
||||
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
|
||||
and only produce a section that exceeds the chunk window size when there is a single element
|
||||
with text longer than that window.
|
||||
A Table element is placed into a section by itself. CheckBox elements are dropped.
|
||||
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
|
||||
a new "section", hence the "by-title" designation.
|
||||
"""
|
||||
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
|
||||
self._elements = elements
|
||||
self._opts = opts
|
||||
@lazyproperty
|
||||
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
|
||||
"""The semantic-boundary detectors to be applied to break pre-chunks."""
|
||||
return self._opts.boundary_predicates
|
||||
def _is_in_new_semantic_unit(self, element: Element) -> bool:
|
||||
"""True when `element` begins a new semantic unit such as a section or page."""
|
||||
# Use generator expression for lower memory usage and avoid building intermediate list
|
||||
for pred in self._boundary_predicates:
|
||||
if pred(element):
|
||||
return True
|
||||
return False
|
||||
async def global_async():
|
||||
pass
|
||||
|
||||
class ParentClass:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def sync_method(self):
|
||||
pass
|
||||
|
||||
async def async_method(self):
|
||||
pass
|
||||
|
||||
class ChildClass:
|
||||
async def child_async_method(self):
|
||||
pass
|
||||
|
||||
def child_sync_method(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
# Looking for specific functions
|
||||
function_names = {
|
||||
(None, "global_sync"),
|
||||
(None, "global_async"),
|
||||
("ParentClass", "sync_method"),
|
||||
("ParentClass", "async_method"),
|
||||
("ChildClass", "child_async_method")
|
||||
}
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = OptimFunctionCollector(
|
||||
function_names=function_names,
|
||||
preexisting_objects=None
|
||||
)
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect all specified functions (mix of sync and async)
|
||||
assert len(collector.modified_functions) == 5
|
||||
assert (None, "global_sync") in collector.modified_functions
|
||||
assert (None, "global_async") in collector.modified_functions
|
||||
assert ("ParentClass", "sync_method") in collector.modified_functions
|
||||
assert ("ParentClass", "async_method") in collector.modified_functions
|
||||
assert ("ChildClass", "child_async_method") in collector.modified_functions
|
||||
|
||||
# Should collect __init__ method
|
||||
assert "ParentClass" in collector.modified_init_functions
|
||||
|
||||
|
||||
|
||||
def test_is_zero_diff_async_sleep():
|
||||
original_code = '''
|
||||
import time
|
||||
|
||||
async def task():
|
||||
time.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert new_code == expected
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert not is_zero_diff(original_code, optimized_code)
|
||||
|
||||
def test_is_zero_diff_with_equivalent_code():
|
||||
original_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
optimized_code = '''
|
||||
import asyncio
|
||||
|
||||
async def task():
|
||||
"""A task that does something."""
|
||||
await asyncio.sleep(1)
|
||||
return "done"
|
||||
'''
|
||||
assert is_zero_diff(original_code, optimized_code)
|
||||
|
|
@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
|
|||
is_class_defined_in_file,
|
||||
module_name_from_file_path,
|
||||
path_belongs_to_site_packages,
|
||||
has_any_async_functions,
|
||||
validate_python_code,
|
||||
)
|
||||
from codeflash.code_utils.concolic_utils import clean_concolic_tests
|
||||
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
|
||||
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -308,6 +308,86 @@ def my_function():
|
|||
assert is_class_defined_in_file("MyClass", test_file) is False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_code_context():
|
||||
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
|
||||
from unittest.mock import MagicMock
|
||||
from codeflash.models.models import CodeOptimizationContext
|
||||
|
||||
context = MagicMock(spec=CodeOptimizationContext)
|
||||
context.preexisting_objects = []
|
||||
return context
|
||||
|
||||
|
||||
def test_extract_dependent_function_sync_and_async(mock_code_context):
|
||||
"""Test extract_dependent_function with both sync and async functions."""
|
||||
# Test sync function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
|
||||
|
||||
# Test async function extraction
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
async def async_helper_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
|
||||
|
||||
|
||||
def test_extract_dependent_function_edge_cases(mock_code_context):
|
||||
"""Test extract_dependent_function edge cases."""
|
||||
# No dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
# Multiple dependent functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
def main_function():
|
||||
pass
|
||||
|
||||
def helper1():
|
||||
pass
|
||||
|
||||
async def helper2():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("main_function", mock_code_context) is False
|
||||
|
||||
|
||||
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
|
||||
"""Test extract_dependent_function with mixed sync/async scenarios."""
|
||||
# Async main with sync helper
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
def sync_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
|
||||
|
||||
# Only async functions
|
||||
mock_code_context.testgen_context_code = """
|
||||
async def async_main():
|
||||
pass
|
||||
|
||||
async def async_helper():
|
||||
pass
|
||||
"""
|
||||
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
|
||||
|
||||
|
||||
def test_is_class_defined_in_file_with_non_existing_file() -> None:
|
||||
non_existing_file = Path("/non/existing/file.py")
|
||||
|
||||
|
|
@ -445,25 +525,41 @@ def test_Grammar_copy():
|
|||
assert cleaned_code == expected_cleaned_code.strip()
|
||||
|
||||
|
||||
def test_has_any_async_functions_with_async_code() -> None:
|
||||
def test_validate_python_code_valid() -> None:
|
||||
code = "def hello():\n return 'world'"
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_invalid() -> None:
|
||||
code = "def hello(:\n return 'world'"
|
||||
with pytest.raises(ValueError, match="Invalid Python code"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_empty() -> None:
|
||||
code = ""
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
||||
|
||||
def test_validate_python_code_complex_invalid() -> None:
|
||||
code = "if True\n print('missing colon')"
|
||||
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
|
||||
validate_python_code(code)
|
||||
|
||||
|
||||
def test_validate_python_code_valid_complex() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
async def async_function():
|
||||
pass
|
||||
def calculate(a, b):
|
||||
if a > b:
|
||||
return a + b
|
||||
else:
|
||||
return a * b
|
||||
|
||||
class MyClass:
|
||||
def __init__(self):
|
||||
self.value = 42
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_has_any_async_functions_without_async_code() -> None:
|
||||
code = """
|
||||
def normal_function():
|
||||
pass
|
||||
|
||||
def another_function():
|
||||
pass
|
||||
"""
|
||||
result = has_any_async_functions(code)
|
||||
assert result is False
|
||||
result = validate_python_code(code)
|
||||
assert result == code
|
||||
|
|
|
|||
Loading…
Reference in a new issue