Improve testgen constructor context extraction

This commit is contained in:
Kevin Turcios 2026-03-16 00:47:17 -06:00
parent cee12fe430
commit 282f2ba713
2 changed files with 661 additions and 65 deletions

View file

@ -621,55 +621,450 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
return set()
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = None
MAX_RAW_PROJECT_CLASS_BODY_ITEMS = 8
MAX_RAW_PROJECT_CLASS_LINES = 40
def _get_expr_name(node: ast.AST | None) -> str | None:
if node is None:
return None
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Attribute):
parent_name = _get_expr_name(node.value)
return node.attr if parent_name is None else f"{parent_name}.{node.attr}"
if isinstance(node, ast.Call):
return _get_expr_name(node.func)
return None
def _collect_import_aliases(module_tree: ast.Module) -> dict[str, str]:
aliases: dict[str, str] = {}
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
bound_name = alias.asname if alias.asname else alias.name.split(".")[0]
aliases[bound_name] = alias.name
elif isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
bound_name = alias.asname if alias.asname else alias.name
aliases[bound_name] = f"{node.module}.{alias.name}"
return aliases
def _find_class_node_by_name(class_name: str, module_tree: ast.Module) -> ast.ClassDef | None:
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
q: deque[ast.AST] = deque([module_tree])
while q:
candidate = q.popleft()
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
class_node = candidate
break
return candidate
q.extend(ast.iter_child_nodes(candidate))
return None
def _expr_matches_name(node: ast.AST | None, import_aliases: dict[str, str], suffix: str) -> bool:
expr_name = _get_expr_name(node)
if expr_name is None:
return False
if expr_name == suffix or expr_name.endswith(f".{suffix}"):
return True
resolved_name = import_aliases.get(expr_name)
return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(f".{suffix}"))
def _get_node_source(node: ast.AST | None, module_source: str, fallback: str = "...") -> str:
if node is None:
return fallback
source_segment = ast.get_source_segment(module_source, node)
if source_segment is not None:
return source_segment
try:
return ast.unparse(node)
except Exception:
return fallback
def _bool_literal(node: ast.AST) -> bool | None:
if isinstance(node, ast.Constant) and isinstance(node.value, bool):
return node.value
return None
def _is_namedtuple_class(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
return any(_expr_matches_name(base, import_aliases, "NamedTuple") for base in class_node.bases)
def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
for decorator in class_node.decorator_list:
if not _expr_matches_name(decorator, import_aliases, "dataclass"):
continue
init_enabled = True
kw_only = False
if isinstance(decorator, ast.Call):
for keyword in decorator.keywords:
literal_value = _bool_literal(keyword.value)
if literal_value is None:
continue
if keyword.arg == "init":
init_enabled = literal_value
elif keyword.arg == "kw_only":
kw_only = literal_value
return True, init_enabled, kw_only
return False, False, False
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
def _is_project_path(module_path: Path, project_root_path: Path) -> bool:
return str(module_path.resolve()).startswith(str(project_root_path.resolve()) + os.sep)
def _get_class_start_line(class_node: ast.ClassDef) -> int:
start_line = class_node.lineno
if class_node.decorator_list:
for decorator in class_node.decorator_list:
start_line = min(start_line, decorator.lineno)
return start_line
def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
return any(isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" for item in class_node.body)
def _collect_synthetic_constructor_type_names(
class_node: ast.ClassDef, import_aliases: dict[str, str]
) -> set[str]:
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
return set()
if is_dataclass and not dataclass_init_enabled:
return set()
names = set[str]()
for item in class_node.body:
if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name) or item.annotation is None:
continue
if _is_classvar_annotation(item.annotation, import_aliases):
continue
include_in_init = True
if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"):
for keyword in item.value.keywords:
if keyword.arg != "init":
continue
literal_value = _bool_literal(keyword.value)
if literal_value is not None:
include_in_init = literal_value
break
if include_in_init:
names |= collect_type_names_from_annotation(item.annotation)
return names
def _extract_synthetic_init_parameters(
class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str], *, kw_only_by_default: bool
) -> list[tuple[str, str, str | None, bool]]:
parameters: list[tuple[str, str, str | None, bool]] = []
for item in class_node.body:
if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name):
continue
if _is_classvar_annotation(item.annotation, import_aliases):
continue
include_in_init = True
kw_only = kw_only_by_default
default_value: str | None = None
if item.value is not None:
if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"):
for keyword in item.value.keywords:
if keyword.arg == "init":
literal_value = _bool_literal(keyword.value)
if literal_value is not None:
include_in_init = literal_value
elif keyword.arg == "kw_only":
literal_value = _bool_literal(keyword.value)
if literal_value is not None:
kw_only = literal_value
elif keyword.arg == "default":
default_value = _get_node_source(keyword.value, module_source)
elif keyword.arg == "default_factory":
# Default factories still imply an optional constructor parameter, but
# the generated __init__ does not use the field() call directly.
default_value = "..."
else:
default_value = _get_node_source(item.value, module_source)
if not include_in_init:
continue
parameters.append((item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only))
return parameters
def _build_synthetic_init_stub(
class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str]
) -> str | None:
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass:
return None
if is_dataclass and not dataclass_init_enabled:
return None
parameters = _extract_synthetic_init_parameters(
class_node,
module_source,
import_aliases,
kw_only_by_default=dataclass_kw_only,
)
if not parameters:
return None
signature_parts = ["self"]
inserted_kw_only_marker = False
for param_name, annotation_source, default_value, kw_only in parameters:
if kw_only and not inserted_kw_only_marker:
signature_parts.append("*")
inserted_kw_only_marker = True
part = f"{param_name}: {annotation_source}"
if default_value is not None:
part += f" = {default_value}"
signature_parts.append(part)
signature = ", ".join(signature_parts)
return f" def __init__({signature}):\n ..."
def _extract_function_stub_snippet(
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]
) -> str:
start_line = fn_node.lineno
if fn_node.decorator_list:
for decorator in fn_node.decorator_list:
start_line = min(start_line, decorator.lineno)
return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno])
def _extract_raw_class_context(class_node: ast.ClassDef, module_source: str, module_tree: ast.Module) -> str:
class_source = "\n".join(module_source.splitlines()[_get_class_start_line(class_node) - 1 : class_node.end_lineno])
needed_imports = extract_imports_for_class(module_tree, class_node, module_source)
if needed_imports:
return f"{needed_imports}\n\n{class_source}"
return class_source
def _has_non_property_method_decorator(
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, import_aliases: dict[str, str]
) -> bool:
for decorator in fn_node.decorator_list:
if _expr_matches_name(decorator, import_aliases, "property"):
continue
decorator_name = _get_expr_name(decorator)
if decorator_name is not None and decorator_name.endswith(".setter"):
continue
if decorator_name is not None and decorator_name.endswith(".deleter"):
continue
return True
return False
def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool:
for item in class_node.body:
if isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
return True
if isinstance(item, ast.AnnAssign) and isinstance(item.value, ast.Call):
return True
return False
def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
start_line = _get_class_start_line(class_node)
class_line_count = class_node.end_lineno - start_line + 1
is_small = class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
if is_small and _class_has_explicit_init(class_node):
return True
if _is_namedtuple_class(class_node, import_aliases):
return True
is_dataclass, _, _ = _get_dataclass_config(class_node, import_aliases)
if is_dataclass:
return True
if class_node.decorator_list:
return True
if _has_descriptor_like_class_fields(class_node):
return True
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and _has_non_property_method_decorator(
item, import_aliases
):
return True
return False
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
class_node = _find_class_node_by_name(class_name, module_tree)
if class_node is None:
return None
lines = module_source.splitlines()
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
import_aliases = _collect_import_aliases(module_tree)
explicit_init_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
support_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
is_relevant = False
if item.name in ("__init__", "__post_init__"):
is_relevant = True
else:
# Check decorators explicitly to avoid generator overhead
for d in item.decorator_list:
if (isinstance(d, ast.Name) and d.id == "property") or (
isinstance(d, ast.Attribute) and d.attr == "property"
):
is_relevant = True
break
if is_relevant:
relevant_nodes.append(item)
if not relevant_nodes:
return None
if item.name == "__init__":
explicit_init_nodes.append(item)
support_nodes.append(item)
continue
if item.name == "__post_init__":
support_nodes.append(item)
continue
# Check decorators explicitly to avoid generator overhead
for d in item.decorator_list:
if (isinstance(d, ast.Name) and d.id == "property") or (
isinstance(d, ast.Attribute) and d.attr == "property"
):
support_nodes.append(item)
break
snippets: list[str] = []
for fn_node in relevant_nodes:
start = fn_node.lineno
if fn_node.decorator_list:
# Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead)
m = start
for d in fn_node.decorator_list:
m = min(m, d.lineno)
start = m
snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno]))
if explicit_init_nodes:
for fn_node in support_nodes:
snippets.append(_extract_function_stub_snippet(fn_node, lines))
else:
synthetic_init = _build_synthetic_init_stub(class_node, module_source, import_aliases)
if synthetic_init is not None:
snippets.append(synthetic_init)
for fn_node in support_nodes:
snippets.append(_extract_function_stub_snippet(fn_node, lines))
if not snippets:
return None
return f"class {class_name}:\n" + "\n".join(snippets)
def _get_module_source_and_tree(
module_path: Path, module_cache: dict[Path, tuple[str, ast.Module]]
) -> tuple[str, ast.Module] | None:
if module_path in module_cache:
return module_cache[module_path]
try:
module_source = module_path.read_text(encoding="utf-8")
module_tree = ast.parse(module_source)
except Exception:
return None
module_cache[module_path] = (module_source, module_tree)
return module_source, module_tree
def _resolve_imported_class_reference(
base_expr_name: str,
current_module_tree: ast.Module,
current_module_path: Path,
project_root_path: Path,
module_cache: dict[Path, tuple[str, ast.Module]],
) -> tuple[str, Path] | None:
import jedi
import_aliases = _collect_import_aliases(current_module_tree)
class_name = base_expr_name.rsplit(".", 1)[-1]
if "." not in base_expr_name and _find_class_node_by_name(class_name, current_module_tree) is not None:
return class_name, current_module_path
resolved_name = base_expr_name
if base_expr_name in import_aliases:
resolved_name = import_aliases[base_expr_name]
elif "." in base_expr_name:
head, tail = base_expr_name.split(".", 1)
if head in import_aliases:
resolved_name = f"{import_aliases[head]}.{tail}"
if "." not in resolved_name:
return None
module_name, class_name = resolved_name.rsplit(".", 1)
try:
script_code = f"from {module_name} import {class_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True)
except Exception:
return None
if not definitions or definitions[0].module_path is None:
return None
module_path = definitions[0].module_path
if not _is_project_path(module_path, project_root_path):
return None
if _get_module_source_and_tree(module_path, module_cache) is None:
return None
return class_name, module_path
def _append_project_class_context(
class_name: str,
module_path: Path,
project_root_path: Path,
module_cache: dict[Path, tuple[str, ast.Module]],
existing_class_names: set[str],
emitted_classes: set[tuple[Path, str]],
emitted_class_names: set[str],
code_strings: list[CodeString],
) -> bool:
module_result = _get_module_source_and_tree(module_path, module_cache)
if module_result is None:
return False
module_source, module_tree = module_result
class_node = _find_class_node_by_name(class_name, module_tree)
if class_node is None:
return False
class_key = (module_path, class_name)
if class_key in emitted_classes or class_name in existing_class_names:
return True
for base in class_node.bases:
base_expr_name = _get_expr_name(base)
if base_expr_name is None:
continue
resolved = _resolve_imported_class_reference(
base_expr_name,
module_tree,
module_path,
project_root_path,
module_cache,
)
if resolved is None:
continue
base_name, base_module_path = resolved
if base_name in existing_class_names:
continue
_append_project_class_context(
base_name,
base_module_path,
project_root_path,
module_cache,
existing_class_names,
emitted_classes,
emitted_class_names,
code_strings,
)
code_strings.append(CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path))
emitted_classes.add(class_key)
emitted_class_names.add(class_name)
return True
def extract_parameter_type_constructors(
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
) -> CodeStringsMarkdown:
@ -751,35 +1146,68 @@ def extract_parameter_type_constructors(
code_strings: list[CodeString] = []
module_cache: dict[Path, tuple[str, ast.Module]] = {}
emitted_classes: set[tuple[Path, str]] = set()
emitted_class_names: set[str] = set()
for type_name in sorted(type_names):
module_name = import_map.get(type_name)
if not module_name:
continue
def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None:
try:
script_code = f"from {module_name} import {type_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
if not definitions:
continue
return
module_path = definitions[0].module_path
if not module_path:
continue
return
resolved_module = module_path.resolve()
module_str = str(resolved_module)
is_project = _is_project_path(module_path, project_root_path)
is_third_party = "site-packages" in module_str
if transitive and not is_project and not is_third_party:
return
if module_path in module_cache:
mod_source, mod_tree = module_cache[module_path]
else:
mod_source = module_path.read_text(encoding="utf-8")
mod_tree = ast.parse(mod_source)
module_cache[module_path] = (mod_source, mod_tree)
module_result = _get_module_source_and_tree(module_path, module_cache)
if module_result is None:
return
mod_source, mod_tree = module_result
class_key = (module_path, type_name)
if class_key in emitted_classes or type_name in existing_class_names:
return
class_node = _find_class_node_by_name(type_name, mod_tree)
if class_node is not None and is_project:
import_aliases = _collect_import_aliases(mod_tree)
if _should_use_raw_project_class_context(class_node, import_aliases):
if _append_project_class_context(
type_name,
module_path,
project_root_path,
module_cache,
existing_class_names,
emitted_classes,
emitted_class_names,
code_strings,
):
return
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
emitted_classes.add(class_key)
emitted_class_names.add(type_name)
except Exception:
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
if transitive:
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
else:
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
for type_name in sorted(type_names):
module_name = import_map.get(type_name)
if not module_name:
continue
append_type_context(type_name, module_name)
# Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs
# Build an extended import map that includes imports from source modules of already-extracted stubs
@ -792,13 +1220,14 @@ def extract_parameter_type_constructors(
if name not in transitive_import_map:
transitive_import_map[name] = cache_node.module
emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES
emitted_names = type_names | existing_class_names | emitted_class_names | BUILTIN_AND_TYPING_NAMES
transitive_type_names: set[str] = set()
for cs in code_strings:
try:
stub_tree = ast.parse(cs.code)
except SyntaxError:
continue
import_aliases = _collect_import_aliases(stub_tree)
for stub_node in ast.walk(stub_tree):
if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in (
"__init__",
@ -806,32 +1235,14 @@ def extract_parameter_type_constructors(
):
for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs:
transitive_type_names |= collect_type_names_from_annotation(arg.annotation)
elif isinstance(stub_node, ast.ClassDef):
transitive_type_names |= _collect_synthetic_constructor_type_names(stub_node, import_aliases)
transitive_type_names -= emitted_names
for type_name in sorted(transitive_type_names):
module_name = transitive_import_map.get(type_name)
if not module_name:
continue
try:
script_code = f"from {module_name} import {type_name}"
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
if not definitions:
continue
module_path = definitions[0].module_path
if not module_path:
continue
if module_path in module_cache:
mod_source, mod_tree = module_cache[module_path]
else:
mod_source = module_path.read_text(encoding="utf-8")
mod_tree = ast.parse(mod_source)
module_cache[module_path] = (mod_source, mod_tree)
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
if stub:
code_strings.append(CodeString(code=stub, file_path=module_path))
except Exception:
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
continue
append_type_context(type_name, module_name, transitive=True)
return CodeStringsMarkdown(code_strings=code_strings)

View file

@ -4501,6 +4501,104 @@ def process(w: Widget) -> str:
assert "size" in code
def test_extract_parameter_type_constructors_stdlib_type(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "processor.py").write_text(
"""from argparse import Namespace
def process(ns: Namespace) -> str:
return str(ns)
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class Namespace:" in code
assert "def __init__(self, **kwargs):" in code
def test_extract_parameter_type_constructors_namedtuple_project_type(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""from pathlib import Path
from typing import NamedTuple
class FunctionNode(NamedTuple):
file_path: Path
qualified_name: str
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import FunctionNode
def process(node: FunctionNode) -> str:
return node.qualified_name
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "class FunctionNode(NamedTuple):" in code
assert "file_path: Path" in code
assert "qualified_name: str" in code
def test_extract_parameter_type_constructors_uses_raw_project_context_for_small_class(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "models.py").write_text(
"""from functools import total_ordering
@total_ordering
class Rank:
def __init__(self, value: int):
self.value = value
def __lt__(self, other: "Rank") -> bool:
return self.value < other.value
def __eq__(self, other: object) -> bool:
return isinstance(other, Rank) and self.value == other.value
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import Rank
def process(rank: Rank) -> int:
return rank.value
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
assert len(result.code_strings) == 1
code = result.code_strings[0].code
assert "from functools import total_ordering" in code
assert "@total_ordering" in code
assert "def __lt__" in code
assert "def __eq__" in code
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
@ -4788,6 +4886,48 @@ def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> Non
assert "class BaseProcessor:" in result.code_strings[0].code
def test_extract_parameter_type_constructors_attribute_base_prefers_imported_project_class(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "external.py").write_text(
"""class Base:
def __init__(self, x: int):
self.x = x
""",
encoding="utf-8",
)
(pkg / "models.py").write_text(
"""import mypkg.external as ext
class Base:
pass
class Child(ext.Base):
def __init__(self, x: int):
super().__init__(x)
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import Child
def process(c: Child) -> int:
return c.x
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
combined = "\n".join(cs.code for cs in result.code_strings)
assert "class Child(ext.Base):" in combined
assert "self.x = x" in combined
assert "class Base:\n pass" not in combined
def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None:
"""Isinstance with builtins (int, str, etc.) should not produce stubs."""
pkg = tmp_path / "mypkg"
@ -4828,6 +4968,51 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
assert "class Config:" in combined
def test_extract_parameter_type_constructors_uses_raw_project_context_for_dataclass_inheritance(tmp_path: Path) -> None:
pkg = tmp_path / "mypkg"
pkg.mkdir()
(pkg / "__init__.py").write_text("", encoding="utf-8")
(pkg / "base.py").write_text(
"""from dataclasses import dataclass
from pathlib import Path
@dataclass
class BaseConfig:
file_path: Path
""",
encoding="utf-8",
)
(pkg / "models.py").write_text(
"""from dataclasses import dataclass
from mypkg.base import BaseConfig
@dataclass
class ChildConfig(BaseConfig):
qualified_name: str
""",
encoding="utf-8",
)
(pkg / "processor.py").write_text(
"""from mypkg.models import ChildConfig
def process(cfg: ChildConfig) -> str:
return cfg.qualified_name
""",
encoding="utf-8",
)
fto = FunctionToOptimize(
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
)
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
combined = "\n".join(cs.code for cs in result.code_strings)
assert "@dataclass" in combined
assert "class BaseConfig" in combined
assert "file_path: Path" in combined
assert "class ChildConfig(BaseConfig):" in combined
assert "qualified_name: str" in combined
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it