mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
- Extract _parse_candidates helper in _client.py (used by get_candidates and optimize_with_line_profiler) - Parameterize URL resolution in _http.py (_resolve_url_from_env replaces two near-identical functions) - Delegate get_repo_owner_and_name to parse_repo_owner_and_name in _git.py - Simplify _par_apply_fns to delegate to _apply_fns in danom/stream.py - Remove duplicate performance_gain from _verification.py (use codeflash_core's version) - Extract _extract_pytest_error helper in _verification.py (replaces duplicated 6-line block) - Consolidate collect_names_from_annotation into collect_type_names_from_annotation in _ast_helpers.py - Add ast.Attribute handling and relax BinOp guard in collect_type_names_from_annotation - Add unit tests for all extracted helpers
738 lines
23 KiB
Python
738 lines
23 KiB
Python
"""Tests for _context.enrichment — testgen context enrichment."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import textwrap
|
|
from pathlib import Path
|
|
|
|
from codeflash_python._model import FunctionToOptimize
|
|
from codeflash_python.context._ast_helpers import (
|
|
collect_existing_class_names,
|
|
collect_import_aliases,
|
|
collect_type_names_from_annotation,
|
|
find_class_node_by_name,
|
|
)
|
|
from codeflash_python.context._class_analysis import (
|
|
build_import_from_map,
|
|
build_synthetic_init_stub,
|
|
get_attrs_config,
|
|
get_class_start_line,
|
|
get_dataclass_config,
|
|
is_namedtuple_class,
|
|
resolve_instance_class_name,
|
|
should_use_raw_project_class_context,
|
|
)
|
|
from codeflash_python.context.enrichment import (
|
|
collect_type_names_from_function,
|
|
enrich_testgen_context,
|
|
extract_function_stub_snippet,
|
|
extract_imports_for_class,
|
|
extract_init_stub_from_class,
|
|
extract_parameter_type_constructors,
|
|
)
|
|
from codeflash_python.context.models import (
|
|
CodeString,
|
|
CodeStringsMarkdown,
|
|
)
|
|
|
|
|
|
def _parse(code: str) -> ast.Module:
|
|
return ast.parse(textwrap.dedent(code))
|
|
|
|
|
|
class TestCollectImportAliases:
|
|
"""Tests for collect_import_aliases."""
|
|
|
|
def test_import(self) -> None:
|
|
"""Plain import produces name → dotted name."""
|
|
tree = _parse("import os.path")
|
|
assert {"os": "os.path"} == collect_import_aliases(tree)
|
|
|
|
def test_import_from(self) -> None:
|
|
"""from-import produces name → module.name."""
|
|
tree = _parse("from pathlib import Path")
|
|
assert {"Path": "pathlib.Path"} == collect_import_aliases(tree)
|
|
|
|
def test_alias(self) -> None:
|
|
"""as-alias overrides the bound name."""
|
|
tree = _parse("from pathlib import Path as P")
|
|
assert {"P": "pathlib.Path"} == collect_import_aliases(tree)
|
|
|
|
|
|
class TestFindClassNodeByName:
|
|
"""Tests for find_class_node_by_name."""
|
|
|
|
def test_top_level(self) -> None:
|
|
"""Finds a top-level class."""
|
|
tree = _parse("class Foo: pass")
|
|
node = find_class_node_by_name("Foo", tree)
|
|
assert node is not None
|
|
assert "Foo" == node.name
|
|
|
|
def test_nested(self) -> None:
|
|
"""Finds a class nested inside another class."""
|
|
tree = _parse(
|
|
"""\
|
|
class Outer:
|
|
class Inner:
|
|
pass
|
|
"""
|
|
)
|
|
node = find_class_node_by_name("Inner", tree)
|
|
assert node is not None
|
|
assert "Inner" == node.name
|
|
|
|
def test_missing(self) -> None:
|
|
"""Returns None when class is absent."""
|
|
tree = _parse("x = 1")
|
|
assert find_class_node_by_name("Foo", tree) is None
|
|
|
|
|
|
class TestCollectExistingClassNames:
|
|
"""Tests for collect_existing_class_names."""
|
|
|
|
def test_multiple_classes(self) -> None:
|
|
"""Collects all class names in a module."""
|
|
tree = _parse(
|
|
"""\
|
|
class A: pass
|
|
class B: pass
|
|
"""
|
|
)
|
|
assert {"A", "B"} == collect_existing_class_names(tree)
|
|
|
|
|
|
class TestCollectTypeNamesFromAnnotation:
|
|
"""Tests for collect_type_names_from_annotation."""
|
|
|
|
def test_name(self) -> None:
|
|
"""Simple name annotation."""
|
|
node = _parse("x: Foo").body[0].annotation # type: ignore[union-attr]
|
|
assert {"Foo"} == collect_type_names_from_annotation(node)
|
|
|
|
def test_subscript(self) -> None:
|
|
"""Generic subscript annotation."""
|
|
node = _parse("x: List[int]").body[0].annotation # type: ignore[union-attr]
|
|
assert {"List", "int"} == collect_type_names_from_annotation(node)
|
|
|
|
def test_bitor_union(self) -> None:
|
|
"""PEP 604 union annotation."""
|
|
node = _parse("x: int | str").body[0].annotation # type: ignore[union-attr]
|
|
assert {"int", "str"} == collect_type_names_from_annotation(node)
|
|
|
|
def test_none(self) -> None:
|
|
"""None returns empty set."""
|
|
assert set() == collect_type_names_from_annotation(None)
|
|
|
|
def test_attribute_collects_module(self) -> None:
|
|
"""An Attribute like typing.Optional collects the module name."""
|
|
ann = ast.Attribute(
|
|
value=ast.Name(id="typing"),
|
|
attr="Optional",
|
|
)
|
|
assert {"typing"} == collect_type_names_from_annotation(ann)
|
|
|
|
|
|
class TestDeclarativeClassDetection:
|
|
"""Tests for NamedTuple, dataclass, and attrs detection."""
|
|
|
|
def test_namedtuple(self) -> None:
|
|
"""Detects NamedTuple base class."""
|
|
tree = _parse(
|
|
"""\
|
|
from typing import NamedTuple
|
|
class Point(NamedTuple):
|
|
x: int
|
|
y: int
|
|
"""
|
|
)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Point", tree)
|
|
assert node is not None
|
|
assert is_namedtuple_class(node, aliases) is True
|
|
|
|
def test_dataclass(self) -> None:
|
|
"""Detects @dataclass decorator."""
|
|
tree = _parse(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Config:
|
|
name: str
|
|
"""
|
|
)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Config", tree)
|
|
assert node is not None
|
|
is_dc, init_enabled, kw_only = get_dataclass_config(node, aliases)
|
|
assert is_dc is True
|
|
assert init_enabled is True
|
|
assert kw_only is False
|
|
|
|
def test_dataclass_no_init(self) -> None:
|
|
"""Detects @dataclass(init=False)."""
|
|
tree = _parse(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass(init=False)
|
|
class Config:
|
|
name: str
|
|
"""
|
|
)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Config", tree)
|
|
assert node is not None
|
|
is_dc, init_enabled, _ = get_dataclass_config(node, aliases)
|
|
assert is_dc is True
|
|
assert init_enabled is False
|
|
|
|
def test_attrs_frozen(self) -> None:
|
|
"""Detects @attrs.frozen decorator."""
|
|
tree = _parse(
|
|
"""\
|
|
import attrs
|
|
@attrs.frozen
|
|
class Point:
|
|
x: int
|
|
"""
|
|
)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Point", tree)
|
|
assert node is not None
|
|
is_at, init_enabled, kw_only = get_attrs_config(node, aliases)
|
|
assert is_at is True
|
|
assert init_enabled is True
|
|
assert kw_only is False
|
|
|
|
|
|
class TestBuildSyntheticInitStub:
|
|
"""Tests for build_synthetic_init_stub."""
|
|
|
|
def test_dataclass_stub(self) -> None:
|
|
"""Generates __init__ for a dataclass."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Point:
|
|
x: int
|
|
y: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Point", tree)
|
|
assert node is not None
|
|
stub = build_synthetic_init_stub(node, source, aliases)
|
|
assert stub is not None
|
|
assert "def __init__(self, x: int, y: int):" in stub
|
|
|
|
def test_dataclass_with_default(self) -> None:
|
|
"""Includes defaults in synthetic __init__."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Config:
|
|
name: str
|
|
debug: bool = False
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Config", tree)
|
|
assert node is not None
|
|
stub = build_synthetic_init_stub(node, source, aliases)
|
|
assert stub is not None
|
|
assert "debug: bool = False" in stub
|
|
|
|
def test_namedtuple_stub(self) -> None:
|
|
"""Generates __init__ for a NamedTuple."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from typing import NamedTuple
|
|
class Pair(NamedTuple):
|
|
a: str
|
|
b: str
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Pair", tree)
|
|
assert node is not None
|
|
stub = build_synthetic_init_stub(node, source, aliases)
|
|
assert stub is not None
|
|
assert "def __init__(self, a: str, b: str):" in stub
|
|
|
|
def test_kw_only(self) -> None:
|
|
"""Generates __init__ with *, for kw_only dataclass."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass(kw_only=True)
|
|
class Opts:
|
|
a: int
|
|
b: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Opts", tree)
|
|
assert node is not None
|
|
stub = build_synthetic_init_stub(node, source, aliases)
|
|
assert stub is not None
|
|
assert "*, a: int" in stub
|
|
|
|
def test_plain_class_returns_none(self) -> None:
|
|
"""Returns None for a plain class."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
class Plain:
|
|
x: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Plain", tree)
|
|
assert node is not None
|
|
assert build_synthetic_init_stub(node, source, aliases) is None
|
|
|
|
def test_attrs_define(self) -> None:
|
|
"""Generates __init__ for an attrs.define class."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
import attrs
|
|
@attrs.define
|
|
class Widget:
|
|
name: str
|
|
count: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Widget", tree)
|
|
assert node is not None
|
|
stub = build_synthetic_init_stub(node, source, aliases)
|
|
assert stub is not None
|
|
assert "name: str" in stub
|
|
assert "count: int" in stub
|
|
|
|
|
|
class TestExtractInitStubFromClass:
|
|
"""Tests for extract_init_stub_from_class."""
|
|
|
|
def test_explicit_init(self) -> None:
|
|
"""Extracts existing __init__ verbatim."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
class Foo:
|
|
def __init__(self, x: int) -> None:
|
|
self.x = x
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
result = extract_init_stub_from_class("Foo", source, tree)
|
|
assert result is not None
|
|
assert "class Foo:" in result
|
|
assert "def __init__(self, x: int)" in result
|
|
|
|
def test_dataclass_synthetic(self) -> None:
|
|
"""Synthesizes __init__ for a dataclass."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Bar:
|
|
name: str
|
|
value: int = 0
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
result = extract_init_stub_from_class("Bar", source, tree)
|
|
assert result is not None
|
|
assert "class Bar:" in result
|
|
assert "name: str" in result
|
|
|
|
def test_missing_class(self) -> None:
|
|
"""Returns None when the class doesn't exist."""
|
|
source = "x = 1"
|
|
tree = ast.parse(source)
|
|
assert extract_init_stub_from_class("Missing", source, tree) is None
|
|
|
|
def test_includes_post_init(self) -> None:
|
|
"""Includes __post_init__ in the output."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Validated:
|
|
x: int
|
|
def __post_init__(self):
|
|
if self.x < 0:
|
|
raise ValueError
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
result = extract_init_stub_from_class("Validated", source, tree)
|
|
assert result is not None
|
|
assert "__post_init__" in result
|
|
|
|
|
|
class TestExtractFunctionStubSnippet:
|
|
"""Tests for extract_function_stub_snippet."""
|
|
|
|
def test_plain_function(self) -> None:
|
|
"""Extracts a function's source lines."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
def foo(x: int) -> int:
|
|
return x + 1
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[0]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
lines = source.splitlines()
|
|
snippet = extract_function_stub_snippet(fn, lines)
|
|
assert "def foo(x: int) -> int:" in snippet
|
|
assert "return x + 1" in snippet
|
|
|
|
def test_decorated_function(self) -> None:
|
|
"""Includes decorator lines."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[0]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
lines = source.splitlines()
|
|
snippet = extract_function_stub_snippet(fn, lines)
|
|
assert "@property" in snippet
|
|
|
|
|
|
class TestGetClassStartLine:
|
|
"""Tests for get_class_start_line."""
|
|
|
|
def test_no_decorators(self) -> None:
|
|
"""Start line is the class keyword line."""
|
|
tree = _parse("class Foo: pass")
|
|
node = find_class_node_by_name("Foo", tree)
|
|
assert node is not None
|
|
assert 1 == get_class_start_line(node)
|
|
|
|
def test_with_decorator(self) -> None:
|
|
"""Start line is the first decorator line."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Foo:
|
|
x: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
node = find_class_node_by_name("Foo", tree)
|
|
assert node is not None
|
|
assert 2 == get_class_start_line(node)
|
|
|
|
|
|
class TestResolveInstanceClassName:
|
|
"""Tests for resolve_instance_class_name."""
|
|
|
|
def test_call_assignment(self) -> None:
|
|
"""Resolves name = SomeClass()."""
|
|
tree = _parse("instance = MyClass()")
|
|
assert "MyClass" == resolve_instance_class_name("instance", tree)
|
|
|
|
def test_annotated_assignment(self) -> None:
|
|
"""Resolves name: SomeClass."""
|
|
tree = _parse("instance: MyClass")
|
|
assert "MyClass" == resolve_instance_class_name("instance", tree)
|
|
|
|
def test_not_found(self) -> None:
|
|
"""Returns None for non-matching name."""
|
|
tree = _parse("x = 1")
|
|
assert resolve_instance_class_name("y", tree) is None
|
|
|
|
|
|
class TestBuildImportFromMap:
|
|
"""Tests for build_import_from_map."""
|
|
|
|
def test_from_import(self) -> None:
|
|
"""Maps imported name → module."""
|
|
tree = _parse("from pathlib import Path")
|
|
assert {"Path": "pathlib"} == build_import_from_map(tree)
|
|
|
|
def test_alias(self) -> None:
|
|
"""Alias is used as the key."""
|
|
tree = _parse("from pathlib import Path as P")
|
|
assert {"P": "pathlib"} == build_import_from_map(tree)
|
|
|
|
|
|
class TestExtractImportsForClass:
|
|
"""Tests for extract_imports_for_class."""
|
|
|
|
def test_base_class_import(self) -> None:
|
|
"""Extracts import needed for a base class."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from abc import ABC
|
|
class MyClass(ABC):
|
|
pass
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
node = find_class_node_by_name("MyClass", tree)
|
|
assert node is not None
|
|
result = extract_imports_for_class(tree, node, source)
|
|
assert "from abc import ABC" in result
|
|
|
|
def test_decorator_import(self) -> None:
|
|
"""Extracts import needed for a decorator."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Point:
|
|
x: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
node = find_class_node_by_name("Point", tree)
|
|
assert node is not None
|
|
result = extract_imports_for_class(tree, node, source)
|
|
assert "from dataclasses import dataclass" in result
|
|
|
|
|
|
class TestShouldUseRawProjectClassContext:
|
|
"""Tests for should_use_raw_project_class_context."""
|
|
|
|
def test_decorated_class(self) -> None:
|
|
"""Decorated classes always get raw context."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
@dataclass
|
|
class Foo:
|
|
x: int
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Foo", tree)
|
|
assert node is not None
|
|
assert should_use_raw_project_class_context(node, aliases) is True
|
|
|
|
def test_plain_large_class(self) -> None:
|
|
"""A large plain class without special features returns False."""
|
|
# Build a class with many methods (> MAX_RAW_PROJECT_CLASS_BODY_ITEMS)
|
|
methods = "\n".join(
|
|
f" def method_{i}(self): pass" for i in range(20)
|
|
)
|
|
source = f"class Big:\n{methods}\n"
|
|
tree = ast.parse(source)
|
|
aliases = collect_import_aliases(tree)
|
|
node = find_class_node_by_name("Big", tree)
|
|
assert node is not None
|
|
assert should_use_raw_project_class_context(node, aliases) is False
|
|
|
|
|
|
class TestCollectTypeNamesFromFunction:
|
|
"""Tests for collect_type_names_from_function."""
|
|
|
|
def test_annotation_types(self) -> None:
|
|
"""Collects types from parameter annotations."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
def process(items: MyList, config: Config) -> Result:
|
|
return Result()
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[0]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
names = collect_type_names_from_function(fn, tree, None)
|
|
assert "MyList" in names
|
|
assert "Config" in names
|
|
|
|
def test_isinstance_types(self) -> None:
|
|
"""Collects types from isinstance checks."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
def check(x):
|
|
if isinstance(x, MyType):
|
|
pass
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[0]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
names = collect_type_names_from_function(fn, tree, None)
|
|
assert "MyType" in names
|
|
|
|
def test_isinstance_tuple(self) -> None:
|
|
"""Collects types from isinstance with tuple of types."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
def check(x):
|
|
if isinstance(x, (TypeA, TypeB)):
|
|
pass
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[0]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
names = collect_type_names_from_function(fn, tree, None)
|
|
assert "TypeA" in names
|
|
assert "TypeB" in names
|
|
|
|
def test_class_bases(self) -> None:
|
|
"""Collects base class types when class_name is provided."""
|
|
source = textwrap.dedent(
|
|
"""\
|
|
class Parent:
|
|
pass
|
|
class Child(Parent):
|
|
def method(self):
|
|
pass
|
|
"""
|
|
)
|
|
tree = ast.parse(source)
|
|
fn = tree.body[1].body[0] # type: ignore[union-attr]
|
|
assert isinstance(fn, ast.FunctionDef)
|
|
names = collect_type_names_from_function(fn, tree, "Child")
|
|
assert "Parent" in names
|
|
|
|
|
|
class TestEnrichTestgenContext:
|
|
"""Tests for enrich_testgen_context."""
|
|
|
|
def test_project_class_resolution(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Resolves a project class via Jedi and extracts source."""
|
|
models = tmp_path / "models.py"
|
|
models.write_text(
|
|
textwrap.dedent(
|
|
"""\
|
|
class Widget:
|
|
def __init__(self, name: str) -> None:
|
|
self.name = name
|
|
"""
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
code = textwrap.dedent(
|
|
"""\
|
|
from models import Widget
|
|
def process(w: Widget) -> str:
|
|
return w.name
|
|
"""
|
|
)
|
|
context = CodeStringsMarkdown(code_strings=[CodeString(code=code)])
|
|
result = enrich_testgen_context(context, tmp_path)
|
|
if result.code_strings:
|
|
combined = "\n".join(cs.code for cs in result.code_strings)
|
|
assert "Widget" in combined
|
|
|
|
def test_empty_context(self) -> None:
|
|
"""Returns empty result for empty input."""
|
|
context = CodeStringsMarkdown(code_strings=[])
|
|
result = enrich_testgen_context(context, Path("/nonexistent"))
|
|
assert [] == result.code_strings
|
|
|
|
def test_syntax_error_in_context(self) -> None:
|
|
"""Returns empty result for unparseable code."""
|
|
context = CodeStringsMarkdown(
|
|
code_strings=[CodeString(code="def broken(")]
|
|
)
|
|
result = enrich_testgen_context(context, Path("/nonexistent"))
|
|
assert [] == result.code_strings
|
|
|
|
|
|
class TestExtractParameterTypeConstructors:
|
|
"""Tests for extract_parameter_type_constructors."""
|
|
|
|
def test_dataclass_type_in_signature(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Extracts __init__ stub for a dataclass used in signature."""
|
|
models = tmp_path / "models.py"
|
|
models.write_text(
|
|
textwrap.dedent(
|
|
"""\
|
|
from dataclasses import dataclass
|
|
|
|
@dataclass
|
|
class Config:
|
|
name: str
|
|
debug: bool = False
|
|
"""
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
main = tmp_path / "main.py"
|
|
main.write_text(
|
|
textwrap.dedent(
|
|
"""\
|
|
from models import Config
|
|
|
|
def process(config: Config) -> str:
|
|
return config.name
|
|
"""
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
fn = FunctionToOptimize(
|
|
function_name="process",
|
|
file_path=main,
|
|
starting_line=3,
|
|
)
|
|
result = extract_parameter_type_constructors(fn, tmp_path, set())
|
|
if result.code_strings:
|
|
combined = "\n".join(cs.code for cs in result.code_strings)
|
|
assert "Config" in combined
|
|
|
|
def test_builtin_types_ignored(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Builtin types are not resolved."""
|
|
main = tmp_path / "main.py"
|
|
main.write_text(
|
|
textwrap.dedent(
|
|
"""\
|
|
def add(x: int, y: int) -> int:
|
|
return x + y
|
|
"""
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
fn = FunctionToOptimize(
|
|
function_name="add",
|
|
file_path=main,
|
|
starting_line=1,
|
|
)
|
|
result = extract_parameter_type_constructors(fn, tmp_path, set())
|
|
assert [] == result.code_strings
|
|
|
|
def test_missing_function(
|
|
self,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
"""Returns empty result when function is not found."""
|
|
main = tmp_path / "main.py"
|
|
main.write_text("x = 1\n", encoding="utf-8")
|
|
fn = FunctionToOptimize(
|
|
function_name="nonexistent",
|
|
file_path=main,
|
|
)
|
|
result = extract_parameter_type_constructors(fn, tmp_path, set())
|
|
assert [] == result.code_strings
|