fix: handle annotated assignments in GlobalAssignmentCollector

GlobalAssignmentCollector only handled cst.Assign but not cst.AnnAssign
(annotated assignments like `X: int = 1`). When the LLM generated
optimizations with annotated module-level variables, these weren't
copied to the target file, causing NameError at runtime.

- Add visit_AnnAssign to GlobalAssignmentCollector
- Add leave_AnnAssign to GlobalAssignmentTransformer
- Update type hints to include cst.AnnAssign
- Add test for annotated assignment handling
This commit is contained in:
Kevin Turcios 2026-01-23 18:46:49 -05:00
parent 412779d7ba
commit 9f929c2151
2 changed files with 163 additions and 100 deletions

View file

@ -30,7 +30,7 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.assignments: dict[str, cst.Assign] = {} self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {}
self.assignment_order: list[str] = [] self.assignment_order: list[str] = []
# Track scope depth to identify global assignments # Track scope depth to identify global assignments
self.scope_depth = 0 self.scope_depth = 0
@ -72,6 +72,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
self.assignment_order.append(name) self.assignment_order.append(name)
return True return True
def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]:
# Handle annotated assignments like: _CACHE: Dict[str, int] = {}
# Only process module-level annotated assignments with a value
if (
self.scope_depth == 0
and self.if_else_depth == 0
and isinstance(node.target, cst.Name)
and node.value is not None
):
name = node.target.value
self.assignments[name] = node
if name not in self.assignment_order:
self.assignment_order.append(name)
return True
def find_insertion_index_after_imports(node: cst.Module) -> int: def find_insertion_index_after_imports(node: cst.Module) -> int:
"""Find the position of the last import statement in the top-level of the module.""" """Find the position of the last import statement in the top-level of the module."""
@ -103,7 +118,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int:
class GlobalAssignmentTransformer(cst.CSTTransformer): class GlobalAssignmentTransformer(cst.CSTTransformer):
"""Transforms global assignments in the original file with those from the new file.""" """Transforms global assignments in the original file with those from the new file."""
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None: def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None:
super().__init__() super().__init__()
self.new_assignments = new_assignments self.new_assignments = new_assignments
self.new_assignment_order = new_assignment_order self.new_assignment_order = new_assignment_order
@ -150,6 +165,19 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
return updated_node return updated_node
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode:
if self.scope_depth > 0 or self.if_else_depth > 0:
return updated_node
# Check if this is a global annotated assignment we need to replace
if isinstance(original_node.target, cst.Name):
name = original_node.target.value
if name in self.new_assignments:
self.processed_assignments.add(name)
return self.new_assignments[name]
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
# Add any new assignments that weren't in the original file # Add any new assignments that weren't in the original file
new_statements = list(updated_node.body) new_statements = list(updated_node.body)

View file

@ -7,18 +7,18 @@ from collections import defaultdict
from pathlib import Path from pathlib import Path
import pytest import pytest
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.context.code_context_extractor import ( from codeflash.context.code_context_extractor import (
get_code_optimization_context,
get_imported_class_definitions,
collect_names_from_annotation, collect_names_from_annotation,
extract_imports_for_class, extract_imports_for_class,
get_code_optimization_context,
get_imported_class_definitions,
) )
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.optimizer import Optimizer from codeflash.optimization.optimizer import Optimizer
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
class HelperClass: class HelperClass:
@ -91,7 +91,10 @@ def test_code_replacement10() -> None:
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions} qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here assert qualified_names == {
"HelperClass.helper_method",
"HelperClass.__init__",
} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context hashing_context = code_ctx.hashing_code_context
@ -234,7 +237,7 @@ def test_bubble_sort_helper() -> None:
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context hashing_context = code_ctx.hashing_code_context
expected_read_write_context = f""" expected_read_write_context = """
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py ```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
import math import math
@ -1108,7 +1111,9 @@ class HelperClass:
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
# the global x variable shouldn't be included in any context type # the global x variable shouldn't be included in any context type
assert code_ctx.read_writable_code.flat == '''# file: test_code.py assert (
code_ctx.read_writable_code.flat
== '''# file: test_code.py
class MyClass: class MyClass:
def __init__(self): def __init__(self):
self.x = 1 self.x = 1
@ -1123,7 +1128,10 @@ class HelperClass:
def helper_method(self): def helper_method(self):
return self.x return self.x
''' '''
assert code_ctx.testgen_context.flat == '''# file: test_code.py )
assert (
code_ctx.testgen_context.flat
== '''# file: test_code.py
class MyClass: class MyClass:
"""A class with a helper method. """ """A class with a helper method. """
def __init__(self): def __init__(self):
@ -1143,6 +1151,7 @@ class HelperClass:
def helper_method(self): def helper_method(self):
return self.x return self.x
''' '''
)
def test_repo_helper() -> None: def test_repo_helper() -> None:
@ -2353,9 +2362,7 @@ def standalone_function():
assert '"""Helper method with docstring."""' not in hashing_context, ( assert '"""Helper method with docstring."""' not in hashing_context, (
"Docstrings should be removed from helper functions" "Docstrings should be removed from helper functions"
) )
assert '"""Process data method."""' not in hashing_context, ( assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods"
"Docstrings should be removed from helper class methods"
)
def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None: def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None:
@ -2593,16 +2600,21 @@ def test_circular_deps():
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
content = Path(file_abs_path).read_text(encoding="utf-8") content = Path(file_abs_path).read_text(encoding="utf-8")
new_code = replace_functions_and_add_imports( new_code = replace_functions_and_add_imports(
source_code= add_global_assignments(optimized_code, content), source_code=add_global_assignments(optimized_code, content),
function_names= ["ApiClient.get_console_url"], function_names=["ApiClient.get_console_url"],
optimized_code= optimized_code, optimized_code=optimized_code,
module_abspath= Path(file_abs_path), module_abspath=Path(file_abs_path),
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))}, preexisting_objects={
project_root_path= Path(path_to_root), ("ApiClient", ()),
("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)),
},
project_root_path=Path(path_to_root),
) )
assert "import ApiClient" not in new_code, "Error: Circular dependency found" assert "import ApiClient" not in new_code, "Error: Circular dependency found"
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist" assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
def test_global_assignment_collector_with_async_function(): def test_global_assignment_collector_with_async_function():
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions.""" """Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
import libcst as cst import libcst as cst
@ -2750,6 +2762,59 @@ FINAL_ASSIGNMENT = {"data": "value"}
assert collector.assignment_order == expected_order assert collector.assignment_order == expected_order
def test_global_assignment_collector_annotated_assignments():
"""Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign)."""
import libcst as cst
source_code = """
# Regular global assignment
REGULAR_VAR = "regular"
# Annotated global assignments
TYPED_VAR: str = "typed"
CACHE: dict[str, int] = {}
SENTINEL: object = object()
# Annotated without value (type declaration only) - should NOT be collected
DECLARED_ONLY: int
def some_function():
# Annotated assignment inside function - should not be collected
local_typed: str = "local"
return local_typed
class SomeClass:
# Class-level annotated assignment - should not be collected
class_attr: str = "class"
# Another regular assignment
FINAL_VAR = 123
"""
tree = cst.parse_module(source_code)
collector = GlobalAssignmentCollector()
tree.visit(collector)
# Should collect both regular and annotated global assignments with values
assert len(collector.assignments) == 5
assert "REGULAR_VAR" in collector.assignments
assert "TYPED_VAR" in collector.assignments
assert "CACHE" in collector.assignments
assert "SENTINEL" in collector.assignments
assert "FINAL_VAR" in collector.assignments
# Should not collect type declarations without values
assert "DECLARED_ONLY" not in collector.assignments
# Should not collect assignments from inside functions or classes
assert "local_typed" not in collector.assignments
assert "class_attr" not in collector.assignments
# Verify correct order
expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"]
assert collector.assignment_order == expected_order
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None: def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper. """Test that when a class is instantiated, its __init__ method is tracked as a helper.
@ -2790,11 +2855,7 @@ def target_function():
) )
) )
function_to_optimize = FunctionToOptimize( function_to_optimize = FunctionToOptimize(
function_name="target_function", function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
) )
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
@ -2808,15 +2869,11 @@ def target_function():
# The testgen context should contain the class with __init__ (critical for LLM to know constructor) # The testgen context should contain the class with __init__ (critical for LLM to know constructor)
testgen_context = code_ctx.testgen_context.markdown testgen_context = code_ctx.testgen_context.markdown
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context" assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
assert "def __init__(self, data):" in testgen_context, ( assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context"
"__init__ method should be included in testgen context"
)
# The hashing context should NOT contain __init__ (excluded for stability) # The hashing context should NOT contain __init__ (excluded for stability)
hashing_context = code_ctx.hashing_code_context hashing_context = code_ctx.hashing_code_context
assert "__init__" not in hashing_context, ( assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)"
"__init__ should NOT be in hashing context (excluded for hash stability)"
)
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None: def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
@ -2870,11 +2927,7 @@ def dump_layout(layout_type, layout):
) )
) )
function_to_optimize = FunctionToOptimize( function_to_optimize = FunctionToOptimize(
function_name="dump_layout", function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
) )
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
@ -2884,9 +2937,7 @@ def dump_layout(layout_type, layout):
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, ( assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
"ObjectDetectionLayoutDumper.__init__ should be tracked" "ObjectDetectionLayoutDumper.__init__ should be tracked"
) )
assert "LayoutDumper.__init__" in qualified_names, ( assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked"
"LayoutDumper.__init__ should be tracked"
)
# The testgen context should include both classes with their __init__ methods # The testgen context should include both classes with their __init__ methods
testgen_context = code_ctx.testgen_context.markdown testgen_context = code_ctx.testgen_context.markdown
@ -2896,9 +2947,7 @@ def dump_layout(layout_type, layout):
) )
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures) # Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
assert testgen_context.count("def __init__") >= 2, ( assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
"Both __init__ methods should be in testgen context"
)
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None: def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
@ -2934,7 +2983,7 @@ class Text(Element):
elements_path.write_text(elements_code, encoding="utf-8") elements_path.write_text(elements_code, encoding="utf-8")
# Create another module that imports from elements # Create another module that imports from elements
chunking_code = ''' chunking_code = """
from mypackage.elements import Element from mypackage.elements import Element
class PreChunk: class PreChunk:
@ -2944,14 +2993,12 @@ class PreChunk:
class Accumulator: class Accumulator:
def will_fit(self, chunk: PreChunk) -> bool: def will_fit(self, chunk: PreChunk) -> bool:
return True return True
''' """
chunking_path = package_dir / "chunking.py" chunking_path = package_dir / "chunking.py"
chunking_path.write_text(chunking_code, encoding="utf-8") chunking_path.write_text(chunking_code, encoding="utf-8")
# Create CodeStringsMarkdown from the chunking module (simulating testgen context) # Create CodeStringsMarkdown from the chunking module (simulating testgen context)
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]
)
# Call get_imported_class_definitions # Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -2975,16 +3022,16 @@ def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Pat
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with a class definition # Create a module with a class definition
elements_code = ''' elements_code = """
class Element: class Element:
def __init__(self, text: str): def __init__(self, text: str):
self.text = text self.text = text
''' """
elements_path = package_dir / "elements.py" elements_path = package_dir / "elements.py"
elements_path.write_text(elements_code, encoding="utf-8") elements_path.write_text(elements_code, encoding="utf-8")
# Create code that imports Element but also redefines it locally # Create code that imports Element but also redefines it locally
code_with_local_def = ''' code_with_local_def = """
from mypackage.elements import Element from mypackage.elements import Element
# Local redefinition (this happens when LLM redefines classes) # Local redefinition (this happens when LLM redefines classes)
@ -2995,13 +3042,11 @@ class Element:
class User: class User:
def process(self, elem: Element): def process(self, elem: Element):
pass pass
''' """
code_path = package_dir / "user.py" code_path = package_dir / "user.py"
code_path.write_text(code_with_local_def, encoding="utf-8") code_path.write_text(code_with_local_def, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]
)
# Call get_imported_class_definitions # Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3018,7 +3063,7 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
# Code with stdlib/third-party imports # Code with stdlib/third-party imports
code = ''' code = """
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from dataclasses import dataclass from dataclasses import dataclass
@ -3026,13 +3071,11 @@ from dataclasses import dataclass
class MyClass: class MyClass:
def __init__(self, path: Path): def __init__(self, path: Path):
self.path = path self.path = path
''' """
code_path = package_dir / "main.py" code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions # Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3049,7 +3092,7 @@ def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path)
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with multiple class definitions # Create a module with multiple class definitions
types_code = ''' types_code = """
class TypeA: class TypeA:
def __init__(self, value: int): def __init__(self, value: int):
self.value = value self.value = value
@ -3061,24 +3104,22 @@ class TypeB:
class TypeC: class TypeC:
def __init__(self): def __init__(self):
pass pass
''' """
types_path = package_dir / "types.py" types_path = package_dir / "types.py"
types_path.write_text(types_code, encoding="utf-8") types_path.write_text(types_code, encoding="utf-8")
# Create code that imports multiple classes # Create code that imports multiple classes
code = ''' code = """
from mypackage.types import TypeA, TypeB from mypackage.types import TypeA, TypeB
class Processor: class Processor:
def process(self, a: TypeA, b: TypeB): def process(self, a: TypeA, b: TypeB):
pass pass
''' """
code_path = package_dir / "processor.py" code_path = package_dir / "processor.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions # Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3100,7 +3141,7 @@ def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path:
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with dataclass definitions (like LLMConfig in skyvern) # Create a module with dataclass definitions (like LLMConfig in skyvern)
models_code = '''from dataclasses import dataclass, field models_code = """from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@dataclass(frozen=True) @dataclass(frozen=True)
@ -3114,23 +3155,21 @@ class LLMConfigBase:
class LLMConfig(LLMConfigBase): class LLMConfig(LLMConfigBase):
litellm_params: Optional[dict] = field(default=None) litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = None max_tokens: int | None = None
''' """
models_path = package_dir / "models.py" models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8") models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the dataclass # Create code that imports the dataclass
code = '''from mypackage.models import LLMConfig code = """from mypackage.models import LLMConfig
class ConfigRegistry: class ConfigRegistry:
def get_config(self) -> LLMConfig: def get_config(self) -> LLMConfig:
pass pass
''' """
code_path = package_dir / "registry.py" code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions # Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3165,7 +3204,7 @@ def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(t
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with decorated class that uses field() and various type annotations # Create a module with decorated class that uses field() and various type annotations
models_code = '''from dataclasses import dataclass, field models_code = """from dataclasses import dataclass, field
from typing import Optional, List from typing import Optional, List
@dataclass @dataclass
@ -3173,22 +3212,20 @@ class Config:
name: str name: str
values: List[int] = field(default_factory=list) values: List[int] = field(default_factory=list)
description: Optional[str] = None description: Optional[str] = None
''' """
models_path = package_dir / "models.py" models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8") models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the class # Create code that imports the class
code = '''from mypackage.models import Config code = """from mypackage.models import Config
def create_config() -> Config: def create_config() -> Config:
return Config(name="test") return Config(name="test")
''' """
code_path = package_dir / "main.py" code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
code_strings=[CodeString(code=code, file_path=code_path)]
)
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3282,12 +3319,12 @@ class TestExtractImportsForClass:
"""Test that base class imports are extracted.""" """Test that base class imports are extracted."""
import ast import ast
module_source = '''from abc import ABC module_source = """from abc import ABC
from mypackage import BaseClass from mypackage import BaseClass
class MyClass(BaseClass, ABC): class MyClass(BaseClass, ABC):
pass pass
''' """
tree = ast.parse(module_source) tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source) result = extract_imports_for_class(tree, class_node, module_source)
@ -3298,13 +3335,13 @@ class MyClass(BaseClass, ABC):
"""Test that decorator imports are extracted.""" """Test that decorator imports are extracted."""
import ast import ast
module_source = '''from dataclasses import dataclass module_source = """from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
@dataclass @dataclass
class MyClass: class MyClass:
name: str name: str
''' """
tree = ast.parse(module_source) tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source) result = extract_imports_for_class(tree, class_node, module_source)
@ -3314,14 +3351,14 @@ class MyClass:
"""Test that type annotation imports are extracted.""" """Test that type annotation imports are extracted."""
import ast import ast
module_source = '''from typing import Optional, List module_source = """from typing import Optional, List
from mypackage.models import Config from mypackage.models import Config
@dataclass @dataclass
class MyClass: class MyClass:
config: Optional[Config] config: Optional[Config]
items: List[str] items: List[str]
''' """
tree = ast.parse(module_source) tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source) result = extract_imports_for_class(tree, class_node, module_source)
@ -3332,13 +3369,13 @@ class MyClass:
"""Test that field() function imports are extracted for dataclasses.""" """Test that field() function imports are extracted for dataclasses."""
import ast import ast
module_source = '''from dataclasses import dataclass, field module_source = """from dataclasses import dataclass, field
from typing import List from typing import List
@dataclass @dataclass
class MyClass: class MyClass:
items: List[str] = field(default_factory=list) items: List[str] = field(default_factory=list)
''' """
tree = ast.parse(module_source) tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source) result = extract_imports_for_class(tree, class_node, module_source)
@ -3348,13 +3385,13 @@ class MyClass:
"""Test that duplicate imports are not included.""" """Test that duplicate imports are not included."""
import ast import ast
module_source = '''from typing import Optional module_source = """from typing import Optional
@dataclass @dataclass
class MyClass: class MyClass:
field1: Optional[str] field1: Optional[str]
field2: Optional[int] field2: Optional[int]
''' """
tree = ast.parse(module_source) tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef)) class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source) result = extract_imports_for_class(tree, class_node, module_source)
@ -3368,7 +3405,7 @@ def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> N
package_dir.mkdir() package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8") (package_dir / "__init__.py").write_text("", encoding="utf-8")
models_code = '''from dataclasses import dataclass models_code = """from dataclasses import dataclass
from functools import total_ordering from functools import total_ordering
@total_ordering @total_ordering
@ -3379,21 +3416,19 @@ class OrderedConfig:
def __lt__(self, other): def __lt__(self, other):
return self.priority < other.priority return self.priority < other.priority
''' """
models_path = package_dir / "models.py" models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8") models_path.write_text(models_code, encoding="utf-8")
code = '''from mypackage.models import OrderedConfig code = """from mypackage.models import OrderedConfig
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]: def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
return sorted(configs) return sorted(configs)
''' """
code_path = package_dir / "main.py" code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown( context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
code_strings=[CodeString(code=code, file_path=code_path)]
)
result = get_imported_class_definitions(context, tmp_path) result = get_imported_class_definitions(context, tmp_path)
@ -3452,7 +3487,7 @@ class RouterConfig(ParentConfig):
models_path.write_text(models_code, encoding="utf-8") models_path.write_text(models_code, encoding="utf-8")
# Create code that imports only the child classes (not the base classes) # Create code that imports only the child classes (not the base classes)
code = '''from mypackage.models import ChildConfig, RouterConfig code = """from mypackage.models import ChildConfig, RouterConfig
class ConfigRegistry: class ConfigRegistry:
def get_child_config(self) -> ChildConfig: def get_child_config(self) -> ChildConfig:
@ -3460,7 +3495,7 @@ class ConfigRegistry:
def get_router_config(self) -> RouterConfig: def get_router_config(self) -> RouterConfig:
pass pass
''' """
code_path = package_dir / "registry.py" code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8") code_path.write_text(code, encoding="utf-8")