fix: recursively extract base classes for imported dataclasses
When extracting imported class definitions for testgen context, also extract base classes from the same module. This ensures the full inheritance chain is available for understanding constructor signatures. For example, when LLMConfig inherits from LLMConfigBase, both classes are now included in the context so the LLM can see all required positional arguments from parent classes.
This commit is contained in:
parent
722d05345d
commit
4e310324b4
1 changed files with 72 additions and 24 deletions
|
|
@ -526,6 +526,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
the LLM understand the actual class structure (constructors, methods, inheritance)
|
||||
rather than just seeing import statements.
|
||||
|
||||
Also recursively extracts base classes when a class inherits from another class
|
||||
in the same module, ensuring the full inheritance chain is available for
|
||||
understanding constructor signatures.
|
||||
|
||||
Args:
|
||||
code_context: The already extracted code context containing imports
|
||||
project_root_path: Root path of the project
|
||||
|
|
@ -568,6 +572,68 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
|
||||
class_code_strings: list[CodeString] = []
|
||||
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
|
||||
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
|
||||
if module_path in module_cache:
|
||||
return module_cache[module_path]
|
||||
try:
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
module_cache[module_path] = (module_source, module_tree)
|
||||
return module_source, module_tree
|
||||
|
||||
def extract_class_and_bases(
|
||||
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
|
||||
) -> None:
|
||||
"""Extract a class and its base classes recursively from the same module."""
|
||||
# Skip if already extracted
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return
|
||||
|
||||
# Find the class definition in the module
|
||||
class_node = None
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name:
|
||||
class_node = node
|
||||
break
|
||||
|
||||
if class_node is None:
|
||||
return
|
||||
|
||||
# First, recursively extract base classes from the same module
|
||||
for base in class_node.bases:
|
||||
base_name = None
|
||||
if isinstance(base, ast.Name):
|
||||
base_name = base.id
|
||||
elif isinstance(base, ast.Attribute):
|
||||
# For module.ClassName, we skip (cross-module inheritance)
|
||||
continue
|
||||
|
||||
if base_name and base_name not in existing_definitions:
|
||||
# Check if base class is defined in the same module
|
||||
extract_class_and_bases(base_name, module_path, module_source, module_tree)
|
||||
|
||||
# Now extract this class (after its bases, so base classes appear first)
|
||||
if (module_path, class_name) in extracted_classes:
|
||||
return # Already added by another path
|
||||
|
||||
lines = module_source.split("\n")
|
||||
start_line = class_node.lineno
|
||||
if class_node.decorator_list:
|
||||
start_line = min(d.lineno for d in class_node.decorator_list)
|
||||
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
|
||||
|
||||
# Extract imports for the class
|
||||
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
|
||||
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
|
||||
extracted_classes.add((module_path, class_name))
|
||||
|
||||
for name, module_name in imported_names.items():
|
||||
# Skip if already defined in context
|
||||
if name in existing_definitions:
|
||||
|
|
@ -593,32 +659,14 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
if path_belongs_to_site_packages(module_path):
|
||||
continue
|
||||
|
||||
# Skip if we've already extracted this class
|
||||
if (module_path, name) in extracted_classes:
|
||||
# Get module source and tree
|
||||
result = get_module_source_and_tree(module_path)
|
||||
if result is None:
|
||||
continue
|
||||
module_source, module_tree = result
|
||||
|
||||
# Parse the module to find the class definition
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
|
||||
for node in ast.walk(module_tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == name:
|
||||
# Extract the class source code, including decorators
|
||||
lines = module_source.split("\n")
|
||||
# Decorators start before the class line, use first decorator line if present
|
||||
start_line = node.lineno
|
||||
if node.decorator_list:
|
||||
start_line = min(d.lineno for d in node.decorator_list)
|
||||
class_source = "\n".join(lines[start_line - 1 : node.end_lineno])
|
||||
|
||||
# Also extract any necessary imports for the class (base classes, type hints)
|
||||
class_imports = extract_imports_for_class(module_tree, node, module_source)
|
||||
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
|
||||
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
|
||||
extracted_classes.add((module_path, name))
|
||||
break
|
||||
# Extract the class and its base classes
|
||||
extract_class_and_bases(name, module_path, module_source, module_tree)
|
||||
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue