fix: recursively extract base classes for imported dataclasses

When extracting imported class definitions for testgen context, also
extract base classes from the same module. This ensures the full
inheritance chain is available for understanding constructor signatures.

For example, when LLMConfig inherits from LLMConfigBase, both classes
are now included in the context so the LLM can see all required
positional arguments from parent classes.
This commit is contained in:
Kevin Turcios 2026-01-23 07:25:56 -05:00
parent 722d05345d
commit 4e310324b4

View file

@ -526,6 +526,10 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
the LLM understand the actual class structure (constructors, methods, inheritance)
rather than just seeing import statements.
Also recursively extracts base classes when a class inherits from another class
in the same module, ensuring the full inheritance chain is available for
understanding constructor signatures.
Args:
code_context: The already extracted code context containing imports
project_root_path: Root path of the project
@ -568,6 +572,68 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
class_code_strings: list[CodeString] = []
module_cache: dict[Path, tuple[str, ast.Module]] = {}
def get_module_source_and_tree(module_path: Path) -> tuple[str, ast.Module] | None:
if module_path in module_cache:
return module_cache[module_path]
try:
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
except Exception:
return None
else:
module_cache[module_path] = (module_source, module_tree)
return module_source, module_tree
def extract_class_and_bases(
class_name: str, module_path: Path, module_source: str, module_tree: ast.Module
) -> None:
"""Extract a class and its base classes recursively from the same module."""
# Skip if already extracted
if (module_path, class_name) in extracted_classes:
return
# Find the class definition in the module
class_node = None
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
break
if class_node is None:
return
# First, recursively extract base classes from the same module
for base in class_node.bases:
base_name = None
if isinstance(base, ast.Name):
base_name = base.id
elif isinstance(base, ast.Attribute):
# For module.ClassName, we skip (cross-module inheritance)
continue
if base_name and base_name not in existing_definitions:
# Check if base class is defined in the same module
extract_class_and_bases(base_name, module_path, module_source, module_tree)
# Now extract this class (after its bases, so base classes appear first)
if (module_path, class_name) in extracted_classes:
return # Already added by another path
lines = module_source.split("\n")
start_line = class_node.lineno
if class_node.decorator_list:
start_line = min(d.lineno for d in class_node.decorator_list)
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
# Extract imports for the class
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
extracted_classes.add((module_path, class_name))
for name, module_name in imported_names.items():
# Skip if already defined in context
if name in existing_definitions:
@ -593,32 +659,14 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
if path_belongs_to_site_packages(module_path):
continue
# Skip if we've already extracted this class
if (module_path, name) in extracted_classes:
# Get module source and tree
result = get_module_source_and_tree(module_path)
if result is None:
continue
module_source, module_tree = result
# Parse the module to find the class definition
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == name:
# Extract the class source code, including decorators
lines = module_source.split("\n")
# Decorators start before the class line, use first decorator line if present
start_line = node.lineno
if node.decorator_list:
start_line = min(d.lineno for d in node.decorator_list)
class_source = "\n".join(lines[start_line - 1 : node.end_lineno])
# Also extract any necessary imports for the class (base classes, type hints)
class_imports = extract_imports_for_class(module_tree, node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
class_code_strings.append(CodeString(code=full_source, file_path=module_path))
extracted_classes.add((module_path, name))
break
# Extract the class and its base classes
extract_class_and_bases(name, module_path, module_source, module_tree)
except Exception:
logger.debug(f"Error extracting class definition for {name} from {module_name}")