fix: skip attrs classes in __init__ instrumentation; add attrs support to code_context_extractor

- instrument_codeflash_capture: detect @attrs.define / @attr.s / etc. in the
  'no explicit __init__' branch and return early, same as dataclass/NamedTuple.
  Prevents a TypeError caused by attrs(slots=True) creating a new class whose
  __class__ cell no longer matches the injected super().__init__ wrapper.

- code_context_extractor: add _get_attrs_config() helper; update
  _collect_synthetic_constructor_type_names, _build_synthetic_init_stub, and
  _extract_synthetic_init_parameters to handle attrs field conventions
  (factory= keyword, init=False, kw_only).

- tests: add 3 exact-output tests for instrumentation skip behaviour and
  3 exact-output tests for attrs stub generation.

Co-Authored-By: Oz <oz-agent@warp.dev>
This commit is contained in:
Kevin Turcios 2026-03-18 01:33:40 -06:00
parent 948bfedfa0
commit dd5e347bbb
4 changed files with 242 additions and 6 deletions

View file

@ -861,6 +861,33 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st
return False, False, False
_ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
for decorator in class_node.decorator_list:
expr_name = _get_expr_name(decorator)
if expr_name is None:
continue
parts = expr_name.split(".")
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
continue
init_enabled = True
kw_only = False
if isinstance(decorator, ast.Call):
for keyword in decorator.keywords:
literal_value = _bool_literal(keyword.value)
if literal_value is None:
continue
if keyword.arg == "init":
init_enabled = literal_value
elif keyword.arg == "kw_only":
kw_only = literal_value
return True, init_enabled, kw_only
return False, False, False
def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool:
annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation
return _expr_matches_name(annotation_root, import_aliases, "ClassVar")
@ -885,10 +912,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool:
def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]:
is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass:
is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases)
if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs:
return set()
if is_dataclass and not dataclass_init_enabled:
return set()
if is_attrs and not attrs_init_enabled:
return set()
names = set[str]()
for item in class_node.body:
@ -939,9 +969,9 @@ def _extract_synthetic_init_parameters(
kw_only = literal_value
elif keyword.arg == "default":
default_value = _get_node_source(keyword.value, module_source)
elif keyword.arg == "default_factory":
# Default factories still imply an optional constructor parameter, but
# the generated __init__ does not use the field() call directly.
elif keyword.arg in {"default_factory", "factory"}:
# Default factories (dataclass default_factory= / attrs factory=) still imply
# an optional constructor parameter.
default_value = "..."
else:
default_value = _get_node_source(item.value, module_source)
@ -960,13 +990,17 @@ def _build_synthetic_init_stub(
) -> str | None:
is_namedtuple = _is_namedtuple_class(class_node, import_aliases)
is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass:
is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases)
if not is_namedtuple and not is_dataclass and not is_attrs:
return None
if is_dataclass and not dataclass_init_enabled:
return None
if is_attrs and not attrs_init_enabled:
return None
kw_only_by_default = dataclass_kw_only or attrs_kw_only
parameters = _extract_synthetic_init_parameters(
class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only
class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default
)
if not parameters:
return None

View file

@ -188,6 +188,27 @@ class InitDecorator(ast.NodeTransformer):
if base_name is not None and base_name.endswith("NamedTuple"):
return node
# Skip attrs classes — their __init__ is auto-generated by the decorator at class creation
# time. With slots=True (the default for @attrs.define), attrs creates a brand-new class
# object, so the __class__ cell baked into the synthesised
# `super().__init__(*args, **kwargs)` still refers to the *original* class while `self`
# is already an instance of the *new* slots class, producing:
# TypeError: super(type, obj): obj (instance of X) is not an instance or subtype of X
# TODO: support by injecting a module-level wrapper after the class definition that
# captures the attrs-generated __init__ and delegates to it, e.g.:
# _orig = ClassName.__init__
# ClassName.__init__ = codeflash_capture(...)(lambda self, *a, **kw: _orig(self, *a, **kw))
for dec in node.decorator_list:
dec_name = self._expr_name(dec)
if dec_name is not None:
parts = dec_name.split(".")
if (
len(parts) >= 2
and parts[-2] in {"attrs", "attr"}
and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"}
):
return node
# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)
super_call = self._super_call_expr
# Create the complete function using prebuilt arguments/body but attach the class-specific decorator

View file

@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str:
assert "qualified_name: str" in combined
def test_extract_init_stub_attrs_define(tmp_path: Path) -> None:
"""extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes."""
source = """
import attrs
from attrs.validators import instance_of
@attrs.define(frozen=True)
class ImportCST:
module: str = attrs.field(converter=str.lower)
name: str = attrs.field(validator=[instance_of(str)])
as_name: str = attrs.field(validator=[instance_of(str)])
def to_str(self) -> str:
return f"from {self.module} import {self.name}"
"""
expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ImportCST", source, tree)
assert stub == expected
def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None:
"""Fields using attrs factory= keyword should appear as optional (= ...) in the stub."""
source = """
import attrs
@attrs.define
class ClassCST:
name: str = attrs.field()
methods: list = attrs.field(factory=list)
imports: set = attrs.field(factory=set)
def compute(self) -> int:
return len(self.methods)
"""
expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..."
tree = ast.parse(source)
stub = extract_init_stub_from_class("ClassCST", source, tree)
assert stub == expected
def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None:
"""When @attrs.define(init=False) but with explicit __init__, the explicit body is returned."""
source = """
import attrs
@attrs.define(init=False)
class NoAutoInit:
x: int = attrs.field()
def __init__(self, x: int):
self.x = x * 2
def get(self) -> int:
return self.x
"""
expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2"
tree = ast.parse(source)
stub = extract_init_stub_from_class("NoAutoInit", source, tree)
assert stub == expected
def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None:
"""Third-party classes should produce compact __init__ stubs, not full class source."""
# Use a real third-party package (pydantic) so jedi can actually resolve it

View file

@ -499,6 +499,125 @@ class MyTuple(typing.NamedTuple):
test_path.unlink(missing_ok=True)
def test_attrs_define_no_init_skipped():
"""@attrs.define classes have auto-generated __init__; synthesizing super().__init__() breaks
because attrs.define(slots=True) creates a new class whose instances fail the __class__ cell
check. Instrumentation must skip them."""
original_code = """
import attrs
from attrs.validators import instance_of
@attrs.define
class MyAttrsClass:
x: int = attrs.field(validator=[instance_of(int)])
y: str = attrs.field(default="hello")
def compute(self):
return self.x
"""
expected = """import attrs
from attrs.validators import instance_of
@attrs.define
class MyAttrsClass:
x: int = attrs.field(validator=[instance_of(int)])
y: str = attrs.field(default='hello')
def compute(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_attrs_define_frozen_no_init_skipped():
"""@attrs.define(frozen=True) should also be skipped."""
original_code = """
import attrs
@attrs.define(frozen=True)
class FrozenPoint:
x: float = attrs.field()
y: float = attrs.field()
def distance(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
"""
expected = """import attrs
@attrs.define(frozen=True)
class FrozenPoint:
x: float = attrs.field()
y: float = attrs.field()
def distance(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_attr_s_no_init_skipped():
"""@attr.s classes should also be skipped."""
original_code = """
import attr
@attr.s
class MyAttrClass:
x: int = attr.ib()
def display(self):
return self.x
"""
expected = """import attr
@attr.s
class MyAttrClass:
x: int = attr.ib()
def display(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_explicit_init_still_instrumented():
"""A dataclass that defines its own __init__ should still be instrumented normally."""
original_code = """