mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Improve testgen constructor context extraction
This commit is contained in:
parent
cee12fe430
commit
282f2ba713
2 changed files with 661 additions and 65 deletions
|
|
@ -621,55 +621,450 @@ def collect_type_names_from_annotation(node: ast.expr | None) -> set[str]:
|
|||
return set()
|
||||
|
||||
|
||||
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
|
||||
class_node = None
|
||||
MAX_RAW_PROJECT_CLASS_BODY_ITEMS = 8
|
||||
MAX_RAW_PROJECT_CLASS_LINES = 40
|
||||
|
||||
|
||||
def _get_expr_name(node: ast.AST | None) -> str | None:
|
||||
if node is None:
|
||||
return None
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
if isinstance(node, ast.Attribute):
|
||||
parent_name = _get_expr_name(node.value)
|
||||
return node.attr if parent_name is None else f"{parent_name}.{node.attr}"
|
||||
if isinstance(node, ast.Call):
|
||||
return _get_expr_name(node.func)
|
||||
return None
|
||||
|
||||
|
||||
def _collect_import_aliases(module_tree: ast.Module) -> dict[str, str]:
|
||||
aliases: dict[str, str] = {}
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
bound_name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
aliases[bound_name] = alias.name
|
||||
elif isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
bound_name = alias.asname if alias.asname else alias.name
|
||||
aliases[bound_name] = f"{node.module}.{alias.name}"
|
||||
return aliases
|
||||
|
||||
|
||||
def _find_class_node_by_name(class_name: str, module_tree: ast.Module) -> ast.ClassDef | None:
|
||||
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
|
||||
q: deque[ast.AST] = deque([module_tree])
|
||||
while q:
|
||||
candidate = q.popleft()
|
||||
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
|
||||
class_node = candidate
|
||||
break
|
||||
return candidate
|
||||
q.extend(ast.iter_child_nodes(candidate))
|
||||
return None
|
||||
|
||||
|
||||
def _expr_matches_name(node: ast.AST | None, import_aliases: dict[str, str], suffix: str) -> bool:
|
||||
expr_name = _get_expr_name(node)
|
||||
if expr_name is None:
|
||||
return False
|
||||
if expr_name == suffix or expr_name.endswith(f".{suffix}"):
|
||||
return True
|
||||
resolved_name = import_aliases.get(expr_name)
|
||||
return resolved_name is not None and (resolved_name == suffix or resolved_name.endswith(f".{suffix}"))
|
||||
|
||||
|
||||
def _get_node_source(node: ast.AST | None, module_source: str, fallback: str = "...") -> str:
|
||||
if node is None:
|
||||
return fallback
|
||||
source_segment = ast.get_source_segment(module_source, node)
|
||||
if source_segment is not None:
|
||||
return source_segment
|
||||
try:
|
||||
return ast.unparse(node)
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
|
||||
def _bool_literal(node: ast.AST) -> bool | None:
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, bool):
|
||||
return node.value
|
||||
return None
|
||||
|
||||
|
||||
def _is_namedtuple_class(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
|
||||
return any(_expr_matches_name(base, import_aliases, "NamedTuple") for base in class_node.bases)
|
||||
|
||||
|
||||
def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
||||
for decorator in class_node.decorator_list:
|
||||
if not _expr_matches_name(decorator, import_aliases, "dataclass"):
|
||||
continue
|
||||
init_enabled = True
|
||||
kw_only = False
|
||||
if isinstance(decorator, ast.Call):
|
||||
for keyword in decorator.keywords:
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is None:
|
||||
continue
|
||||
if keyword.arg == "init":
|
||||
init_enabled = literal_value
|
||||
elif keyword.arg == "kw_only":
|
||||
kw_only = literal_value
|
||||
return True, init_enabled, kw_only
|
||||
return False, False, False
|
||||
|
||||
|
||||
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
|
||||
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
|
||||
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
|
||||
|
||||
|
||||
def _is_project_path(module_path: Path, project_root_path: Path) -> bool:
|
||||
return str(module_path.resolve()).startswith(str(project_root_path.resolve()) + os.sep)
|
||||
|
||||
|
||||
def _get_class_start_line(class_node: ast.ClassDef) -> int:
|
||||
start_line = class_node.lineno
|
||||
if class_node.decorator_list:
|
||||
for decorator in class_node.decorator_list:
|
||||
start_line = min(start_line, decorator.lineno)
|
||||
return start_line
|
||||
|
||||
|
||||
def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
|
||||
return any(isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and item.name == "__init__" for item in class_node.body)
|
||||
|
||||
|
||||
def _collect_synthetic_constructor_type_names(
|
||||
class_node: ast.ClassDef, import_aliases: dict[str, str]
|
||||
) -> set[str]:
|
||||
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
|
||||
return set()
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return set()
|
||||
|
||||
names = set[str]()
|
||||
for item in class_node.body:
|
||||
if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name) or item.annotation is None:
|
||||
continue
|
||||
if _is_classvar_annotation(item.annotation, import_aliases):
|
||||
continue
|
||||
|
||||
include_in_init = True
|
||||
if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"):
|
||||
for keyword in item.value.keywords:
|
||||
if keyword.arg != "init":
|
||||
continue
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is not None:
|
||||
include_in_init = literal_value
|
||||
break
|
||||
|
||||
if include_in_init:
|
||||
names |= collect_type_names_from_annotation(item.annotation)
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def _extract_synthetic_init_parameters(
|
||||
class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str], *, kw_only_by_default: bool
|
||||
) -> list[tuple[str, str, str | None, bool]]:
|
||||
parameters: list[tuple[str, str, str | None, bool]] = []
|
||||
for item in class_node.body:
|
||||
if not isinstance(item, ast.AnnAssign) or not isinstance(item.target, ast.Name):
|
||||
continue
|
||||
if _is_classvar_annotation(item.annotation, import_aliases):
|
||||
continue
|
||||
|
||||
include_in_init = True
|
||||
kw_only = kw_only_by_default
|
||||
default_value: str | None = None
|
||||
if item.value is not None:
|
||||
if isinstance(item.value, ast.Call) and _expr_matches_name(item.value.func, import_aliases, "field"):
|
||||
for keyword in item.value.keywords:
|
||||
if keyword.arg == "init":
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is not None:
|
||||
include_in_init = literal_value
|
||||
elif keyword.arg == "kw_only":
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is not None:
|
||||
kw_only = literal_value
|
||||
elif keyword.arg == "default":
|
||||
default_value = _get_node_source(keyword.value, module_source)
|
||||
elif keyword.arg == "default_factory":
|
||||
# Default factories still imply an optional constructor parameter, but
|
||||
# the generated __init__ does not use the field() call directly.
|
||||
default_value = "..."
|
||||
else:
|
||||
default_value = _get_node_source(item.value, module_source)
|
||||
|
||||
if not include_in_init:
|
||||
continue
|
||||
|
||||
parameters.append((item.target.id, _get_node_source(item.annotation, module_source, "Any"), default_value, kw_only))
|
||||
return parameters
|
||||
|
||||
|
||||
def _build_synthetic_init_stub(
|
||||
class_node: ast.ClassDef, module_source: str, import_aliases: dict[str, str]
|
||||
) -> str | None:
|
||||
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
|
||||
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
|
||||
if not is_namedtuple and not is_dataclass:
|
||||
return None
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return None
|
||||
|
||||
parameters = _extract_synthetic_init_parameters(
|
||||
class_node,
|
||||
module_source,
|
||||
import_aliases,
|
||||
kw_only_by_default=dataclass_kw_only,
|
||||
)
|
||||
if not parameters:
|
||||
return None
|
||||
|
||||
signature_parts = ["self"]
|
||||
inserted_kw_only_marker = False
|
||||
for param_name, annotation_source, default_value, kw_only in parameters:
|
||||
if kw_only and not inserted_kw_only_marker:
|
||||
signature_parts.append("*")
|
||||
inserted_kw_only_marker = True
|
||||
part = f"{param_name}: {annotation_source}"
|
||||
if default_value is not None:
|
||||
part += f" = {default_value}"
|
||||
signature_parts.append(part)
|
||||
|
||||
signature = ", ".join(signature_parts)
|
||||
return f" def __init__({signature}):\n ..."
|
||||
|
||||
|
||||
def _extract_function_stub_snippet(
|
||||
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, module_lines: list[str]
|
||||
) -> str:
|
||||
start_line = fn_node.lineno
|
||||
if fn_node.decorator_list:
|
||||
for decorator in fn_node.decorator_list:
|
||||
start_line = min(start_line, decorator.lineno)
|
||||
return "\n".join(module_lines[start_line - 1 : fn_node.end_lineno])
|
||||
|
||||
|
||||
def _extract_raw_class_context(class_node: ast.ClassDef, module_source: str, module_tree: ast.Module) -> str:
|
||||
class_source = "\n".join(module_source.splitlines()[_get_class_start_line(class_node) - 1 : class_node.end_lineno])
|
||||
needed_imports = extract_imports_for_class(module_tree, class_node, module_source)
|
||||
if needed_imports:
|
||||
return f"{needed_imports}\n\n{class_source}"
|
||||
return class_source
|
||||
|
||||
|
||||
def _has_non_property_method_decorator(
|
||||
fn_node: ast.FunctionDef | ast.AsyncFunctionDef, import_aliases: dict[str, str]
|
||||
) -> bool:
|
||||
for decorator in fn_node.decorator_list:
|
||||
if _expr_matches_name(decorator, import_aliases, "property"):
|
||||
continue
|
||||
decorator_name = _get_expr_name(decorator)
|
||||
if decorator_name is not None and decorator_name.endswith(".setter"):
|
||||
continue
|
||||
if decorator_name is not None and decorator_name.endswith(".deleter"):
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _has_descriptor_like_class_fields(class_node: ast.ClassDef) -> bool:
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
|
||||
return True
|
||||
if isinstance(item, ast.AnnAssign) and isinstance(item.value, ast.Call):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _should_use_raw_project_class_context(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> bool:
|
||||
start_line = _get_class_start_line(class_node)
|
||||
class_line_count = class_node.end_lineno - start_line + 1
|
||||
is_small = class_line_count <= MAX_RAW_PROJECT_CLASS_LINES and len(class_node.body) <= MAX_RAW_PROJECT_CLASS_BODY_ITEMS
|
||||
|
||||
if is_small and _class_has_explicit_init(class_node):
|
||||
return True
|
||||
if _is_namedtuple_class(class_node, import_aliases):
|
||||
return True
|
||||
is_dataclass, _, _ = _get_dataclass_config(class_node, import_aliases)
|
||||
if is_dataclass:
|
||||
return True
|
||||
if class_node.decorator_list:
|
||||
return True
|
||||
if _has_descriptor_like_class_fields(class_node):
|
||||
return True
|
||||
|
||||
for item in class_node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)) and _has_non_property_method_decorator(
|
||||
item, import_aliases
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def extract_init_stub_from_class(class_name: str, module_source: str, module_tree: ast.Module) -> str | None:
|
||||
class_node = _find_class_node_by_name(class_name, module_tree)
|
||||
|
||||
if class_node is None:
|
||||
return None
|
||||
|
||||
lines = module_source.splitlines()
|
||||
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
|
||||
import_aliases = _collect_import_aliases(module_tree)
|
||||
explicit_init_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
|
||||
support_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
|
||||
for item in class_node.body:
|
||||
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
is_relevant = False
|
||||
if item.name in ("__init__", "__post_init__"):
|
||||
is_relevant = True
|
||||
else:
|
||||
# Check decorators explicitly to avoid generator overhead
|
||||
for d in item.decorator_list:
|
||||
if (isinstance(d, ast.Name) and d.id == "property") or (
|
||||
isinstance(d, ast.Attribute) and d.attr == "property"
|
||||
):
|
||||
is_relevant = True
|
||||
break
|
||||
if is_relevant:
|
||||
relevant_nodes.append(item)
|
||||
|
||||
if not relevant_nodes:
|
||||
return None
|
||||
if item.name == "__init__":
|
||||
explicit_init_nodes.append(item)
|
||||
support_nodes.append(item)
|
||||
continue
|
||||
if item.name == "__post_init__":
|
||||
support_nodes.append(item)
|
||||
continue
|
||||
# Check decorators explicitly to avoid generator overhead
|
||||
for d in item.decorator_list:
|
||||
if (isinstance(d, ast.Name) and d.id == "property") or (
|
||||
isinstance(d, ast.Attribute) and d.attr == "property"
|
||||
):
|
||||
support_nodes.append(item)
|
||||
break
|
||||
|
||||
snippets: list[str] = []
|
||||
for fn_node in relevant_nodes:
|
||||
start = fn_node.lineno
|
||||
if fn_node.decorator_list:
|
||||
# Compute minimum decorator lineno with an explicit loop (avoids generator/min overhead)
|
||||
m = start
|
||||
for d in fn_node.decorator_list:
|
||||
m = min(m, d.lineno)
|
||||
start = m
|
||||
snippets.append("\n".join(lines[start - 1 : fn_node.end_lineno]))
|
||||
if explicit_init_nodes:
|
||||
for fn_node in support_nodes:
|
||||
snippets.append(_extract_function_stub_snippet(fn_node, lines))
|
||||
else:
|
||||
synthetic_init = _build_synthetic_init_stub(class_node, module_source, import_aliases)
|
||||
if synthetic_init is not None:
|
||||
snippets.append(synthetic_init)
|
||||
for fn_node in support_nodes:
|
||||
snippets.append(_extract_function_stub_snippet(fn_node, lines))
|
||||
|
||||
if not snippets:
|
||||
return None
|
||||
|
||||
return f"class {class_name}:\n" + "\n".join(snippets)
|
||||
|
||||
|
||||
def _get_module_source_and_tree(
|
||||
module_path: Path, module_cache: dict[Path, tuple[str, ast.Module]]
|
||||
) -> tuple[str, ast.Module] | None:
|
||||
if module_path in module_cache:
|
||||
return module_cache[module_path]
|
||||
try:
|
||||
module_source = module_path.read_text(encoding="utf-8")
|
||||
module_tree = ast.parse(module_source)
|
||||
except Exception:
|
||||
return None
|
||||
module_cache[module_path] = (module_source, module_tree)
|
||||
return module_source, module_tree
|
||||
|
||||
|
||||
def _resolve_imported_class_reference(
|
||||
base_expr_name: str,
|
||||
current_module_tree: ast.Module,
|
||||
current_module_path: Path,
|
||||
project_root_path: Path,
|
||||
module_cache: dict[Path, tuple[str, ast.Module]],
|
||||
) -> tuple[str, Path] | None:
|
||||
import jedi
|
||||
|
||||
import_aliases = _collect_import_aliases(current_module_tree)
|
||||
class_name = base_expr_name.rsplit(".", 1)[-1]
|
||||
if "." not in base_expr_name and _find_class_node_by_name(class_name, current_module_tree) is not None:
|
||||
return class_name, current_module_path
|
||||
|
||||
resolved_name = base_expr_name
|
||||
if base_expr_name in import_aliases:
|
||||
resolved_name = import_aliases[base_expr_name]
|
||||
elif "." in base_expr_name:
|
||||
head, tail = base_expr_name.split(".", 1)
|
||||
if head in import_aliases:
|
||||
resolved_name = f"{import_aliases[head]}.{tail}"
|
||||
|
||||
if "." not in resolved_name:
|
||||
return None
|
||||
|
||||
module_name, class_name = resolved_name.rsplit(".", 1)
|
||||
try:
|
||||
script_code = f"from {module_name} import {class_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(class_name), follow_imports=True)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if not definitions or definitions[0].module_path is None:
|
||||
return None
|
||||
module_path = definitions[0].module_path
|
||||
if not _is_project_path(module_path, project_root_path):
|
||||
return None
|
||||
if _get_module_source_and_tree(module_path, module_cache) is None:
|
||||
return None
|
||||
return class_name, module_path
|
||||
|
||||
|
||||
def _append_project_class_context(
|
||||
class_name: str,
|
||||
module_path: Path,
|
||||
project_root_path: Path,
|
||||
module_cache: dict[Path, tuple[str, ast.Module]],
|
||||
existing_class_names: set[str],
|
||||
emitted_classes: set[tuple[Path, str]],
|
||||
emitted_class_names: set[str],
|
||||
code_strings: list[CodeString],
|
||||
) -> bool:
|
||||
module_result = _get_module_source_and_tree(module_path, module_cache)
|
||||
if module_result is None:
|
||||
return False
|
||||
module_source, module_tree = module_result
|
||||
class_node = _find_class_node_by_name(class_name, module_tree)
|
||||
if class_node is None:
|
||||
return False
|
||||
|
||||
class_key = (module_path, class_name)
|
||||
if class_key in emitted_classes or class_name in existing_class_names:
|
||||
return True
|
||||
|
||||
for base in class_node.bases:
|
||||
base_expr_name = _get_expr_name(base)
|
||||
if base_expr_name is None:
|
||||
continue
|
||||
resolved = _resolve_imported_class_reference(
|
||||
base_expr_name,
|
||||
module_tree,
|
||||
module_path,
|
||||
project_root_path,
|
||||
module_cache,
|
||||
)
|
||||
if resolved is None:
|
||||
continue
|
||||
base_name, base_module_path = resolved
|
||||
if base_name in existing_class_names:
|
||||
continue
|
||||
_append_project_class_context(
|
||||
base_name,
|
||||
base_module_path,
|
||||
project_root_path,
|
||||
module_cache,
|
||||
existing_class_names,
|
||||
emitted_classes,
|
||||
emitted_class_names,
|
||||
code_strings,
|
||||
)
|
||||
|
||||
code_strings.append(CodeString(code=_extract_raw_class_context(class_node, module_source, module_tree), file_path=module_path))
|
||||
emitted_classes.add(class_key)
|
||||
emitted_class_names.add(class_name)
|
||||
return True
|
||||
|
||||
|
||||
def extract_parameter_type_constructors(
|
||||
function_to_optimize: FunctionToOptimize, project_root_path: Path, existing_class_names: set[str]
|
||||
) -> CodeStringsMarkdown:
|
||||
|
|
@ -751,35 +1146,68 @@ def extract_parameter_type_constructors(
|
|||
|
||||
code_strings: list[CodeString] = []
|
||||
module_cache: dict[Path, tuple[str, ast.Module]] = {}
|
||||
emitted_classes: set[tuple[Path, str]] = set()
|
||||
emitted_class_names: set[str] = set()
|
||||
|
||||
for type_name in sorted(type_names):
|
||||
module_name = import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
def append_type_context(type_name: str, module_name: str, *, transitive: bool = False) -> None:
|
||||
try:
|
||||
script_code = f"from {module_name} import {type_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
|
||||
if not definitions:
|
||||
continue
|
||||
return
|
||||
|
||||
module_path = definitions[0].module_path
|
||||
if not module_path:
|
||||
continue
|
||||
return
|
||||
resolved_module = module_path.resolve()
|
||||
module_str = str(resolved_module)
|
||||
is_project = _is_project_path(module_path, project_root_path)
|
||||
is_third_party = "site-packages" in module_str
|
||||
if transitive and not is_project and not is_third_party:
|
||||
return
|
||||
|
||||
if module_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[module_path]
|
||||
else:
|
||||
mod_source = module_path.read_text(encoding="utf-8")
|
||||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[module_path] = (mod_source, mod_tree)
|
||||
module_result = _get_module_source_and_tree(module_path, module_cache)
|
||||
if module_result is None:
|
||||
return
|
||||
mod_source, mod_tree = module_result
|
||||
|
||||
class_key = (module_path, type_name)
|
||||
if class_key in emitted_classes or type_name in existing_class_names:
|
||||
return
|
||||
|
||||
class_node = _find_class_node_by_name(type_name, mod_tree)
|
||||
if class_node is not None and is_project:
|
||||
import_aliases = _collect_import_aliases(mod_tree)
|
||||
if _should_use_raw_project_class_context(class_node, import_aliases):
|
||||
if _append_project_class_context(
|
||||
type_name,
|
||||
module_path,
|
||||
project_root_path,
|
||||
module_cache,
|
||||
existing_class_names,
|
||||
emitted_classes,
|
||||
emitted_class_names,
|
||||
code_strings,
|
||||
):
|
||||
return
|
||||
|
||||
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
emitted_classes.add(class_key)
|
||||
emitted_class_names.add(type_name)
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
|
||||
if transitive:
|
||||
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
|
||||
else:
|
||||
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
|
||||
|
||||
for type_name in sorted(type_names):
|
||||
module_name = import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
append_type_context(type_name, module_name)
|
||||
|
||||
# Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs
|
||||
# Build an extended import map that includes imports from source modules of already-extracted stubs
|
||||
|
|
@ -792,13 +1220,14 @@ def extract_parameter_type_constructors(
|
|||
if name not in transitive_import_map:
|
||||
transitive_import_map[name] = cache_node.module
|
||||
|
||||
emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES
|
||||
emitted_names = type_names | existing_class_names | emitted_class_names | BUILTIN_AND_TYPING_NAMES
|
||||
transitive_type_names: set[str] = set()
|
||||
for cs in code_strings:
|
||||
try:
|
||||
stub_tree = ast.parse(cs.code)
|
||||
except SyntaxError:
|
||||
continue
|
||||
import_aliases = _collect_import_aliases(stub_tree)
|
||||
for stub_node in ast.walk(stub_tree):
|
||||
if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in (
|
||||
"__init__",
|
||||
|
|
@ -806,32 +1235,14 @@ def extract_parameter_type_constructors(
|
|||
):
|
||||
for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs:
|
||||
transitive_type_names |= collect_type_names_from_annotation(arg.annotation)
|
||||
elif isinstance(stub_node, ast.ClassDef):
|
||||
transitive_type_names |= _collect_synthetic_constructor_type_names(stub_node, import_aliases)
|
||||
transitive_type_names -= emitted_names
|
||||
for type_name in sorted(transitive_type_names):
|
||||
module_name = transitive_import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
try:
|
||||
script_code = f"from {module_name} import {type_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
|
||||
if not definitions:
|
||||
continue
|
||||
module_path = definitions[0].module_path
|
||||
if not module_path:
|
||||
continue
|
||||
if module_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[module_path]
|
||||
else:
|
||||
mod_source = module_path.read_text(encoding="utf-8")
|
||||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[module_path] = (mod_source, mod_tree)
|
||||
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
|
||||
continue
|
||||
append_type_context(type_name, module_name, transitive=True)
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
|
|
|||
|
|
@ -4501,6 +4501,104 @@ def process(w: Widget) -> str:
|
|||
assert "size" in code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_stdlib_type(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from argparse import Namespace
|
||||
|
||||
def process(ns: Namespace) -> str:
|
||||
return str(ns)
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class Namespace:" in code
|
||||
assert "def __init__(self, **kwargs):" in code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_namedtuple_project_type(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"""from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
class FunctionNode(NamedTuple):
|
||||
file_path: Path
|
||||
qualified_name: str
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import FunctionNode
|
||||
|
||||
def process(node: FunctionNode) -> str:
|
||||
return node.qualified_name
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "class FunctionNode(NamedTuple):" in code
|
||||
assert "file_path: Path" in code
|
||||
assert "qualified_name: str" in code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_uses_raw_project_context_for_small_class(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"""from functools import total_ordering
|
||||
|
||||
@total_ordering
|
||||
class Rank:
|
||||
def __init__(self, value: int):
|
||||
self.value = value
|
||||
|
||||
def __lt__(self, other: "Rank") -> bool:
|
||||
return self.value < other.value
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return isinstance(other, Rank) and self.value == other.value
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import Rank
|
||||
|
||||
def process(rank: Rank) -> int:
|
||||
return rank.value
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
code = result.code_strings[0].code
|
||||
assert "from functools import total_ordering" in code
|
||||
assert "@total_ordering" in code
|
||||
assert "def __lt__" in code
|
||||
assert "def __eq__" in code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_excludes_builtins(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
|
|
@ -4788,6 +4886,48 @@ def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> Non
|
|||
assert "class BaseProcessor:" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_attribute_base_prefers_imported_project_class(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "external.py").write_text(
|
||||
"""class Base:
|
||||
def __init__(self, x: int):
|
||||
self.x = x
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "models.py").write_text(
|
||||
"""import mypkg.external as ext
|
||||
|
||||
class Base:
|
||||
pass
|
||||
|
||||
class Child(ext.Base):
|
||||
def __init__(self, x: int):
|
||||
super().__init__(x)
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import Child
|
||||
|
||||
def process(c: Child) -> int:
|
||||
return c.x
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class Child(ext.Base):" in combined
|
||||
assert "self.x = x" in combined
|
||||
assert "class Base:\n pass" not in combined
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None:
|
||||
"""Isinstance with builtins (int, str, etc.) should not produce stubs."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
|
|
@ -4828,6 +4968,51 @@ def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
|
|||
assert "class Config:" in combined
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_uses_raw_project_context_for_dataclass_inheritance(tmp_path: Path) -> None:
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "base.py").write_text(
|
||||
"""from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
file_path: Path
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "models.py").write_text(
|
||||
"""from dataclasses import dataclass
|
||||
from mypkg.base import BaseConfig
|
||||
|
||||
@dataclass
|
||||
class ChildConfig(BaseConfig):
|
||||
qualified_name: str
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"""from mypkg.models import ChildConfig
|
||||
|
||||
def process(cfg: ChildConfig) -> str:
|
||||
return cfg.qualified_name
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "@dataclass" in combined
|
||||
assert "class BaseConfig" in combined
|
||||
assert "file_path: Path" in combined
|
||||
assert "class ChildConfig(BaseConfig):" in combined
|
||||
assert "qualified_name: str" in combined
|
||||
|
||||
|
||||
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
|
||||
"""Third-party classes should produce compact __init__ stubs, not full class source."""
|
||||
# Use a real third-party package (pydantic) so jedi can actually resolve it
|
||||
|
|
|
|||
Loading…
Reference in a new issue