feat: include __init__ signatures from directly imported external classes in testgen context
When generating regression tests, the LLM needs to know how to construct external types used as function parameters. This extends the testgen context to include __init__ signatures from external (site-packages) classes that are directly imported, complementing the existing base class init extraction.
This commit is contained in:
parent
1d9824c36c
commit
5449a32ade
1 changed files with 107 additions and 0 deletions
|
|
@ -70,6 +70,12 @@ def build_testgen_context(
|
|||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
|
||||
external_class_inits = get_external_class_inits(testgen_context, project_root_path)
|
||||
if external_class_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_class_inits.code_strings
|
||||
)
|
||||
|
||||
return testgen_context
|
||||
|
||||
|
||||
|
|
@ -821,6 +827,107 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
|
|||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
"""Extract __init__ methods from directly imported external library classes.
|
||||
|
||||
Scans the code context for classes imported from external packages (site-packages) and extracts
|
||||
their __init__ methods. This helps the LLM understand constructor signatures for instantiation
|
||||
in generated tests.
|
||||
"""
|
||||
import importlib
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
all_code = "\n".join(cs.code for cs in code_context.code_strings)
|
||||
|
||||
try:
|
||||
tree = ast.parse(all_code)
|
||||
except SyntaxError:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
# Collect all from X import Y statements
|
||||
imported_names: dict[str, str] = {}
|
||||
is_project_cache: dict[str, bool] = {}
|
||||
|
||||
# Track classes already defined in the context to avoid duplicates
|
||||
existing_classes: set[str] = set()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
if alias.name != "*":
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
imported_names[imported_name] = node.module
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
existing_classes.add(node.name)
|
||||
|
||||
if not imported_names:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
# Filter to external-only imports
|
||||
external_imports: set[tuple[str, str]] = set()
|
||||
for name, module_name in imported_names.items():
|
||||
if name in existing_classes:
|
||||
continue
|
||||
cached = is_project_cache.get(module_name)
|
||||
if cached is None:
|
||||
is_project = _is_project_module(module_name, project_root_path)
|
||||
is_project_cache[module_name] = is_project
|
||||
else:
|
||||
is_project = cached
|
||||
if not is_project:
|
||||
external_imports.add((name, module_name))
|
||||
|
||||
if not external_imports:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
imported_module_cache: dict[str, object] = {}
|
||||
|
||||
for class_name, module_name in external_imports:
|
||||
try:
|
||||
module = imported_module_cache.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
imported_module_cache[module_name] = module
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is None or not inspect.isclass(cls):
|
||||
continue
|
||||
|
||||
init_method = getattr(cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
continue
|
||||
|
||||
try:
|
||||
class_file = Path(inspect.getfile(cls))
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
continue
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
|
||||
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
|
||||
code_strings.append(CodeString(code=class_source, file_path=class_file))
|
||||
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to extract __init__ for {module_name}.{class_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
|
||||
"""Check if a module is part of the project (not external/stdlib)."""
|
||||
import importlib.util
|
||||
|
|
|
|||
Loading…
Reference in a new issue