diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 367595218..5f2b14ca1 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -865,12 +865,26 @@ _ATTRS_NAMESPACES = frozenset({"attrs", "attr"}) _ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"}) +def _resolve_decorator_name(expr_name: str, import_aliases: dict[str, str]) -> str: + resolved = import_aliases.get(expr_name) + if resolved is not None: + return resolved + parts = expr_name.split(".") + if len(parts) >= 2: + root_resolved = import_aliases.get(parts[0]) + if root_resolved is not None: + parts[0] = root_resolved + return ".".join(parts) + return expr_name + + def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]: for decorator in class_node.decorator_list: expr_name = _get_expr_name(decorator) if expr_name is None: continue - parts = expr_name.split(".") + resolved = _resolve_decorator_name(expr_name, import_aliases) + parts = resolved.split(".") if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES: continue init_enabled = True diff --git a/codeflash/languages/python/instrument_codeflash_capture.py b/codeflash/languages/python/instrument_codeflash_capture.py index 5b40cd4b9..10293a49c 100644 --- a/codeflash/languages/python/instrument_codeflash_capture.py +++ b/codeflash/languages/python/instrument_codeflash_capture.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.formatter import sort_imports +from codeflash.languages.python.context.code_context_extractor import _ATTRS_DECORATOR_NAMES, _ATTRS_NAMESPACES if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -226,11 +227,15 @@ class InitDecorator(ast.NodeTransformer): dec_name = self._expr_name(dec) if dec_name is not None: parts = dec_name.split(".") - if ( - len(parts) >= 2 - and parts[-2] in {"attrs", "attr"} - and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"} - ): + if len(parts) >= 2 and parts[-2] in _ATTRS_NAMESPACES and parts[-1] in _ATTRS_DECORATOR_NAMES: + if isinstance(dec, ast.Call): + for kw in dec.keywords: + if ( + kw.arg == "init" + and isinstance(kw.value, ast.Constant) + and kw.value.value is False + ): + return node self._attrs_classes_to_patch[node.name] = decorator self.inserted_decorator = True return node diff --git a/tests/test_instrument_codeflash_capture.py b/tests/test_instrument_codeflash_capture.py index 62625cfe2..4ad0fade1 100644 --- a/tests/test_instrument_codeflash_capture.py +++ b/tests/test_instrument_codeflash_capture.py @@ -2,8 +2,8 @@ from pathlib import Path from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent from codeflash.languages.python.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.models.models import FunctionParent def test_add_codeflash_capture(): @@ -502,7 +502,8 @@ class MyTuple(typing.NamedTuple): def test_attrs_define_patched_via_module_wrapper(): """@attrs.define classes must NOT get a synthetic body __init__; instead a module-level monkey-patch block is emitted after the class to avoid the __class__ cell TypeError - that arises when attrs.define(slots=True) replaces the original class object.""" + that arises when attrs.define(slots=True) replaces the original class object. + """ original_code = """ import attrs from attrs.validators import instance_of @@ -639,6 +640,43 @@ MyAttrClass.__init__ = codeflash_capture(function_name='MyAttrClass.__init__', t test_path.unlink(missing_ok=True) +def test_attrs_define_init_false_skipped(): + """@attrs.define(init=False) should NOT be monkey-patched because attrs won't generate an __init__.""" + original_code = """ +import attrs + +@attrs.define(init=False) +class ManualInit: + x: int = attrs.field() + + def compute(self): + return self.x +""" + expected = """import attrs + + +@attrs.define(init=False) +class ManualInit: + x: int = attrs.field() + + def compute(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="ManualInit")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + def test_dataclass_with_explicit_init_still_instrumented(): """A dataclass that defines its own __init__ should still be instrumented normally.""" original_code = """