mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: handle module-level function definitions in add_global_assignments
Add GlobalFunctionCollector and GlobalFunctionTransformer to collect and insert module-level function definitions introduced by LLM optimizations. This fixes NameError when optimized code introduces new helper functions like @lru_cache decorated functions that are used by the optimized method.
This commit is contained in:
parent
9f929c2151
commit
6009b83f20
2 changed files with 304 additions and 8 deletions
|
|
@ -25,6 +25,96 @@ if TYPE_CHECKING:
|
|||
from codeflash.models.models import FunctionSource
|
||||
|
||||
|
||||
class GlobalFunctionCollector(cst.CSTVisitor):
|
||||
"""Collects all module-level function definitions (not inside classes or other functions)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.functions: dict[str, cst.FunctionDef] = {}
|
||||
self.function_order: list[str] = []
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
|
||||
if self.scope_depth == 0:
|
||||
# Module-level function
|
||||
name = node.name.value
|
||||
self.functions[name] = node
|
||||
if name not in self.function_order:
|
||||
self.function_order.append(name)
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
return True
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
|
||||
|
||||
class GlobalFunctionTransformer(cst.CSTTransformer):
|
||||
"""Transforms/adds module-level functions from the new file to the original file."""
|
||||
|
||||
def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None:
|
||||
super().__init__()
|
||||
self.new_functions = new_functions
|
||||
self.new_function_order = new_function_order
|
||||
self.processed_functions: set[str] = set()
|
||||
self.scope_depth = 0
|
||||
|
||||
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
|
||||
self.scope_depth -= 1
|
||||
if self.scope_depth > 0:
|
||||
return updated_node
|
||||
|
||||
# Check if this is a module-level function we need to replace
|
||||
name = original_node.name.value
|
||||
if name in self.new_functions:
|
||||
self.processed_functions.add(name)
|
||||
return self.new_functions[name]
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
|
||||
self.scope_depth += 1
|
||||
|
||||
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
|
||||
self.scope_depth -= 1
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
# Add any new functions that weren't in the original file
|
||||
new_statements = list(updated_node.body)
|
||||
|
||||
functions_to_append = [
|
||||
self.new_functions[name]
|
||||
for name in self.new_function_order
|
||||
if name not in self.processed_functions and name in self.new_functions
|
||||
]
|
||||
|
||||
if functions_to_append:
|
||||
# Find the position of the last function or class definition
|
||||
insert_index = find_insertion_index_after_imports(updated_node)
|
||||
for i, stmt in enumerate(new_statements):
|
||||
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
|
||||
insert_index = i + 1
|
||||
|
||||
# Add empty line before each new function
|
||||
function_nodes = []
|
||||
for func in functions_to_append:
|
||||
func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines])
|
||||
function_nodes.append(func_with_empty_line)
|
||||
|
||||
new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:]))
|
||||
|
||||
return updated_node.with_changes(body=new_statements)
|
||||
|
||||
|
||||
class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||
"""Collects all global assignment statements."""
|
||||
|
||||
|
|
@ -439,17 +529,41 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
|
|||
|
||||
# Parse the src_module_code once only (already done above: src_module)
|
||||
# Collect assignments from the new file
|
||||
new_collector = GlobalAssignmentCollector()
|
||||
src_module.visit(new_collector)
|
||||
# Only create transformer if there are assignments to insert/transform
|
||||
if not new_collector.assignments: # nothing to transform
|
||||
new_assignment_collector = GlobalAssignmentCollector()
|
||||
src_module.visit(new_assignment_collector)
|
||||
|
||||
# Collect module-level functions from both source and destination
|
||||
src_function_collector = GlobalFunctionCollector()
|
||||
src_module.visit(src_function_collector)
|
||||
|
||||
dst_function_collector = GlobalFunctionCollector()
|
||||
original_module.visit(dst_function_collector)
|
||||
|
||||
# Filter out functions that already exist in the destination (only add truly new functions)
|
||||
new_functions = {
|
||||
name: func
|
||||
for name, func in src_function_collector.functions.items()
|
||||
if name not in dst_function_collector.functions
|
||||
}
|
||||
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
|
||||
|
||||
# If there are no assignments and no new functions, return the current code
|
||||
if not new_assignment_collector.assignments and not new_functions:
|
||||
return mod_dst_code
|
||||
|
||||
# Transform the original destination module
|
||||
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
|
||||
transformed_module = original_module.visit(transformer)
|
||||
# Transform assignments if any
|
||||
if new_assignment_collector.assignments:
|
||||
transformer = GlobalAssignmentTransformer(
|
||||
new_assignment_collector.assignments, new_assignment_collector.assignment_order
|
||||
)
|
||||
original_module = original_module.visit(transformer)
|
||||
|
||||
return transformed_module.code
|
||||
# Transform functions if any
|
||||
if new_functions:
|
||||
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
|
||||
original_module = original_module.visit(function_transformer)
|
||||
|
||||
return original_module.code
|
||||
|
||||
|
||||
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:
|
||||
|
|
|
|||
|
|
@ -2815,6 +2815,188 @@ FINAL_VAR = 123
|
|||
assert collector.assignment_order == expected_order
|
||||
|
||||
|
||||
def test_global_function_collector():
|
||||
"""Test GlobalFunctionCollector correctly collects module-level function definitions."""
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.code_utils.code_extractor import GlobalFunctionCollector
|
||||
|
||||
source_code = """
|
||||
# Module-level functions
|
||||
def helper_function():
|
||||
return "helper"
|
||||
|
||||
def another_helper(x: int) -> str:
|
||||
return str(x)
|
||||
|
||||
class SomeClass:
|
||||
def method(self):
|
||||
# This is a method, not a module-level function
|
||||
return "method"
|
||||
|
||||
def another_method(self):
|
||||
# Also a method
|
||||
def nested_function():
|
||||
# Nested function inside method
|
||||
return "nested"
|
||||
return nested_function()
|
||||
|
||||
def final_function():
|
||||
def inner_function():
|
||||
# This is a nested function, not module-level
|
||||
return "inner"
|
||||
return inner_function()
|
||||
"""
|
||||
|
||||
tree = cst.parse_module(source_code)
|
||||
collector = GlobalFunctionCollector()
|
||||
tree.visit(collector)
|
||||
|
||||
# Should collect only module-level functions
|
||||
assert len(collector.functions) == 3
|
||||
assert "helper_function" in collector.functions
|
||||
assert "another_helper" in collector.functions
|
||||
assert "final_function" in collector.functions
|
||||
|
||||
# Should not collect methods or nested functions
|
||||
assert "method" not in collector.functions
|
||||
assert "another_method" not in collector.functions
|
||||
assert "nested_function" not in collector.functions
|
||||
assert "inner_function" not in collector.functions
|
||||
|
||||
# Verify correct order
|
||||
expected_order = ["helper_function", "another_helper", "final_function"]
|
||||
assert collector.function_order == expected_order
|
||||
|
||||
|
||||
def test_add_global_assignments_with_new_functions():
|
||||
"""Test add_global_assignments correctly adds new module-level functions."""
|
||||
source_code = """\
|
||||
from functools import lru_cache
|
||||
|
||||
class SkyvernPage:
|
||||
@staticmethod
|
||||
def action_wrap(action):
|
||||
return _get_decorator_for_action(action)
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_decorator_for_action(action):
|
||||
def decorator(fn):
|
||||
return fn
|
||||
return decorator
|
||||
"""
|
||||
|
||||
destination_code = """\
|
||||
from functools import lru_cache
|
||||
|
||||
class SkyvernPage:
|
||||
@staticmethod
|
||||
def action_wrap(action):
|
||||
# Original implementation
|
||||
return action
|
||||
"""
|
||||
|
||||
expected = """\
|
||||
from functools import lru_cache
|
||||
|
||||
class SkyvernPage:
|
||||
@staticmethod
|
||||
def action_wrap(action):
|
||||
# Original implementation
|
||||
return action
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _get_decorator_for_action(action):
|
||||
def decorator(fn):
|
||||
return fn
|
||||
return decorator
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_add_global_assignments_does_not_duplicate_existing_functions():
|
||||
"""Test add_global_assignments does not duplicate functions that already exist in destination."""
|
||||
source_code = """\
|
||||
def helper():
|
||||
return "source_helper"
|
||||
|
||||
def existing_function():
|
||||
return "source_existing"
|
||||
"""
|
||||
|
||||
destination_code = """\
|
||||
def existing_function():
|
||||
return "dest_existing"
|
||||
|
||||
class MyClass:
|
||||
pass
|
||||
"""
|
||||
|
||||
expected = """\
|
||||
def existing_function():
|
||||
return "dest_existing"
|
||||
|
||||
class MyClass:
|
||||
pass
|
||||
|
||||
def helper():
|
||||
return "source_helper"
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_add_global_assignments_with_decorated_functions():
|
||||
"""Test add_global_assignments correctly adds decorated functions."""
|
||||
source_code = """\
|
||||
from functools import lru_cache
|
||||
from typing import Callable
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def cached_helper(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
def regular_helper():
|
||||
return "regular"
|
||||
"""
|
||||
|
||||
destination_code = """\
|
||||
from typing import Any
|
||||
|
||||
class MyClass:
|
||||
def method(self):
|
||||
return cached_helper(5)
|
||||
"""
|
||||
|
||||
expected = """\
|
||||
from typing import Any
|
||||
|
||||
_LOCAL_CACHE: dict[str, int] = {}
|
||||
|
||||
class MyClass:
|
||||
def method(self):
|
||||
return cached_helper(5)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def cached_helper(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
|
||||
def regular_helper():
|
||||
return "regular"
|
||||
"""
|
||||
|
||||
result = add_global_assignments(source_code, destination_code)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
||||
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue