fix: handle annotated assignments in GlobalAssignmentCollector
GlobalAssignmentCollector only handled cst.Assign but not cst.AnnAssign (annotated assignments like `X: int = 1`). When the LLM generated optimizations with annotated module-level variables, these weren't copied to the target file, causing NameError at runtime. - Add visit_AnnAssign to GlobalAssignmentCollector - Add leave_AnnAssign to GlobalAssignmentTransformer - Update type hints to include cst.AnnAssign - Add test for annotated assignment handling
This commit is contained in:
parent
412779d7ba
commit
9f929c2151
2 changed files with 163 additions and 100 deletions
|
|
@ -30,7 +30,7 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.assignments: dict[str, cst.Assign] = {}
|
self.assignments: dict[str, cst.Assign | cst.AnnAssign] = {}
|
||||||
self.assignment_order: list[str] = []
|
self.assignment_order: list[str] = []
|
||||||
# Track scope depth to identify global assignments
|
# Track scope depth to identify global assignments
|
||||||
self.scope_depth = 0
|
self.scope_depth = 0
|
||||||
|
|
@ -72,6 +72,21 @@ class GlobalAssignmentCollector(cst.CSTVisitor):
|
||||||
self.assignment_order.append(name)
|
self.assignment_order.append(name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def visit_AnnAssign(self, node: cst.AnnAssign) -> Optional[bool]:
|
||||||
|
# Handle annotated assignments like: _CACHE: Dict[str, int] = {}
|
||||||
|
# Only process module-level annotated assignments with a value
|
||||||
|
if (
|
||||||
|
self.scope_depth == 0
|
||||||
|
and self.if_else_depth == 0
|
||||||
|
and isinstance(node.target, cst.Name)
|
||||||
|
and node.value is not None
|
||||||
|
):
|
||||||
|
name = node.target.value
|
||||||
|
self.assignments[name] = node
|
||||||
|
if name not in self.assignment_order:
|
||||||
|
self.assignment_order.append(name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def find_insertion_index_after_imports(node: cst.Module) -> int:
|
def find_insertion_index_after_imports(node: cst.Module) -> int:
|
||||||
"""Find the position of the last import statement in the top-level of the module."""
|
"""Find the position of the last import statement in the top-level of the module."""
|
||||||
|
|
@ -103,7 +118,7 @@ def find_insertion_index_after_imports(node: cst.Module) -> int:
|
||||||
class GlobalAssignmentTransformer(cst.CSTTransformer):
|
class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||||
"""Transforms global assignments in the original file with those from the new file."""
|
"""Transforms global assignments in the original file with those from the new file."""
|
||||||
|
|
||||||
def __init__(self, new_assignments: dict[str, cst.Assign], new_assignment_order: list[str]) -> None:
|
def __init__(self, new_assignments: dict[str, cst.Assign | cst.AnnAssign], new_assignment_order: list[str]) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.new_assignments = new_assignments
|
self.new_assignments = new_assignments
|
||||||
self.new_assignment_order = new_assignment_order
|
self.new_assignment_order = new_assignment_order
|
||||||
|
|
@ -150,6 +165,19 @@ class GlobalAssignmentTransformer(cst.CSTTransformer):
|
||||||
|
|
||||||
return updated_node
|
return updated_node
|
||||||
|
|
||||||
|
def leave_AnnAssign(self, original_node: cst.AnnAssign, updated_node: cst.AnnAssign) -> cst.CSTNode:
|
||||||
|
if self.scope_depth > 0 or self.if_else_depth > 0:
|
||||||
|
return updated_node
|
||||||
|
|
||||||
|
# Check if this is a global annotated assignment we need to replace
|
||||||
|
if isinstance(original_node.target, cst.Name):
|
||||||
|
name = original_node.target.value
|
||||||
|
if name in self.new_assignments:
|
||||||
|
self.processed_assignments.add(name)
|
||||||
|
return self.new_assignments[name]
|
||||||
|
|
||||||
|
return updated_node
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||||
# Add any new assignments that weren't in the original file
|
# Add any new assignments that weren't in the original file
|
||||||
new_statements = list(updated_node.body)
|
new_statements = list(updated_node.body)
|
||||||
|
|
|
||||||
|
|
@ -7,18 +7,18 @@ from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||||
|
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||||
from codeflash.context.code_context_extractor import (
|
from codeflash.context.code_context_extractor import (
|
||||||
get_code_optimization_context,
|
|
||||||
get_imported_class_definitions,
|
|
||||||
collect_names_from_annotation,
|
collect_names_from_annotation,
|
||||||
extract_imports_for_class,
|
extract_imports_for_class,
|
||||||
|
get_code_optimization_context,
|
||||||
|
get_imported_class_definitions,
|
||||||
)
|
)
|
||||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.models.models import FunctionParent
|
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||||
from codeflash.optimization.optimizer import Optimizer
|
from codeflash.optimization.optimizer import Optimizer
|
||||||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
|
||||||
from codeflash.code_utils.code_extractor import add_global_assignments, GlobalAssignmentCollector
|
|
||||||
|
|
||||||
|
|
||||||
class HelperClass:
|
class HelperClass:
|
||||||
|
|
@ -91,7 +91,10 @@ def test_code_replacement10() -> None:
|
||||||
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
|
||||||
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
|
||||||
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
|
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
|
||||||
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
|
assert qualified_names == {
|
||||||
|
"HelperClass.helper_method",
|
||||||
|
"HelperClass.__init__",
|
||||||
|
} # Nested method should not be in here
|
||||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||||
hashing_context = code_ctx.hashing_code_context
|
hashing_context = code_ctx.hashing_code_context
|
||||||
|
|
||||||
|
|
@ -234,7 +237,7 @@ def test_bubble_sort_helper() -> None:
|
||||||
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
|
||||||
hashing_context = code_ctx.hashing_code_context
|
hashing_context = code_ctx.hashing_code_context
|
||||||
|
|
||||||
expected_read_write_context = f"""
|
expected_read_write_context = """
|
||||||
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
@ -1108,7 +1111,9 @@ class HelperClass:
|
||||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||||
|
|
||||||
# the global x variable shouldn't be included in any context type
|
# the global x variable shouldn't be included in any context type
|
||||||
assert code_ctx.read_writable_code.flat == '''# file: test_code.py
|
assert (
|
||||||
|
code_ctx.read_writable_code.flat
|
||||||
|
== '''# file: test_code.py
|
||||||
class MyClass:
|
class MyClass:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.x = 1
|
self.x = 1
|
||||||
|
|
@ -1123,7 +1128,10 @@ class HelperClass:
|
||||||
def helper_method(self):
|
def helper_method(self):
|
||||||
return self.x
|
return self.x
|
||||||
'''
|
'''
|
||||||
assert code_ctx.testgen_context.flat == '''# file: test_code.py
|
)
|
||||||
|
assert (
|
||||||
|
code_ctx.testgen_context.flat
|
||||||
|
== '''# file: test_code.py
|
||||||
class MyClass:
|
class MyClass:
|
||||||
"""A class with a helper method. """
|
"""A class with a helper method. """
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -1143,6 +1151,7 @@ class HelperClass:
|
||||||
def helper_method(self):
|
def helper_method(self):
|
||||||
return self.x
|
return self.x
|
||||||
'''
|
'''
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_repo_helper() -> None:
|
def test_repo_helper() -> None:
|
||||||
|
|
@ -2353,9 +2362,7 @@ def standalone_function():
|
||||||
assert '"""Helper method with docstring."""' not in hashing_context, (
|
assert '"""Helper method with docstring."""' not in hashing_context, (
|
||||||
"Docstrings should be removed from helper functions"
|
"Docstrings should be removed from helper functions"
|
||||||
)
|
)
|
||||||
assert '"""Process data method."""' not in hashing_context, (
|
assert '"""Process data method."""' not in hashing_context, "Docstrings should be removed from helper class methods"
|
||||||
"Docstrings should be removed from helper class methods"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None:
|
def test_hashing_code_context_with_nested_classes(tmp_path: Path) -> None:
|
||||||
|
|
@ -2593,16 +2600,21 @@ def test_circular_deps():
|
||||||
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
|
optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8")
|
||||||
content = Path(file_abs_path).read_text(encoding="utf-8")
|
content = Path(file_abs_path).read_text(encoding="utf-8")
|
||||||
new_code = replace_functions_and_add_imports(
|
new_code = replace_functions_and_add_imports(
|
||||||
source_code= add_global_assignments(optimized_code, content),
|
source_code=add_global_assignments(optimized_code, content),
|
||||||
function_names= ["ApiClient.get_console_url"],
|
function_names=["ApiClient.get_console_url"],
|
||||||
optimized_code= optimized_code,
|
optimized_code=optimized_code,
|
||||||
module_abspath= Path(file_abs_path),
|
module_abspath=Path(file_abs_path),
|
||||||
preexisting_objects= {('ApiClient', ()), ('get_console_url', (FunctionParent(name='ApiClient', type='ClassDef'),))},
|
preexisting_objects={
|
||||||
project_root_path= Path(path_to_root),
|
("ApiClient", ()),
|
||||||
|
("get_console_url", (FunctionParent(name="ApiClient", type="ClassDef"),)),
|
||||||
|
},
|
||||||
|
project_root_path=Path(path_to_root),
|
||||||
)
|
)
|
||||||
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
assert "import ApiClient" not in new_code, "Error: Circular dependency found"
|
||||||
|
|
||||||
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
assert "import urllib.parse" in new_code, "Make sure imports for optimization global assignments exist"
|
||||||
|
|
||||||
|
|
||||||
def test_global_assignment_collector_with_async_function():
|
def test_global_assignment_collector_with_async_function():
|
||||||
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
|
"""Test GlobalAssignmentCollector correctly identifies global assignments outside async functions."""
|
||||||
import libcst as cst
|
import libcst as cst
|
||||||
|
|
@ -2750,6 +2762,59 @@ FINAL_ASSIGNMENT = {"data": "value"}
|
||||||
assert collector.assignment_order == expected_order
|
assert collector.assignment_order == expected_order
|
||||||
|
|
||||||
|
|
||||||
|
def test_global_assignment_collector_annotated_assignments():
|
||||||
|
"""Test GlobalAssignmentCollector correctly handles annotated assignments (AnnAssign)."""
|
||||||
|
import libcst as cst
|
||||||
|
|
||||||
|
source_code = """
|
||||||
|
# Regular global assignment
|
||||||
|
REGULAR_VAR = "regular"
|
||||||
|
|
||||||
|
# Annotated global assignments
|
||||||
|
TYPED_VAR: str = "typed"
|
||||||
|
CACHE: dict[str, int] = {}
|
||||||
|
SENTINEL: object = object()
|
||||||
|
|
||||||
|
# Annotated without value (type declaration only) - should NOT be collected
|
||||||
|
DECLARED_ONLY: int
|
||||||
|
|
||||||
|
def some_function():
|
||||||
|
# Annotated assignment inside function - should not be collected
|
||||||
|
local_typed: str = "local"
|
||||||
|
return local_typed
|
||||||
|
|
||||||
|
class SomeClass:
|
||||||
|
# Class-level annotated assignment - should not be collected
|
||||||
|
class_attr: str = "class"
|
||||||
|
|
||||||
|
# Another regular assignment
|
||||||
|
FINAL_VAR = 123
|
||||||
|
"""
|
||||||
|
|
||||||
|
tree = cst.parse_module(source_code)
|
||||||
|
collector = GlobalAssignmentCollector()
|
||||||
|
tree.visit(collector)
|
||||||
|
|
||||||
|
# Should collect both regular and annotated global assignments with values
|
||||||
|
assert len(collector.assignments) == 5
|
||||||
|
assert "REGULAR_VAR" in collector.assignments
|
||||||
|
assert "TYPED_VAR" in collector.assignments
|
||||||
|
assert "CACHE" in collector.assignments
|
||||||
|
assert "SENTINEL" in collector.assignments
|
||||||
|
assert "FINAL_VAR" in collector.assignments
|
||||||
|
|
||||||
|
# Should not collect type declarations without values
|
||||||
|
assert "DECLARED_ONLY" not in collector.assignments
|
||||||
|
|
||||||
|
# Should not collect assignments from inside functions or classes
|
||||||
|
assert "local_typed" not in collector.assignments
|
||||||
|
assert "class_attr" not in collector.assignments
|
||||||
|
|
||||||
|
# Verify correct order
|
||||||
|
expected_order = ["REGULAR_VAR", "TYPED_VAR", "CACHE", "SENTINEL", "FINAL_VAR"]
|
||||||
|
assert collector.assignment_order == expected_order
|
||||||
|
|
||||||
|
|
||||||
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
|
||||||
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.
|
||||||
|
|
||||||
|
|
@ -2790,11 +2855,7 @@ def target_function():
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
function_to_optimize = FunctionToOptimize(
|
function_to_optimize = FunctionToOptimize(
|
||||||
function_name="target_function",
|
function_name="target_function", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
||||||
file_path=file_path,
|
|
||||||
parents=[],
|
|
||||||
starting_line=None,
|
|
||||||
ending_line=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||||
|
|
@ -2808,15 +2869,11 @@ def target_function():
|
||||||
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
|
# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
|
||||||
testgen_context = code_ctx.testgen_context.markdown
|
testgen_context = code_ctx.testgen_context.markdown
|
||||||
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
|
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
|
||||||
assert "def __init__(self, data):" in testgen_context, (
|
assert "def __init__(self, data):" in testgen_context, "__init__ method should be included in testgen context"
|
||||||
"__init__ method should be included in testgen context"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The hashing context should NOT contain __init__ (excluded for stability)
|
# The hashing context should NOT contain __init__ (excluded for stability)
|
||||||
hashing_context = code_ctx.hashing_code_context
|
hashing_context = code_ctx.hashing_code_context
|
||||||
assert "__init__" not in hashing_context, (
|
assert "__init__" not in hashing_context, "__init__ should NOT be in hashing context (excluded for hash stability)"
|
||||||
"__init__ should NOT be in hashing context (excluded for hash stability)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
|
def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
|
||||||
|
|
@ -2870,11 +2927,7 @@ def dump_layout(layout_type, layout):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
function_to_optimize = FunctionToOptimize(
|
function_to_optimize = FunctionToOptimize(
|
||||||
function_name="dump_layout",
|
function_name="dump_layout", file_path=file_path, parents=[], starting_line=None, ending_line=None
|
||||||
file_path=file_path,
|
|
||||||
parents=[],
|
|
||||||
starting_line=None,
|
|
||||||
ending_line=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
|
||||||
|
|
@ -2884,9 +2937,7 @@ def dump_layout(layout_type, layout):
|
||||||
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
|
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
|
||||||
"ObjectDetectionLayoutDumper.__init__ should be tracked"
|
"ObjectDetectionLayoutDumper.__init__ should be tracked"
|
||||||
)
|
)
|
||||||
assert "LayoutDumper.__init__" in qualified_names, (
|
assert "LayoutDumper.__init__" in qualified_names, "LayoutDumper.__init__ should be tracked"
|
||||||
"LayoutDumper.__init__ should be tracked"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The testgen context should include both classes with their __init__ methods
|
# The testgen context should include both classes with their __init__ methods
|
||||||
testgen_context = code_ctx.testgen_context.markdown
|
testgen_context = code_ctx.testgen_context.markdown
|
||||||
|
|
@ -2896,9 +2947,7 @@ def dump_layout(layout_type, layout):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
|
# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
|
||||||
assert testgen_context.count("def __init__") >= 2, (
|
assert testgen_context.count("def __init__") >= 2, "Both __init__ methods should be in testgen context"
|
||||||
"Both __init__ methods should be in testgen context"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
||||||
|
|
@ -2934,7 +2983,7 @@ class Text(Element):
|
||||||
elements_path.write_text(elements_code, encoding="utf-8")
|
elements_path.write_text(elements_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create another module that imports from elements
|
# Create another module that imports from elements
|
||||||
chunking_code = '''
|
chunking_code = """
|
||||||
from mypackage.elements import Element
|
from mypackage.elements import Element
|
||||||
|
|
||||||
class PreChunk:
|
class PreChunk:
|
||||||
|
|
@ -2944,14 +2993,12 @@ class PreChunk:
|
||||||
class Accumulator:
|
class Accumulator:
|
||||||
def will_fit(self, chunk: PreChunk) -> bool:
|
def will_fit(self, chunk: PreChunk) -> bool:
|
||||||
return True
|
return True
|
||||||
'''
|
"""
|
||||||
chunking_path = package_dir / "chunking.py"
|
chunking_path = package_dir / "chunking.py"
|
||||||
chunking_path.write_text(chunking_code, encoding="utf-8")
|
chunking_path.write_text(chunking_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=chunking_code, file_path=chunking_path)])
|
||||||
code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call get_imported_class_definitions
|
# Call get_imported_class_definitions
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
@ -2975,16 +3022,16 @@ def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Pat
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
# Create a module with a class definition
|
# Create a module with a class definition
|
||||||
elements_code = '''
|
elements_code = """
|
||||||
class Element:
|
class Element:
|
||||||
def __init__(self, text: str):
|
def __init__(self, text: str):
|
||||||
self.text = text
|
self.text = text
|
||||||
'''
|
"""
|
||||||
elements_path = package_dir / "elements.py"
|
elements_path = package_dir / "elements.py"
|
||||||
elements_path.write_text(elements_code, encoding="utf-8")
|
elements_path.write_text(elements_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create code that imports Element but also redefines it locally
|
# Create code that imports Element but also redefines it locally
|
||||||
code_with_local_def = '''
|
code_with_local_def = """
|
||||||
from mypackage.elements import Element
|
from mypackage.elements import Element
|
||||||
|
|
||||||
# Local redefinition (this happens when LLM redefines classes)
|
# Local redefinition (this happens when LLM redefines classes)
|
||||||
|
|
@ -2995,13 +3042,11 @@ class Element:
|
||||||
class User:
|
class User:
|
||||||
def process(self, elem: Element):
|
def process(self, elem: Element):
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "user.py"
|
code_path = package_dir / "user.py"
|
||||||
code_path.write_text(code_with_local_def, encoding="utf-8")
|
code_path.write_text(code_with_local_def, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code_with_local_def, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call get_imported_class_definitions
|
# Call get_imported_class_definitions
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
@ -3018,7 +3063,7 @@ def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> Non
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
# Code with stdlib/third-party imports
|
# Code with stdlib/third-party imports
|
||||||
code = '''
|
code = """
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -3026,13 +3071,11 @@ from dataclasses import dataclass
|
||||||
class MyClass:
|
class MyClass:
|
||||||
def __init__(self, path: Path):
|
def __init__(self, path: Path):
|
||||||
self.path = path
|
self.path = path
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "main.py"
|
code_path = package_dir / "main.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call get_imported_class_definitions
|
# Call get_imported_class_definitions
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
@ -3049,7 +3092,7 @@ def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path)
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
# Create a module with multiple class definitions
|
# Create a module with multiple class definitions
|
||||||
types_code = '''
|
types_code = """
|
||||||
class TypeA:
|
class TypeA:
|
||||||
def __init__(self, value: int):
|
def __init__(self, value: int):
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
@ -3061,24 +3104,22 @@ class TypeB:
|
||||||
class TypeC:
|
class TypeC:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
types_path = package_dir / "types.py"
|
types_path = package_dir / "types.py"
|
||||||
types_path.write_text(types_code, encoding="utf-8")
|
types_path.write_text(types_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create code that imports multiple classes
|
# Create code that imports multiple classes
|
||||||
code = '''
|
code = """
|
||||||
from mypackage.types import TypeA, TypeB
|
from mypackage.types import TypeA, TypeB
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
def process(self, a: TypeA, b: TypeB):
|
def process(self, a: TypeA, b: TypeB):
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "processor.py"
|
code_path = package_dir / "processor.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call get_imported_class_definitions
|
# Call get_imported_class_definitions
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
@ -3100,7 +3141,7 @@ def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path:
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
# Create a module with dataclass definitions (like LLMConfig in skyvern)
|
# Create a module with dataclass definitions (like LLMConfig in skyvern)
|
||||||
models_code = '''from dataclasses import dataclass, field
|
models_code = """from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|
@ -3114,23 +3155,21 @@ class LLMConfigBase:
|
||||||
class LLMConfig(LLMConfigBase):
|
class LLMConfig(LLMConfigBase):
|
||||||
litellm_params: Optional[dict] = field(default=None)
|
litellm_params: Optional[dict] = field(default=None)
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
'''
|
"""
|
||||||
models_path = package_dir / "models.py"
|
models_path = package_dir / "models.py"
|
||||||
models_path.write_text(models_code, encoding="utf-8")
|
models_path.write_text(models_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create code that imports the dataclass
|
# Create code that imports the dataclass
|
||||||
code = '''from mypackage.models import LLMConfig
|
code = """from mypackage.models import LLMConfig
|
||||||
|
|
||||||
class ConfigRegistry:
|
class ConfigRegistry:
|
||||||
def get_config(self) -> LLMConfig:
|
def get_config(self) -> LLMConfig:
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "registry.py"
|
code_path = package_dir / "registry.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call get_imported_class_definitions
|
# Call get_imported_class_definitions
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
@ -3165,7 +3204,7 @@ def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(t
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
# Create a module with decorated class that uses field() and various type annotations
|
# Create a module with decorated class that uses field() and various type annotations
|
||||||
models_code = '''from dataclasses import dataclass, field
|
models_code = """from dataclasses import dataclass, field
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -3173,22 +3212,20 @@ class Config:
|
||||||
name: str
|
name: str
|
||||||
values: List[int] = field(default_factory=list)
|
values: List[int] = field(default_factory=list)
|
||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
'''
|
"""
|
||||||
models_path = package_dir / "models.py"
|
models_path = package_dir / "models.py"
|
||||||
models_path.write_text(models_code, encoding="utf-8")
|
models_path.write_text(models_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create code that imports the class
|
# Create code that imports the class
|
||||||
code = '''from mypackage.models import Config
|
code = """from mypackage.models import Config
|
||||||
|
|
||||||
def create_config() -> Config:
|
def create_config() -> Config:
|
||||||
return Config(name="test")
|
return Config(name="test")
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "main.py"
|
code_path = package_dir / "main.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
||||||
|
|
@ -3282,12 +3319,12 @@ class TestExtractImportsForClass:
|
||||||
"""Test that base class imports are extracted."""
|
"""Test that base class imports are extracted."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
module_source = '''from abc import ABC
|
module_source = """from abc import ABC
|
||||||
from mypackage import BaseClass
|
from mypackage import BaseClass
|
||||||
|
|
||||||
class MyClass(BaseClass, ABC):
|
class MyClass(BaseClass, ABC):
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
tree = ast.parse(module_source)
|
tree = ast.parse(module_source)
|
||||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||||
result = extract_imports_for_class(tree, class_node, module_source)
|
result = extract_imports_for_class(tree, class_node, module_source)
|
||||||
|
|
@ -3298,13 +3335,13 @@ class MyClass(BaseClass, ABC):
|
||||||
"""Test that decorator imports are extracted."""
|
"""Test that decorator imports are extracted."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
module_source = '''from dataclasses import dataclass
|
module_source = """from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyClass:
|
class MyClass:
|
||||||
name: str
|
name: str
|
||||||
'''
|
"""
|
||||||
tree = ast.parse(module_source)
|
tree = ast.parse(module_source)
|
||||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||||
result = extract_imports_for_class(tree, class_node, module_source)
|
result = extract_imports_for_class(tree, class_node, module_source)
|
||||||
|
|
@ -3314,14 +3351,14 @@ class MyClass:
|
||||||
"""Test that type annotation imports are extracted."""
|
"""Test that type annotation imports are extracted."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
module_source = '''from typing import Optional, List
|
module_source = """from typing import Optional, List
|
||||||
from mypackage.models import Config
|
from mypackage.models import Config
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyClass:
|
class MyClass:
|
||||||
config: Optional[Config]
|
config: Optional[Config]
|
||||||
items: List[str]
|
items: List[str]
|
||||||
'''
|
"""
|
||||||
tree = ast.parse(module_source)
|
tree = ast.parse(module_source)
|
||||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||||
result = extract_imports_for_class(tree, class_node, module_source)
|
result = extract_imports_for_class(tree, class_node, module_source)
|
||||||
|
|
@ -3332,13 +3369,13 @@ class MyClass:
|
||||||
"""Test that field() function imports are extracted for dataclasses."""
|
"""Test that field() function imports are extracted for dataclasses."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
module_source = '''from dataclasses import dataclass, field
|
module_source = """from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyClass:
|
class MyClass:
|
||||||
items: List[str] = field(default_factory=list)
|
items: List[str] = field(default_factory=list)
|
||||||
'''
|
"""
|
||||||
tree = ast.parse(module_source)
|
tree = ast.parse(module_source)
|
||||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||||
result = extract_imports_for_class(tree, class_node, module_source)
|
result = extract_imports_for_class(tree, class_node, module_source)
|
||||||
|
|
@ -3348,13 +3385,13 @@ class MyClass:
|
||||||
"""Test that duplicate imports are not included."""
|
"""Test that duplicate imports are not included."""
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
module_source = '''from typing import Optional
|
module_source = """from typing import Optional
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MyClass:
|
class MyClass:
|
||||||
field1: Optional[str]
|
field1: Optional[str]
|
||||||
field2: Optional[int]
|
field2: Optional[int]
|
||||||
'''
|
"""
|
||||||
tree = ast.parse(module_source)
|
tree = ast.parse(module_source)
|
||||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||||
result = extract_imports_for_class(tree, class_node, module_source)
|
result = extract_imports_for_class(tree, class_node, module_source)
|
||||||
|
|
@ -3368,7 +3405,7 @@ def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> N
|
||||||
package_dir.mkdir()
|
package_dir.mkdir()
|
||||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||||
|
|
||||||
models_code = '''from dataclasses import dataclass
|
models_code = """from dataclasses import dataclass
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
|
|
||||||
@total_ordering
|
@total_ordering
|
||||||
|
|
@ -3379,21 +3416,19 @@ class OrderedConfig:
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.priority < other.priority
|
return self.priority < other.priority
|
||||||
'''
|
"""
|
||||||
models_path = package_dir / "models.py"
|
models_path = package_dir / "models.py"
|
||||||
models_path.write_text(models_code, encoding="utf-8")
|
models_path.write_text(models_code, encoding="utf-8")
|
||||||
|
|
||||||
code = '''from mypackage.models import OrderedConfig
|
code = """from mypackage.models import OrderedConfig
|
||||||
|
|
||||||
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
||||||
return sorted(configs)
|
return sorted(configs)
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "main.py"
|
code_path = package_dir / "main.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
context = CodeStringsMarkdown(
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = get_imported_class_definitions(context, tmp_path)
|
result = get_imported_class_definitions(context, tmp_path)
|
||||||
|
|
||||||
|
|
@ -3452,7 +3487,7 @@ class RouterConfig(ParentConfig):
|
||||||
models_path.write_text(models_code, encoding="utf-8")
|
models_path.write_text(models_code, encoding="utf-8")
|
||||||
|
|
||||||
# Create code that imports only the child classes (not the base classes)
|
# Create code that imports only the child classes (not the base classes)
|
||||||
code = '''from mypackage.models import ChildConfig, RouterConfig
|
code = """from mypackage.models import ChildConfig, RouterConfig
|
||||||
|
|
||||||
class ConfigRegistry:
|
class ConfigRegistry:
|
||||||
def get_child_config(self) -> ChildConfig:
|
def get_child_config(self) -> ChildConfig:
|
||||||
|
|
@ -3460,7 +3495,7 @@ class ConfigRegistry:
|
||||||
|
|
||||||
def get_router_config(self) -> RouterConfig:
|
def get_router_config(self) -> RouterConfig:
|
||||||
pass
|
pass
|
||||||
'''
|
"""
|
||||||
code_path = package_dir / "registry.py"
|
code_path = package_dir / "registry.py"
|
||||||
code_path.write_text(code, encoding="utf-8")
|
code_path.write_text(code, encoding="utf-8")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue