refactor: remove 13 unused functions from code_context_extractor

Remove safe_relative_to, resolve_classes_from_modules,
extract_classes_from_type_hint, resolve_transitive_type_deps,
extract_init_stub, _is_project_module_cached, is_project_path,
_is_project_module, extract_imports_for_class,
collect_names_from_annotation, is_dunder_method, _qualified_name,
and _validate_classdef. Inline trivial helpers into prune_cst and
clean up enrich_testgen_context and get_function_sources_from_jedi.
Remove corresponding tests.
This commit is contained in:
Kevin Turcios 2026-02-18 05:03:54 -05:00
parent 8cb7209851
commit a2238168a3
2 changed files with 24 additions and 645 deletions

View file

@ -45,13 +45,6 @@ READ_WRITABLE_LIMIT_ERROR = "Read-writable code has exceeded token limit, cannot
TESTGEN_LIMIT_ERROR = "Testgen code context has exceeded token limit, cannot proceed"
def safe_relative_to(path: Path, root: Path) -> Path:
try:
return path.resolve().relative_to(root.resolve())
except ValueError:
return path
def build_testgen_context(
helpers_of_fto_dict: dict[Path, set[FunctionSource]],
helpers_of_helpers_dict: dict[Path, set[FunctionSource]],
@ -234,7 +227,10 @@ def get_code_optimization_context_for_language(
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
# Get relative path for target file
target_relative_path = safe_relative_to(function_to_optimize.file_path, project_root_path)
try:
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
target_relative_path = function_to_optimize.file_path
# Group helpers by file path
helpers_by_file: dict[Path, list[HelperFunction]] = defaultdict(list)
@ -282,7 +278,10 @@ def get_code_optimization_context_for_language(
if file_path == function_to_optimize.file_path:
continue # Already included in target file
helper_relative_path = safe_relative_to(file_path, project_root_path)
try:
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
helper_relative_path = file_path
# Combine all helpers from this file
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
@ -366,7 +365,11 @@ def process_file_context(
project_root=project_root_path,
helper_functions=helper_functions,
)
return CodeString(code=code_context, file_path=safe_relative_to(file_path, project_root_path))
try:
relative_path = file_path.resolve().relative_to(project_root_path.resolve())
except ValueError:
relative_path = file_path
return CodeString(code=code_context, file_path=relative_path)
return None
@ -514,7 +517,9 @@ def get_function_sources_from_jedi(
# The definition is part of this project and not defined within the original function
is_valid_definition = (
is_project_path(definition_path, project_root_path)
definition_path is not None
and not path_belongs_to_site_packages(definition_path)
and str(definition_path).startswith(str(project_root_path) + os.sep)
and definition.full_name
and not belongs_to_function_qualified(definition, qualified_function_name)
and definition.full_name.startswith(definition.module_name)
@ -635,28 +640,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
existing_classes = collect_existing_class_names(tree)
# Collect base class names from ClassDef nodes (single walk)
base_class_names: set[str] = set()
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
for base in node.bases:
if isinstance(base, ast.Name):
base_class_names.add(base.id)
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
base_class_names.add(base.attr)
# Classify external imports using importlib-based check
is_project_cache: dict[str, bool] = {}
external_base_classes: set[tuple[str, str]] = set()
external_direct_imports: set[tuple[str, str]] = set()
for name, module_name in imported_names.items():
if not _is_project_module_cached(module_name, project_root_path, is_project_cache):
if name in base_class_names:
external_base_classes.add((name, module_name))
if name not in existing_classes:
external_direct_imports.add((name, module_name))
code_strings: list[CodeString] = []
emitted_class_names: set[str] = set()
@ -710,8 +693,7 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
start_line = min(d.lineno for d in class_node.decorator_list)
class_source = "\n".join(lines[start_line - 1 : class_node.end_lineno])
class_imports = extract_imports_for_class(module_tree, class_node, module_source)
full_source = class_imports + "\n\n" + class_source if class_imports else class_source
full_source = class_source
code_strings.append(CodeString(code=full_source, file_path=module_path))
extracted_classes.add((module_path, class_name))
@ -732,9 +714,6 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
if not module_path:
continue
if not is_project_path(module_path, project_root_path):
continue
mod_result = get_module_source_and_tree(module_path)
if mod_result is None:
continue
@ -746,290 +725,9 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
logger.debug(f"Error extracting class definition for {name} from {module_name}")
continue
# --- Step 2: External base class __init__ stubs ---
if external_base_classes:
for cls, name in resolve_classes_from_modules(external_base_classes):
if name in emitted_class_names:
continue
stub = extract_init_stub(cls, name, require_site_packages=False)
if stub is not None:
code_strings.append(stub)
emitted_class_names.add(name)
# --- Step 3: External direct import __init__ stubs with BFS ---
if external_direct_imports:
processed_classes: set[type] = set()
worklist: list[tuple[type, str, int]] = [
(cls, name, 0) for cls, name in resolve_classes_from_modules(external_direct_imports)
]
while worklist:
cls, class_name, depth = worklist.pop(0)
if cls in processed_classes:
continue
processed_classes.add(cls)
stub = extract_init_stub(cls, class_name)
if stub is None:
continue
if class_name not in emitted_class_names:
code_strings.append(stub)
emitted_class_names.add(class_name)
if depth < MAX_TRANSITIVE_DEPTH:
for dep_cls in resolve_transitive_type_deps(cls):
if dep_cls not in processed_classes:
worklist.append((dep_cls, dep_cls.__name__, depth + 1))
return CodeStringsMarkdown(code_strings=code_strings)
def resolve_classes_from_modules(candidates: set[tuple[str, str]]) -> list[tuple[type, str]]:
"""Import modules and resolve candidate (class_name, module_name) pairs to class objects."""
import importlib
import inspect
resolved: list[tuple[type, str]] = []
module_cache: dict[str, object] = {}
for class_name, module_name in candidates:
try:
module = module_cache.get(module_name)
if module is None:
module = importlib.import_module(module_name)
module_cache[module_name] = module
cls = getattr(module, class_name, None)
if cls is not None and inspect.isclass(cls):
resolved.append((cls, class_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to import {module_name}.{class_name}")
return resolved
MAX_TRANSITIVE_DEPTH = 5
def extract_classes_from_type_hint(hint: object) -> list[type]:
"""Recursively extract concrete class objects from a type annotation.
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
Filters out builtins and typing module types.
"""
import typing
classes: list[type] = []
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", None)
if origin is not None and args:
for arg in args:
classes.extend(extract_classes_from_type_hint(arg))
elif isinstance(hint, type):
module = getattr(hint, "__module__", "")
if module not in ("builtins", "typing", "typing_extensions", "types"):
classes.append(hint)
# Handle typing.Annotated on older Pythons where __origin__ may not be set
if hasattr(typing, "get_args") and origin is None and args is None:
try:
inner_args = typing.get_args(hint)
if inner_args:
for arg in inner_args:
classes.extend(extract_classes_from_type_hint(arg))
except Exception:
pass
return classes
def resolve_transitive_type_deps(cls: type) -> list[type]:
"""Find external classes referenced in cls.__init__ type annotations.
Returns classes from site-packages that have a custom __init__.
"""
import inspect
import typing
try:
init_method = getattr(cls, "__init__")
hints = typing.get_type_hints(init_method)
except Exception:
return []
deps: list[type] = []
for param_name, hint in hints.items():
if param_name == "return":
continue
for dep_cls in extract_classes_from_type_hint(hint):
if dep_cls is cls:
continue
init_method = getattr(dep_cls, "__init__", None)
if init_method is None or init_method is object.__init__:
continue
try:
class_file = Path(inspect.getfile(dep_cls))
except (OSError, TypeError):
continue
if not path_belongs_to_site_packages(class_file):
continue
deps.append(dep_cls)
return deps
def extract_init_stub(cls: type, class_name: str, require_site_packages: bool = True) -> CodeString | None:
"""Extract a stub containing the class definition with only its __init__ method.
Args:
cls: The class object to extract __init__ from
class_name: Name to use for the class in the stub
require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
"""
import inspect
import textwrap
init_method = getattr(cls, "__init__", None)
if init_method is None or init_method is object.__init__:
return None
try:
class_file = Path(inspect.getfile(cls))
except (OSError, TypeError):
return None
if require_site_packages and not path_belongs_to_site_packages(class_file):
return None
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
except (OSError, TypeError):
return None
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
return CodeString(code=class_source, file_path=class_file)
def _is_project_module_cached(module_name: str, project_root_path: Path, cache: dict[str, bool]) -> bool:
cached = cache.get(module_name)
if cached is not None:
return cached
is_project = _is_project_module(module_name, project_root_path)
cache[module_name] = is_project
return is_project
def is_project_path(module_path: Path | None, project_root_path: Path) -> bool:
if module_path is None:
return False
# site-packages must be checked first because .venv/site-packages is under project root
if path_belongs_to_site_packages(module_path):
return False
return str(module_path).startswith(str(project_root_path) + os.sep)
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
"""Check if a module is part of the project (not external/stdlib)."""
import importlib.util
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ModuleNotFoundError, ValueError):
return False
else:
if spec is None or spec.origin is None:
return False
return is_project_path(Path(spec.origin), project_root_path)
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.
This extracts imports for base classes, decorators, and type annotations.
"""
needed_names: set[str] = set()
# Get base class names
for base in class_node.bases:
if isinstance(base, ast.Name):
needed_names.add(base.id)
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
# For things like abc.ABC, we need the module name
needed_names.add(base.value.id)
# Get decorator names (e.g., dataclass, field)
for decorator in class_node.decorator_list:
if isinstance(decorator, ast.Name):
needed_names.add(decorator.id)
elif isinstance(decorator, ast.Call):
if isinstance(decorator.func, ast.Name):
needed_names.add(decorator.func.id)
elif isinstance(decorator.func, ast.Attribute) and isinstance(decorator.func.value, ast.Name):
needed_names.add(decorator.func.value.id)
# Get type annotation names from class body (for dataclass fields)
for item in class_node.body:
if isinstance(item, ast.AnnAssign) and item.annotation:
collect_names_from_annotation(item.annotation, needed_names)
# Also check for field() calls which are common in dataclasses
elif isinstance(item, ast.Assign) and isinstance(item.value, ast.Call):
if isinstance(item.value.func, ast.Name):
needed_names.add(item.value.func.id)
# Find imports that provide these names
import_lines: list[str] = []
source_lines = module_source.split("\n")
added_imports: set[int] = set() # Track line numbers to avoid duplicates
for node in module_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
name = alias.asname if alias.asname else alias.name.split(".")[0]
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
name = alias.asname if alias.asname else alias.name
if name in needed_names and node.lineno not in added_imports:
import_lines.append(source_lines[node.lineno - 1])
added_imports.add(node.lineno)
break
return "\n".join(import_lines)
def collect_names_from_annotation(node: ast.expr, names: set[str]) -> None:
"""Recursively collect type annotation names from an AST node."""
if isinstance(node, ast.Name):
names.add(node.id)
elif isinstance(node, ast.Subscript):
collect_names_from_annotation(node.value, names)
collect_names_from_annotation(node.slice, names)
elif isinstance(node, ast.Tuple):
for elt in node.elts:
collect_names_from_annotation(elt, names)
elif isinstance(node, ast.BinOp): # For Union types with | syntax
collect_names_from_annotation(node.left, names)
collect_names_from_annotation(node.right, names)
elif isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name):
names.add(node.value.id)
def is_dunder_method(name: str) -> bool:
return len(name) > 4 and name.isascii() and name.startswith("__") and name.endswith("__")
def remove_docstring_from_body(indented_block: cst.IndentedBlock) -> cst.CSTNode:
"""Removes the docstring from an indented block if it exists."""
if not isinstance(indented_block.body[0], cst.SimpleStatementLine):
@ -1090,18 +788,6 @@ def parse_code_and_prune_cst(
return ""
def _qualified_name(prefix: str, name: str) -> str:
return f"{prefix}.{name}" if prefix else name
def _validate_classdef(node: cst.ClassDef, prefix: str) -> tuple[str, cst.IndentedBlock] | None:
if prefix:
return None
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
return _qualified_name(prefix, node.name.value), node.body
def prune_cst(
node: cst.CSTNode,
target_functions: set[str],
@ -1141,7 +827,7 @@ def prune_cst(
return None, False
if isinstance(node, cst.FunctionDef):
qualified_name = _qualified_name(prefix, node.name.value)
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
# Check if it's a helper function (higher priority than target)
if helpers and qualified_name in helpers:
@ -1166,7 +852,7 @@ def prune_cst(
return node, False
# Handle dunder methods for READ_ONLY/TESTGEN modes
if include_dunder_methods and is_dunder_method(node.name.value):
if include_dunder_methods and len(node.name.value) > 4 and node.name.value.startswith("__") and node.name.value.endswith("__"):
if not include_init_dunder and node.name.value == "__init__":
return None, False
if remove_docstrings and isinstance(node.body, cst.IndentedBlock):
@ -1176,17 +862,18 @@ def prune_cst(
return None, False
if isinstance(node, cst.ClassDef):
result = _validate_classdef(node, prefix)
if result is None:
if prefix:
return None, False
class_prefix, _ = result
if not isinstance(node.body, cst.IndentedBlock):
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
class_prefix = node.name.value
class_name = node.name.value
# Handle dependency classes for READ_WRITABLE mode
if defs_with_usages:
# Check if this class contains any target functions
has_target_functions = any(
isinstance(stmt, cst.FunctionDef) and _qualified_name(class_prefix, stmt.name.value) in target_functions
isinstance(stmt, cst.FunctionDef) and f"{class_prefix}.{stmt.name.value}" in target_functions
for stmt in node.body.body
)

View file

@ -12,12 +12,8 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.python.context.code_context_extractor import (
collect_names_from_annotation,
enrich_testgen_context,
extract_classes_from_type_hint,
extract_imports_for_class,
get_code_optimization_context,
resolve_transitive_type_deps,
)
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
from codeflash.optimization.optimizer import Optimizer
@ -3368,167 +3364,6 @@ def create_config() -> Config:
assert "from dataclasses import" in extracted_code, "Should include dataclasses import for decorator"
class TestCollectNamesFromAnnotation:
"""Tests for the collect_names_from_annotation helper function."""
def test_simple_name(self):
"""Test extracting a simple type name."""
import ast
code = "def f(x: MyClass): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "MyClass" in names
def test_subscript_type(self):
"""Test extracting names from generic types like List[int]."""
import ast
code = "def f(x: List[int]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "List" in names
assert "int" in names
def test_optional_type(self):
"""Test extracting names from Optional[MyClass]."""
import ast
code = "def f(x: Optional[MyClass]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "Optional" in names
assert "MyClass" in names
def test_union_type_with_pipe(self):
"""Test extracting names from union types with | syntax."""
import ast
code = "def f(x: int | str | None): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
# int | str | None becomes BinOp nodes
assert "int" in names
assert "str" in names
def test_nested_generic_types(self):
"""Test extracting names from nested generics like Dict[str, List[MyClass]]."""
import ast
code = "def f(x: Dict[str, List[MyClass]]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "Dict" in names
assert "str" in names
assert "List" in names
assert "MyClass" in names
def test_tuple_annotation(self):
"""Test extracting names from tuple type hints."""
import ast
code = "def f(x: tuple[int, str, MyClass]): pass"
annotation = ast.parse(code).body[0].args.args[0].annotation
names: set[str] = set()
collect_names_from_annotation(annotation, names)
assert "tuple" in names
assert "int" in names
assert "str" in names
assert "MyClass" in names
class TestExtractImportsForClass:
"""Tests for the extract_imports_for_class helper function."""
def test_extracts_base_class_imports(self):
"""Test that base class imports are extracted."""
import ast
module_source = """from abc import ABC
from mypackage import BaseClass
class MyClass(BaseClass, ABC):
pass
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from abc import ABC" in result
assert "from mypackage import BaseClass" in result
def test_extracts_decorator_imports(self):
"""Test that decorator imports are extracted."""
import ast
module_source = """from dataclasses import dataclass
from functools import lru_cache
@dataclass
class MyClass:
name: str
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from dataclasses import dataclass" in result
def test_extracts_type_annotation_imports(self):
"""Test that type annotation imports are extracted."""
import ast
module_source = """from typing import Optional, List
from mypackage.models import Config
@dataclass
class MyClass:
config: Optional[Config]
items: List[str]
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from typing import Optional, List" in result
assert "from mypackage.models import Config" in result
def test_extracts_field_function_imports(self):
"""Test that field() function imports are extracted for dataclasses."""
import ast
module_source = """from dataclasses import dataclass, field
from typing import List
@dataclass
class MyClass:
items: List[str] = field(default_factory=list)
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
assert "from dataclasses import dataclass, field" in result
def test_no_duplicate_imports(self):
"""Test that duplicate imports are not included."""
import ast
module_source = """from typing import Optional
@dataclass
class MyClass:
field1: Optional[str]
field2: Optional[int]
"""
tree = ast.parse(module_source)
class_node = next(n for n in ast.walk(tree) if isinstance(n, ast.ClassDef))
result = extract_imports_for_class(tree, class_node, module_source)
# Should only have one import line even though Optional is used twice
assert result.count("from typing import Optional") == 1
def test_enrich_testgen_context_multiple_decorators(tmp_path: Path) -> None:
"""Test that classes with multiple decorators are extracted correctly."""
package_dir = tmp_path / "mypackage"
@ -3981,58 +3816,6 @@ class MyProtocol(Protocol):
assert isinstance(result.code_strings, list)
def test_collect_names_from_annotation_attribute(tmp_path: Path) -> None:
"""Test collect_names_from_annotation handles ast.Attribute annotations.
This covers line 756 in code_context_extractor.py.
"""
# Use __import__ to avoid polluting the test file's detected imports
ast_mod = __import__("ast")
# Parse code with type annotation using attribute access
code = "x: typing.List[int] = []"
tree = ast_mod.parse(code)
names: set[str] = set()
# Find the annotation node
for node in ast_mod.walk(tree):
if isinstance(node, ast_mod.AnnAssign) and node.annotation:
collect_names_from_annotation(node.annotation, names)
break
assert "typing" in names
def test_extract_imports_for_class_decorator_call_attribute(tmp_path: Path) -> None:
"""Test extract_imports_for_class handles decorator calls with attribute access.
This covers lines 707-708 in code_context_extractor.py.
"""
ast_mod = __import__("ast")
code = """
import functools
@functools.lru_cache(maxsize=128)
class CachedClass:
pass
"""
tree = ast_mod.parse(code)
# Find the class node
class_node = None
for node in ast_mod.walk(tree):
if isinstance(node, ast_mod.ClassDef):
class_node = node
break
assert class_node is not None
result = extract_imports_for_class(tree, class_node, code)
# Should include the functools import
assert "functools" in result
def test_annotated_assignment_in_read_writable(tmp_path: Path) -> None:
"""Test that annotated assignments used by target function are in read-writable context.
@ -4293,97 +4076,6 @@ def test_enrich_testgen_context_empty_when_no_imports(tmp_path: Path) -> None:
assert result.code_strings == []
# --- Tests for extract_classes_from_type_hint ---
def test_extract_classes_from_type_hint_plain_class() -> None:
"""Extracts a plain class directly."""
from click import Option
result = extract_classes_from_type_hint(Option)
assert Option in result
def test_extract_classes_from_type_hint_optional() -> None:
"""Unwraps Optional[X] to find X."""
from typing import Optional
from click import Option
result = extract_classes_from_type_hint(Optional[Option])
assert Option in result
def test_extract_classes_from_type_hint_union() -> None:
"""Unwraps Union[X, Y] to find both X and Y."""
from typing import Union
from click import Command, Option
result = extract_classes_from_type_hint(Union[Option, Command])
assert Option in result
assert Command in result
def test_extract_classes_from_type_hint_list() -> None:
"""Unwraps List[X] to find X."""
from typing import List
from click import Option
result = extract_classes_from_type_hint(List[Option])
assert Option in result
def test_extract_classes_from_type_hint_filters_builtins() -> None:
"""Filters out builtins like str, int, None."""
from typing import Optional
result = extract_classes_from_type_hint(Optional[str])
assert len(result) == 0
def test_extract_classes_from_type_hint_callable() -> None:
"""Handles bare Callable without error."""
from typing import Callable
result = extract_classes_from_type_hint(Callable)
assert isinstance(result, list)
def test_extract_classes_from_type_hint_callable_with_args() -> None:
"""Unwraps Callable[[X], Y] to find classes."""
from typing import Callable
from click import Context
result = extract_classes_from_type_hint(Callable[[Context], None])
assert Context in result
# --- Tests for resolve_transitive_type_deps ---
def test_resolve_transitive_type_deps_click_context() -> None:
"""click.Context.__init__ references Command, which should be found."""
from click import Command, Context
deps = resolve_transitive_type_deps(Context)
dep_names = {cls.__name__ for cls in deps}
assert "Command" in dep_names or Command in deps
def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
"""Returns empty list for a class where get_type_hints fails."""
class BadClass:
def __init__(self, x: NonexistentType) -> None: # type: ignore[name-defined] # noqa: F821
pass
result = resolve_transitive_type_deps(BadClass)
assert result == []
# --- Integration tests for transitive resolution in enrich_testgen_context ---