mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: extract imported class definitions for testgen context
When generating tests, the LLM now receives class definitions for types imported from project modules. This helps the LLM understand: - Constructor signatures (avoiding incorrect argument guessing) - Base classes (e.g., abstract classes that can't be instantiated) - Class structure for creating proper test instances Previously, the LLM only saw import statements like: from mypackage.elements import Element Now it also sees the actual class definition with constructor details. Changes: - Add get_imported_class_definitions() to extract class definitions from project modules referenced in import statements - Integrate into get_code_optimization_context() to include extracted classes in testgen context - Gracefully handle token limits by dropping class definitions if needed - Add 4 unit tests covering extraction, deduplication, and filtering
This commit is contained in:
parent
2fb90de7ca
commit
0312c37631
2 changed files with 374 additions and 2 deletions
|
|
@ -127,9 +127,20 @@ def get_code_optimization_context(
|
|||
remove_docstrings=False,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
|
||||
# Extract class definitions for imported types from project modules
|
||||
# This helps the LLM understand class constructors and structure
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
# Merge imported class definitions into testgen context
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
# First try removing docstrings
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
|
|
@ -137,10 +148,27 @@ def get_code_optimization_context(
|
|||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
# Re-extract imported classes (they may still fit)
|
||||
imported_class_context = get_imported_class_definitions(testgen_context, project_root_path)
|
||||
if imported_class_context.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
# If still over limit, try without imported class definitions
|
||||
testgen_context = extract_code_markdown_context_from_files(
|
||||
helpers_of_fto_dict,
|
||||
helpers_of_helpers_dict,
|
||||
project_root_path,
|
||||
remove_docstrings=True,
|
||||
code_context_type=CodeContextType.TESTGEN,
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
|
||||
code_hash_context = hashing_code_context.markdown
|
||||
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
|
||||
|
||||
|
|
@ -489,6 +517,158 @@ def get_function_sources_from_jedi(
|
|||
return file_path_to_function_source, function_source_list
|
||||
|
||||
|
||||
def get_imported_class_definitions(
|
||||
code_context: CodeStringsMarkdown,
|
||||
project_root_path: Path,
|
||||
) -> CodeStringsMarkdown:
|
||||
"""Extract class definitions for imported types from project modules.
|
||||
|
||||
This function analyzes the imports in the extracted code context and fetches
|
||||
class definitions for any classes imported from project modules. This helps
|
||||
the LLM understand the actual class structure (constructors, methods, inheritance)
|
||||
rather than just seeing import statements.
|
||||
|
||||
Args:
|
||||
code_context: The already extracted code context containing imports
|
||||
project_root_path: Root path of the project
|
||||
|
||||
Returns:
|
||||
CodeStringsMarkdown containing class definitions from imported project modules
|
||||
"""
|
||||
import jedi
|
||||
|
||||
# Collect all code from the context
|
||||
all_code = "\n".join(cs.code for cs in code_context.code_strings)
|
||||
|
||||
# Parse to find import statements
|
||||
try:
|
||||
tree = ast.parse(all_code)
|
||||
except SyntaxError:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
# Collect imported names and their source modules
|
||||
imported_names: dict[str, str] = {} # name -> module_path
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
if alias.name != "*":
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
imported_names[imported_name] = node.module
|
||||
|
||||
if not imported_names:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
# Track which classes we've already extracted to avoid duplicates
|
||||
extracted_classes: set[tuple[Path, str]] = set() # (file_path, class_name)
|
||||
|
||||
# Also track what's already defined in the context
|
||||
existing_definitions: set[str] = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
existing_definitions.add(node.name)
|
||||
|
||||
class_code_strings: list[CodeString] = []
|
||||
|
||||
for name, module_name in imported_names.items():
|
||||
# Skip if already defined in context
|
||||
if name in existing_definitions:
|
||||
continue
|
||||
|
||||
# Try to find the module file using Jedi
|
||||
try:
|
||||
# Create a script that imports the module to resolve it
|
||||
test_code = f"import {module_name}"
|
||||
script = jedi.Script(test_code, project=jedi.Project(path=project_root_path))
|
||||
completions = script.goto(1, len(test_code))
|
||||
|
||||
if not completions:
|
||||
continue
|
||||
|
||||
module_path = completions[0].module_path
|
||||
if not module_path:
|
||||
continue
|
||||
|
||||
# Check if this is a project module (not stdlib/third-party)
|
||||
if not str(module_path).startswith(str(project_root_path) + os.sep):
|
||||
continue
|
||||
if path_belongs_to_site_packages(module_path):
|
||||
continue
|
||||
|
||||
# Skip if we've already extracted this class
|
||||
if (module_path, name) in extracted_classes:
|
||||
continue
|
||||
|
||||
# Parse the module to find the class definition
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
# Extract the class source code
|
||||
lines = module_source.split("\n")
|
||||
class_source = "\n".join(lines[node.lineno - 1 : node.end_lineno])
|
||||
|
||||
# Also extract any necessary imports for the class (base classes, type hints)
|
||||
class_imports = _extract_imports_for_class(module_tree, node, module_source)
|
||||
|
||||
if class_imports:
|
||||
full_source = class_imports + "\n\n" + class_source
|
||||
else:
|
||||
full_source = class_source
|
||||
|
||||
class_code_strings.append(
|
||||
CodeString(
|
||||
code=full_source,
|
||||
file_path=module_path,
|
||||
)
|
||||
)
|
||||
extracted_classes.add((module_path, name))
|
||||
break
|
||||
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=class_code_strings)
|
||||
|
||||
|
||||
def _extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
This extracts imports for base classes and commonly used type annotations.
|
||||
"""
|
||||
needed_names: set[str] = set()
|
||||
|
||||
# Get base class names
|
||||
for base in class_node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
needed_names.add(base.id)
|
||||
elif isinstance(base, ast.Attribute):
|
||||
# For things like abc.ABC, we need the module name
|
||||
if isinstance(base.value, ast.Name):
|
||||
needed_names.add(base.value.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
source_lines = module_source.split("\n")
|
||||
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
if name in needed_names:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
break
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name in needed_names:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
break
|
||||
|
||||
return "\n".join(import_lines)
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,8 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context, get_imported_class_definitions
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
|
@ -2893,3 +2894,194 @@ def dump_layout(layout_type, layout):
|
|||
assert testgen_context.count("def __init__") >= 2, (
|
||||
"Both __init__ methods should be in testgen context"
|
||||
)
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_extracts_project_classes(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions extracts class definitions from project modules."""
|
||||
# Create a package structure with two modules
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with a class definition (simulating Element-like class)
|
||||
elements_code = '''
|
||||
import abc
|
||||
|
||||
class Element(abc.ABC):
|
||||
"""An element in the document."""
|
||||
|
||||
def __init__(self, element_id: str = None):
|
||||
self._element_id = element_id
|
||||
self.text = ""
|
||||
|
||||
def __str__(self):
|
||||
return self.text
|
||||
|
||||
|
||||
class Text(Element):
|
||||
"""A text element."""
|
||||
|
||||
def __init__(self, text: str, element_id: str = None):
|
||||
super().__init__(element_id)
|
||||
self.text = text
|
||||
'''
|
||||
elements_path = package_dir / "elements.py"
|
||||
elements_path.write_text(elements_code, encoding="utf-8")
|
||||
|
||||
# Create another module that imports from elements
|
||||
chunking_code = '''
|
||||
from mypackage.elements import Element
|
||||
|
||||
class PreChunk:
|
||||
def __init__(self, elements: list[Element]):
|
||||
self._elements = elements
|
||||
|
||||
class Accumulator:
|
||||
def will_fit(self, chunk: PreChunk) -> bool:
|
||||
return True
|
||||
'''
|
||||
chunking_path = package_dir / "chunking.py"
|
||||
chunking_path.write_text(chunking_code, encoding="utf-8")
|
||||
|
||||
# Create CodeStringsMarkdown from the chunking module (simulating testgen context)
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=chunking_code, file_path=chunking_path)]
|
||||
)
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Verify Element class was extracted
|
||||
assert len(result.code_strings) == 1, "Should extract exactly one class (Element)"
|
||||
extracted_code = result.code_strings[0].code
|
||||
|
||||
# Verify the extracted code contains the Element class
|
||||
assert "class Element" in extracted_code, "Should contain Element class definition"
|
||||
assert "def __init__" in extracted_code, "Should contain __init__ method"
|
||||
assert "element_id" in extracted_code, "Should contain constructor parameter"
|
||||
assert "import abc" in extracted_code, "Should include necessary imports for base class"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_existing_definitions(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips classes already defined in context."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with a class definition
|
||||
elements_code = '''
|
||||
class Element:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
'''
|
||||
elements_path = package_dir / "elements.py"
|
||||
elements_path.write_text(elements_code, encoding="utf-8")
|
||||
|
||||
# Create code that imports Element but also redefines it locally
|
||||
code_with_local_def = '''
|
||||
from mypackage.elements import Element
|
||||
|
||||
# Local redefinition (this happens when LLM redefines classes)
|
||||
class Element:
|
||||
def __init__(self, text: str):
|
||||
self.text = text
|
||||
|
||||
class User:
|
||||
def process(self, elem: Element):
|
||||
pass
|
||||
'''
|
||||
code_path = package_dir / "user.py"
|
||||
code_path.write_text(code_with_local_def, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code_with_local_def, file_path=code_path)]
|
||||
)
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should NOT extract Element since it's already defined locally
|
||||
assert len(result.code_strings) == 0, "Should not extract classes already defined in context"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_skips_third_party(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions skips third-party/stdlib imports."""
|
||||
# Create a simple package
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Code with stdlib/third-party imports
|
||||
code = '''
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
class MyClass:
|
||||
def __init__(self, path: Path):
|
||||
self.path = path
|
||||
'''
|
||||
code_path = package_dir / "main.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
||||
)
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should not extract any classes (Path, Optional, dataclass are stdlib/third-party)
|
||||
assert len(result.code_strings) == 0, "Should not extract stdlib/third-party classes"
|
||||
|
||||
|
||||
def test_get_imported_class_definitions_handles_multiple_imports(tmp_path: Path) -> None:
|
||||
"""Test that get_imported_class_definitions handles multiple class imports."""
|
||||
# Create a package structure
|
||||
package_dir = tmp_path / "mypackage"
|
||||
package_dir.mkdir()
|
||||
(package_dir / "__init__.py").write_text("", encoding="utf-8")
|
||||
|
||||
# Create a module with multiple class definitions
|
||||
types_code = '''
|
||||
class TypeA:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
class TypeB:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
class TypeC:
|
||||
def __init__(self):
|
||||
pass
|
||||
'''
|
||||
types_path = package_dir / "types.py"
|
||||
types_path.write_text(types_code, encoding="utf-8")
|
||||
|
||||
# Create code that imports multiple classes
|
||||
code = '''
|
||||
from mypackage.types import TypeA, TypeB
|
||||
|
||||
class Processor:
|
||||
def process(self, a: TypeA, b: TypeB):
|
||||
pass
|
||||
'''
|
||||
code_path = package_dir / "processor.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(
|
||||
code_strings=[CodeString(code=code, file_path=code_path)]
|
||||
)
|
||||
|
||||
# Call get_imported_class_definitions
|
||||
result = get_imported_class_definitions(context, tmp_path)
|
||||
|
||||
# Should extract both TypeA and TypeB (but not TypeC since it's not imported)
|
||||
assert len(result.code_strings) == 2, "Should extract exactly two classes (TypeA, TypeB)"
|
||||
|
||||
all_extracted_code = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class TypeA" in all_extracted_code, "Should contain TypeA class"
|
||||
assert "class TypeB" in all_extracted_code, "Should contain TypeB class"
|
||||
assert "class TypeC" not in all_extracted_code, "Should NOT contain TypeC (not imported)"
|
||||
|
|
|
|||
Loading…
Reference in a new issue