mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: address review feedback for attrs init instrumentation
- Fix bug: skip attrs classes with init=False (no __init__ to patch) - Deduplicate attrs namespace/name sets into shared constants - Fix _get_attrs_config to resolve import aliases properly - Add test for init=False case with exact expected output
This commit is contained in:
parent
1f2027c731
commit
115cdba481
3 changed files with 65 additions and 8 deletions
|
|
@ -865,12 +865,26 @@ _ATTRS_NAMESPACES = frozenset({"attrs", "attr"})
|
||||||
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
|
_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"})
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str:
|
||||||
|
resolved = import_aliases.get(expr_name)
|
||||||
|
if resolved is not None:
|
||||||
|
return resolved
|
||||||
|
parts = expr_name.split(".")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
root_resolved = import_aliases.get(parts[0])
|
||||||
|
if root_resolved is not None:
|
||||||
|
parts[0] = root_resolved
|
||||||
|
return ".".join(parts)
|
||||||
|
return expr_name
|
||||||
|
|
||||||
|
|
||||||
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]:
|
||||||
for decorator in class_node.decorator_list:
|
for decorator in class_node.decorator_list:
|
||||||
expr_name = _get_expr_name(decorator)
|
expr_name = _get_expr_name(decorator)
|
||||||
if expr_name is None:
|
if expr_name is None:
|
||||||
continue
|
continue
|
||||||
parts = expr_name.split(".")
|
resolved = _resolve_decorator_name(expr_name, import_aliases)
|
||||||
|
parts = resolved.split(".")
|
||||||
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
|
if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES:
|
||||||
continue
|
continue
|
||||||
init_enabled = True
|
init_enabled = True
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||||
from codeflash.code_utils.formatter import sort_imports
|
from codeflash.code_utils.formatter import sort_imports
|
||||||
|
from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
|
|
@ -226,11 +227,15 @@ class InitDecorator(ast.NodeTransformer):
|
||||||
dec_name = self._expr_name(dec)
|
dec_name = self._expr_name(dec)
|
||||||
if dec_name is not None:
|
if dec_name is not None:
|
||||||
parts = dec_name.split(".")
|
parts = dec_name.split(".")
|
||||||
if (
|
if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES:
|
||||||
len(parts) >= 2
|
if isinstance(dec, ast.Call):
|
||||||
and parts[-2] in {"attrs", "attr"}
|
for kw in dec.keywords:
|
||||||
and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"}
|
if (
|
||||||
):
|
kw.arg == "init"
|
||||||
|
and isinstance(kw.value, ast.Constant)
|
||||||
|
and kw.value.value is False
|
||||||
|
):
|
||||||
|
return node
|
||||||
self._attrs_classes_to_patch[node.name] = decorator
|
self._attrs_classes_to_patch[node.name] = decorator
|
||||||
self.inserted_decorator = True
|
self.inserted_decorator = True
|
||||||
return node
|
return node
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ from pathlib import Path
|
||||||
|
|
||||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.models.models import FunctionParent
|
|
||||||
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
|
from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture
|
||||||
|
from codeflash.models.models import FunctionParent
|
||||||
|
|
||||||
|
|
||||||
def test_add_codeflash_capture():
|
def test_add_codeflash_capture():
|
||||||
|
|
@ -502,7 +502,8 @@ class MyTuple(typing.NamedTuple):
|
||||||
def test_attrs_define_patched_via_module_wrapper():
|
def test_attrs_define_patched_via_module_wrapper():
|
||||||
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
|
"""@attrs.define classes must NOT get a synthetic body __init__; instead a module-level
|
||||||
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
|
monkey-patch block is emitted after the class to avoid the __class__ cell TypeError
|
||||||
that arises when attrs.define(slots=True) replaces the original class object."""
|
that arises when attrs.define(slots=True) replaces the original class object.
|
||||||
|
"""
|
||||||
original_code = """
|
original_code = """
|
||||||
import attrs
|
import attrs
|
||||||
from attrs.validators import instance_of
|
from attrs.validators import instance_of
|
||||||
|
|
@ -639,6 +640,43 @@ MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', t
|
||||||
test_path.unlink(missing_ok=True)
|
test_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_attrs_define_init_false_skipped():
|
||||||
|
"""@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__."""
|
||||||
|
original_code = """
|
||||||
|
import attrs
|
||||||
|
|
||||||
|
@attrs.define(init=False)
|
||||||
|
class ManualInit:
|
||||||
|
x: int = attrs.field()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
"""
|
||||||
|
expected = """import attrs
|
||||||
|
|
||||||
|
|
||||||
|
@attrs.define(init=False)
|
||||||
|
class ManualInit:
|
||||||
|
x: int = attrs.field()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
return self.x
|
||||||
|
"""
|
||||||
|
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
|
||||||
|
test_path.write_text(original_code)
|
||||||
|
|
||||||
|
function = FunctionToOptimize(
|
||||||
|
function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
instrument_codeflash_capture(function, {}, test_path.parent)
|
||||||
|
modified_code = test_path.read_text()
|
||||||
|
assert modified_code.strip() == expected.strip()
|
||||||
|
finally:
|
||||||
|
test_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
def test_dataclass_with_explicit_init_still_instrumented():
|
def test_dataclass_with_explicit_init_still_instrumented():
|
||||||
"""A dataclass that defines its own __init__ should still be instrumented normally."""
|
"""A dataclass that defines its own __init__ should still be instrumented normally."""
|
||||||
original_code = """
|
original_code = """
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue