fix: handle annotated assignments in GlobalAssignmentCollector

GlobalAssignmentCollector only handled cst.Assign but not cst.AnnAssign
(annotated assignments like `X: int = 1`). When the LLM generated
optimizations with annotated module-level variables, these weren't
copied to the target file, causing NameError at runtime.

- Add visit_AnnAssign to GlobalAssignmentCollector
- Add leave_AnnAssign to GlobalAssignmentTransformer
- Update type hints to include cst.AnnAssign
- Add test for annotated assignment handling
This commit is contained in:
Kevin Turcios 2026-01-23 18:46:49 -05:00
parent 412779d7ba
commit 9f929c2151
2 changed files with 163 additions and 100 deletions

View file

@ -30,7 +30,7 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
def __init__(self) -> None:
super().__init__()
self.assignments: dict[str, cst.Assign] = {}
self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {}
self.assignment_order: list[str] = []
# Track scope depth to identify global assignments
self.scope_depth = 0
@ -72,6 +72,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
self.assignment_order.append(name)
return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]:
# Handle annotated assignments like: _CACHE: Dict[str, int] = {}
# Only process module-level annotated assignments with a value
if (
self.scope_depth == 0
and self.if_else_depth == 0
and isinstance(node.target, cst.Name)
and node.value is not None
):
name = node.target.value
self.assignments[name] = node
if name not in self.assignment_order:
self.assignment_order.append(name)
return True
def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module."""
@ -103,7 +118,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int:
class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file."""
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None:
super().__init__()
self.new_assignments = new_assignments
self.new_assignment_order = new_assignment_order
@ -150,6 +165,19 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode:
if self.scope_depth > 0 or self.if_else_depth > 0:
return updated_node
# Check if this is a global annotated assignment we need to replace
if isinstance(original_node.target, cst.Name):
name = original_node.target.value
if name in self.new_assignments:
self.processed_assignments.add(name)
return self.new_assignments[name]
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file
new_statements = list(updated_node.body)

View file

@ -7,18 +7,18 @@ from collections import defaultdict
from pathlib import Path
import pytest
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.context.code_context_extractor import (
get_code_optimization_context,
get_imported_class_definitions,
collect_names_from_annotation,
extract_imports_for_class,
get_code_optimization_context,
get_imported_class_definitions,
)
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
class HelperClass:
@ -91,7 +91,10 @@ def test_code_replacement10() -> None:
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
assert qualified_names == {
"HelperClass.helper_method",
"HelperClass.__init__",
} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
@ -234,7 +237,7 @@ def test_bubble_sort_helper() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f"""
expected_read_write_context = """
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
import math
@ -1108,7 +1111,9 @@ class HelperClass:
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
# the global x variable shouldn't be included in any context type
assert code_ctx.read_writable_code.flat == '''# file: test_code.py
assert (
code_ctx.read_writable_code.flat
== '''# file: test_code.py
class MyClass:
def __init__(self):
self.x = 1
@ -1123,7 +1128,10 @@ class HelperClass:
def helper_method(self):
return self.x
'''
assert code_ctx.testgen_context.flat == '''# file: test_code.py
)
assert (
code_ctx.testgen_context.flat
== '''# file: test_code.py
class MyClass:
"""A class with a helper method. """
def __init__(self):
@ -1143,6 +1151,7 @@ class HelperClass:
def helper_method(self):
return self.x
'''
)
def test_repo_helper() -> None:
@ -2353,9 +2362,7 @@ def standalone_function():
assert '"""Helper method with docstring."""' not in hashing_context, (
"Docstrings should be removed from helper functions"
)
assert '"""Process data method."""' not in hashing_context, (
"Docstrings should be removed from helper class methods"
)
assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods"
def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None:
@ -2593,16 +2600,21 @@ def test_circular_deps():
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
content = Path(file_abs_path).read_text(encoding="utf-8")
new_code = replace_functions_and_add_imports(
source_code= add_global_assignments(optimized_code, content),
function_names= ["ApiClient.get_console_url"],
optimized_code= optimized_code,
module_abspath= Path(file_abs_path),
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
project_root_path= Path(path_to_root),
source_code=add_global_assignments(optimized_code, content),
function_names=["ApiClient.get_console_url"],
optimized_code=optimized_code,
module_abspath=Path(file_abs_path),
preexisting_objects={
("ApiClient", ()),
("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)),
},
project_root_path=Path(path_to_root),
)
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst
@ -2750,6 +2762,59 @@ FINAL_ASSIGNMENT = {"data": "value"}
assert collector.assignment_order == expected_order
def test_global_assignment_collector_annotated_assignments():
"""Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign)."""
import libcst as cst
source_code = """
# Regular global assignment
REGULAR_VAR = "regular"
# Annotated global assignments
TYPED_VAR: str = "typed"
CACHE: dict[str, int] = {}
SENTINEL: object = object()
# Annotated without value (type declaration only) - should NOT be collected
DECLARED_ONLY: int
def some_function():
# Annotated assignment inside function - should not be collected
local_typed: str = "local"
return local_typed
class SomeClass:
# Class-level annotated assignment - should not be collected
class_attr: str = "class"
# Another regular assignment
FINAL_VAR = 123
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should collect both regular and annotated global assignments with values
assert len(collector.assignments) == 5
assert "REGULAR_VAR" in collector.assignments
assert "TYPED_VAR" in collector.assignments
assert "CACHE" in collector.assignments
assert "SENTINEL" in collector.assignments
assert "FINAL_VAR" in collector.assignments
# Should not collect type declarations without values
assert "DECLARED_ONLY" not in collector.assignments
# Should not collect assignments from inside functions or classes
assert "local_typed" not in collector.assignments
assert "class_attr" not in collector.assignments
# Verify correct order
expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"]
assert collector.assignment_order == expected_order
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
@ -2790,11 +2855,7 @@ def target_function():
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
@ -2808,15 +2869,11 @@ def target_function():
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
testgen_context = code_ctx.testgen_context.markdown
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
assert "def __init__(self, data):" in testgen_context, (
"__init__ method should be included in testgen context"
)
assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context"
# The hashing context should NOT contain __init__ (excluded for stability)
hashing_context = code_ctx.hashing_code_context
assert "__init__" not in hashing_context, (
"__init__ should NOT be in hashing context (excluded for hash stability)"
)
assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)"
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
@ -2870,11 +2927,7 @@ def dump_layout(layout_type, layout):
)
)
function_to_optimize = FunctionToOptimize(
function_name="dump_layout",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None
)
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
@ -2884,9 +2937,7 @@ def dump_layout(layout_type, layout):
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
"ObjectDetectionLayoutDumper.__init__ should be tracked"
)
assert "LayoutDumper.__init__" in qualified_names, (
"LayoutDumper.__init__ should be tracked"
)
assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked"
# The testgen context should include both classes with their __init__ methods
testgen_context = code_ctx.testgen_context.markdown
@ -2896,9 +2947,7 @@ def dump_layout(layout_type, layout):
)
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
assert testgen_context.count("def __init__") >= 2, (
"Both __init__ methods should be in testgen context"
)
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
@ -2934,7 +2983,7 @@ class Text(Element):
elements_path.write_text(elements_code, encoding="utf-8")
# Create another module that imports from elements
chunking_code = '''
chunking_code = """
from mypackage.elements import Element
class PreChunk:
@ -2944,14 +2993,12 @@ class PreChunk:
class Accumulator:
def will_fit(self, chunk: PreChunk) -> bool:
return True
'''
"""
chunking_path = package_dir / "chunking.py"
chunking_path.write_text(chunking_code, encoding="utf-8")
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
context = CodeStringsMarkdown(
code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
@ -2975,16 +3022,16 @@ def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Pat
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with a class definition
elements_code = '''
elements_code = """
class Element:
def __init__(self, text: str):
self.text = text
'''
"""
elements_path = package_dir / "elements.py"
elements_path.write_text(elements_code, encoding="utf-8")
# Create code that imports Element but also redefines it locally
code_with_local_def = '''
code_with_local_def = """
from mypackage.elements import Element
# Local redefinition (this happens when LLM redefines classes)
@ -2995,13 +3042,11 @@ class Element:
class User:
def process(self, elem: Element):
pass
'''
"""
code_path = package_dir / "user.py"
code_path.write_text(code_with_local_def, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
@ -3018,7 +3063,7 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Code with stdlib/third-party imports
code = '''
code = """
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
@ -3026,13 +3071,11 @@ from dataclasses import dataclass
class MyClass:
def __init__(self, path: Path):
self.path = path
'''
"""
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
@ -3049,7 +3092,7 @@ def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path)
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with multiple class definitions
types_code = '''
types_code = """
class TypeA:
def __init__(self, value: int):
self.value = value
@ -3061,24 +3104,22 @@ class TypeB:
class TypeC:
def __init__(self):
pass
'''
"""
types_path = package_dir / "types.py"
types_path.write_text(types_code, encoding="utf-8")
# Create code that imports multiple classes
code = '''
code = """
from mypackage.types import TypeA, TypeB
class Processor:
def process(self, a: TypeA, b: TypeB):
pass
'''
"""
code_path = package_dir / "processor.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
@ -3100,7 +3141,7 @@ def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path:
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with dataclass definitions (like LLMConfig in skyvern)
models_code = '''from dataclasses import dataclass, field
models_code = """from dataclasses import dataclass, field
from typing import Optional
@dataclass(frozen=True)
@ -3114,23 +3155,21 @@ class LLMConfigBase:
class LLMConfig(LLMConfigBase):
litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = None
'''
"""
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the dataclass
code = '''from mypackage.models import LLMConfig
code = """from mypackage.models import LLMConfig
class ConfigRegistry:
def get_config(self) -> LLMConfig:
pass
'''
"""
code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
@ -3165,7 +3204,7 @@ def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(t
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with decorated class that uses field() and various type annotations
models_code = '''from dataclasses import dataclass, field
models_code = """from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
@ -3173,22 +3212,20 @@ class Config:
name: str
values: List[int] = field(default_factory=list)
description: Optional[str] = None
'''
"""
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the class
code = '''from mypackage.models import Config
code = """from mypackage.models import Config
def create_config() -> Config:
return Config(name="test")
'''
"""
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
@ -3282,12 +3319,12 @@ class TestExtractImportsForClass:
"""Test that base class imports are extracted."""
import ast
module_source = '''from abc import ABC
module_source = """from abc import ABC
from mypackage import BaseClass
class MyClass(BaseClass, ABC):
pass
'''
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
@ -3298,13 +3335,13 @@ class MyClass(BaseClass, ABC):
"""Test that decorator imports are extracted."""
import ast
module_source = '''from dataclasses import dataclass
module_source = """from dataclasses import dataclass
from functools import lru_cache
@dataclass
class MyClass:
name: str
'''
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
@ -3314,14 +3351,14 @@ class MyClass:
"""Test that type annotation imports are extracted."""
import ast
module_source = '''from typing import Optional, List
module_source = """from typing import Optional, List
from mypackage.models import Config
@dataclass
class MyClass:
config: Optional[Config]
items: List[str]
'''
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
@ -3332,13 +3369,13 @@ class MyClass:
"""Test that field() function imports are extracted for dataclasses."""
import ast
module_source = '''from dataclasses import dataclass, field
module_source = """from dataclasses import dataclass, field
from typing import List
@dataclass
class MyClass:
items: List[str] = field(default_factory=list)
'''
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
@ -3348,13 +3385,13 @@ class MyClass:
"""Test that duplicate imports are not included."""
import ast
module_source = '''from typing import Optional
module_source = """from typing import Optional
@dataclass
class MyClass:
field1: Optional[str]
field2: Optional[int]
'''
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
@ -3368,7 +3405,7 @@ def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> N
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
models_code = '''from dataclasses import dataclass
models_code = """from dataclasses import dataclass
from functools import total_ordering
@total_ordering
@ -3379,21 +3416,19 @@ class OrderedConfig:
def __lt__(self, other):
return self.priority < other.priority
'''
"""
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
code = '''from mypackage.models import OrderedConfig
code = """from mypackage.models import OrderedConfig
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
return sorted(configs)
'''
"""
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_imported_class_definitions(context, tmp_path)
@ -3452,7 +3487,7 @@ class RouterConfig(ParentConfig):
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports only the child classes (not the base classes)
code = '''from mypackage.models import ChildConfig, RouterConfig
code = """from mypackage.models import ChildConfig, RouterConfig
class ConfigRegistry:
def get_child_config(self) -> ChildConfig:
@ -3460,7 +3495,7 @@ class ConfigRegistry:
def get_router_config(self) -> RouterConfig:
pass
'''
"""
code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8")