few missing things

This commit is contained in:
Kevin Turcios 2025-09-26 16:25:28 -07:00
parent 6bd39f7791
commit c6c9d9559f
5 changed files with 671 additions and 171 deletions

View file

@ -40,6 +40,30 @@ matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
def calculate_function_throughput_from_test_results(test_results: TestResults, function_name: str) -> int:
"""Calculate function throughput from TestResults by extracting performance stdout.
A completed execution is defined as having both a start tag and matching end tag from performance wrappers.
Start: !$######test_module:test_function:function_name:loop_index:iteration_id######$!
End: !######test_module:test_function:function_name:loop_index:iteration_id:duration######!
"""
start_matches = start_pattern.findall(test_results.perf_stdout or "")
end_matches = end_pattern.findall(test_results.perf_stdout or "")
end_matches_truncated = [end_match[:5] for end_match in end_matches]
end_matches_set = set(end_matches_truncated)
function_throughput = 0
for start_match in start_matches:
if start_match in end_matches_set and len(start_match) > 2 and start_match[2] == function_name:
function_throughput += 1
return function_throughput
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
test_results = TestResults()
if not file_location.exists():

View file

@ -1903,3 +1903,209 @@ def test_bubble_sort(input, expected_output):
# Check that comments were added
modified_source = result.generated_tests[0].generated_original_test_source
assert modified_source == expected
def test_async_basic_runtime_comment_addition(self, test_config):
"""Test basic functionality of adding runtime comments to async test functions."""
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_bubble_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py",
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_invocation = self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0') # 500μs
optimized_invocation = self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0') # 300μs
original_test_results.add(original_invocation)
optimized_test_results.add(optimized_invocation)
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "codeflash_output = await async_bubble_sort([3, 1, 2]) # 500μs -> 300μs" in modified_source
def test_async_multiple_test_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_async_bubble_sort():
codeflash_output = await async_quick_sort([3, 1, 2])
assert codeflash_output == [1, 2, 3]
async def test_async_quick_sort():
codeflash_output = await async_quick_sort([5, 2, 8])
assert codeflash_output == [2, 5, 8]
def helper_function():
return "not a test"
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_async_bubble_sort", 500_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_quick_sort", 800_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_bubble_sort", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_quick_sort", 600_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 500μs -> 300μs" in modified_source
assert "# 800μs -> 600μs" in modified_source
assert (
"helper_function():" in modified_source
and "# " not in modified_source.split("helper_function():")[1].split("\n")[0]
)
def test_async_class_method(self, test_config):
os.chdir(test_config.project_root_path)
test_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function()
assert codeflash_output == expected
'''
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
invocation_id = InvocationId(
test_module_path="tests.test_module__unit_test_0",
test_class_name="TestAsyncClass",
test_function_name="test_async_function",
function_getting_tested="some_async_function",
iteration_id="0",
)
original_runtimes = {invocation_id: [2000000000]} # 2s in nanoseconds
optimized_runtimes = {invocation_id: [1000000000]} # 1s in nanoseconds
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
expected_source = '''class TestAsyncClass:
async def test_async_function(self):
codeflash_output = await some_async_function() # 2.00s -> 1.00s (100% faster)
assert codeflash_output == expected
'''
assert len(result.generated_tests) == 1
assert result.generated_tests[0].generated_original_test_source == expected_source
def test_async_mixed_sync_and_async_functions(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """def test_sync_function():
codeflash_output = sync_function([1, 2, 3])
assert codeflash_output == [1, 2, 3]
async def test_async_function():
codeflash_output = await async_function([4, 5, 6])
assert codeflash_output == [4, 5, 6]
def test_another_sync():
result = another_sync_func()
assert result is True
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
# Add test invocations for all test functions
original_test_results.add(self.create_test_invocation("test_sync_function", 400_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_async_function", 600_000, iteration_id='0'))
original_test_results.add(self.create_test_invocation("test_another_sync", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_sync_function", 200_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_async_function", 300_000, iteration_id='0'))
optimized_test_results.add(self.create_test_invocation("test_another_sync", 100_000, iteration_id='0'))
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 400μs -> 200μs" in modified_source
assert "# 600μs -> 300μs" in modified_source
assert "# 200μs -> 100μs" in modified_source
assert "async def test_async_function():" in modified_source
assert "await async_function([4, 5, 6])" in modified_source
def test_async_complex_await_patterns(self, test_config):
os.chdir(test_config.project_root_path)
test_source = """async def test_complex_async():
# Multiple await calls
result1 = await async_func1()
codeflash_output = await async_func2(result1)
result3 = await async_func3(codeflash_output)
assert result3 == expected
# Await in context manager
async with async_context() as ctx:
final_result = await ctx.process()
assert final_result is not None
"""
generated_test = GeneratedTests(
generated_original_test_source=test_source,
instrumented_behavior_test_source="",
instrumented_perf_test_source="",
behavior_file_path=test_config.tests_root / "test_module__unit_test_0.py",
perf_file_path=test_config.tests_root / "test_perf.py"
)
generated_tests = GeneratedTestsList(generated_tests=[generated_test])
original_test_results = TestResults()
optimized_test_results = TestResults()
original_test_results.add(self.create_test_invocation("test_complex_async", 750_000, iteration_id='1')) # 750μs
optimized_test_results.add(self.create_test_invocation("test_complex_async", 450_000, iteration_id='1')) # 450μs
original_runtimes = original_test_results.usable_runtime_data_by_test_case()
optimized_runtimes = optimized_test_results.usable_runtime_data_by_test_case()
result = add_runtime_comments_to_generated_tests(generated_tests, original_runtimes, optimized_runtimes)
modified_source = result.generated_tests[0].generated_original_test_source
assert "# 750μs -> 450μs" in modified_source

View file

@ -12,7 +12,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
class HelperClass:
@ -1800,9 +1800,10 @@ def get_system_details():
# Set up the optimizer
file_path = main_file_path.resolve()
project_root = package_dir.resolve()
opt = Optimizer(
Namespace(
project_root=package_dir.resolve(),
project_root=project_root,
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
@ -1826,8 +1827,10 @@ def get_system_details():
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
# The expected contexts
# Resolve both paths to handle symlink issues on macOS
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f"""
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
```python:{relative_path}
import utility_module
class Calculator:
@ -2045,9 +2048,10 @@ def get_system_details():
# Set up the optimizer
file_path = main_file_path.resolve()
project_root = package_dir.resolve()
opt = Optimizer(
Namespace(
project_root=package_dir.resolve(),
project_root=project_root,
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
@ -2070,6 +2074,7 @@ def get_system_details():
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
# The expected contexts
relative_path = file_path.relative_to(project_root)
expected_read_write_context = f"""
```python:utility_module.py
# Function that will be used in the main code
@ -2096,7 +2101,7 @@ def select_precision(precision, fallback_precision):
else:
return DEFAULT_PRECISION
```
```python:{main_file_path.resolve().relative_to(opt.args.project_root.resolve())}
```python:{relative_path}
import utility_module
class Calculator:
@ -2477,3 +2482,148 @@ def test_circular_deps():
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst
source_code = """
# Global assignment
GLOBAL_VAR = "global_value"
OTHER_GLOBAL = 42
async def async_function():
# This should not be collected (inside async function)
local_var = "local_value"
INNER_ASSIGNMENT = "should_not_be_global"
return local_var
# Another global assignment
ANOTHER_GLOBAL = "another_global"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should collect global assignments but not the ones inside async function
assert len(collector.assignments) == 3
assert "GLOBAL_VAR" in collector.assignments
assert "OTHER_GLOBAL" in collector.assignments
assert "ANOTHER_GLOBAL" in collector.assignments
# Should not collect assignments from inside async function
assert "local_var" not in collector.assignments
assert "INNER_ASSIGNMENT" not in collector.assignments
# Verify assignment order
expected_order = ["GLOBAL_VAR", "OTHER_GLOBAL", "ANOTHER_GLOBAL"]
assert collector.assignment_order == expected_order
def test_global_assignment_collector_nested_async_functions():
"""Test GlobalAssignmentCollector handles nested async functions correctly."""
import libcst as cst
source_code = """
# Global assignment
CONFIG = {"key": "value"}
def sync_function():
# Inside sync function - should not be collected
sync_local = "sync"
async def nested_async():
# Inside nested async function - should not be collected
nested_var = "nested"
return nested_var
return sync_local
async def async_function():
# Inside async function - should not be collected
async_local = "async"
def nested_sync():
# Inside nested function - should not be collected
deeply_nested = "deep"
return deeply_nested
return async_local
# Another global assignment
FINAL_GLOBAL = "final"
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 2
assert "CONFIG" in collector.assignments
assert "FINAL_GLOBAL" in collector.assignments
# Should not collect any assignments from inside functions
assert "sync_local" not in collector.assignments
assert "nested_var" not in collector.assignments
assert "async_local" not in collector.assignments
assert "deeply_nested" not in collector.assignments
def test_global_assignment_collector_mixed_async_sync_with_classes():
"""Test GlobalAssignmentCollector with async functions, sync functions, and classes."""
import libcst as cst
source_code = """
# Global assignments
GLOBAL_CONSTANT = "constant"
class TestClass:
# Class-level assignment - should not be collected
class_var = "class_value"
def sync_method(self):
# Method assignment - should not be collected
method_var = "method"
return method_var
async def async_method(self):
# Async method assignment - should not be collected
async_method_var = "async_method"
return async_method_var
def sync_function():
# Function assignment - should not be collected
func_var = "function"
return func_var
async def async_function():
# Async function assignment - should not be collected
async_func_var = "async_function"
return async_func_var
# More global assignments
ANOTHER_CONSTANT = 100
FINAL_ASSIGNMENT = {"data": "value"}
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should only collect global-level assignments
assert len(collector.assignments) == 3
assert "GLOBAL_CONSTANT" in collector.assignments
assert "ANOTHER_CONSTANT" in collector.assignments
assert "FINAL_ASSIGNMENT" in collector.assignments
# Should not collect assignments from inside any scoped blocks
assert "class_var" not in collector.assignments
assert "method_var" not in collector.assignments
assert "async_method_var" not in collector.assignments
assert "func_var" not in collector.assignments
assert "async_func_var" not in collector.assignments
# Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order

View file

@ -12,6 +12,7 @@ from codeflash.code_utils.code_replacer import (
is_zero_diff,
replace_functions_and_add_imports,
replace_functions_in_file,
OptimFunctionCollector,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent
@ -1707,6 +1708,7 @@ print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
a=2
print("Hello world")
def some_fn():
@ -1782,6 +1784,7 @@ print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
print("Hello world")
def some_fn():
a=np.zeros(10)
@ -1860,6 +1863,7 @@ print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
a=3
print("Hello world")
def some_fn():
@ -1937,6 +1941,7 @@ print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
a=2
print("Hello world")
def some_fn():
@ -2015,6 +2020,7 @@ print("Hello world")
"""
expected_code = """import numpy as np
print("Hello world")
a=3
print("Hello world")
def some_fn():
@ -2101,6 +2107,7 @@ print("Hello world")
a = 6
print("Hello world")
if 2<3:
a=4
else:
@ -3448,156 +3455,173 @@ def hydrate_input_text_actions_with_field_names(
assert new_code == expected
def test_duplicate_global_assignments_when_reverting_helpers():
root_dir = Path(__file__).parent.parent.resolve()
main_file = Path(root_dir / "code_to_optimize/temp_main.py").resolve()
original_code = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# -- all detectors need to be called to update state and avoid double counting
# -- boundaries that happen to coincide, like Table and new section on same element.
# -- Using `any()` would short-circuit on first True.
semantic_boundaries = [pred(element) for pred in self._boundary_predicates]
return any(semantic_boundaries)
'''
main_file.write_text(original_code, encoding="utf-8")
optim_code = f'''```python:{main_file.relative_to(root_dir)}
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
from __future__ import annotations
from typing import Iterable
from unstructured.documents.elements import Element
from unstructured.utils import lazyproperty
class PreChunker:
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
```
'''
# OptimFunctionCollector async function tests
def test_optim_function_collector_with_async_functions():
"""Test OptimFunctionCollector correctly collects async functions."""
import libcst as cst
func = FunctionToOptimize(function_name="_is_in_new_semantic_unit", parents=[FunctionParent("PreChunker", "ClassDef")], file_path=main_file)
test_config = TestConfig(
tests_root=root_dir / "tests/pytest",
tests_project_rootdir=root_dir,
project_root_path=root_dir,
test_framework="pytest",
pytest_cmd="pytest",
source_code = """
def sync_function():
return "sync"
async def async_function():
return "async"
class TestClass:
def sync_method(self):
return "sync_method"
async def async_method(self):
return "async_method"
"""
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names={(None, "sync_function"), (None, "async_function"), ("TestClass", "sync_method"), ("TestClass", "async_method")},
preexisting_objects=None
)
func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config)
code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap()
tree.visit(collector)
original_helper_code: dict[Path, str] = {}
helper_function_paths = {hf.file_path for hf in code_context.helper_functions}
for helper_function_path in helper_function_paths:
with helper_function_path.open(encoding="utf8") as f:
helper_code = f.read()
original_helper_code[helper_function_path] = helper_code
# Should collect both sync and async functions
assert len(collector.modified_functions) == 4
assert (None, "sync_function") in collector.modified_functions
assert (None, "async_function") in collector.modified_functions
assert ("TestClass", "sync_method") in collector.modified_functions
assert ("TestClass", "async_method") in collector.modified_functions
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optim_code), original_helper_code=original_helper_code
def test_optim_function_collector_new_async_functions():
"""Test OptimFunctionCollector identifies new async functions not in preexisting objects."""
import libcst as cst
source_code = """
def existing_function():
return "existing"
async def new_async_function():
return "new_async"
def new_sync_function():
return "new_sync"
class ExistingClass:
async def new_class_async_method(self):
return "new_class_async"
"""
# Only existing_function is in preexisting objects
preexisting_objects = {("existing_function", ())}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=set(), # Not looking for specific functions
preexisting_objects=preexisting_objects
)
tree.visit(collector)
# Should identify new functions (both sync and async)
assert len(collector.new_functions) == 2
function_names = [func.name.value for func in collector.new_functions]
assert "new_async_function" in function_names
assert "new_sync_function" in function_names
# Should identify new class methods
assert "ExistingClass" in collector.new_class_functions
assert len(collector.new_class_functions["ExistingClass"]) == 1
assert collector.new_class_functions["ExistingClass"][0].name.value == "new_class_async_method"
new_code = main_file.read_text(encoding="utf-8")
main_file.unlink(missing_ok=True)
def test_optim_function_collector_mixed_scenarios():
"""Test OptimFunctionCollector with complex mix of sync/async functions and classes."""
import libcst as cst
expected = '''"""Chunking objects not specific to a particular chunking strategy."""
from __future__ import annotations
import collections
import copy
from typing import Any, Callable, DefaultDict, Iterable, Iterator, cast
import regex
from typing_extensions import Self, TypeAlias
from unstructured.utils import lazyproperty
from unstructured.documents.elements import Element
# ================================================================================================
# MODEL
# ================================================================================================
CHUNK_MAX_CHARS_DEFAULT: int = 500
# ================================================================================================
# PRE-CHUNKER
# ================================================================================================
class PreChunker:
"""Gathers sequential elements into pre-chunks as length constraints allow.
The pre-chunker's responsibilities are:
- **Segregate semantic units.** Identify semantic unit boundaries and segregate elements on
either side of those boundaries into different sections. In this case, the primary indicator
of a semantic boundary is a `Title` element. A page-break (change in page-number) is also a
semantic boundary when `multipage_sections` is `False`.
- **Minimize chunk count for each semantic unit.** Group the elements within a semantic unit
into sections as big as possible without exceeding the chunk window size.
- **Minimize chunks that must be split mid-text.** Precompute the text length of each section
and only produce a section that exceeds the chunk window size when there is a single element
with text longer than that window.
A Table element is placed into a section by itself. CheckBox elements are dropped.
The "by-title" strategy specifies breaking on section boundaries; a `Title` element indicates
a new "section", hence the "by-title" designation.
"""
def __init__(self, elements: Iterable[Element], opts: ChunkingOptions):
self._elements = elements
self._opts = opts
@lazyproperty
def _boundary_predicates(self) -> tuple[BoundaryPredicate, ...]:
"""The semantic-boundary detectors to be applied to break pre-chunks."""
return self._opts.boundary_predicates
def _is_in_new_semantic_unit(self, element: Element) -> bool:
"""True when `element` begins a new semantic unit such as a section or page."""
# Use generator expression for lower memory usage and avoid building intermediate list
for pred in self._boundary_predicates:
if pred(element):
return True
return False
source_code = """
# Global functions
def global_sync():
pass
async def global_async():
pass
class ParentClass:
def __init__(self):
pass
def sync_method(self):
pass
async def async_method(self):
pass
class ChildClass:
async def child_async_method(self):
pass
def child_sync_method(self):
pass
"""
# Looking for specific functions
function_names = {
(None, "global_sync"),
(None, "global_async"),
("ParentClass", "sync_method"),
("ParentClass", "async_method"),
("ChildClass", "child_async_method")
}
tree = cst.parse_module(source_code)
collector = OptimFunctionCollector(
function_names=function_names,
preexisting_objects=None
)
tree.visit(collector)
# Should collect all specified functions (mix of sync and async)
assert len(collector.modified_functions) == 5
assert (None, "global_sync") in collector.modified_functions
assert (None, "global_async") in collector.modified_functions
assert ("ParentClass", "sync_method") in collector.modified_functions
assert ("ParentClass", "async_method") in collector.modified_functions
assert ("ChildClass", "child_async_method") in collector.modified_functions
# Should collect __init__ method
assert "ParentClass" in collector.modified_init_functions
def test_is_zero_diff_async_sleep():
original_code = '''
import time
async def task():
time.sleep(1)
return "done"
'''
assert new_code == expected
optimized_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
assert not is_zero_diff(original_code, optimized_code)
def test_is_zero_diff_with_equivalent_code():
original_code = '''
import asyncio
async def task():
await asyncio.sleep(1)
return "done"
'''
optimized_code = '''
import asyncio
async def task():
"""A task that does something."""
await asyncio.sleep(1)
return "done"
'''
assert is_zero_diff(original_code, optimized_code)

View file

@ -17,10 +17,10 @@ from codeflash.code_utils.code_utils import (
is_class_defined_in_file,
module_name_from_file_path,
path_belongs_to_site_packages,
has_any_async_functions,
validate_python_code,
)
from codeflash.code_utils.concolic_utils import clean_concolic_tests
from codeflash.code_utils.coverage_utils import generate_candidates, prepare_coverage_files
from codeflash.code_utils.coverage_utils import extract_dependent_function, generate_candidates, prepare_coverage_files
@pytest.fixture
@ -308,6 +308,86 @@ def my_function():
assert is_class_defined_in_file("MyClass", test_file) is False
@pytest.fixture
def mock_code_context():
"""Mock CodeOptimizationContext for testing extract_dependent_function."""
from unittest.mock import MagicMock
from codeflash.models.models import CodeOptimizationContext
context = MagicMock(spec=CodeOptimizationContext)
context.preexisting_objects = []
return context
def test_extract_dependent_function_sync_and_async(mock_code_context):
"""Test extract_dependent_function with both sync and async functions."""
# Test sync function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "helper_function"
# Test async function extraction
mock_code_context.testgen_context_code = """
def main_function():
pass
async def async_helper_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) == "async_helper_function"
def test_extract_dependent_function_edge_cases(mock_code_context):
"""Test extract_dependent_function edge cases."""
# No dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
# Multiple dependent functions
mock_code_context.testgen_context_code = """
def main_function():
pass
def helper1():
pass
async def helper2():
pass
"""
assert extract_dependent_function("main_function", mock_code_context) is False
def test_extract_dependent_function_mixed_scenarios(mock_code_context):
"""Test extract_dependent_function with mixed sync/async scenarios."""
# Async main with sync helper
mock_code_context.testgen_context_code = """
async def async_main():
pass
def sync_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "sync_helper"
# Only async functions
mock_code_context.testgen_context_code = """
async def async_main():
pass
async def async_helper():
pass
"""
assert extract_dependent_function("async_main", mock_code_context) == "async_helper"
def test_is_class_defined_in_file_with_non_existing_file() -> None:
non_existing_file = Path("/non/existing/file.py")
@ -445,25 +525,41 @@ def test_Grammar_copy():
assert cleaned_code == expected_cleaned_code.strip()
def test_has_any_async_functions_with_async_code() -> None:
def test_validate_python_code_valid() -> None:
code = "def hello():\n return 'world'"
result = validate_python_code(code)
assert result == code
def test_validate_python_code_invalid() -> None:
code = "def hello(:\n return 'world'"
with pytest.raises(ValueError, match="Invalid Python code"):
validate_python_code(code)
def test_validate_python_code_empty() -> None:
code = ""
result = validate_python_code(code)
assert result == code
def test_validate_python_code_complex_invalid() -> None:
code = "if True\n print('missing colon')"
with pytest.raises(ValueError, match="Invalid Python code.*line 1.*column 8"):
validate_python_code(code)
def test_validate_python_code_valid_complex() -> None:
code = """
def normal_function():
pass
def calculate(a, b):
if a > b:
return a + b
else:
return a * b
async def async_function():
pass
class MyClass:
def __init__(self):
self.value = 42
"""
result = has_any_async_functions(code)
assert result is True
def test_has_any_async_functions_without_async_code() -> None:
code = """
def normal_function():
pass
def another_function():
pass
"""
result = has_any_async_functions(code)
assert result is False
result = validate_python_code(code)
assert result == code