mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: improve test generation context for external library types
Extend extract_parameter_type_constructors to scan function bodies for isinstance/type() patterns and collect base class names from enclosing classes. Add one-level transitive stub extraction so the LLM also sees constructor signatures for types referenced in __init__ parameters. In enrich_testgen_context, branch on source: project classes get full definitions, third-party (site-packages) classes get compact __init__ stubs to avoid blowing token limits.
This commit is contained in:
parent
d284be521d
commit
5cee1b5b48
2 changed files with 270 additions and 8 deletions
|
|
@ -837,6 +837,41 @@ def extract_parameter_type_constructors(
|
|||
if func_node.args.kwarg:
|
||||
type_names |= collect_type_names_from_annotation(func_node.args.kwarg.annotation)
|
||||
|
||||
# Scan function body for isinstance(x, SomeType) and type(x) is/== SomeType patterns
|
||||
for body_node in ast.walk(func_node):
|
||||
if (
|
||||
isinstance(body_node, ast.Call)
|
||||
and isinstance(body_node.func, ast.Name)
|
||||
and body_node.func.id == "isinstance"
|
||||
):
|
||||
if len(body_node.args) >= 2:
|
||||
second_arg = body_node.args[1]
|
||||
if isinstance(second_arg, ast.Name):
|
||||
type_names.add(second_arg.id)
|
||||
elif isinstance(second_arg, ast.Tuple):
|
||||
for elt in second_arg.elts:
|
||||
if isinstance(elt, ast.Name):
|
||||
type_names.add(elt.id)
|
||||
elif isinstance(body_node, ast.Compare):
|
||||
# type(x) is/== SomeType
|
||||
if (
|
||||
isinstance(body_node.left, ast.Call)
|
||||
and isinstance(body_node.left.func, ast.Name)
|
||||
and body_node.left.func.id == "type"
|
||||
):
|
||||
for comparator in body_node.comparators:
|
||||
if isinstance(comparator, ast.Name):
|
||||
type_names.add(comparator.id)
|
||||
|
||||
# Collect base class names from enclosing class (if this is a method)
|
||||
if function_to_optimize.class_name is not None:
|
||||
for top_node in ast.walk(tree):
|
||||
if isinstance(top_node, ast.ClassDef) and top_node.name == function_to_optimize.class_name:
|
||||
for base in top_node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
type_names.add(base.id)
|
||||
break
|
||||
|
||||
type_names -= BUILTIN_AND_TYPING_NAMES
|
||||
type_names -= existing_class_names
|
||||
if not type_names:
|
||||
|
|
@ -881,6 +916,58 @@ def extract_parameter_type_constructors(
|
|||
logger.debug(f"Error extracting constructor stub for {type_name} from {module_name}")
|
||||
continue
|
||||
|
||||
# Transitive extraction (one level): for each extracted stub, find __init__ param types and extract their stubs
|
||||
# Build an extended import map that includes imports from source modules of already-extracted stubs
|
||||
transitive_import_map = dict(import_map)
|
||||
for _, cached_tree in module_cache.values():
|
||||
for cache_node in ast.walk(cached_tree):
|
||||
if isinstance(cache_node, ast.ImportFrom) and cache_node.module:
|
||||
for alias in cache_node.names:
|
||||
name = alias.asname if alias.asname else alias.name
|
||||
if name not in transitive_import_map:
|
||||
transitive_import_map[name] = cache_node.module
|
||||
|
||||
emitted_names = type_names | existing_class_names | BUILTIN_AND_TYPING_NAMES
|
||||
transitive_type_names: set[str] = set()
|
||||
for cs in code_strings:
|
||||
try:
|
||||
stub_tree = ast.parse(cs.code)
|
||||
except SyntaxError:
|
||||
continue
|
||||
for stub_node in ast.walk(stub_tree):
|
||||
if isinstance(stub_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and stub_node.name in (
|
||||
"__init__",
|
||||
"__post_init__",
|
||||
):
|
||||
for arg in stub_node.args.args + stub_node.args.posonlyargs + stub_node.args.kwonlyargs:
|
||||
transitive_type_names |= collect_type_names_from_annotation(arg.annotation)
|
||||
transitive_type_names -= emitted_names
|
||||
for type_name in sorted(transitive_type_names):
|
||||
module_name = transitive_import_map.get(type_name)
|
||||
if not module_name:
|
||||
continue
|
||||
try:
|
||||
script_code = f"from {module_name} import {type_name}"
|
||||
script = jedi.Script(script_code, project=jedi.Project(path=project_root_path))
|
||||
definitions = script.goto(1, len(f"from {module_name} import ") + len(type_name), follow_imports=True)
|
||||
if not definitions:
|
||||
continue
|
||||
module_path = definitions[0].module_path
|
||||
if not module_path:
|
||||
continue
|
||||
if module_path in module_cache:
|
||||
mod_source, mod_tree = module_cache[module_path]
|
||||
else:
|
||||
mod_source = module_path.read_text(encoding="utf-8")
|
||||
mod_tree = ast.parse(mod_source)
|
||||
module_cache[module_path] = (mod_source, mod_tree)
|
||||
stub = extract_init_stub_from_class(type_name, mod_source, mod_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting transitive constructor stub for {type_name} from {module_name}")
|
||||
continue
|
||||
|
||||
return CodeStringsMarkdown(code_strings=code_strings)
|
||||
|
||||
|
||||
|
|
@ -1004,12 +1091,23 @@ def enrich_testgen_context(code_context: CodeStringsMarkdown, project_root_path:
|
|||
continue
|
||||
module_source, module_tree = mod_result
|
||||
|
||||
extract_class_and_bases(name, module_path, module_source, module_tree)
|
||||
|
||||
if (module_path, name) not in extracted_classes:
|
||||
resolved_class = resolve_instance_class_name(name, module_tree)
|
||||
if resolved_class and resolved_class not in existing_classes:
|
||||
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
|
||||
if is_project:
|
||||
extract_class_and_bases(name, module_path, module_source, module_tree)
|
||||
if (module_path, name) not in extracted_classes:
|
||||
resolved_class = resolve_instance_class_name(name, module_tree)
|
||||
if resolved_class and resolved_class not in existing_classes:
|
||||
extract_class_and_bases(resolved_class, module_path, module_source, module_tree)
|
||||
elif is_third_party:
|
||||
target_name = name
|
||||
if not any(isinstance(n, ast.ClassDef) and n.name == name for n in ast.walk(module_tree)):
|
||||
resolved_class = resolve_instance_class_name(name, module_tree)
|
||||
if resolved_class:
|
||||
target_name = resolved_class
|
||||
if target_name not in emitted_class_names:
|
||||
stub = extract_init_stub_from_class(target_name, module_source, module_tree)
|
||||
if stub:
|
||||
code_strings.append(CodeString(code=stub, file_path=module_path))
|
||||
emitted_class_names.add(target_name)
|
||||
|
||||
except Exception:
|
||||
logger.debug(f"Error extracting class definition for {name} from {module_name}")
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.python.context.code_context_extractor import (
|
||||
collect_type_names_from_annotation,
|
||||
|
|
@ -20,6 +18,8 @@ from codeflash.languages.python.context.code_context_extractor import (
|
|||
get_code_optimization_context,
|
||||
resolve_instance_class_name,
|
||||
)
|
||||
from codeflash.languages.python.static_analysis.code_extractor import GlobalAssignmentCollector, add_global_assignments
|
||||
from codeflash.languages.python.static_analysis.code_replacer import replace_functions_and_add_imports
|
||||
from codeflash.models.models import CodeString, CodeStringsMarkdown, FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
|
@ -4701,3 +4701,167 @@ def get_log_level() -> str:
|
|||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class AppConfig:" in combined
|
||||
assert "@property" in combined
|
||||
|
||||
def test_extract_parameter_type_constructors_isinstance_single(tmp_path: Path) -> None:
|
||||
"""isinstance(x, SomeType) in function body should be picked up."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"class Widget:\n def __init__(self, size: int):\n self.size = size\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Widget\n\ndef check(x) -> bool:\n return isinstance(x, Widget)\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class Widget:" in result.code_strings[0].code
|
||||
assert "__init__" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_isinstance_tuple(tmp_path: Path) -> None:
|
||||
"""isinstance(x, (TypeA, TypeB)) should pick up both types."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"class Alpha:\n def __init__(self, a: int):\n self.a = a\n\n"
|
||||
"class Beta:\n def __init__(self, b: str):\n self.b = b\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Alpha, Beta\n\ndef check(x) -> bool:\n return isinstance(x, (Alpha, Beta))\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 2
|
||||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class Alpha:" in combined
|
||||
assert "class Beta:" in combined
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_type_is_pattern(tmp_path: Path) -> None:
|
||||
"""type(x) is SomeType pattern should be picked up."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "models.py").write_text(
|
||||
"class Gadget:\n def __init__(self, val: float):\n self.val = val\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Gadget\n\ndef check(x) -> bool:\n return type(x) is Gadget\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class Gadget:" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_base_classes(tmp_path: Path) -> None:
|
||||
"""Base classes of enclosing class should be picked up for methods."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "base.py").write_text(
|
||||
"class BaseProcessor:\n def __init__(self, config: str):\n self.config = config\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "child.py").write_text(
|
||||
"from mypkg.base import BaseProcessor\n\nclass ChildProcessor(BaseProcessor):\n"
|
||||
" def process(self) -> str:\n return self.config\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process",
|
||||
file_path=(pkg / "child.py").resolve(),
|
||||
starting_line=4,
|
||||
ending_line=5,
|
||||
parents=[FunctionParent(name="ChildProcessor", type="ClassDef")],
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 1
|
||||
assert "class BaseProcessor:" in result.code_strings[0].code
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_isinstance_builtins_excluded(tmp_path: Path) -> None:
|
||||
"""Isinstance with builtins (int, str, etc.) should not produce stubs."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "func.py").write_text(
|
||||
"def check(x) -> bool:\n return isinstance(x, (int, str, float))\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="check", file_path=(pkg / "func.py").resolve(), starting_line=1, ending_line=2
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
assert len(result.code_strings) == 0
|
||||
|
||||
|
||||
def test_extract_parameter_type_constructors_transitive(tmp_path: Path) -> None:
|
||||
"""Transitive extraction: if Widget.__init__ takes a Config, Config's stub should also appear."""
|
||||
pkg = tmp_path / "mypkg"
|
||||
pkg.mkdir()
|
||||
(pkg / "__init__.py").write_text("", encoding="utf-8")
|
||||
(pkg / "config.py").write_text(
|
||||
"class Config:\n def __init__(self, debug: bool = False):\n self.debug = debug\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "models.py").write_text(
|
||||
"from mypkg.config import Config\n\n"
|
||||
"class Widget:\n def __init__(self, cfg: Config):\n self.cfg = cfg\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(pkg / "processor.py").write_text(
|
||||
"from mypkg.models import Widget\n\ndef process(w: Widget) -> str:\n return str(w)\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
fto = FunctionToOptimize(
|
||||
function_name="process", file_path=(pkg / "processor.py").resolve(), starting_line=3, ending_line=4
|
||||
)
|
||||
result = extract_parameter_type_constructors(fto, tmp_path.resolve(), set())
|
||||
combined = "\n".join(cs.code for cs in result.code_strings)
|
||||
assert "class Widget:" in combined
|
||||
assert "class Config:" in combined
|
||||
|
||||
|
||||
|
||||
|
||||
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
|
||||
"""Third-party classes should produce compact __init__ stubs, not full class source."""
|
||||
# Use a real third-party package (pydantic) so jedi can actually resolve it
|
||||
context_code = (
|
||||
"from pydantic import BaseModel\n\n"
|
||||
"class MyModel(BaseModel):\n"
|
||||
" name: str\n\n"
|
||||
"def process(m: MyModel) -> str:\n"
|
||||
" return m.name\n"
|
||||
)
|
||||
consumer_path = tmp_path / "consumer.py"
|
||||
consumer_path.write_text(context_code, encoding="utf-8")
|
||||
|
||||
context = CodeStringsMarkdown(code_strings=[CodeString(code=context_code, file_path=consumer_path)])
|
||||
result = enrich_testgen_context(context, tmp_path)
|
||||
|
||||
# BaseModel lives in site-packages so should get stub treatment (compact __init__),
|
||||
# not the full class definition with hundreds of methods
|
||||
for cs in result.code_strings:
|
||||
if "BaseModel" in cs.code:
|
||||
assert "class BaseModel:" in cs.code
|
||||
assert "__init__" in cs.code
|
||||
# Full BaseModel has many methods; stubs should only have __init__/properties
|
||||
assert "model_dump" not in cs.code
|
||||
break
|
||||
|
|
|
|||
Loading…
Reference in a new issue