feat: include __init__ signatures from directly imported external classes in testgen context

When generating regression tests, the LLM needs to know how to construct
external types used as function parameters. This extends the testgen context
to include __init__ signatures from external (site-packages) classes that
are directly imported, complementing the existing base class init extraction.
This commit is contained in:
Kevin Turcios 2026-02-13 08:42:22 -05:00
parent 1d9824c36c
commit 5449a32ade

View file

@ -70,6 +70,12 @@ def build_testgen_context(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
external_class_inits = get_external_class_inits(testgen_context, project_root_path)
if external_class_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_class_inits.code_strings
)
return testgen_context
@ -821,6 +827,107 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
return CodeStringsMarkdown(code_strings=code_strings)
def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
"""Extract __init__ methods from directly imported external library classes.
Scans the code context for classes imported from external packages (site-packages) and extracts
their __init__ methods. This helps the LLM understand constructor signatures for instantiation
in generated tests.
"""
import importlib
import inspect
import textwrap
all_code = "\n".join(cs.code for cs in code_context.code_strings)
try:
tree = ast.parse(all_code)
except SyntaxError:
return CodeStringsMarkdown(code_strings=[])
# Collect all from X import Y statements
imported_names: dict[str, str] = {}
is_project_cache: dict[str, bool] = {}
# Track classes already defined in the context to avoid duplicates
existing_classes: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
if alias.name != "*":
imported_name = alias.asname if alias.asname else alias.name
imported_names[imported_name] = node.module
elif isinstance(node, ast.ClassDef):
existing_classes.add(node.name)
if not imported_names:
return CodeStringsMarkdown(code_strings=[])
# Filter to external-only imports
external_imports: set[tuple[str, str]] = set()
for name, module_name in imported_names.items():
if name in existing_classes:
continue
cached = is_project_cache.get(module_name)
if cached is None:
is_project = _is_project_module(module_name, project_root_path)
is_project_cache[module_name] = is_project
else:
is_project = cached
if not is_project:
external_imports.add((name, module_name))
if not external_imports:
return CodeStringsMarkdown(code_strings=[])
code_strings: list[CodeString] = []
imported_module_cache: dict[str, object] = {}
for class_name, module_name in external_imports:
try:
module = imported_module_cache.get(module_name)
if module is None:
module = importlib.import_module(module_name)
imported_module_cache[module_name] = module
cls = getattr(module, class_name, None)
if cls is None or not inspect.isclass(cls):
continue
init_method = getattr(cls, "__init__", None)
if init_method is None or init_method is object.__init__:
continue
try:
class_file = Path(inspect.getfile(cls))
except (OSError, TypeError):
continue
if not path_belongs_to_site_packages(class_file):
continue
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
except (OSError, TypeError):
continue
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{class_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
"""Check if a module is part of the project (not external/stdlib)."""
import importlib.util