feat: improve test generation context for external library types

Extend extract_parameter_type_constructors to scan function bodies for
isinstance/type() patterns and collect base class names from enclosing
classes. Add one-level transitive stub extraction so the LLM also sees
constructor signatures for types referenced in __init__ parameters.

In enrich_testgen_context, branch on source: project classes get full
definitions, third-party (site-packages) classes get compact __init__
stubs to avoid blowing token limits.
This commit is contained in:
Kevin Turcios 2026-02-26 09:40:04 -05:00
parent d284be521d
commit 5cee1b5b48
2 changed files with 270 additions and 8 deletions

View file

@ -837,6 +837,41 @@ def extract_parameter_type_constructors(
if func_node.args.kwarg:
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
# Scan function body for isinstance(x, SomeType) and type(x) is/== SomeType patterns
for body_node in ast.walk(func_node):
if (
isinstance(body_node, ast.Call)
and isinstance(body_node.func, ast.Name)
and body_node.func.id == "isinstance"
):
if len(body_node.args) >= 2:
second_arg = body_node.args[1]
if isinstance(second_arg, ast.Name):
type_names.add(second_arg.id)
elif isinstance(second_arg, ast.Tuple):
for elt in second_arg.elts:
if isinstance(elt, ast.Name):
type_names.add(elt.id)
elif isinstance(body_node, ast.Compare):
# type(x) is/== SomeType
if (
isinstance(body_node.left, ast.Call)
and isinstance(body_node.left.func, ast.Name)
and body_node.left.func.id == "type"
):
for comparator in body_node.comparators:
if isinstance(comparator, ast.Name):
type_names.add(comparator.id)
# Collect base class names from enclosing class (if this is a method)
if function_to_optimize.class_name is not None:
for top_node in ast.walk(tree):
if isinstance(top_node, ast.ClassDef) and top_node.name == function_to_optimize.class_name:
for base in top_node.bases:
if isinstance(base, ast.Name):
type_names.add(base.id)
break
type_names -= BUILTIN_AND_TYPING_NAMES
type_names -= existing_class_names
if not type_names:
@ -881,6 +916,58 @@ def extract_parameter_type_constructors(
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
continue
# Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs
# Build an extended import map that includes imports from source modules of already-extracted stubs
transitive_import_map = dict(import_map)
for _, cached_tree in module_cache.values():
for cache_node in ast.walk(cached_tree):
if isinstance(cache_node, ast.ImportFrom) and cache_node.module:
for alias in cache_node.names:
name = alias.asname if alias.asname else alias.name
if name not in transitive_import_map:
transitive_import_map[name] = cache_node.module
emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES
transitive_type_names: set[str] = set()
for cs in code_strings:
try:
stub_tree = ast.parse(cs.code)
except SyntaxError:
continue
for stub_node in ast.walk(stub_tree):
if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in (
"__init__",
"__post_init__",
):
for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs:
transitive_type_names |= collect_type_names_from_annotation(arg.annotation)
transitive_type_names -= emitted_names
for type_name in sorted(transitive_type_names):
module_name = transitive_import_map.get(type_name)
if not module_name:
continue
try:
script_code = f"from {module_name} import {type_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
if not definitions:
continue
module_path = definitions[0].module_path
if not module_path:
continue
if module_path in module_cache:
mod_source, mod_tree = module_cache[module_path]
else:
mod_source = module_path.read_text(encoding="utf-8")
mod_tree = ast.parse(mod_source)
module_cache[module_path] = (mod_source, mod_tree)
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
except Exception:
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
@ -1004,12 +1091,23 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
continue
module_source, module_tree = mod_result
extract_class_and_bases(name, module_path, module_source, module_tree)
if (module_path, name) not in extracted_classes:
resolved_class = resolve_instance_class_name(name, module_tree)
if resolved_class and resolved_class not in existing_classes:
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
if is_project:
extract_class_and_bases(name, module_path, module_source, module_tree)
if (module_path, name) not in extracted_classes:
resolved_class = resolve_instance_class_name(name, module_tree)
if resolved_class and resolved_class not in existing_classes:
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
elif is_third_party:
target_name = name
if not any(isinstance(n, ast.ClassDef) and n.name == name for n in ast.walk(module_tree)):
resolved_class = resolve_instance_class_name(name, module_tree)
if resolved_class:
target_name = resolved_class
if target_name not in emitted_class_names:
stub = extract_init_stub_from_class(target_name, module_source, module_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
emitted_class_names.add(target_name)
except Exception:
logger.debug(f"Error extracting class definition for {name} from {module_name}")

View file

@ -9,8 +9,6 @@ from pathlib import Path
import pytest
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_type_names_from_annotation,
@ -20,6 +18,8 @@ from codeflash.languages.python.context.code_context_extractor import (
get_code_optimization_context,
resolve_instance_class_name,
)
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.optimizer import Optimizer
@ -4701,3 +4701,167 @@ def get_log_level() -> str:
combined = "\n".join(cs.code for cs in result.code_strings)
assert "class AppConfig:" in combined
assert "@property" in combined
def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None:
"""isinstance(x, SomeType) in function body should be picked up."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"class Widget:\n def __init__(self, size: int):\n self.size = size\n",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
assert "class Widget:" in result.code_strings[0].code
assert "__init__" in result.code_strings[0].code
def test_extract_parameter_type_constructors_isinstance_tuple(tmp_path: Path) -> None:
"""isinstance(x, (TypeA, TypeB)) should pick up both types."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"class Alpha:\n def __init__(self, a: int):\n self.a = a\n\n"
"class Beta:\n def __init__(self, b: str):\n self.b = b\n",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"from mypkg.models import Alpha, Beta\n\ndef check(x) -> bool:\n return isinstance(x, (Alpha, Beta))\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 2
combined = "\n".join(cs.code for cs in result.code_strings)
assert "class Alpha:" in combined
assert "class Beta:" in combined
def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) -> None:
"""type(x) is SomeType pattern should be picked up."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
assert "class Gadget:" in result.code_strings[0].code
def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> None:
"""Base classes of enclosing class should be picked up for methods."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "base.py").write_text(
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n",
encoding="utf-8",
)
(pkg / "child.py").write_text(
"from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n"
" def process(self) -> str:\n return self.config\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process",
file_path=(pkg / "child.py").resolve(),
starting_line=4,
ending_line=5,
parents=[FunctionParent(name="ChildProcessor", type="ClassDef")],
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
assert "class BaseProcessor:" in result.code_strings[0].code
def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None:
"""Isinstance with builtins (int, str, etc.) should not produce stubs."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "func.py").write_text(
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 0
def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
"""Transitive extraction: if Widget.__init__ takes a Config, Config's stub should also appear."""
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "config.py").write_text(
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n",
encoding="utf-8",
)
(pkg / "models.py").write_text(
"from mypkg.config import Config\n\n"
"class Widget:\n def __init__(self, cfg: Config):\n self.cfg = cfg\n",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
combined = "\n".join(cs.code for cs in result.code_strings)
assert "class Widget:" in combined
assert "class Config:" in combined
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it
context_code = (
"from pydantic import BaseModel\n\n"
"class MyModel(BaseModel):\n"
" name: str\n\n"
"def process(m: MyModel) -> str:\n"
" return m.name\n"
)
consumer_path = tmp_path / "consumer.py"
consumer_path.write_text(context_code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=context_code, file_path=consumer_path)])
result = enrich_testgen_context(context, tmp_path)
# BaseModel lives in site-packages so should get stub treatment (compact __init__),
# not the full class definition with hundreds of methods
for cs in result.code_strings:
if "BaseModel" in cs.code:
assert "class BaseModel:" in cs.code
assert "__init__" in cs.code
# Full BaseModel has many methods; stubs should only have __init__/properties
assert "model_dump" not in cs.code
break