fix: address review feedback for attrs init instrumentation

- Fix bug: skip attrs classes with init=False (no __init__ to patch)
- Deduplicate attrs namespace/name sets into shared constants
- Fix _get_attrs_config to resolve import aliases properly
- Add test for init=False case with exact expected output
This commit is contained in:
Kevin Turcios 2026-03-18 03:34:44 -06:00
parent 1f2027c731
commit 115cdba481
3 changed files with 65 additions and 8 deletions

View file

@ -865,12 +865,26 @@ _ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
resolved = import_aliases.get(expr_name)
if resolved is not None:
return resolved
parts = expr_name.split(".")
if len(parts) >= 2:
root_resolved = import_aliases.get(parts[0])
if root_resolved is not None:
parts[0] = root_resolved
return ".".join(parts)
return expr_name
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
for decorator in class_node.decorator_list:
expr_name = _get_expr_name(decorator)
if expr_name is None:
continue
parts = expr_name.split(".")
resolved = _resolve_decorator_name(expr_name, import_aliases)
parts = resolved.split(".")
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
continue
init_enabled = True

View file

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.code_utils.formatter import sort_imports
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@ -226,11 +227,15 @@ class InitDecorator(ast.NodeTransformer):
dec_name = self._expr_name(dec)
if dec_name is not None:
parts = dec_name.split(".")
if (
len(parts) >= 2
and parts[-2] in {"attrs", "attr"}
and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"}
):
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
if isinstance(dec, ast.Call):
for kw in dec.keywords:
if (
kw.arg == "init"
and isinstance(kw.value, ast.Constant)
and kw.value.value is False
):
return node
self._attrs_classes_to_patch[node.name] = decorator
self.inserted_decorator = True
return node

View file

@ -2,8 +2,8 @@ from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
from codeflash.models.models import FunctionParent
def test_add_codeflash_capture():
@ -502,7 +502,8 @@ class MyTuple(typing.NamedTuple):
def test_attrs_define_patched_via_module_wrapper():
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
that arises when attrs.define(slots=True) replaces the original class object."""
that arises when attrs.define(slots=True) replaces the original class object.
"""
original_code = """
import attrs
from attrs.validators import instance_of
@ -639,6 +640,43 @@ MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', t
test_path.unlink(missing_ok=True)
def test_attrs_define_init_false_skipped():
"""@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__."""
original_code = """
import attrs
@attrs.define(init=False)
class ManualInit:
x: int = attrs.field()
def compute(self):
return self.x
"""
expected = """import attrs
@attrs.define(init=False)
class ManualInit:
x: int = attrs.field()
def compute(self):
return self.x
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert modified_code.strip() == expected.strip()
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_explicit_init_still_instrumented():
"""A dataclass that defines its own __init__ should still be instrumented normally."""
original_code = """