mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: extract __init__ from external library base classes for test context
Add get_external_base_class_inits to extract __init__ methods from external library base classes (e.g., collections.UserDict) when project classes inherit from them. This helps the LLM understand constructor signatures for mocking.
This commit is contained in:
parent
6009b83f20
commit
dbc88ad105
2 changed files with 219 additions and 2 deletions
|
|
@ -5,6 +5,7 @@ import hashlib
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import libcst as cst
|
||||
|
|
@ -29,8 +30,6 @@ from codeflash.models.models import (
|
|||
from codeflash.optimization.function_context import belongs_to_function_qualified
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from libcst import CSTNode
|
||||
|
||||
|
|
@ -138,6 +137,14 @@ def get_code_optimization_context(
|
|||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
|
||||
# Extract __init__ methods from external library base classes
|
||||
# This helps the LLM understand how to mock/test classes that inherit from external libraries
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
|
|
@ -155,6 +162,12 @@ def get_code_optimization_context(
|
|||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + imported_class_context.code_strings
|
||||
)
|
||||
# Re-extract external base class inits
|
||||
external_base_inits = get_external_base_class_inits(testgen_context, project_root_path)
|
||||
if external_base_inits.code_strings:
|
||||
testgen_context = CodeStringsMarkdown(
|
||||
code_strings=testgen_context.code_strings + external_base_inits.code_strings
|
||||
)
|
||||
testgen_markdown_code = testgen_context.markdown
|
||||
testgen_code_token_length = encoded_tokens_len(testgen_markdown_code)
|
||||
if testgen_code_token_length > testgen_token_limit:
|
||||
|
|
@ -675,6 +688,107 @@ def get_imported_class_definitions(code_context: CodeStringsMarkdown, project_ro
|
|||
return CodeStringsMarkdown(code_strings=class_code_strings)
|
||||
|
||||
|
||||
def get_external_base_class_inits(code_context: CodeStringsMarkdown, project_root_path: Path) -> CodeStringsMarkdown:
|
||||
"""Extract __init__ methods from external library base classes.
|
||||
|
||||
Scans the code context for classes that inherit from external libraries and extracts
|
||||
just their __init__ methods. This helps the LLM understand constructor signatures
|
||||
for mocking or instantiation.
|
||||
"""
|
||||
import importlib
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
all_code = "\n".join(cs.code for cs in code_context.code_strings)
|
||||
|
||||
try:
|
||||
tree = ast.parse(all_code)
|
||||
except SyntaxError:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
imported_names: dict[str, str] = {}
|
||||
external_bases: list[tuple[str, str]] = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom) and node.module:
|
||||
for alias in node.names:
|
||||
if alias.name != "*":
|
||||
imported_name = alias.asname if alias.asname else alias.name
|
||||
imported_names[imported_name] = node.module
|
||||
elif isinstance(node, ast.ClassDef):
|
||||
for base in node.bases:
|
||||
base_name = None
|
||||
if isinstance(base, ast.Name):
|
||||
base_name = base.id
|
||||
elif isinstance(base, ast.Attribute) and isinstance(base.value, ast.Name):
|
||||
base_name = base.attr
|
||||
|
||||
if base_name and base_name in imported_names:
|
||||
module_name = imported_names[base_name]
|
||||
if not _is_project_module(module_name, project_root_path):
|
||||
external_bases.append((base_name, module_name))
|
||||
|
||||
if not external_bases:
|
||||
return CodeStringsMarkdown(code_strings=[])
|
||||
|
||||
code_strings: list[CodeString] = []
|
||||
extracted: set[tuple[str, str]] = set()
|
||||
|
||||
for base_name, module_name in external_bases:
|
||||
if (module_name, base_name) in extracted:
|
||||
continue
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
base_class = getattr(module, base_name, None)
|
||||
if base_class is None:
|
||||
continue
|
||||
|
||||
init_method = getattr(base_class, "__init__", None)
|
||||
if init_method is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
init_source = inspect.getsource(init_method)
|
||||
init_source = textwrap.dedent(init_source)
|
||||
class_file = Path(inspect.getfile(base_class))
|
||||
parts = class_file.parts
|
||||
if "site-packages" in parts:
|
||||
idx = parts.index("site-packages")
|
||||
class_file = Path(*parts[idx + 1 :])
|
||||
except (OSError, TypeError):
|
||||
continue
|
||||
|
||||
class_source = f"class {base_name}:\n" + textwrap.indent(init_source, " ")
|
||||
code_strings.append(CodeString(code=class_source, file_path=class_file))
|
||||
extracted.add((module_name, base_name))
|
||||
|
||||
except (ImportError, ModuleNotFoundError, AttributeError):
|
||||
logger.debug(f"Failed to extract __init__ for {module_name}.{base_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
def _is_project_module(module_name: str, project_root_path: Path) -> bool:
|
||||
"""Check if a module is part of the project (not external/stdlib)."""
|
||||
import importlib.util
|
||||
|
||||
try:
|
||||
spec = importlib.util.find_spec(module_name)
|
||||
except (ImportError, ModuleNotFoundError, ValueError):
|
||||
return False
|
||||
else:
|
||||
if spec is None or spec.origin is None:
|
||||
return False
|
||||
module_path = Path(spec.origin)
|
||||
# Check if the module is in site-packages (external dependency)
|
||||
# This must be checked first because .venv/site-packages is under project root
|
||||
if path_belongs_to_site_packages(module_path):
|
||||
return False
|
||||
# Check if the module is within the project root
|
||||
return str(module_path).startswith(str(project_root_path) + os.sep)
|
||||
|
||||
|
||||
def extract_imports_for_class(module_tree: ast.Module, class_node: ast.ClassDef, module_source: str) -> str:
|
||||
"""Extract import statements needed for a class definition.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from codeflash.context.code_context_extractor import (
|
|||
collect_names_from_annotation,
|
||||
extract_imports_for_class,
|
||||
get_code_optimization_context,
|
||||
get_external_base_class_inits,
|
||||
get_imported_class_definitions,
|
||||
)
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -3717,3 +3718,105 @@ class ConfigRegistry:
|
|||
assert "supports_vision: bool" in all_extracted_code, "Should include supports_vision field from Parent"
|
||||
assert "litellm_params:" in all_extracted_code, "Should include litellm_params field from Child"
|
||||
assert "model_list: list" in all_extracted_code, "Should include model_list field from Router"
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_extracts_userdict(tmp_path: Path) -> None:
|
||||
"""Extracts __init__ from collections.UserDict when a class inherits from it."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyCustomDict(UserDict):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "mydict.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
code_string = result.code_strings[0]
|
||||
|
||||
expected_code = """\
|
||||
class UserDict:
|
||||
def __init__(self, dict=None, /, **kwargs):
|
||||
self.data = {}
|
||||
if dict is not None:
|
||||
self.update(dict)
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
"""
|
||||
assert code_string.code == expected_code
|
||||
assert code_string.file_path.as_posix().endswith("collections/__init__.py")
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_project_classes(tmp_path: Path) -> None:
|
||||
"""Returns empty when base class is from the project, not external."""
|
||||
child_code = """from base import ProjectBase
|
||||
|
||||
class Child(ProjectBase):
|
||||
pass
|
||||
"""
|
||||
child_path = tmp_path / "child.py"
|
||||
child_path.write_text(child_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=child_code, file_path=child_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_skips_builtins(tmp_path: Path) -> None:
|
||||
"""Returns empty for builtin classes like list that have no inspectable source."""
|
||||
code = """class MyList(list):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "mylist.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_deduplicates(tmp_path: Path) -> None:
|
||||
"""Extracts the same external base class only once even when inherited multiple times."""
|
||||
code = """from collections import UserDict
|
||||
|
||||
class MyDict1(UserDict):
|
||||
pass
|
||||
|
||||
class MyDict2(UserDict):
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "mydicts.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert len(result.code_strings) == 1
|
||||
expected_code = """\
|
||||
class UserDict:
|
||||
def __init__(self, dict=None, /, **kwargs):
|
||||
self.data = {}
|
||||
if dict is not None:
|
||||
self.update(dict)
|
||||
if kwargs:
|
||||
self.update(kwargs)
|
||||
"""
|
||||
assert result.code_strings[0].code == expected_code
|
||||
|
||||
|
||||
def test_get_external_base_class_inits_empty_when_no_inheritance(tmp_path: Path) -> None:
|
||||
"""Returns empty when there are no external base classes."""
|
||||
code = """class SimpleClass:
|
||||
pass
|
||||
"""
|
||||
code_path = tmp_path / "simple.py"
|
||||
code_path.write_text(code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=code, file_path=code_path)])
|
||||
result = get_external_base_class_inits(context, tmp_path)
|
||||
|
||||
assert result.code_strings == []
|
||||
|
|
|
|||
Loading…
Reference in a new issue