mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
decorators & annotations
This commit is contained in:
parent
2ed55ed3bb
commit
722d05345d
2 changed files with 364 additions and 8 deletions
|
|
@ -603,12 +603,16 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
# Extract the class source code
|
||||
# Extract the class source code, including decorators
|
||||
lines = module_source.split("\n")
|
||||
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
|
||||
# Decorators start before the class line, use first decorator line if present
|
||||
start_line = node.lineno
|
||||
if node.decorator_list:
|
||||
start_line = min(d.lineno for d in node.decorator_list)
|
||||
class_source = "\n".join(lines[start_line - 1 : node.end_lineno])
|
||||
|
||||
# Also extract any necessary imports for the class (base classes, type hints)
|
||||
class_imports = _extract_imports_for_class(module_tree, node, module_source)
|
||||
class_imports = extract_imports_for_class(module_tree, node, module_source)
|
||||
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
|
||||
|
|
@ -623,10 +627,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
return CodeStringsMarkdown(code_strings=class_code_strings)
|
||||
|
||||
|
||||
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
This extracts imports for base classes and commonly used type annotations.
|
||||
This extracts imports for base classes, decorators, and type annotations.
|
||||
"""
|
||||
needed_names: set[str] = set()
|
||||
|
||||
|
|
@ -638,27 +642,65 @@ def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef
|
|||
# For things like abc.ABC, we need the module name
|
||||
needed_names.add(base.value.id)
|
||||
|
||||
# Get decorator names (e.g., dataclass, field)
|
||||
for decorator in class_node.decorator_list:
|
||||
if isinstance(decorator, ast.Name):
|
||||
needed_names.add(decorator.id)
|
||||
elif isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name):
|
||||
needed_names.add(decorator.func.id)
|
||||
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
|
||||
needed_names.add(decorator.func.value.id)
|
||||
|
||||
# Get type annotation names from class body (for dataclass fields)
|
||||
for item in ast.walk(class_node):
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
# Also check for field() calls which are common in dataclasses
|
||||
if isinstance(item, ast.Call) and isinstance(item.func, ast.Name):
|
||||
needed_names.add(item.func.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
source_lines = module_source.split("\n")
|
||||
added_imports: set[int] = set() # Track line numbers to avoid duplicates
|
||||
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
if name in needed_names:
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name in needed_names:
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
|
||||
return "\n".join(import_lines)
|
||||
|
||||
|
||||
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
|
||||
"""Recursively collect type annotation names from an AST node."""
|
||||
if isinstance(node, ast.Name):
|
||||
names.add(node.id)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
collect_names_from_annotation(node.value, names)
|
||||
collect_names_from_annotation(node.slice, names)
|
||||
elif isinstance(node, ast.Tuple):
|
||||
for elt in node.elts:
|
||||
collect_names_from_annotation(elt, names)
|
||||
elif isinstance(node, ast.BinOp): # For Union types with | syntax
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions
|
||||
from codeflash.context.code_context_extractor import (
|
||||
get_code_optimization_context,
|
||||
get_imported_class_definitions,
|
||||
collect_names_from_annotation,
|
||||
extract_imports_for_class,
|
||||
)
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
|
@ -3085,3 +3090,312 @@ class Processor:
|
|||
assert "class TypeA" in all_extracted_code, "Should contain TypeA class"
|
||||
assert "class TypeB" in all_extracted_code, "Should contain TypeB class"
|
||||
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_includes_dataclass_decorators(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions includes decorators when extracting dataclasses."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with dataclass definitions (like LLMConfig in skyvern)
|
||||
models_code = '''from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfigBase:
|
||||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMConfig(LLMConfigBase):
|
||||
litellm_params: Optional[dict] = field(default=None)
|
||||
max_tokens: int | None = None
|
||||
'''
|
||||
models_path = package_dir / "models.py"
|
||||
models_path.write_text(models_code, encoding="utf-8")
|
||||
|
||||
# Create code that imports the dataclass
|
||||
code = '''from mypackage.models import LLMConfig
|
||||
|
||||
class ConfigRegistry:
|
||||
def get_config(self) -> LLMConfig:
|
||||
pass
|
||||
'''
|
||||
code_path = package_dir / "registry.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
||||
)
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should extract LLMConfig
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (LLMConfig)"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
||||
# Verify the extracted code includes the @dataclass decorator
|
||||
assert "@dataclass(frozen=True)" in extracted_code, (
|
||||
"Should include @dataclass decorator - this is critical for LLM to understand constructor"
|
||||
)
|
||||
assert "class LLMConfig" in extracted_code, "Should contain LLMConfig class definition"
|
||||
|
||||
# Verify imports are included for dataclass-related items
|
||||
assert "from dataclasses import" in extracted_code, "Should include dataclasses import"
|
||||
assert "Optional" in extracted_code or "from typing import" in extracted_code, (
|
||||
"Should include type annotation imports"
|
||||
)
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_imports_for_decorated_classes(tmp_path: Path) -> None:
|
||||
"""Test that extract_imports_for_class includes decorator and type annotation imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with decorated class that uses field() and various type annotations
|
||||
models_code = '''from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
name: str
|
||||
values: List[int] = field(default_factory=list)
|
||||
description: Optional[str] = None
|
||||
'''
|
||||
models_path = package_dir / "models.py"
|
||||
models_path.write_text(models_code, encoding="utf-8")
|
||||
|
||||
# Create code that imports the class
|
||||
code = '''from mypackage.models import Config
|
||||
|
||||
def create_config() -> Config:
|
||||
return Config(name="test")
|
||||
'''
|
||||
code_path = package_dir / "main.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
||||
)
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1, "Should extract Config class"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
||||
# The extracted code should include the decorator
|
||||
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
|
||||
# The imports should include dataclass and field
|
||||
assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator"
|
||||
|
||||
|
||||
class TestCollectNamesFromAnnotation:
|
||||
"""Tests for the collect_names_from_annotation helper function."""
|
||||
|
||||
def test_simple_name(self):
|
||||
"""Test extracting a simple type name."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: MyClass): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_subscript_type(self):
|
||||
"""Test extracting names from generic types like List[int]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: List[int]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "List" in names
|
||||
assert "int" in names
|
||||
|
||||
def test_optional_type(self):
|
||||
"""Test extracting names from Optional[MyClass]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: Optional[MyClass]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "Optional" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_union_type_with_pipe(self):
|
||||
"""Test extracting names from union types with | syntax."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: int | str | None): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
# int | str | None becomes BinOp nodes
|
||||
assert "int" in names
|
||||
assert "str" in names
|
||||
|
||||
def test_nested_generic_types(self):
|
||||
"""Test extracting names from nested generics like Dict[str, List[MyClass]]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: Dict[str, List[MyClass]]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "Dict" in names
|
||||
assert "str" in names
|
||||
assert "List" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_tuple_annotation(self):
|
||||
"""Test extracting names from tuple type hints."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: tuple[int, str, MyClass]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "tuple" in names
|
||||
assert "int" in names
|
||||
assert "str" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
|
||||
class TestExtractImportsForClass:
|
||||
"""Tests for the extract_imports_for_class helper function."""
|
||||
|
||||
def test_extracts_base_class_imports(self):
|
||||
"""Test that base class imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = '''from abc import ABC
|
||||
from mypackage import BaseClass
|
||||
|
||||
class MyClass(BaseClass, ABC):
|
||||
pass
|
||||
'''
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from abc import ABC" in result
|
||||
assert "from mypackage import BaseClass" in result
|
||||
|
||||
def test_extracts_decorator_imports(self):
|
||||
"""Test that decorator imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = '''from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
name: str
|
||||
'''
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from dataclasses import dataclass" in result
|
||||
|
||||
def test_extracts_type_annotation_imports(self):
|
||||
"""Test that type annotation imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = '''from typing import Optional, List
|
||||
from mypackage.models import Config
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
config: Optional[Config]
|
||||
items: List[str]
|
||||
'''
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from typing import Optional, List" in result
|
||||
assert "from mypackage.models import Config" in result
|
||||
|
||||
def test_extracts_field_function_imports(self):
|
||||
"""Test that field() function imports are extracted for dataclasses."""
|
||||
import ast
|
||||
|
||||
module_source = '''from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
items: List[str] = field(default_factory=list)
|
||||
'''
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from dataclasses import dataclass, field" in result
|
||||
|
||||
def test_no_duplicate_imports(self):
|
||||
"""Test that duplicate imports are not included."""
|
||||
import ast
|
||||
|
||||
module_source = '''from typing import Optional
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
field1: Optional[str]
|
||||
field2: Optional[int]
|
||||
'''
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
# Should only have one import line even though Optional is used twice
|
||||
assert result.count("from typing import Optional") == 1
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_multiple_decorators(tmp_path: Path) -> None:
|
||||
"""Test that classes with multiple decorators are extracted correctly."""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
models_code = '''from dataclasses import dataclass
|
||||
from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
@dataclass
|
||||
class OrderedConfig:
|
||||
name: str
|
||||
priority: int
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.priority < other.priority
|
||||
'''
|
||||
models_path = package_dir / "models.py"
|
||||
models_path.write_text(models_code, encoding="utf-8")
|
||||
|
||||
code = '''from mypackage.models import OrderedConfig
|
||||
|
||||
def sort_configs(configs: list[OrderedConfig]) -> list[OrderedConfig]:
|
||||
return sorted(configs)
|
||||
'''
|
||||
code_path = package_dir / "main.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
||||
)
|
||||
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
||||
# Both decorators should be included
|
||||
assert "@total_ordering" in extracted_code, "Should include @total_ordering decorator"
|
||||
assert "@dataclass" in extracted_code, "Should include @dataclass decorator"
|
||||
assert "class OrderedConfig" in extracted_code
|
||||
|
|
|
|||
Loading…
Reference in a new issue