mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: resolve transitive type dependencies in get_external_class_inits
Add BFS-based transitive resolution so that classes referenced in __init__ type annotations of imported external classes are also extracted. This gives the LLM the constructor signatures it needs to instantiate parameter types.
This commit is contained in:
parent
8eb1c86245
commit
e837ad9d17
2 changed files with 284 additions and 30 deletions
|
|
@ -827,16 +827,117 @@ def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_roo
|
|||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
MAX_TRANSITIVE_DEPTH = 2
|
||||
|
||||
|
||||
def extract_classes_from_type_hint(hint: object) -> list[type]:
|
||||
"""Recursively extract concrete class objects from a type annotation.
|
||||
|
||||
Unwraps Optional, Union, List, Dict, Callable, Annotated, etc.
|
||||
Filters out builtins and typing module types.
|
||||
"""
|
||||
import typing
|
||||
|
||||
classes: list[type] = []
|
||||
origin = getattr(hint, "__origin__", None)
|
||||
args = getattr(hint, "__args__", None)
|
||||
|
||||
if origin is not None and args:
|
||||
for arg in args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
elif isinstance(hint, type):
|
||||
module = getattr(hint, "__module__", "")
|
||||
if module not in ("builtins", "typing", "typing_extensions", "types"):
|
||||
classes.append(hint)
|
||||
# Handle typing.Annotated on older Pythons where __origin__ may not be set
|
||||
if hasattr(typing, "get_args") and origin is None and args is None:
|
||||
try:
|
||||
inner_args = typing.get_args(hint)
|
||||
if inner_args:
|
||||
for arg in inner_args:
|
||||
classes.extend(extract_classes_from_type_hint(arg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return classes
|
||||
|
||||
|
||||
def resolve_transitive_type_deps(cls: type) -> list[type]:
|
||||
"""Find external classes referenced in cls.__init__ type annotations.
|
||||
|
||||
Returns classes from site-packages that have a custom __init__.
|
||||
"""
|
||||
import inspect
|
||||
import typing
|
||||
|
||||
try:
|
||||
init_method = getattr(cls, "__init__")
|
||||
hints = typing.get_type_hints(init_method)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
deps: list[type] = []
|
||||
for param_name, hint in hints.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
for dep_cls in extract_classes_from_type_hint(hint):
|
||||
if dep_cls is cls:
|
||||
continue
|
||||
init_method = getattr(dep_cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
continue
|
||||
try:
|
||||
class_file = Path(inspect.getfile(dep_cls))
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
continue
|
||||
deps.append(dep_cls)
|
||||
|
||||
return deps
|
||||
|
||||
|
||||
def extract_init_stub_for_class(cls: type, class_name: str) -> CodeString | None:
|
||||
"""Extract a stub containing the class definition with only its __init__ method."""
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
init_method = getattr(cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
return None
|
||||
|
||||
try:
|
||||
class_file = Path(inspect.getfile(cls))
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
except (OSError, TypeError):
|
||||
return None
|
||||
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
|
||||
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
|
||||
return CodeString(code=class_source, file_path=class_file)
|
||||
|
||||
|
||||
def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
"""Extract __init__ methods from directly imported external library classes.
|
||||
|
||||
Scans the code context for classes imported from external packages (site-packages) and extracts
|
||||
their __init__ methods. This helps the LLM understand constructor signatures for instantiation
|
||||
in generated tests.
|
||||
their __init__ methods, including transitive type dependencies found in __init__ annotations.
|
||||
This helps the LLM understand constructor signatures for instantiation in generated tests.
|
||||
"""
|
||||
import importlib
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
all_code = "\n".join(cs.code for cs in code_context.code_strings)
|
||||
|
||||
|
|
@ -883,7 +984,13 @@ def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_pat
|
|||
|
||||
code_strings: list[CodeString] = []
|
||||
imported_module_cache: dict[str, object] = {}
|
||||
processed_classes: set[type] = set()
|
||||
emitted_names: set[str] = set()
|
||||
|
||||
# BFS worklist: (class_object, class_name, depth)
|
||||
worklist: list[tuple[type, str, int]] = []
|
||||
|
||||
# Seed the worklist with directly imported classes
|
||||
for class_name, module_name in external_imports:
|
||||
try:
|
||||
module = imported_module_cache.get(module_name)
|
||||
|
|
@ -895,36 +1002,32 @@ def get_external_class_inits(code_context: CodeStringsMarkdown, project_root_pat
|
|||
if cls is None or not inspect.isclass(cls):
|
||||
continue
|
||||
|
||||
init_method = getattr(cls, "__init__", None)
|
||||
if init_method is None or init_method is object.__init__:
|
||||
continue
|
||||
|
||||
try:
|
||||
class_file = Path(inspect.getfile(cls))
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
if not path_belongs_to_site_packages(class_file):
|
||||
continue
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
|
||||
class_source = f"class {class_name}:\n" + textwrap.indent(init_source, " ")
|
||||
code_strings.append(CodeString(code=class_source, file_path=class_file))
|
||||
|
||||
worklist.append((cls, class_name, 0))
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to extract __init__ for {module_name}.{class_name}")
|
||||
logger.debug(f"Failed to import {module_name}.{class_name}")
|
||||
continue
|
||||
|
||||
while worklist:
|
||||
cls, class_name, depth = worklist.pop(0)
|
||||
|
||||
if cls in processed_classes:
|
||||
continue
|
||||
processed_classes.add(cls)
|
||||
|
||||
stub = extract_init_stub_for_class(cls, class_name)
|
||||
if stub is None:
|
||||
continue
|
||||
|
||||
if class_name not in emitted_names:
|
||||
code_strings.append(stub)
|
||||
emitted_names.add(class_name)
|
||||
|
||||
# Resolve transitive type dependencies up to MAX_TRANSITIVE_DEPTH
|
||||
if depth < MAX_TRANSITIVE_DEPTH:
|
||||
for dep_cls in resolve_transitive_type_deps(cls):
|
||||
if dep_cls not in processed_classes:
|
||||
worklist.append((dep_cls, dep_cls.__name__, depth + 1))
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ from codeflash.code_utils.code_extractor import GlobalAssignmentCollector, add_g
|
|||
from codeflash.code_utils.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.context.code_context_extractor import (
|
||||
collect_names_from_annotation,
|
||||
extract_classes_from_type_hint,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_external_class_inits,
|
||||
get_imported_class_definitions,
|
||||
resolve_transitive_type_deps,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
|
|
@ -4752,3 +4754,152 @@ def test_get_external_class_inits_empty_when_no_imports(tmp_path: Path) -> None:
|
|||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
# --- Tests for extract_classes_from_type_hint ---
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_plain_class() -> None:
|
||||
"""Extracts a plain class directly."""
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(Option)
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_optional() -> None:
|
||||
"""Unwraps Optional[X] to find X."""
|
||||
from typing import Optional
|
||||
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(Optional[Option])
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_union() -> None:
|
||||
"""Unwraps Union[X, Y] to find both X and Y."""
|
||||
from typing import Union
|
||||
|
||||
from click import Command, Option
|
||||
|
||||
result = extract_classes_from_type_hint(Union[Option, Command])
|
||||
assert Option in result
|
||||
assert Command in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_list() -> None:
|
||||
"""Unwraps List[X] to find X."""
|
||||
from typing import List
|
||||
|
||||
from click import Option
|
||||
|
||||
result = extract_classes_from_type_hint(List[Option])
|
||||
assert Option in result
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_filters_builtins() -> None:
|
||||
"""Filters out builtins like str, int, None."""
|
||||
from typing import Optional
|
||||
|
||||
result = extract_classes_from_type_hint(Optional[str])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_callable() -> None:
|
||||
"""Handles bare Callable without error."""
|
||||
from typing import Callable
|
||||
|
||||
result = extract_classes_from_type_hint(Callable)
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
def test_extract_classes_from_type_hint_callable_with_args() -> None:
|
||||
"""Unwraps Callable[[X], Y] to find classes."""
|
||||
from typing import Callable
|
||||
|
||||
from click import Context
|
||||
|
||||
result = extract_classes_from_type_hint(Callable[[Context], None])
|
||||
assert Context in result
|
||||
|
||||
|
||||
# --- Tests for resolve_transitive_type_deps ---
|
||||
|
||||
|
||||
def test_resolve_transitive_type_deps_click_context() -> None:
|
||||
"""click.Context.__init__ references Command, which should be found."""
|
||||
from click import Command, Context
|
||||
|
||||
deps = resolve_transitive_type_deps(Context)
|
||||
dep_names = {cls.__name__ for cls in deps}
|
||||
assert "Command" in dep_names or Command in deps
|
||||
|
||||
|
||||
def test_resolve_transitive_type_deps_handles_failure_gracefully() -> None:
|
||||
"""Returns empty list for a class where get_type_hints fails."""
|
||||
|
||||
class BadClass:
|
||||
def __init__(self, x: "NonexistentType") -> None: # type: ignore[name-defined] # noqa: F821
|
||||
pass
|
||||
|
||||
result = resolve_transitive_type_deps(BadClass)
|
||||
assert result == []
|
||||
|
||||
|
||||
# --- Integration tests for transitive resolution in get_external_class_inits ---
|
||||
|
||||
|
||||
def test_get_external_class_inits_transitive_deps(tmp_path: Path) -> None:
|
||||
"""Extracts transitive type dependencies from __init__ annotations."""
|
||||
code = """from click import Context
|
||||
|
||||
def my_func(ctx: Context) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
class_names = {cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings}
|
||||
assert "Context" in class_names
|
||||
# Command is a transitive dep via Context.__init__
|
||||
assert "Command" in class_names
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_infinite_loops(tmp_path: Path) -> None:
|
||||
"""Handles classes with circular type references without infinite loops."""
|
||||
# click.Context references Command, and Command references Context back
|
||||
# This should terminate without issues due to the processed_classes set
|
||||
code = """from click import Context
|
||||
|
||||
def my_func(ctx: Context) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
# Should complete without hanging; just verify we got results
|
||||
assert len(result.code_strings) >= 1
|
||||
|
||||
|
||||
def test_get_external_class_inits_no_duplicate_stubs(tmp_path: Path) -> None:
|
||||
"""Does not emit duplicate stubs for the same class name."""
|
||||
code = """from click import Context
|
||||
|
||||
def my_func(ctx: Context) -> None:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "myfunc.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_class_inits(context, tmp_path)
|
||||
|
||||
class_names = [cs.code.split("\n")[0].replace("class ", "").rstrip(":") for cs in result.code_strings]
|
||||
assert len(class_names) == len(set(class_names)), f"Duplicate class stubs found: {class_names}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue