decorators & annotations

This commit is contained in:
Kevin Turcios 2026-01-23 06:52:45 -05:00
parent 2ed55ed3bb
commit 722d05345d
2 changed files with 364 additions and 8 deletions

View file

@ -603,12 +603,16 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == name:
# Extract the class source code
# Extract the class source code, including decorators
lines = module_source.split("\n")
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
# Decorators start before the class line, use first decorator line if present
start_line = node.lineno
if node.decorator_list:
start_line = min(d.lineno for d in node.decorator_list)
class_source = "\n".join(lines[start_line - 1 : node.end_lineno])
# Also extract any necessary imports for the class (base classes, type hints)
class_imports = _extract_imports_for_class(module_tree, node, module_source)
class_imports = extract_imports_for_class(module_tree, node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
@ -623,10 +627,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
return CodeStringsMarkdown(code_strings=class_code_strings)
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.
This extracts imports for base classes and commonly used type annotations.
This extracts imports for base classes, decorators, and type annotations.
"""
needed_names: set[str] = set()
@ -638,27 +642,65 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef
# For things like abc.ABC, we need the module name
needed_names.add(base.value.id)
# Get decorator names (e.g., dataclass, field)
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
needed_names.add(decorator.id)
elif isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name):
needed_names.add(decorator.func.id)
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in ast.walk(class_node):
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
needed_names.add(item.func.id)
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
added_imports: set[int] = set() # Track line numbers to avoid duplicates
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names:
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names:
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
return "\n".join(import_lines)
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
"""Recursively collect type annotation names from an AST node."""
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.Subscript):
collect_names_from_annotation(node.value, names)
collect_names_from_annotation(node.slice, names)
elif isinstance(node, ast.Tuple):
for elt in node.elts:
collect_names_from_annotation(elt, names)
elif isinstance(node, ast.BinOp): # For Union types with | syntax
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
def is_dunder_method(name: str) -> bool:
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")

View file

@ -7,7 +7,12 @@ from collections import defaultdict
from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions
from codeflash.context.code_context_extractor import (
get_code_optimization_context,
get_imported_class_definitions,
collect_names_from_annotation,
extract_imports_for_class,
)
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
@ -3085,3 +3090,312 @@ class Processor:
assert "class TypeA" in all_extracted_code, "Should contain TypeA class"
assert "class TypeB" in all_extracted_code, "Should contain TypeB class"
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with dataclass definitions (like LLMConfig in skyvern)
models_code = '''from dataclasses import dataclass, field
from typing import Optional
@dataclass(frozen=True)
class LLMConfigBase:
model_name: str
required_env_vars: list[str]
supports_vision: bool
add_assistant_prefix: bool
@dataclass(frozen=True)
class LLMConfig(LLMConfigBase):
litellm_params: Optional[dict] = field(default=None)
max_tokens: int | None = None
'''
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the dataclass
code = '''from mypackage.models import LLMConfig
class ConfigRegistry:
def get_config(self) -> LLMConfig:
pass
'''
code_path = package_dir / "registry.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should extract LLMConfig
assert len(result.code_strings) == 1, "Should extract exactly one class (LLMConfig)"
extracted_code = result.code_strings[0].code
# Verify the extracted code includes the @dataclass decorator
assert "@dataclass(frozen=True)" in extracted_code, (
"Should include @dataclass decorator - this is critical for LLM to understand constructor"
)
assert "class LLMConfig" in extracted_code, "Should contain LLMConfig class definition"
# Verify imports are included for dataclass-related items
assert "from dataclasses import" in extracted_code, "Should include dataclasses import"
assert "Optional" in extracted_code or "from typing import" in extracted_code, (
"Should include type annotation imports"
)
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with decorated class that uses field() and various type annotations
models_code = '''from dataclasses import dataclass, field
from typing import Optional, List
@dataclass
class Config:
name: str
values: List[int] = field(default_factory=list)
description: Optional[str] = None
'''
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
# Create code that imports the class
code = '''from mypackage.models import Config
def create_config() -> Config:
return Config(name="test")
'''
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
result = get_imported_class_definitions(context, tmp_path)
assert len(result.code_strings) == 1, "Should extract Config class"
extracted_code = result.code_strings[0].code
# The extracted code should include the decorator
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
# The imports should include dataclass and field
assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator"
class TestCollectNamesFromAnnotation:
"""Tests for the collect_names_from_annotation helper function."""
def test_simple_name(self):
"""Test extracting a simple type name."""
import ast
code = "def f(x: MyClass): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "MyClass" in names
def test_subscript_type(self):
"""Test extracting names from generic types like List[int]."""
import ast
code = "def f(x: List[int]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "List" in names
assert "int" in names
def test_optional_type(self):
"""Test extracting names from Optional[MyClass]."""
import ast
code = "def f(x: Optional[MyClass]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "Optional" in names
assert "MyClass" in names
def test_union_type_with_pipe(self):
"""Test extracting names from union types with | syntax."""
import ast
code = "def f(x: int | str | None): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
# int | str | None becomes BinOp nodes
assert "int" in names
assert "str" in names
def test_nested_generic_types(self):
"""Test extracting names from nested generics like Dict[str, List[MyClass]]."""
import ast
code = "def f(x: Dict[str, List[MyClass]]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "Dict" in names
assert "str" in names
assert "List" in names
assert "MyClass" in names
def test_tuple_annotation(self):
"""Test extracting names from tuple type hints."""
import ast
code = "def f(x: tuple[int, str, MyClass]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "tuple" in names
assert "int" in names
assert "str" in names
assert "MyClass" in names
class TestExtractImportsForClass:
"""Tests for the extract_imports_for_class helper function."""
def test_extracts_base_class_imports(self):
"""Test that base class imports are extracted."""
import ast
module_source = '''from abc import ABC
from mypackage import BaseClass
class MyClass(BaseClass, ABC):
pass
'''
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from abc import ABC" in result
assert "from mypackage import BaseClass" in result
def test_extracts_decorator_imports(self):
"""Test that decorator imports are extracted."""
import ast
module_source = '''from dataclasses import dataclass
from functools import lru_cache
@dataclass
class MyClass:
name: str
'''
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from dataclasses import dataclass" in result
def test_extracts_type_annotation_imports(self):
"""Test that type annotation imports are extracted."""
import ast
module_source = '''from typing import Optional, List
from mypackage.models import Config
@dataclass
class MyClass:
config: Optional[Config]
items: List[str]
'''
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from typing import Optional, List" in result
assert "from mypackage.models import Config" in result
def test_extracts_field_function_imports(self):
"""Test that field() function imports are extracted for dataclasses."""
import ast
module_source = '''from dataclasses import dataclass, field
from typing import List
@dataclass
class MyClass:
items: List[str] = field(default_factory=list)
'''
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from dataclasses import dataclass, field" in result
def test_no_duplicate_imports(self):
"""Test that duplicate imports are not included."""
import ast
module_source = '''from typing import Optional
@dataclass
class MyClass:
field1: Optional[str]
field2: Optional[int]
'''
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
# Should only have one import line even though Optional is used twice
assert result.count("from typing import Optional") == 1
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
"""Test that classes with multiple decorators are extracted correctly."""
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
models_code = '''from dataclasses import dataclass
from functools import total_ordering
@total_ordering
@dataclass
class OrderedConfig:
name: str
priority: int
def __lt__(self, other):
return self.priority < other.priority
'''
models_path = package_dir / "models.py"
models_path.write_text(models_code, encoding="utf-8")
code = '''from mypackage.models import OrderedConfig
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
return sorted(configs)
'''
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
result = get_imported_class_definitions(context, tmp_path)
assert len(result.code_strings) == 1
extracted_code = result.code_strings[0].code
# Both decorators should be included
assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator"
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
assert "class OrderedConfig" in extracted_code