mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: address review feedback for attrs init instrumentation
- Fix bug: skip attrs classes with init=False (no __init__ to patch) - Deduplicate attrs namespace/name sets into shared constants - Fix _get_attrs_config to resolve import aliases properly - Add test for init=False case with exact expected output
This commit is contained in:
parent
1f2027c731
commit
115cdba481
3 changed files with 65 additions and 8 deletions
|
|
@ -865,12 +865,26 @@ _ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
|
|||
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
|
||||
|
||||
|
||||
def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
|
||||
resolved = import_aliases.get(expr_name)
|
||||
if resolved is not None:
|
||||
return resolved
|
||||
parts = expr_name.split(".")
|
||||
if len(parts) >= 2:
|
||||
root_resolved = import_aliases.get(parts[0])
|
||||
if root_resolved is not None:
|
||||
parts[0] = root_resolved
|
||||
return ".".join(parts)
|
||||
return expr_name
|
||||
|
||||
|
||||
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
||||
for decorator in class_node.decorator_list:
|
||||
expr_name = _get_expr_name(decorator)
|
||||
if expr_name is None:
|
||||
continue
|
||||
parts = expr_name.split(".")
|
||||
resolved = _resolve_decorator_name(expr_name, import_aliases)
|
||||
parts = resolved.split(".")
|
||||
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
|
||||
continue
|
||||
init_enabled = True
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
|||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.code_utils.formatter import sort_imports
|
||||
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
|
@ -226,11 +227,15 @@ class InitDecorator(ast.NodeTransformer):
|
|||
dec_name = self._expr_name(dec)
|
||||
if dec_name is not None:
|
||||
parts = dec_name.split(".")
|
||||
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
|
||||
if isinstance(dec, ast.Call):
|
||||
for kw in dec.keywords:
|
||||
if (
|
||||
len(parts) >= 2
|
||||
and parts[-2] in {"attrs", "attr"}
|
||||
and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"}
|
||||
kw.arg == "init"
|
||||
and isinstance(kw.value, ast.Constant)
|
||||
and kw.value.value is False
|
||||
):
|
||||
return node
|
||||
self._attrs_classes_to_patch[node.name] = decorator
|
||||
self.inserted_decorator = True
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ from pathlib import Path
|
|||
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
|
||||
from codeflash.models.models import FunctionParent
|
||||
|
||||
|
||||
def test_add_codeflash_capture():
|
||||
|
|
@ -502,7 +502,8 @@ class MyTuple(typing.NamedTuple):
|
|||
def test_attrs_define_patched_via_module_wrapper():
|
||||
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
|
||||
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
|
||||
that arises when attrs.define(slots=True) replaces the original class object."""
|
||||
that arises when attrs.define(slots=True) replaces the original class object.
|
||||
"""
|
||||
original_code = """
|
||||
import attrs
|
||||
from attrs.validators import instance_of
|
||||
|
|
@ -639,6 +640,43 @@ MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', t
|
|||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_attrs_define_init_false_skipped():
|
||||
"""@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__."""
|
||||
original_code = """
|
||||
import attrs
|
||||
|
||||
@attrs.define(init=False)
|
||||
class ManualInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
expected = """import attrs
|
||||
|
||||
|
||||
@attrs.define(init=False)
|
||||
class ManualInit:
|
||||
x: int = attrs.field()
|
||||
|
||||
def compute(self):
|
||||
return self.x
|
||||
"""
|
||||
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||
test_path.write_text(original_code)
|
||||
|
||||
function = FunctionToOptimize(
|
||||
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")]
|
||||
)
|
||||
|
||||
try:
|
||||
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||
modified_code = test_path.read_text()
|
||||
assert modified_code.strip() == expected.strip()
|
||||
finally:
|
||||
test_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def test_dataclass_with_explicit_init_still_instrumented():
|
||||
"""A dataclass that defines its own __init__ should still be instrumented normally."""
|
||||
original_code = """
|
||||
|
|
|
|||
Loading…
Reference in a new issue