refactor: remove 13 unused functions from code_context_extractor
Remove safe_relative_to, resolve_classes_from_modules, extract_classes_from_type_hint, resolve_transitive_type_deps, extract_init_stub, _is_project_module_cached, is_project_path, _is_project_module, extract_imports_for_class, collect_names_from_annotation, is_dunder_method, _qualified_name, and _validate_classdef. Inline trivial helpers into prune_cst and clean up enrich_testgen_context and get_function_sources_from_jedi. Remove corresponding tests.
This commit is contained in:
parent
8cb7209851
commit
a2238168a3
2 changed files with 24 additions and 645 deletions
|
|
@ -45,13 +45,6 @@ READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot
|
|||
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
|
||||
|
||||
|
||||
def safe_relative_to(path: Path, root: Path) -> Path:
|
||||
try:
|
||||
return path.resolve().relative_to(root.resolve())
|
||||
except ValueError:
|
||||
return path
|
||||
|
||||
|
||||
def build_testgen_context(
|
||||
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
|
||||
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
|
||||
|
|
@ -234,7 +227,10 @@ def get_code_optimization_context_for_language(
|
|||
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
|
||||
|
||||
# Get relative path for target file
|
||||
target_relative_path = safe_relative_to(function_to_optimize.file_path, project_root_path)
|
||||
try:
|
||||
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
target_relative_path = function_to_optimize.file_path
|
||||
|
||||
# Group helpers by file path
|
||||
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
|
||||
|
|
@ -282,7 +278,10 @@ def get_code_optimization_context_for_language(
|
|||
if file_path == function_to_optimize.file_path:
|
||||
continue # Already included in target file
|
||||
|
||||
helper_relative_path = safe_relative_to(file_path, project_root_path)
|
||||
try:
|
||||
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
helper_relative_path = file_path
|
||||
|
||||
# Combine all helpers from this file
|
||||
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
|
||||
|
|
@ -366,7 +365,11 @@ def process_file_context(
|
|||
project_root=project_root_path,
|
||||
helper_functions=helper_functions,
|
||||
)
|
||||
return CodeString(code=code_context, file_path=safe_relative_to(file_path, project_root_path))
|
||||
try:
|
||||
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
|
||||
except ValueError:
|
||||
relative_path = file_path
|
||||
return CodeString(code=code_context, file_path=relative_path)
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -514,7 +517,9 @@ def get_function_sources_from_jedi(
|
|||
|
||||
# The definition is part of this project and not defined within the original function
|
||||
is_valid_definition = (
|
||||
is_project_path(definition_path, project_root_path)
|
||||
definition_path is not None
|
||||
and not path_belongs_to_site_packages(definition_path)
|
||||
and str(definition_path).startswith(str(project_root_path) + os.sep)
|
||||
and definition.full_name
|
||||
and not belongs_to_function_qualified(definition, qualified_function_name)
|
||||
and definition.full_name.startswith(definition.module_name)
|
||||
|
|
@ -635,28 +640,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
|
||||
existing_classes = collect_existing_class_names(tree)
|
||||
|
||||
# Collect base class names from ClassDef nodes (single walk)
|
||||
base_class_names: set[str] = set()
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef):
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
base_class_names.add(base.id)
|
||||
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
|
||||
base_class_names.add(base.attr)
|
||||
|
||||
# Classify external imports using importlib-based check
|
||||
is_project_cache: dict[str, bool] = {}
|
||||
external_base_classes: set[tuple[str, str]] = set()
|
||||
external_direct_imports: set[tuple[str, str]] = set()
|
||||
|
||||
for name, module_name in imported_names.items():
|
||||
if not _is_project_module_cached(module_name, project_root_path, is_project_cache):
|
||||
if name in base_class_names:
|
||||
external_base_classes.add((name, module_name))
|
||||
if name not in existing_classes:
|
||||
external_direct_imports.add((name, module_name))
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
emitted_class_names: set[str] = set()
|
||||
|
||||
|
|
@ -710,8 +693,7 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
start_line = min(d.lineno for d in class_node.decorator_list)
|
||||
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
|
||||
|
||||
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
|
||||
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
|
||||
full_source = class_source
|
||||
|
||||
code_strings.append(CodeString(code=full_source, file_path=module_path))
|
||||
extracted_classes.add((module_path, class_name))
|
||||
|
|
@ -732,9 +714,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
if not module_path:
|
||||
continue
|
||||
|
||||
if not is_project_path(module_path, project_root_path):
|
||||
continue
|
||||
|
||||
mod_result = get_module_source_and_tree(module_path)
|
||||
if mod_result is None:
|
||||
continue
|
||||
|
|
@ -746,290 +725,9 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
continue
|
||||
|
||||
# --- Step 2: External base class __init__ stubs ---
|
||||
if external_base_classes:
|
||||
for cls, name in resolve_classes_from_modules(external_base_classes):
|
||||
if name in emitted_class_names:
|
||||
continue
|
||||
stub = extract_init_stub(cls, name, require_site_packages=False)
|
||||
if stub is not None:
|
||||
code_strings.append(stub)
|
||||
emitted_class_names.add(name)
|
||||
|
||||
# --- Step 3: External direct import __init__ stubs with BFS ---
|
||||
if external_direct_imports:
|
||||
processed_classes: set[type] = set()
|
||||
worklist: list[tuple[type, str, int]] = [
|
||||
(cls, name, 0) for cls, name in resolve_classes_from_modules(external_direct_imports)
|
||||
]
|
||||
|
||||
while worklist:
|
||||
cls, class_name, depth = worklist.pop(0)
|
||||
|
||||
if cls in processed_classes:
|
||||
continue
|
||||
processed_classes.add(cls)
|
||||
|
||||
stub = extract_init_stub(cls, class_name)
|
||||
if stub is None:
|
||||
continue
|
||||
|
||||
if class_name not in emitted_class_names:
|
||||
code_strings.append(stub)
|
||||
emitted_class_names.add(class_name)
|
||||
|
||||
if depth < MAX_TRANSITIVE_DEPTH:
|
||||
for dep_cls in resolve_transitive_type_deps(cls):
|
||||
if dep_cls not in processed_classes:
|
||||
worklist.append((dep_cls, dep_cls.__name__, depth + 1))
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]:
|
||||
"""Import modules and resolve candidate (class_name, module_name) pairs to class objects."""
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
resolved: list[tuple[type, str]] = []
|
||||
module_cache: dict[str, object] = {}
|
||||
|
||||
for class_name, module_name in candidates:
|
||||
try:
|
||||
module = module_cache.get(module_name)
|
||||
if module is None:
|
||||
module = importlib.import_module(module_name)
|
||||
module_cache[module_name] = module
|
||||
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is not None and inspect.isclass(cls):
|
||||
resolved.append((cls, class_name))
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to import {module_name}.{class_name}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
MAX_TRANSITIVE_DEPTH = 5
|
||||
|
||||
|
||||
def extract_classes_from_type_hint(hint: object) -> list[type]:
|
||||
"""Recursively extract concrete class objects from a type annotation.
|
||||
|
||||
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
|
||||
Filters out builtins and typing module types.
|
||||
"""
|
||||
import typing
|
||||
|
||||
classes: list[type] = []
|
||||
origin = getattr(hint, "__origin__", None)
|
||||
args = getattr(hint, "__args__", None)
|
||||
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
elif isinstance(hint, type):
|
||||
module = getattr(hint, "__module__", "")
|
||||
if module not in ("builtins", "typing", "typing_extensions", "types"):
|
||||
classes.append(hint)
|
||||
# Handle typing.Annotated on older Pythons where __origin__ may not be set
|
||||
if hasattr(typing, "get_args") and origin is None and args is None:
|
||||
try:
|
||||
inner_args = typing.get_args(hint)
|
||||
if inner_args:
|
||||
for arg in inner_args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def resolve_transitive_type_deps(cls: type) -> list[type]:
|
||||
"""Find external classes referenced in cls.__init__ type annotations.
|
||||
|
||||
Returns classes from site-packages that have a custom __init__.
|
||||
"""
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
try:
|
||||
init_method = getattr(cls, "__init__")
|
||||
hints = typing.get_type_hints(init_method)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
deps: list[type] = []
|
||||
for param_name, hint in hints.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
for dep_cls in extract_classes_from_type_hint(hint):
|
||||
if dep_cls is cls:
|
||||
continue
|
||||
init_method = getattr(dep_cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
continue
|
||||
try:
|
||||
class_file = Path(inspect.getfile(dep_cls))
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
continue
|
||||
deps.append(dep_cls)
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
|
||||
"""Extract a stub containing the class definition with only its __init__ method.
|
||||
|
||||
Args:
|
||||
cls: The class object to extract __init__ from
|
||||
class_name: Name to use for the class in the stub
|
||||
require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
|
||||
|
||||
"""
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
init_method = getattr(cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
return None
|
||||
|
||||
try:
|
||||
class_file = Path(inspect.getfile(cls))
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
if require_site_packages and not path_belongs_to_site_packages(class_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
|
||||
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
|
||||
return CodeString(code=class_source, file_path=class_file)
|
||||
|
||||
|
||||
def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool:
|
||||
cached = cache.get(module_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
is_project = _is_project_module(module_name, project_root_path)
|
||||
cache[module_name] = is_project
|
||||
return is_project
|
||||
|
||||
|
||||
def is_project_path(module_path: Path | None, project_root_path: Path) -> bool:
|
||||
if module_path is None:
|
||||
return False
|
||||
# site-packages must be checked first because .venv/site-packages is under project root
|
||||
if path_belongs_to_site_packages(module_path):
|
||||
return False
|
||||
return str(module_path).startswith(str(project_root_path) + os.sep)
|
||||
|
||||
|
||||
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
|
||||
"""Check if a module is part of the project (not external/stdlib)."""
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
else:
|
||||
if spec is None or spec.origin is None:
|
||||
return False
|
||||
return is_project_path(Path(spec.origin), project_root_path)
|
||||
|
||||
|
||||
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
This extracts imports for base classes, decorators, and type annotations.
|
||||
"""
|
||||
needed_names: set[str] = set()
|
||||
|
||||
# Get base class names
|
||||
for base in class_node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
needed_names.add(base.id)
|
||||
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
|
||||
# For things like abc.ABC, we need the module name
|
||||
needed_names.add(base.value.id)
|
||||
|
||||
# Get decorator names (e.g., dataclass, field)
|
||||
for decorator in class_node.decorator_list:
|
||||
if isinstance(decorator, ast.Name):
|
||||
needed_names.add(decorator.id)
|
||||
elif isinstance(decorator, ast.Call):
|
||||
if isinstance(decorator.func, ast.Name):
|
||||
needed_names.add(decorator.func.id)
|
||||
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
|
||||
needed_names.add(decorator.func.value.id)
|
||||
|
||||
# Get type annotation names from class body (for dataclass fields)
|
||||
for item in class_node.body:
|
||||
if isinstance(item, ast.AnnAssign) and item.annotation:
|
||||
collect_names_from_annotation(item.annotation, needed_names)
|
||||
# Also check for field() calls which are common in dataclasses
|
||||
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
|
||||
if isinstance(item.value.func, ast.Name):
|
||||
needed_names.add(item.value.func.id)
|
||||
|
||||
# Find imports that provide these names
|
||||
import_lines: list[str] = []
|
||||
source_lines = module_source.split("\n")
|
||||
added_imports: set[int] = set() # Track line numbers to avoid duplicates
|
||||
|
||||
for node in module_tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name.split(".")[0]
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name in needed_names and node.lineno not in added_imports:
|
||||
import_lines.append(source_lines[node.lineno - 1])
|
||||
added_imports.add(node.lineno)
|
||||
break
|
||||
|
||||
return "\n".join(import_lines)
|
||||
|
||||
|
||||
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
|
||||
"""Recursively collect type annotation names from an AST node."""
|
||||
if isinstance(node, ast.Name):
|
||||
names.add(node.id)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
collect_names_from_annotation(node.value, names)
|
||||
collect_names_from_annotation(node.slice, names)
|
||||
elif isinstance(node, ast.Tuple):
|
||||
for elt in node.elts:
|
||||
collect_names_from_annotation(elt, names)
|
||||
elif isinstance(node, ast.BinOp): # For Union types with | syntax
|
||||
collect_names_from_annotation(node.left, names)
|
||||
collect_names_from_annotation(node.right, names)
|
||||
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
|
||||
names.add(node.value.id)
|
||||
|
||||
|
||||
def is_dunder_method(name: str) -> bool:
|
||||
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
|
||||
|
||||
|
||||
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
|
||||
"""Removes the docstring from an indented block if it exists."""
|
||||
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
|
||||
|
|
@ -1090,18 +788,6 @@ def parse_code_and_prune_cst(
|
|||
return ""
|
||||
|
||||
|
||||
def _qualified_name(prefix: str, name: str) -> str:
|
||||
return f"{prefix}.{name}" if prefix else name
|
||||
|
||||
|
||||
def _validate_classdef(node: cst.ClassDef, prefix: str) -> tuple[str, cst.IndentedBlock] | None:
|
||||
if prefix:
|
||||
return None
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
return _qualified_name(prefix, node.name.value), node.body
|
||||
|
||||
|
||||
def prune_cst(
|
||||
node: cst.CSTNode,
|
||||
target_functions: set[str],
|
||||
|
|
@ -1141,7 +827,7 @@ def prune_cst(
|
|||
return None, False
|
||||
|
||||
if isinstance(node, cst.FunctionDef):
|
||||
qualified_name = _qualified_name(prefix, node.name.value)
|
||||
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
|
||||
|
||||
# Check if it's a helper function (higher priority than target)
|
||||
if helpers and qualified_name in helpers:
|
||||
|
|
@ -1166,7 +852,7 @@ def prune_cst(
|
|||
return node, False
|
||||
|
||||
# Handle dunder methods for READ_ONLY/TESTGEN modes
|
||||
if include_dunder_methods and is_dunder_method(node.name.value):
|
||||
if include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__"):
|
||||
if not include_init_dunder and node.name.value == "__init__":
|
||||
return None, False
|
||||
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
|
||||
|
|
@ -1176,17 +862,18 @@ def prune_cst(
|
|||
return None, False
|
||||
|
||||
if isinstance(node, cst.ClassDef):
|
||||
result = _validate_classdef(node, prefix)
|
||||
if result is None:
|
||||
if prefix:
|
||||
return None, False
|
||||
class_prefix, _ = result
|
||||
if not isinstance(node.body, cst.IndentedBlock):
|
||||
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
|
||||
class_prefix = node.name.value
|
||||
class_name = node.name.value
|
||||
|
||||
# Handle dependency classes for READ_WRITABLE mode
|
||||
if defs_with_usages:
|
||||
# Check if this class contains any target functions
|
||||
has_target_functions = any(
|
||||
isinstance(stmt, cst.FunctionDef) and _qualified_name(class_prefix, stmt.name.value) in target_functions
|
||||
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
|
||||
for stmt in node.body.body
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,8 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
|
|||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_names_from_annotation,
|
||||
enrich_testgen_context,
|
||||
extract_classes_from_type_hint,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
resolve_transitive_type_deps,
|
||||
)
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
|
@ -3368,167 +3364,6 @@ def create_config() -> Config:
|
|||
assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator"
|
||||
|
||||
|
||||
class TestCollectNamesFromAnnotation:
|
||||
"""Tests for the collect_names_from_annotation helper function."""
|
||||
|
||||
def test_simple_name(self):
|
||||
"""Test extracting a simple type name."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: MyClass): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_subscript_type(self):
|
||||
"""Test extracting names from generic types like List[int]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: List[int]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "List" in names
|
||||
assert "int" in names
|
||||
|
||||
def test_optional_type(self):
|
||||
"""Test extracting names from Optional[MyClass]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: Optional[MyClass]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "Optional" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_union_type_with_pipe(self):
|
||||
"""Test extracting names from union types with | syntax."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: int | str | None): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
# int | str | None becomes BinOp nodes
|
||||
assert "int" in names
|
||||
assert "str" in names
|
||||
|
||||
def test_nested_generic_types(self):
|
||||
"""Test extracting names from nested generics like Dict[str, List[MyClass]]."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: Dict[str, List[MyClass]]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "Dict" in names
|
||||
assert "str" in names
|
||||
assert "List" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
def test_tuple_annotation(self):
|
||||
"""Test extracting names from tuple type hints."""
|
||||
import ast
|
||||
|
||||
code = "def f(x: tuple[int, str, MyClass]): pass"
|
||||
annotation = ast.parse(code).body[0].args.args[0].annotation
|
||||
names: set[str] = set()
|
||||
collect_names_from_annotation(annotation, names)
|
||||
assert "tuple" in names
|
||||
assert "int" in names
|
||||
assert "str" in names
|
||||
assert "MyClass" in names
|
||||
|
||||
|
||||
class TestExtractImportsForClass:
|
||||
"""Tests for the extract_imports_for_class helper function."""
|
||||
|
||||
def test_extracts_base_class_imports(self):
|
||||
"""Test that base class imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = """from abc import ABC
|
||||
from mypackage import BaseClass
|
||||
|
||||
class MyClass(BaseClass, ABC):
|
||||
pass
|
||||
"""
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from abc import ABC" in result
|
||||
assert "from mypackage import BaseClass" in result
|
||||
|
||||
def test_extracts_decorator_imports(self):
|
||||
"""Test that decorator imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = """from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
name: str
|
||||
"""
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from dataclasses import dataclass" in result
|
||||
|
||||
def test_extracts_type_annotation_imports(self):
|
||||
"""Test that type annotation imports are extracted."""
|
||||
import ast
|
||||
|
||||
module_source = """from typing import Optional, List
|
||||
from mypackage.models import Config
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
config: Optional[Config]
|
||||
items: List[str]
|
||||
"""
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from typing import Optional, List" in result
|
||||
assert "from mypackage.models import Config" in result
|
||||
|
||||
def test_extracts_field_function_imports(self):
|
||||
"""Test that field() function imports are extracted for dataclasses."""
|
||||
import ast
|
||||
|
||||
module_source = """from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
items: List[str] = field(default_factory=list)
|
||||
"""
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
assert "from dataclasses import dataclass, field" in result
|
||||
|
||||
def test_no_duplicate_imports(self):
|
||||
"""Test that duplicate imports are not included."""
|
||||
import ast
|
||||
|
||||
module_source = """from typing import Optional
|
||||
|
||||
@dataclass
|
||||
class MyClass:
|
||||
field1: Optional[str]
|
||||
field2: Optional[int]
|
||||
"""
|
||||
tree = ast.parse(module_source)
|
||||
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
|
||||
result = extract_imports_for_class(tree, class_node, module_source)
|
||||
# Should only have one import line even though Optional is used twice
|
||||
assert result.count("from typing import Optional") == 1
|
||||
|
||||
|
||||
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
|
||||
"""Test that classes with multiple decorators are extracted correctly."""
|
||||
package_dir = tmp_path / "mypackage"
|
||||
|
|
@ -3981,58 +3816,6 @@ class MyProtocol(Protocol):
|
|||
assert isinstance(result.code_strings, list)
|
||||
|
||||
|
||||
def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None:
|
||||
"""Test collect_names_from_annotation handles ast.Attribute annotations.
|
||||
|
||||
This covers line 756 in code_context_extractor.py.
|
||||
"""
|
||||
# Use __import__ to avoid polluting the test file's detected imports
|
||||
ast_mod = __import__("ast")
|
||||
|
||||
# Parse code with type annotation using attribute access
|
||||
code = "x: typing.List[int] = []"
|
||||
tree = ast_mod.parse(code)
|
||||
names: set[str] = set()
|
||||
|
||||
# Find the annotation node
|
||||
for node in ast_mod.walk(tree):
|
||||
if isinstance(node, ast_mod.AnnAssign) and node.annotation:
|
||||
collect_names_from_annotation(node.annotation, names)
|
||||
break
|
||||
|
||||
assert "typing" in names
|
||||
|
||||
|
||||
def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None:
|
||||
"""Test extract_imports_for_class handles decorator calls with attribute access.
|
||||
|
||||
This covers lines 707-708 in code_context_extractor.py.
|
||||
"""
|
||||
ast_mod = __import__("ast")
|
||||
|
||||
code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
class CachedClass:
|
||||
pass
|
||||
"""
|
||||
tree = ast_mod.parse(code)
|
||||
|
||||
# Find the class node
|
||||
class_node = None
|
||||
for node in ast_mod.walk(tree):
|
||||
if isinstance(node, ast_mod.ClassDef):
|
||||
class_node = node
|
||||
break
|
||||
|
||||
assert class_node is not None
|
||||
result = extract_imports_for_class(tree, class_node, code)
|
||||
|
||||
# Should include the functools import
|
||||
assert "functools" in result
|
||||
|
||||
|
||||
def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None:
|
||||
"""Test that annotated assignments used by target function are in read-writable context.
|
||||
|
||||
|
|
@ -4293,97 +4076,6 @@ def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
|
|||
assert result.code_strings == []
|
||||
|
||||
|
||||
# --- Tests for extract_classes_from_type_hint ---
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_plain_class() -> None:
|
||||
"""Extracts a plain class directly."""
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(Option)
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_optional() -> None:
|
||||
"""Unwraps Optional[X] to find X."""
|
||||
from typing import Optional
|
||||
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(Optional[Option])
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_union() -> None:
|
||||
"""Unwraps Union[X, Y] to find both X and Y."""
|
||||
from typing import Union
|
||||
|
||||
from click import Command, Option
|
||||
|
||||
result = extract_classes_from_type_hint(Union[Option, Command])
|
||||
assert Option in result
|
||||
assert Command in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_list() -> None:
|
||||
"""Unwraps List[X] to find X."""
|
||||
from typing import List
|
||||
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(List[Option])
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_filters_builtins() -> None:
|
||||
"""Filters out builtins like str, int, None."""
|
||||
from typing import Optional
|
||||
|
||||
result = extract_classes_from_type_hint(Optional[str])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_callable() -> None:
|
||||
"""Handles bare Callable without error."""
|
||||
from typing import Callable
|
||||
|
||||
result = extract_classes_from_type_hint(Callable)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_callable_with_args() -> None:
|
||||
"""Unwraps Callable[[X], Y] to find classes."""
|
||||
from typing import Callable
|
||||
|
||||
from click import Context
|
||||
|
||||
result = extract_classes_from_type_hint(Callable[[Context], None])
|
||||
assert Context in result
|
||||
|
||||
|
||||
# --- Tests for resolve_transitive_type_deps ---
|
||||
|
||||
|
||||
def test_resolve_transitive_type_deps_click_context() -> None:
|
||||
"""click.Context.__init__ references Command, which should be found."""
|
||||
from click import Command, Context
|
||||
|
||||
deps = resolve_transitive_type_deps(Context)
|
||||
dep_names = {cls.__name__ for cls in deps}
|
||||
assert "Command" in dep_names or Command in deps
|
||||
|
||||
|
||||
def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
|
||||
"""Returns empty list for a class where get_type_hints fails."""
|
||||
|
||||
class BadClass:
|
||||
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
|
||||
pass
|
||||
|
||||
result = resolve_transitive_type_deps(BadClass)
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- Integration tests for transitive resolution in enrich_testgen_context ---
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue