feat: extract __init__ from external library base classes for test context

Add get_external_base_class_inits to extract __init__ methods from external
library base classes (e.g., collections.UserDict) when project classes inherit
from them. This helps the LLM understand constructor signatures for mocking.
This commit is contained in:
Kevin Turcios 2026-01-23 20:46:51 -05:00
parent 6009b83f20
commit dbc88ad105
2 changed files with 219 additions and 2 deletions

View file

@ -5,6 +5,7 @@ import hashlib
import os
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, cast
import libcst as cst
@ -29,8 +30,6 @@ from codeflash.models.models import (
from codeflash.optimization.function_context import belongs_to_function_qualified
if TYPE_CHECKING:
from pathlib import Path
from jedi.api.classes import Name
from libcst import CSTNode
@ -138,6 +137,14 @@ def get_code_optimization_context(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
# Extract __init__ methods from external library base classes
# This helps the LLM understand how to mock/test classes that inherit from external libraries
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
@ -155,6 +162,12 @@ def get_code_optimization_context(
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + imported_class_context.code_strings
)
# Re-extract external base class inits
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
if external_base_inits.code_strings:
testgen_context = CodeStringsMarkdown(
code_strings=testgen_context.code_strings + external_base_inits.code_strings
)
testgen_markdown_code = testgen_context.markdown
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
if testgen_code_token_length > testgen_token_limit:
@ -675,6 +688,107 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
return CodeStringsMarkdown(code_strings=class_code_strings)
def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
"""Extract __init__ methods from external library base classes.
Scans the code context for classes that inherit from external libraries and extracts
just their __init__ methods. This helps the LLM understand constructor signatures
for mocking or instantiation.
"""
import importlib
import inspect
import textwrap
all_code = "\n".join(cs.code for cs in code_context.code_strings)
try:
tree = ast.parse(all_code)
except SyntaxError:
return CodeStringsMarkdown(code_strings=[])
imported_names: dict[str, str] = {}
external_bases: list[tuple[str, str]] = []
for node in ast.walk(tree):
if isinstance(node, ast.ImportFrom) and node.module:
for alias in node.names:
if alias.name != "*":
imported_name = alias.asname if alias.asname else alias.name
imported_names[imported_name] = node.module
elif isinstance(node, ast.ClassDef):
for base in node.bases:
base_name = None
if isinstance(base, ast.Name):
base_name = base.id
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
base_name = base.attr
if base_name and base_name in imported_names:
module_name = imported_names[base_name]
if not _is_project_module(module_name, project_root_path):
external_bases.append((base_name, module_name))
if not external_bases:
return CodeStringsMarkdown(code_strings=[])
code_strings: list[CodeString] = []
extracted: set[tuple[str, str]] = set()
for base_name, module_name in external_bases:
if (module_name, base_name) in extracted:
continue
try:
module = importlib.import_module(module_name)
base_class = getattr(module, base_name, None)
if base_class is None:
continue
init_method = getattr(base_class, "__init__", None)
if init_method is None:
continue
try:
init_source = inspect.getsource(init_method)
init_source = textwrap.dedent(init_source)
class_file = Path(inspect.getfile(base_class))
parts = class_file.parts
if "site-packages" in parts:
idx = parts.index("site-packages")
class_file = Path(*parts[idx + 1 :])
except (OSError, TypeError):
continue
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
code_strings.append(CodeString(code=class_source, file_path=class_file))
extracted.add((module_name, base_name))
except (ImportError, ModuleNotFoundError, AttributeError):
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
continue
return CodeStringsMarkdown(code_strings=code_strings)
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
"""Check if a module is part of the project (not external/stdlib)."""
import importlib.util
try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ModuleNotFoundError, ValueError):
return False
else:
if spec is None or spec.origin is None:
return False
module_path = Path(spec.origin)
# Check if the module is in site-packages (external dependency)
# This must be checked first because .venv/site-packages is under project root
if path_belongs_to_site_packages(module_path):
return False
# Check if the module is within the project root
return str(module_path).startswith(str(project_root_path) + os.sep)
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
"""Extract import statements needed for a class definition.

View file

@ -14,6 +14,7 @@ from codeflash.context.code_context_extractor import (
collect_names_from_annotation,
extract_imports_for_class,
get_code_optimization_context,
get_external_base_class_inits,
get_imported_class_definitions,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -3717,3 +3718,105 @@ class ConfigRegistry:
assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent"
assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child"
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
code = """from collections import UserDict
class MyCustomDict(UserDict):
pass
"""
code_path = tmp_path / "mydict.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
assert len(result.code_strings) == 1
code_string = result.code_strings[0]
expected_code = """\
class UserDict:
def __init__(self, dict=None, /, **kwargs):
self.data = {}
if dict is not None:
self.update(dict)
if kwargs:
self.update(kwargs)
"""
assert code_string.code == expected_code
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
"""Returns empty when base class is from the project, not external."""
child_code = """from base import ProjectBase
class Child(ProjectBase):
pass
"""
child_path = tmp_path / "child.py"
child_path.write_text(child_code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
result = get_external_base_class_inits(context, tmp_path)
assert result.code_strings == []
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
"""Returns empty for builtin classes like list that have no inspectable source."""
code = """class MyList(list):
pass
"""
code_path = tmp_path / "mylist.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
assert result.code_strings == []
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
"""Extracts the same external base class only once even when inherited multiple times."""
code = """from collections import UserDict
class MyDict1(UserDict):
pass
class MyDict2(UserDict):
pass
"""
code_path = tmp_path / "mydicts.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
assert len(result.code_strings) == 1
expected_code = """\
class UserDict:
def __init__(self, dict=None, /, **kwargs):
self.data = {}
if dict is not None:
self.update(dict)
if kwargs:
self.update(kwargs)
"""
assert result.code_strings[0].code == expected_code
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
"""Returns empty when there are no external base classes."""
code = """class SimpleClass:
pass
"""
code_path = tmp_path / "simple.py"
code_path.write_text(code, encoding="utf-8")
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
result = get_external_base_class_inits(context, tmp_path)
assert result.code_strings == []