feat: resolve transitive type dependencies in get_external_class_inits

Add BFS-based transitive resolution so that classes referenced in __init__
type annotations of imported external classes are also extracted. This gives
the LLM the constructor signatures it needs to instantiate parameter types.
This commit is contained in:
Kevin Turcios 2026-02-13 09:35:30 -05:00
parent 8eb1c86245
commit e837ad9d17
2 changed files with 284 additions and 30 deletions

View file

@ -827,16 +827,117 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
return CodeStringsMarkdown(code_strings=code_strings)
MAX_TRANSITIVE_DEPTH = 2
def extract_classes_from_type_hint(hint: object) -> list[type]:
"""Recursively extract concrete class objects from a type annotation.
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
Filters out builtins and typing module types.
"""
import typing
classes: list[type] = []
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", None)
if origin is not None and args:
for arg in args:
classes.extend(extract_classes_from_type_hint(arg))
elif isinstance(hint, type):
module = getattr(hint, "__module__", "")
if module not in ("builtins", "typing", "typing_extensions", "types"):
classes.append(hint)
# Handle typing.Annotated on older Pythons where __origin__ may not be set
if hasattr(typing, "get_args") and origin is None and args is None:
try:
inner_args = typing.get_args(hint)
if inner_args:
for arg in inner_args:
classes.extend(extract_classes_from_type_hint(arg))
except Exception:
pass
return classes
def resolve_transitive_type_deps(cls: type) -> list[type]:
"""Find external classes referenced in cls.__init__ type annotations.
Returns classes from site-packages that have a custom __init__.
"""
import inspect
import typing
try:
init_method = getattr(cls, "__init__")
hints = typing.get_type_hints(init_method)
except Exception:
return []
deps: list[type] = []
for param_name, hint in hints.items():
if param_name == "return":
continue
for dep_cls in extract_classes_from_type_hint(hint):
if dep_cls is cls:
continue
init_method = getattr(dep_cls, "__init__", None)
if init_method is None or init_method is object.__init__:
continue
try:
class_file = Path(inspect.getfile(dep_cls))
except (OSError, TypeError):
continue
if not path_belongs_to_site_packages(class_file):
continue
deps.append(dep_cls)
return deps
def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None:
"""Extract a stub containing the class definition with only its __init__ method."""
import inspect
import textwrap
init_method = getattr(cls, "__init__", None)
if init_method is None or init_method is object.__init__:
return None
try:
class_file = Path(inspect.getfile(cls))
except (OSError, TypeError):
return None
if not path_belongs_to_site_packages(class_file):
return None
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
except (OSError, TypeError):
return None
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
return CodeString(code=class_source, file_path=class_file)
def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
"""Extract __init__ methods from directly imported external library classes.
Scans the code context for classes imported from external packages (site-packages) and extracts
their __init__ methods. This helps the LLM understand constructor signatures for instantiation
in generated tests.
their __init__ methods, including transitive type dependencies found in __init__ annotations.
This helps the LLM understand constructor signatures for instantiation in generated tests.
"""
import importlib
import inspect
import textwrap
all_code = "\n".join(cs.code for cs in code_context.code_strings)
@ -883,7 +984,13 @@ def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_pat
code_strings: list[CodeString] = []
imported_module_cache: dict[str, object] = {}
processed_classes: set[type] = set()
emitted_names: set[str] = set()
# BFS worklist: (class_object, class_name, depth)
worklist: list[tuple[type, str, int]] = []
# Seed the worklist with directly imported classes
for class_name, module_name in external_imports:
try:
module = imported_module_cache.get(module_name)
@ -895,36 +1002,32 @@ def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_pat
if cls is None or not inspect.isclass(cls):
continue
init_method = getattr(cls, "__init__", None)
if init_method is None or init_method is object.__init__:
continue
try:
class_file = Path(inspect.getfile(cls))
except (OSError, TypeError):
continue
if not path_belongs_to_site_packages(class_file):
continue
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
except (OSError, TypeError):
continue
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
worklist.append((cls, class_name, 0))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{class_name}")
logger.debug(f"Failed to import {module_name}.{class_name}")
continue
while worklist:
cls, class_name, depth = worklist.pop(0)
if cls in processed_classes:
continue
processed_classes.add(cls)
stub = extract_init_stub_for_class(cls, class_name)
if stub is None:
continue
if class_name not in emitted_names:
code_strings.append(stub)
emitted_names.add(class_name)
# Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH
if depth < MAX_TRANSITIVE_DEPTH:
for dep_cls in resolve_transitive_type_deps(cls):
if dep_cls not in processed_classes:
worklist.append((dep_cls, dep_cls.__name__, depth + 1))
return CodeStringsMarkdown(code_strings=code_strings)

View file

@ -12,11 +12,13 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
from codeflash.context.code_context_extractor import (
collect_names_from_annotation,
extract_classes_from_type_hint,
extract_imports_for_class,
get_code_optimization_context,
get_external_base_class_inits,
get_external_class_inits,
get_imported_class_definitions,
resolve_transitive_type_deps,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
@ -4752,3 +4754,152 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
result = get_external_class_inits(context, tmp_path)
assert result.code_strings == []
# --- Tests for extract_classes_from_type_hint ---
def test_extract_classes_from_type_hint_plain_class() -> None:
"""Extracts a plain class directly."""
from click import Option
result = extract_classes_from_type_hint(Option)
assert Option in result
def test_extract_classes_from_type_hint_optional() -> None:
"""Unwraps Optional[X] to find X."""
from typing import Optional
from click import Option
result = extract_classes_from_type_hint(Optional[Option])
assert Option in result
def test_extract_classes_from_type_hint_union() -> None:
"""Unwraps Union[X, Y] to find both X and Y."""
from typing import Union
from click import Command, Option
result = extract_classes_from_type_hint(Union[Option, Command])
assert Option in result
assert Command in result
def test_extract_classes_from_type_hint_list() -> None:
"""Unwraps List[X] to find X."""
from typing import List
from click import Option
result = extract_classes_from_type_hint(List[Option])
assert Option in result
def test_extract_classes_from_type_hint_filters_builtins() -> None:
"""Filters out builtins like str, int, None."""
from typing import Optional
result = extract_classes_from_type_hint(Optional[str])
assert len(result) == 0
def test_extract_classes_from_type_hint_callable() -> None:
"""Handles bare Callable without error."""
from typing import Callable
result = extract_classes_from_type_hint(Callable)
assert isinstance(result, list)
def test_extract_classes_from_type_hint_callable_with_args() -> None:
"""Unwraps Callable[[X], Y] to find classes."""
from typing import Callable
from click import Context
result = extract_classes_from_type_hint(Callable[[Context], None])
assert Context in result
# --- Tests for resolve_transitive_type_deps ---
def test_resolve_transitive_type_deps_click_context() -> None:
"""click.Context.__init__ references Command, which should be found."""
from click import Command, Context
deps = resolve_transitive_type_deps(Context)
dep_names = {cls.__name__ for cls in deps}
assert "Command" in dep_names or Command in deps
def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
"""Returns empty list for a class where get_type_hints fails."""
class BadClass:
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
pass
result = resolve_transitive_type_deps(BadClass)
assert result == []
# --- Integration tests for transitive resolution in get_external_class_inits ---
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
"""Extracts transitive type dependencies from __init__ annotations."""
code = """from click import Context
def my_func(ctx: Context) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
assert "Context" in class_names
# Command is a transitive dep via Context.__init__
assert "Command" in class_names
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
"""Handles classes with circular type references without infinite loops."""
# click.Context references Command, and Command references Context back
# This should terminate without issues due to the processed_classes set
code = """from click import Context
def my_func(ctx: Context) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
# Should complete without hanging; just verify we got results
assert len(result.code_strings) >= 1
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
"""Does not emit duplicate stubs for the same class name."""
code = """from click import Context
def my_func(ctx: Context) -> None:
pass
"""
code_path = tmp_path / "myfunc.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_class_inits(context, tmp_path)
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"