feat: extract imported class definitions for testgen context

When generating tests, the LLM now receives class definitions for
types imported from project modules. This helps the LLM understand:
- Constructor signatures (avoiding incorrect argument guessing)
- Base classes (e.g., abstract classes that can't be instantiated)
- Class structure for creating proper test instances

Previously, the LLM only saw import statements like:
  from mypackage.elements import Element

Now it also sees the actual class definition with constructor details.

Changes:
- Add get_imported_class_definitions() to extract class definitions
  from project modules referenced in import statements
- Integrate into get_code_optimization_context() to include extracted
  classes in testgen context
- Gracefully handle token limits by dropping class definitions if needed
- Add 4 unit tests covering extraction, deduplication, and filtering
This commit is contained in:
Kevin Turcios 2026-01-07 16:09:29 -05:00
parent 2fb90de7ca
commit 0312c37631
2 changed files with 374 additions and 2 deletions

View file

@ -127,9 +127,20 @@ def get_code_optimization_context(
remove_docstrings=False,
code_context_type=CodeContextType.TESTGEN,
)
# Extract class definitions for imported types from project modules
# This helps the LLM understand class constructors and structure
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
# Merge imported class definitions into testgen context
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
# First try removing docstrings
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
@ -137,10 +148,27 @@ def get_code_optimization_context(
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
# Re-extract imported classes (they may still fit)
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
if imported_class_context.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
# If still over limit, try without imported class definitions
testgen_context = extract_code_markdown_context_from_files(
helpers_of_fto_dict,
helpers_of_helpers_dict,
project_root_path,
remove_docstrings=True,
code_context_type=CodeContextType.TESTGEN,
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
code_hash_context = hashing_code_context.markdown
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
@ -489,6 +517,158 @@ def get_function_sources_from_jedi(
return file_path_to_function_source, function_source_list
def get_imported_class_definitions(
code_context: CodeStringsMarkdown,
project_root_path: Path,
) -> CodeStringsMarkdown:
"""Extract class definitions for imported types from project modules.
This function analyzes the imports in the extracted code context and fetches
class definitions for any classes imported from project modules. This helps
the LLM understand the actual class structure (constructors, methods, inheritance)
rather than just seeing import statements.
Args:
code_context: The already extracted code context containing imports
project_root_path: Root path of the project
Returns:
CodeStringsMarkdown containing class definitions from imported project modules
"""
import jedi
# Collect all code from the context
all_code = "\n".join(cs.code for cs in code_context.code_strings)
# Parse to find import statements
try:
tree = ast.parse(all_code)
except SyntaxError:
return CodeStringsMarkdown(code_strings=[])
# Collect imported names and their source modules
imported_names: dict[str, str] = {} # name -> module_path
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
if alias.name != "*":
imported_name = alias.asname if alias.asname else alias.name
imported_names[imported_name] = node.module
if not imported_names:
return CodeStringsMarkdown(code_strings=[])
# Track which classes we've already extracted to avoid duplicates
extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name)
# Also track what's already defined in the context
existing_definitions: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
existing_definitions.add(node.name)
class_code_strings: list[CodeString] = []
for name, module_name in imported_names.items():
# Skip if already defined in context
if name in existing_definitions:
continue
# Try to find the module file using Jedi
try:
# Create a script that imports the module to resolve it
test_code = f"import {module_name}"
script = jedi.Script(test_code, project=jedi.Project(path=project_root_path))
completions = script.goto(1, len(test_code))
if not completions:
continue
module_path = completions[0].module_path
if not module_path:
continue
# Check if this is a project module (not stdlib/third-party)
if not str(module_path).startswith(str(project_root_path) + os.sep):
continue
if path_belongs_to_site_packages(module_path):
continue
# Skip if we've already extracted this class
if (module_path, name) in extracted_classes:
continue
# Parse the module to find the class definition
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == name:
# Extract the class source code
lines = module_source.split("\n")
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
# Also extract any necessary imports for the class (base classes, type hints)
class_imports = _extract_imports_for_class(module_tree, node, module_source)
if class_imports:
full_source = class_imports + "\n\n" + class_source
else:
full_source = class_source
class_code_strings.append(
CodeString(
code=full_source,
file_path=module_path,
)
)
extracted_classes.add((module_path, name))
break
except Exception:
logger.debug(f"Error extracting class definition for {name} from {module_name}")
continue
return CodeStringsMarkdown(code_strings=class_code_strings)
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.
This extracts imports for base classes and commonly used type annotations.
"""
needed_names: set[str] = set()
# Get base class names
for base in class_node.bases:
if isinstance(base, ast.Name):
needed_names.add(base.id)
elif isinstance(base, ast.Attribute):
# For things like abc.ABC, we need the module name
if isinstance(base.value, ast.Name):
needed_names.add(base.value.id)
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names:
import_lines.append(source_lines[node.lineno - 1])
break
return "\n".join(import_lines)
def is_dunder_method(name: str) -> bool:
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")

View file

@ -7,7 +7,8 @@ from collections import defaultdict
from pathlib import Path
import pytest
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
@ -2893,3 +2894,194 @@ def dump_layout(layout_type, layout):
assert testgen_context.count("def __init__") >= 2, (
"Both __init__ methods should be in testgen context"
)
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
# Create a package structure with two modules
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with a class definition (simulating Element-like class)
elements_code = '''
import abc
class Element(abc.ABC):
"""An element in the document."""
def __init__(self, element_id: str = None):
self._element_id = element_id
self.text = ""
def __str__(self):
return self.text
class Text(Element):
"""A text element."""
def __init__(self, text: str, element_id: str = None):
super().__init__(element_id)
self.text = text
'''
elements_path = package_dir / "elements.py"
elements_path.write_text(elements_code, encoding="utf-8")
# Create another module that imports from elements
chunking_code = '''
from mypackage.elements import Element
class PreChunk:
def __init__(self, elements: list[Element]):
self._elements = elements
class Accumulator:
def will_fit(self, chunk: PreChunk) -> bool:
return True
'''
chunking_path = package_dir / "chunking.py"
chunking_path.write_text(chunking_code, encoding="utf-8")
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
context = CodeStringsMarkdown(
code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]
)
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Verify Element class was extracted
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
extracted_code = result.code_strings[0].code
# Verify the extracted code contains the Element class
assert "class Element" in extracted_code, "Should contain Element class definition"
assert "def __init__" in extracted_code, "Should contain __init__ method"
assert "element_id" in extracted_code, "Should contain constructor parameter"
assert "import abc" in extracted_code, "Should include necessary imports for base class"
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions skips classes already defined in context."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with a class definition
elements_code = '''
class Element:
def __init__(self, text: str):
self.text = text
'''
elements_path = package_dir / "elements.py"
elements_path.write_text(elements_code, encoding="utf-8")
# Create code that imports Element but also redefines it locally
code_with_local_def = '''
from mypackage.elements import Element
# Local redefinition (this happens when LLM redefines classes)
class Element:
def __init__(self, text: str):
self.text = text
class User:
def process(self, elem: Element):
pass
'''
code_path = package_dir / "user.py"
code_path.write_text(code_with_local_def, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]
)
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should NOT extract Element since it's already defined locally
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
# Create a simple package
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Code with stdlib/third-party imports
code = '''
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
class MyClass:
def __init__(self, path: Path):
self.path = path
'''
code_path = package_dir / "main.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
"""Test that get_imported_class_definitions handles multiple class imports."""
# Create a package structure
package_dir = tmp_path / "mypackage"
package_dir.mkdir()
(package_dir / "__init__.py").write_text("", encoding="utf-8")
# Create a module with multiple class definitions
types_code = '''
class TypeA:
def __init__(self, value: int):
self.value = value
class TypeB:
def __init__(self, name: str):
self.name = name
class TypeC:
def __init__(self):
pass
'''
types_path = package_dir / "types.py"
types_path.write_text(types_code, encoding="utf-8")
# Create code that imports multiple classes
code = '''
from mypackage.types import TypeA, TypeB
class Processor:
def process(self, a: TypeA, b: TypeB):
pass
'''
code_path = package_dir / "processor.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(
code_strings=[CodeString(code=code, file_path=code_path)]
)
# Call get_imported_class_definitions
result = get_imported_class_definitions(context, tmp_path)
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
assert "class TypeA" in all_extracted_code, "Should contain TypeA class"
assert "class TypeB" in all_extracted_code, "Should contain TypeB class"
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"