mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: skip attrs classes in __init__ instrumentation; add attrs support to code_context_extractor
- instrument_codeflash_capture: detect @attrs.define / @attr.s / etc. in the 'no explicit __init__' branch and return early, same as dataclass/NamedTuple. Prevents a TypeError caused by attrs(slots=True) creating a new class whose __class__ cell no longer matches the injected super().__init__ wrapper. - code_context_extractor: add _get_attrs_config() helper; update _collect_synthetic_constructor_type_names, _build_synthetic_init_stub, and _extract_synthetic_init_parameters to handle attrs field conventions (factory= keyword, init=False, kw_only). - tests: add 3 exact-output tests for instrumentation skip behaviour and 3 exact-output tests for attrs stub generation. Co-Authored-By: Oz <oz-agent@warp.dev>
This commit is contained in:
parent
948bfedfa0
commit
dd5e347bbb
4 changed files with 242 additions and 6 deletions
|
|
@ -861,6 +861,33 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st
|
|||
return False, False, False
|
||||
|
||||
|
||||
_ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
|
||||
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
|
||||
|
||||
|
||||
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
||||
for decorator in class_node.decorator_list:
|
||||
expr_name = _get_expr_name(decorator)
|
||||
if expr_name is None:
|
||||
continue
|
||||
parts = expr_name.split(".")
|
||||
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
|
||||
continue
|
||||
init_enabled = True
|
||||
kw_only = False
|
||||
if isinstance(decorator, ast.Call):
|
||||
for keyword in decorator.keywords:
|
||||
literal_value = _bool_literal(keyword.value)
|
||||
if literal_value is None:
|
||||
continue
|
||||
if keyword.arg == "init":
|
||||
init_enabled = literal_value
|
||||
elif keyword.arg == "kw_only":
|
||||
kw_only = literal_value
|
||||
return True, init_enabled, kw_only
|
||||
return False, False, False
|
||||
|
||||
|
||||
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
|
||||
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
|
||||
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
|
||||
|
|
@ -885,10 +912,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
|
|||
|
||||
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
|
||||
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
|
||||
is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases)
|
||||
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs:
|
||||
return set()
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return set()
|
||||
if is_attrs and not attrs_init_enabled:
|
||||
return set()
|
||||
|
||||
names = set[str]()
|
||||
for item in class_node.body:
|
||||
|
|
@ -939,9 +969,9 @@ def _extract_synthetic_init_parameters(
|
|||
kw_only = literal_value
|
||||
elif keyword.arg == "default":
|
||||
default_value = _get_node_source(keyword.value, module_source)
|
||||
elif keyword.arg == "default_factory":
|
||||
# Default factories still imply an optional constructor parameter, but
|
||||
# the generated __init__ does not use the field() call directly.
|
||||
elif keyword.arg in {"default_factory", "factory"}:
|
||||
# Default factories (dataclass default_factory= / attrs factory=) still imply
|
||||
# an optional constructor parameter.
|
||||
default_value = "..."
|
||||
else:
|
||||
default_value = _get_node_source(item.value, module_source)
|
||||
|
|
@ -960,13 +990,17 @@ def _build_synthetic_init_stub(
|
|||
) -> str | None:
|
||||
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
|
||||
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
|
||||
if not is_namedtuple and not is_dataclass:
|
||||
is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases)
|
||||
if not is_namedtuple and not is_dataclass and not is_attrs:
|
||||
return None
|
||||
if is_dataclass and not dataclass_init_enabled:
|
||||
return None
|
||||
if is_attrs and not attrs_init_enabled:
|
||||
return None
|
||||
|
||||
kw_only_by_default = dataclass_kw_only or attrs_kw_only
|
||||
parameters = _extract_synthetic_init_parameters(
|
||||
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
|
||||
class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default
|
||||
)
|
||||
if not parameters:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -188,6 +188,27 @@ class InitDecorator(ast.NodeTransformer):
|
|||
if base_name is not None and base_name.endswith("NamedTuple"):
|
||||
return node
|
||||
|
||||
# Skip attrs classes — their __init__ is auto-generated by the decorator at class creation
|
||||
# time. With slots=True (the default for @attrs.define), attrs creates a brand-new class
|
||||
# object, so the __class__ cell baked into the synthesised
|
||||
# `super().__init__(*args, **kwargs)` still refers to the *original* class while `self`
|
||||
# is already an instance of the *new* slots class, producing:
|
||||
# TypeError: super(type, obj): obj (instance of X) is not an instance or subtype of X
|
||||
# TODO: support by injecting a module-level wrapper after the class definition that
|
||||
# captures the attrs-generated __init__ and delegates to it, e.g.:
|
||||
# _orig = ClassName.__init__
|
||||
# ClassName.__init__ = codeflash_capture(...)(lambda self, *a, **kw: _orig(self, *a, **kw))
|
||||
for dec in node.decorator_list:
|
||||
dec_name = self._expr_name(dec)
|
||||
if dec_name is not None:
|
||||
parts = dec_name.split(".")
|
||||
if (
|
||||
len(parts) >= 2
|
||||
and parts[-2] in {"attrs", "attr"}
|
||||
and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"}
|
||||
):
|
||||
return node
|
||||
|
||||
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
|
||||
super_call = self._super_call_expr
|
||||
# Create the complete function using prebuilt arguments/body but attach the class-specific decorator
|
||||
|
|
|
|||
|
|
@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str:
|
|||
assert "qualified_name: str" in combined
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_define(tmp_path: Path) -> None:
|
||||
"""extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes."""
|
||||
source = """
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class ImportCST:
|
||||
module: str = attrs.field(converter=str.lower)
|
||||
name: str = attrs.field(validator=[instance_of(str)])
|
||||
as_name: str = attrs.field(validator=[instance_of(str)])
|
||||
|
||||
def to_str(self) -> str:
|
||||
return f"from {self.module} import {self.name}"
|
||||
"""
|
||||
expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..."
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("ImportCST", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None:
|
||||
"""Fields using attrs factory= keyword should appear as optional (= ...) in the stub."""
|
||||
source = """
|
||||
import attrs
|
||||
|
||||
@attrs.define
|
||||
class ClassCST:
|
||||
name: str = attrs.field()
|
||||
methods: list = attrs.field(factory=list)
|
||||
imports: set = attrs.field(factory=set)
|
||||
|
||||
def compute(self) -> int:
|
||||
return len(self.methods)
|
||||
"""
|
||||
expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..."
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("ClassCST", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None:
|
||||
"""When @attrs.define(init=False) but with explicit __init__, the explicit body is returned."""
|
||||
source = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(init=False)
|
||||
class NoAutoInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def __init__(self, x: int):
|
||||
self.x = x * 2
|
||||
|
||||
def get(self) -> int:
|
||||
return self.x
|
||||
"""
|
||||
expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2"
|
||||
tree = ast.parse(source)
|
||||
stub = extract_init_stub_from_class("NoAutoInit", source, tree)
|
||||
assert stub == expected
|
||||
|
||||
|
||||
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
|
||||
"""Third-party classes should produce compact __init__ stubs, not full class source."""
|
||||
# Use a real third-party package (pydantic) so jedi can actually resolve it
|
||||
|
|
|
|||
|
|
@ -499,6 +499,125 @@ class MyTuple(typing.NamedTuple):
|
|||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_no_init_skipped():
|
||||
"""@attrs.define classes have auto-generated __init__; synthesizing super().__init__() breaks
|
||||
because attrs.define(slots=True) creates a new class whose instances fail the __class__ cell
|
||||
check. Instrumentation must skip them."""
|
||||
original_code = """
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
@attrs.define
|
||||
class MyAttrsClass:
|
||||
x: int = attrs.field(validator=[instance_of(int)])
|
||||
y: str = attrs.field(default="hello")
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
expected = """import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
||||
|
||||
@attrs.define
|
||||
class MyAttrsClass:
|
||||
x: int = attrs.field(validator=[instance_of(int)])
|
||||
y: str = attrs.field(default='hello')
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_frozen_no_init_skipped():
|
||||
"""@attrs.define(frozen=True) should also be skipped."""
|
||||
original_code = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class FrozenPoint:
|
||||
x: float = attrs.field()
|
||||
y: float = attrs.field()
|
||||
|
||||
def distance(self):
|
||||
return (self.x ** 2 + self.y ** 2) ** 0.5
|
||||
"""
|
||||
expected = """import attrs
|
||||
|
||||
|
||||
@attrs.define(frozen=True)
|
||||
class FrozenPoint:
|
||||
x: float = attrs.field()
|
||||
y: float = attrs.field()
|
||||
|
||||
def distance(self):
|
||||
return (self.x ** 2 + self.y ** 2) ** 0.5
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attr_s_no_init_skipped():
|
||||
"""@attr.s classes should also be skipped."""
|
||||
original_code = """
|
||||
import attr
|
||||
|
||||
@attr.s
|
||||
class MyAttrClass:
|
||||
x: int = attr.ib()
|
||||
|
||||
def display(self):
|
||||
return self.x
|
||||
"""
|
||||
expected = """import attr
|
||||
|
||||
|
||||
@attr.s
|
||||
class MyAttrClass:
|
||||
x: int = attr.ib()
|
||||
|
||||
def display(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_dataclass_with_explicit_init_still_instrumented():
|
||||
"""A dataclass that defines its own __init__ should still be instrumented normally."""
|
||||
original_code = """
|
||||
|
|
|
|||
Loading…
Reference in a new issue