fix: handle module-level function definitions in add_global_assignments

Add GlobalFunctionCollector and GlobalFunctionTransformer to collect and
insert module-level function definitions introduced by LLM optimizations.
This fixes NameError when optimized code introduces new helper functions
like @lru_cache decorated functions that are used by the optimized method.
This commit is contained in:
Kevin Turcios 2026-01-23 18:55:47 -05:00
parent 9f929c2151
commit 6009b83f20
2 changed files with 304 additions and 8 deletions

View file

@ -25,6 +25,96 @@ if TYPE_CHECKING:
from codeflash.models.models import FunctionSource
class GlobalFunctionCollector(cst.CSTVisitor):
"""Collects all module-level function definitions (not inside classes or other functions)."""
def __init__(self) -> None:
super().__init__()
self.functions: dict[str, cst.FunctionDef] = {}
self.function_order: list[str] = []
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
if self.scope_depth == 0:
# Module-level function
name = node.name.value
self.functions[name] = node
if name not in self.function_order:
self.function_order.append(name)
self.scope_depth += 1
return True
def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: # noqa: ARG002
self.scope_depth -= 1
def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]: # noqa: ARG002
self.scope_depth += 1
return True
def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
self.scope_depth -= 1
class GlobalFunctionTransformer(cst.CSTTransformer):
"""Transforms/adds module-level functions from the new file to the original file."""
def __init__(self, new_functions: dict[str, cst.FunctionDef], new_function_order: list[str]) -> None:
super().__init__()
self.new_functions = new_functions
self.new_function_order = new_function_order
self.processed_functions: set[str] = set()
self.scope_depth = 0
def visit_FunctionDef(self, node: cst.FunctionDef) -> None: # noqa: ARG002
self.scope_depth += 1
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
self.scope_depth -= 1
if self.scope_depth > 0:
return updated_node
# Check if this is a module-level function we need to replace
name = original_node.name.value
if name in self.new_functions:
self.processed_functions.add(name)
return self.new_functions[name]
return updated_node
def visit_ClassDef(self, node: cst.ClassDef) -> None: # noqa: ARG002
self.scope_depth += 1
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: # noqa: ARG002
self.scope_depth -= 1
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new functions that weren't in the original file
new_statements = list(updated_node.body)
functions_to_append = [
self.new_functions[name]
for name in self.new_function_order
if name not in self.processed_functions and name in self.new_functions
]
if functions_to_append:
# Find the position of the last function or class definition
insert_index = find_insertion_index_after_imports(updated_node)
for i, stmt in enumerate(new_statements):
if isinstance(stmt, (cst.FunctionDef, cst.ClassDef)):
insert_index = i + 1
# Add empty line before each new function
function_nodes = []
for func in functions_to_append:
func_with_empty_line = func.with_changes(leading_lines=[cst.EmptyLine(), *func.leading_lines])
function_nodes.append(func_with_empty_line)
new_statements = list(chain(new_statements[:insert_index], function_nodes, new_statements[insert_index:]))
return updated_node.with_changes(body=new_statements)
class GlobalAssignmentCollector(cst.CSTVisitor):
"""Collects all global assignment statements."""
@ -439,17 +529,41 @@ def add_global_assignments(src_module_code: str, dst_module_code: str) -> str:
# Parse the src_module_code once only (already done above: src_module)
# Collect assignments from the new file
new_collector = GlobalAssignmentCollector()
src_module.visit(new_collector)
# Only create transformer if there are assignments to insert/transform
if not new_collector.assignments: # nothing to transform
new_assignment_collector = GlobalAssignmentCollector()
src_module.visit(new_assignment_collector)
# Collect module-level functions from both source and destination
src_function_collector = GlobalFunctionCollector()
src_module.visit(src_function_collector)
dst_function_collector = GlobalFunctionCollector()
original_module.visit(dst_function_collector)
# Filter out functions that already exist in the destination (only add truly new functions)
new_functions = {
name: func
for name, func in src_function_collector.functions.items()
if name not in dst_function_collector.functions
}
new_function_order = [name for name in src_function_collector.function_order if name in new_functions]
# If there are no assignments and no new functions, return the current code
if not new_assignment_collector.assignments and not new_functions:
return mod_dst_code
# Transform the original destination module
transformer = GlobalAssignmentTransformer(new_collector.assignments, new_collector.assignment_order)
transformed_module = original_module.visit(transformer)
# Transform assignments if any
if new_assignment_collector.assignments:
transformer = GlobalAssignmentTransformer(
new_assignment_collector.assignments, new_assignment_collector.assignment_order
)
original_module = original_module.visit(transformer)
return transformed_module.code
# Transform functions if any
if new_functions:
function_transformer = GlobalFunctionTransformer(new_functions, new_function_order)
original_module = original_module.visit(function_transformer)
return original_module.code
def resolve_star_import(module_name: str, project_root: Path) -> set[str]:

View file

@ -2815,6 +2815,188 @@ FINAL_VAR = 123
assert collector.assignment_order == expected_order
def test_global_function_collector():
"""Test GlobalFunctionCollector correctly collects module-level function definitions."""
import libcst as cst
from codeflash.code_utils.code_extractor import GlobalFunctionCollector
source_code = """
# Module-level functions
def helper_function():
return "helper"
def another_helper(x: int) -> str:
return str(x)
class SomeClass:
def method(self):
# This is a method, not a module-level function
return "method"
def another_method(self):
# Also a method
def nested_function():
# Nested function inside method
return "nested"
return nested_function()
def final_function():
def inner_function():
# This is a nested function, not module-level
return "inner"
return inner_function()
"""
tree = cst.parse_module(source_code)
collector = GlobalFunctionCollector()
tree.visit(collector)
# Should collect only module-level functions
assert len(collector.functions) == 3
assert "helper_function" in collector.functions
assert "another_helper" in collector.functions
assert "final_function" in collector.functions
# Should not collect methods or nested functions
assert "method" not in collector.functions
assert "another_method" not in collector.functions
assert "nested_function" not in collector.functions
assert "inner_function" not in collector.functions
# Verify correct order
expected_order = ["helper_function", "another_helper", "final_function"]
assert collector.function_order == expected_order
def test_add_global_assignments_with_new_functions():
"""Test add_global_assignments correctly adds new module-level functions."""
source_code = """\
from functools import lru_cache
class SkyvernPage:
@staticmethod
def action_wrap(action):
return _get_decorator_for_action(action)
@lru_cache(maxsize=None)
def _get_decorator_for_action(action):
def decorator(fn):
return fn
return decorator
"""
destination_code = """\
from functools import lru_cache
class SkyvernPage:
@staticmethod
def action_wrap(action):
# Original implementation
return action
"""
expected = """\
from functools import lru_cache
class SkyvernPage:
@staticmethod
def action_wrap(action):
# Original implementation
return action
@lru_cache(maxsize=None)
def _get_decorator_for_action(action):
def decorator(fn):
return fn
return decorator
"""
result = add_global_assignments(source_code, destination_code)
assert result == expected
def test_add_global_assignments_does_not_duplicate_existing_functions():
"""Test add_global_assignments does not duplicate functions that already exist in destination."""
source_code = """\
def helper():
return "source_helper"
def existing_function():
return "source_existing"
"""
destination_code = """\
def existing_function():
return "dest_existing"
class MyClass:
pass
"""
expected = """\
def existing_function():
return "dest_existing"
class MyClass:
pass
def helper():
return "source_helper"
"""
result = add_global_assignments(source_code, destination_code)
assert result == expected
def test_add_global_assignments_with_decorated_functions():
"""Test add_global_assignments correctly adds decorated functions."""
source_code = """\
from functools import lru_cache
from typing import Callable
_LOCAL_CACHE: dict[str, int] = {}
@lru_cache(maxsize=128)
def cached_helper(x: int) -> int:
return x * 2
def regular_helper():
return "regular"
"""
destination_code = """\
from typing import Any
class MyClass:
def method(self):
return cached_helper(5)
"""
expected = """\
from typing import Any
_LOCAL_CACHE: dict[str, int] = {}
class MyClass:
def method(self):
return cached_helper(5)
@lru_cache(maxsize=128)
def cached_helper(x: int) -> int:
return x * 2
def regular_helper():
return "regular"
"""
result = add_global_assignments(source_code, destination_code)
assert result == expected
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.