codeflash-agent/packages/codeflash-python/tests/test_enrichment.py
Kevin Turcios 6b73b07d15 fix: deduplicate code across codeflash-core and codeflash-python
- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler)
- Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions)
- Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py
- Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py
- Remove duplicate performance_gain from _verification.py (use codeflash_core's version)
- Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block)
- Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py
- Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation
- Add unit tests for all extracted helpers
2026-04-23 22:39:50 -05:00

738 lines
23 KiB
Python

"""Tests for _context.enrichment — testgen context enrichment."""
from __future__ import annotations
import ast
import textwrap
from pathlib import Path
from codeflash_python._model import FunctionToOptimize
from codeflash_python.context._ast_helpers import (
collect_existing_class_names,
collect_import_aliases,
collect_type_names_from_annotation,
find_class_node_by_name,
)
from codeflash_python.context._class_analysis import (
build_import_from_map,
build_synthetic_init_stub,
get_attrs_config,
get_class_start_line,
get_dataclass_config,
is_namedtuple_class,
resolve_instance_class_name,
should_use_raw_project_class_context,
)
from codeflash_python.context.enrichment import (
collect_type_names_from_function,
enrich_testgen_context,
extract_function_stub_snippet,
extract_imports_for_class,
extract_init_stub_from_class,
extract_parameter_type_constructors,
)
from codeflash_python.context.models import (
CodeString,
CodeStringsMarkdown,
)
def _parse(code: str) -> ast.Module:
return ast.parse(textwrap.dedent(code))
class TestCollectImportAliases:
"""Tests for collect_import_aliases."""
def test_import(self) -> None:
"""Plain import produces name → dotted name."""
tree = _parse("import os.path")
assert {"os": "os.path"} == collect_import_aliases(tree)
def test_import_from(self) -> None:
"""from-import produces name → module.name."""
tree = _parse("from pathlib import Path")
assert {"Path": "pathlib.Path"} == collect_import_aliases(tree)
def test_alias(self) -> None:
"""as-alias overrides the bound name."""
tree = _parse("from pathlib import Path as P")
assert {"P": "pathlib.Path"} == collect_import_aliases(tree)
class TestFindClassNodeByName:
"""Tests for find_class_node_by_name."""
def test_top_level(self) -> None:
"""Finds a top-level class."""
tree = _parse("class Foo: pass")
node = find_class_node_by_name("Foo", tree)
assert node is not None
assert "Foo" == node.name
def test_nested(self) -> None:
"""Finds a class nested inside another class."""
tree = _parse(
"""\
class Outer:
class Inner:
pass
"""
)
node = find_class_node_by_name("Inner", tree)
assert node is not None
assert "Inner" == node.name
def test_missing(self) -> None:
"""Returns None when class is absent."""
tree = _parse("x = 1")
assert find_class_node_by_name("Foo", tree) is None
class TestCollectExistingClassNames:
"""Tests for collect_existing_class_names."""
def test_multiple_classes(self) -> None:
"""Collects all class names in a module."""
tree = _parse(
"""\
class A: pass
class B: pass
"""
)
assert {"A", "B"} == collect_existing_class_names(tree)
class TestCollectTypeNamesFromAnnotation:
"""Tests for collect_type_names_from_annotation."""
def test_name(self) -> None:
"""Simple name annotation."""
node = _parse("x: Foo").body[0].annotation # type: ignore[union-attr]
assert {"Foo"} == collect_type_names_from_annotation(node)
def test_subscript(self) -> None:
"""Generic subscript annotation."""
node = _parse("x: List[int]").body[0].annotation # type: ignore[union-attr]
assert {"List", "int"} == collect_type_names_from_annotation(node)
def test_bitor_union(self) -> None:
"""PEP 604 union annotation."""
node = _parse("x: int | str").body[0].annotation # type: ignore[union-attr]
assert {"int", "str"} == collect_type_names_from_annotation(node)
def test_none(self) -> None:
"""None returns empty set."""
assert set() == collect_type_names_from_annotation(None)
def test_attribute_collects_module(self) -> None:
"""An Attribute like typing.Optional collects the module name."""
ann = ast.Attribute(
value=ast.Name(id="typing"),
attr="Optional",
)
assert {"typing"} == collect_type_names_from_annotation(ann)
class TestDeclarativeClassDetection:
"""Tests for NamedTuple, dataclass, and attrs detection."""
def test_namedtuple(self) -> None:
"""Detects NamedTuple base class."""
tree = _parse(
"""\
from typing import NamedTuple
class Point(NamedTuple):
x: int
y: int
"""
)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Point", tree)
assert node is not None
assert is_namedtuple_class(node, aliases) is True
def test_dataclass(self) -> None:
"""Detects @dataclass decorator."""
tree = _parse(
"""\
from dataclasses import dataclass
@dataclass
class Config:
name: str
"""
)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Config", tree)
assert node is not None
is_dc, init_enabled, kw_only = get_dataclass_config(node, aliases)
assert is_dc is True
assert init_enabled is True
assert kw_only is False
def test_dataclass_no_init(self) -> None:
"""Detects @dataclass(init=False)."""
tree = _parse(
"""\
from dataclasses import dataclass
@dataclass(init=False)
class Config:
name: str
"""
)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Config", tree)
assert node is not None
is_dc, init_enabled, _ = get_dataclass_config(node, aliases)
assert is_dc is True
assert init_enabled is False
def test_attrs_frozen(self) -> None:
"""Detects @attrs.frozen decorator."""
tree = _parse(
"""\
import attrs
@attrs.frozen
class Point:
x: int
"""
)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Point", tree)
assert node is not None
is_at, init_enabled, kw_only = get_attrs_config(node, aliases)
assert is_at is True
assert init_enabled is True
assert kw_only is False
class TestBuildSyntheticInitStub:
"""Tests for build_synthetic_init_stub."""
def test_dataclass_stub(self) -> None:
"""Generates __init__ for a dataclass."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Point:
x: int
y: int
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Point", tree)
assert node is not None
stub = build_synthetic_init_stub(node, source, aliases)
assert stub is not None
assert "def __init__(self, x: int, y: int):" in stub
def test_dataclass_with_default(self) -> None:
"""Includes defaults in synthetic __init__."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Config:
name: str
debug: bool = False
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Config", tree)
assert node is not None
stub = build_synthetic_init_stub(node, source, aliases)
assert stub is not None
assert "debug: bool = False" in stub
def test_namedtuple_stub(self) -> None:
"""Generates __init__ for a NamedTuple."""
source = textwrap.dedent(
"""\
from typing import NamedTuple
class Pair(NamedTuple):
a: str
b: str
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Pair", tree)
assert node is not None
stub = build_synthetic_init_stub(node, source, aliases)
assert stub is not None
assert "def __init__(self, a: str, b: str):" in stub
def test_kw_only(self) -> None:
"""Generates __init__ with *, for kw_only dataclass."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass(kw_only=True)
class Opts:
a: int
b: int
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Opts", tree)
assert node is not None
stub = build_synthetic_init_stub(node, source, aliases)
assert stub is not None
assert "*, a: int" in stub
def test_plain_class_returns_none(self) -> None:
"""Returns None for a plain class."""
source = textwrap.dedent(
"""\
class Plain:
x: int
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Plain", tree)
assert node is not None
assert build_synthetic_init_stub(node, source, aliases) is None
def test_attrs_define(self) -> None:
"""Generates __init__ for an attrs.define class."""
source = textwrap.dedent(
"""\
import attrs
@attrs.define
class Widget:
name: str
count: int
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Widget", tree)
assert node is not None
stub = build_synthetic_init_stub(node, source, aliases)
assert stub is not None
assert "name: str" in stub
assert "count: int" in stub
class TestExtractInitStubFromClass:
"""Tests for extract_init_stub_from_class."""
def test_explicit_init(self) -> None:
"""Extracts existing __init__ verbatim."""
source = textwrap.dedent(
"""\
class Foo:
def __init__(self, x: int) -> None:
self.x = x
"""
)
tree = ast.parse(source)
result = extract_init_stub_from_class("Foo", source, tree)
assert result is not None
assert "class Foo:" in result
assert "def __init__(self, x: int)" in result
def test_dataclass_synthetic(self) -> None:
"""Synthesizes __init__ for a dataclass."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Bar:
name: str
value: int = 0
"""
)
tree = ast.parse(source)
result = extract_init_stub_from_class("Bar", source, tree)
assert result is not None
assert "class Bar:" in result
assert "name: str" in result
def test_missing_class(self) -> None:
"""Returns None when the class doesn't exist."""
source = "x = 1"
tree = ast.parse(source)
assert extract_init_stub_from_class("Missing", source, tree) is None
def test_includes_post_init(self) -> None:
"""Includes __post_init__ in the output."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Validated:
x: int
def __post_init__(self):
if self.x < 0:
raise ValueError
"""
)
tree = ast.parse(source)
result = extract_init_stub_from_class("Validated", source, tree)
assert result is not None
assert "__post_init__" in result
class TestExtractFunctionStubSnippet:
"""Tests for extract_function_stub_snippet."""
def test_plain_function(self) -> None:
"""Extracts a function's source lines."""
source = textwrap.dedent(
"""\
def foo(x: int) -> int:
return x + 1
"""
)
tree = ast.parse(source)
fn = tree.body[0]
assert isinstance(fn, ast.FunctionDef)
lines = source.splitlines()
snippet = extract_function_stub_snippet(fn, lines)
assert "def foo(x: int) -> int:" in snippet
assert "return x + 1" in snippet
def test_decorated_function(self) -> None:
"""Includes decorator lines."""
source = textwrap.dedent(
"""\
@property
def name(self) -> str:
return self._name
"""
)
tree = ast.parse(source)
fn = tree.body[0]
assert isinstance(fn, ast.FunctionDef)
lines = source.splitlines()
snippet = extract_function_stub_snippet(fn, lines)
assert "@property" in snippet
class TestGetClassStartLine:
"""Tests for get_class_start_line."""
def test_no_decorators(self) -> None:
"""Start line is the class keyword line."""
tree = _parse("class Foo: pass")
node = find_class_node_by_name("Foo", tree)
assert node is not None
assert 1 == get_class_start_line(node)
def test_with_decorator(self) -> None:
"""Start line is the first decorator line."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Foo:
x: int
"""
)
tree = ast.parse(source)
node = find_class_node_by_name("Foo", tree)
assert node is not None
assert 2 == get_class_start_line(node)
class TestResolveInstanceClassName:
"""Tests for resolve_instance_class_name."""
def test_call_assignment(self) -> None:
"""Resolves name = SomeClass()."""
tree = _parse("instance = MyClass()")
assert "MyClass" == resolve_instance_class_name("instance", tree)
def test_annotated_assignment(self) -> None:
"""Resolves name: SomeClass."""
tree = _parse("instance: MyClass")
assert "MyClass" == resolve_instance_class_name("instance", tree)
def test_not_found(self) -> None:
"""Returns None for non-matching name."""
tree = _parse("x = 1")
assert resolve_instance_class_name("y", tree) is None
class TestBuildImportFromMap:
"""Tests for build_import_from_map."""
def test_from_import(self) -> None:
"""Maps imported name → module."""
tree = _parse("from pathlib import Path")
assert {"Path": "pathlib"} == build_import_from_map(tree)
def test_alias(self) -> None:
"""Alias is used as the key."""
tree = _parse("from pathlib import Path as P")
assert {"P": "pathlib"} == build_import_from_map(tree)
class TestExtractImportsForClass:
"""Tests for extract_imports_for_class."""
def test_base_class_import(self) -> None:
"""Extracts import needed for a base class."""
source = textwrap.dedent(
"""\
from abc import ABC
class MyClass(ABC):
pass
"""
)
tree = ast.parse(source)
node = find_class_node_by_name("MyClass", tree)
assert node is not None
result = extract_imports_for_class(tree, node, source)
assert "from abc import ABC" in result
def test_decorator_import(self) -> None:
"""Extracts import needed for a decorator."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Point:
x: int
"""
)
tree = ast.parse(source)
node = find_class_node_by_name("Point", tree)
assert node is not None
result = extract_imports_for_class(tree, node, source)
assert "from dataclasses import dataclass" in result
class TestShouldUseRawProjectClassContext:
"""Tests for should_use_raw_project_class_context."""
def test_decorated_class(self) -> None:
"""Decorated classes always get raw context."""
source = textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Foo:
x: int
"""
)
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Foo", tree)
assert node is not None
assert should_use_raw_project_class_context(node, aliases) is True
def test_plain_large_class(self) -> None:
"""A large plain class without special features returns False."""
# Build a class with many methods (> MAX_RAW_PROJECT_CLASS_BODY_ITEMS)
methods = "\n".join(
f" def method_{i}(self): pass" for i in range(20)
)
source = f"class Big:\n{methods}\n"
tree = ast.parse(source)
aliases = collect_import_aliases(tree)
node = find_class_node_by_name("Big", tree)
assert node is not None
assert should_use_raw_project_class_context(node, aliases) is False
class TestCollectTypeNamesFromFunction:
"""Tests for collect_type_names_from_function."""
def test_annotation_types(self) -> None:
"""Collects types from parameter annotations."""
source = textwrap.dedent(
"""\
def process(items: MyList, config: Config) -> Result:
return Result()
"""
)
tree = ast.parse(source)
fn = tree.body[0]
assert isinstance(fn, ast.FunctionDef)
names = collect_type_names_from_function(fn, tree, None)
assert "MyList" in names
assert "Config" in names
def test_isinstance_types(self) -> None:
"""Collects types from isinstance checks."""
source = textwrap.dedent(
"""\
def check(x):
if isinstance(x, MyType):
pass
"""
)
tree = ast.parse(source)
fn = tree.body[0]
assert isinstance(fn, ast.FunctionDef)
names = collect_type_names_from_function(fn, tree, None)
assert "MyType" in names
def test_isinstance_tuple(self) -> None:
"""Collects types from isinstance with tuple of types."""
source = textwrap.dedent(
"""\
def check(x):
if isinstance(x, (TypeA, TypeB)):
pass
"""
)
tree = ast.parse(source)
fn = tree.body[0]
assert isinstance(fn, ast.FunctionDef)
names = collect_type_names_from_function(fn, tree, None)
assert "TypeA" in names
assert "TypeB" in names
def test_class_bases(self) -> None:
"""Collects base class types when class_name is provided."""
source = textwrap.dedent(
"""\
class Parent:
pass
class Child(Parent):
def method(self):
pass
"""
)
tree = ast.parse(source)
fn = tree.body[1].body[0] # type: ignore[union-attr]
assert isinstance(fn, ast.FunctionDef)
names = collect_type_names_from_function(fn, tree, "Child")
assert "Parent" in names
class TestEnrichTestgenContext:
"""Tests for enrich_testgen_context."""
def test_project_class_resolution(
self,
tmp_path: Path,
) -> None:
"""Resolves a project class via Jedi and extracts source."""
models = tmp_path / "models.py"
models.write_text(
textwrap.dedent(
"""\
class Widget:
def __init__(self, name: str) -> None:
self.name = name
"""
),
encoding="utf-8",
)
code = textwrap.dedent(
"""\
from models import Widget
def process(w: Widget) -> str:
return w.name
"""
)
context = CodeStringsMarkdown(code_strings=[CodeString(code=code)])
result = enrich_testgen_context(context, tmp_path)
if result.code_strings:
combined = "\n".join(cs.code for cs in result.code_strings)
assert "Widget" in combined
def test_empty_context(self) -> None:
"""Returns empty result for empty input."""
context = CodeStringsMarkdown(code_strings=[])
result = enrich_testgen_context(context, Path("/nonexistent"))
assert [] == result.code_strings
def test_syntax_error_in_context(self) -> None:
"""Returns empty result for unparseable code."""
context = CodeStringsMarkdown(
code_strings=[CodeString(code="def broken(")]
)
result = enrich_testgen_context(context, Path("/nonexistent"))
assert [] == result.code_strings
class TestExtractParameterTypeConstructors:
"""Tests for extract_parameter_type_constructors."""
def test_dataclass_type_in_signature(
self,
tmp_path: Path,
) -> None:
"""Extracts __init__ stub for a dataclass used in signature."""
models = tmp_path / "models.py"
models.write_text(
textwrap.dedent(
"""\
from dataclasses import dataclass
@dataclass
class Config:
name: str
debug: bool = False
"""
),
encoding="utf-8",
)
main = tmp_path / "main.py"
main.write_text(
textwrap.dedent(
"""\
from models import Config
def process(config: Config) -> str:
return config.name
"""
),
encoding="utf-8",
)
fn = FunctionToOptimize(
function_name="process",
file_path=main,
starting_line=3,
)
result = extract_parameter_type_constructors(fn, tmp_path, set())
if result.code_strings:
combined = "\n".join(cs.code for cs in result.code_strings)
assert "Config" in combined
def test_builtin_types_ignored(
self,
tmp_path: Path,
) -> None:
"""Builtin types are not resolved."""
main = tmp_path / "main.py"
main.write_text(
textwrap.dedent(
"""\
def add(x: int, y: int) -> int:
return x + y
"""
),
encoding="utf-8",
)
fn = FunctionToOptimize(
function_name="add",
file_path=main,
starting_line=1,
)
result = extract_parameter_type_constructors(fn, tmp_path, set())
assert [] == result.code_strings
def test_missing_function(
self,
tmp_path: Path,
) -> None:
"""Returns empty result when function is not found."""
main = tmp_path / "main.py"
main.write_text("x = 1\n", encoding="utf-8")
fn = FunctionToOptimize(
function_name="nonexistent",
file_path=main,
)
result = extract_parameter_type_constructors(fn, tmp_path, set())
assert [] == result.code_strings