diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index bfbf02fc4..367595218 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -861,6 +861,33 @@ def _get_dataclass_config(class_node: ast.ClassDef, import_aliases: dict[str, st return False, False, False +_ATTRS_NAMESPACES = frozenset({"attrs", "attr"}) +_ATTRS_DECORATOR_NAMES = frozenset({"define", "mutable", "frozen", "s", "attrs"}) + + +def _get_attrs_config(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> tuple[bool, bool, bool]: + for decorator in class_node.decorator_list: + expr_name = _get_expr_name(decorator) + if expr_name is None: + continue + parts = expr_name.split(".") + if len(parts) < 2 or parts[-2] not in _ATTRS_NAMESPACES or parts[-1] not in _ATTRS_DECORATOR_NAMES: + continue + init_enabled = True + kw_only = False + if isinstance(decorator, ast.Call): + for keyword in decorator.keywords: + literal_value = _bool_literal(keyword.value) + if literal_value is None: + continue + if keyword.arg == "init": + init_enabled = literal_value + elif keyword.arg == "kw_only": + kw_only = literal_value + return True, init_enabled, kw_only + return False, False, False + + def _is_classvar_annotation(annotation: ast.expr, import_aliases: dict[str, str]) -> bool: annotation_root = annotation.value if isinstance(annotation, ast.Subscript) else annotation return _expr_matches_name(annotation_root, import_aliases, "ClassVar") @@ -885,10 +912,13 @@ def _class_has_explicit_init(class_node: ast.ClassDef) -> bool: def _collect_synthetic_constructor_type_names(class_node: ast.ClassDef, import_aliases: dict[str, str]) -> set[str]: is_dataclass, dataclass_init_enabled, _ = _get_dataclass_config(class_node, import_aliases) - if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass: + is_attrs, attrs_init_enabled, _ = _get_attrs_config(class_node, import_aliases) + if not _is_namedtuple_class(class_node, import_aliases) and not is_dataclass and not is_attrs: return set() if is_dataclass and not dataclass_init_enabled: return set() + if is_attrs and not attrs_init_enabled: + return set() names = set[str]() for item in class_node.body: @@ -939,9 +969,9 @@ def _extract_synthetic_init_parameters( kw_only = literal_value elif keyword.arg == "default": default_value = _get_node_source(keyword.value, module_source) - elif keyword.arg == "default_factory": - # Default factories still imply an optional constructor parameter, but - # the generated __init__ does not use the field() call directly. + elif keyword.arg in {"default_factory", "factory"}: + # Default factories (dataclass default_factory= / attrs factory=) still imply + # an optional constructor parameter. default_value = "..." else: default_value = _get_node_source(item.value, module_source) @@ -960,13 +990,17 @@ def _build_synthetic_init_stub( ) -> str | None: is_namedtuple = _is_namedtuple_class(class_node, import_aliases) is_dataclass, dataclass_init_enabled, dataclass_kw_only = _get_dataclass_config(class_node, import_aliases) - if not is_namedtuple and not is_dataclass: + is_attrs, attrs_init_enabled, attrs_kw_only = _get_attrs_config(class_node, import_aliases) + if not is_namedtuple and not is_dataclass and not is_attrs: return None if is_dataclass and not dataclass_init_enabled: return None + if is_attrs and not attrs_init_enabled: + return None + kw_only_by_default = dataclass_kw_only or attrs_kw_only parameters = _extract_synthetic_init_parameters( - class_node, module_source, import_aliases, kw_only_by_default=dataclass_kw_only + class_node, module_source, import_aliases, kw_only_by_default=kw_only_by_default ) if not parameters: return None diff --git a/codeflash/languages/python/instrument_codeflash_capture.py b/codeflash/languages/python/instrument_codeflash_capture.py index 3b77a1f53..bfa2f18d6 100644 --- a/codeflash/languages/python/instrument_codeflash_capture.py +++ b/codeflash/languages/python/instrument_codeflash_capture.py @@ -188,6 +188,27 @@ class InitDecorator(ast.NodeTransformer): if base_name is not None and base_name.endswith("NamedTuple"): return node + # Skip attrs classes — their __init__ is auto-generated by the decorator at class creation + # time. With slots=True (the default for @attrs.define), attrs creates a brand-new class + # object, so the __class__ cell baked into the synthesised + # `super().__init__(*args, **kwargs)` still refers to the *original* class while `self` + # is already an instance of the *new* slots class, producing: + # TypeError: super(type, obj): obj (instance of X) is not an instance or subtype of X + # TODO: support by injecting a module-level wrapper after the class definition that + # captures the attrs-generated __init__ and delegates to it, e.g.: + # _orig = ClassName.__init__ + # ClassName.__init__ = codeflash_capture(...)(lambda self, *a, **kw: _orig(self, *a, **kw)) + for dec in node.decorator_list: + dec_name = self._expr_name(dec) + if dec_name is not None: + parts = dec_name.split(".") + if ( + len(parts) >= 2 + and parts[-2] in {"attrs", "attr"} + and parts[-1] in {"define", "mutable", "frozen", "s", "attrs"} + ): + return node + # Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments) super_call = self._super_call_expr # Create the complete function using prebuilt arguments/body but attach the class-specific decorator diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index a2b31eb94..ccfa5410d 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -5013,6 +5013,68 @@ def process(cfg: ChildConfig) -> str: assert "qualified_name: str" in combined +def test_extract_init_stub_attrs_define(tmp_path: Path) -> None: + """extract_init_stub_from_class produces a synthetic __init__ stub for @attrs.define classes.""" + source = """ +import attrs +from attrs.validators import instance_of + +@attrs.define(frozen=True) +class ImportCST: + module: str = attrs.field(converter=str.lower) + name: str = attrs.field(validator=[instance_of(str)]) + as_name: str = attrs.field(validator=[instance_of(str)]) + + def to_str(self) -> str: + return f"from {self.module} import {self.name}" +""" + expected = "class ImportCST:\n def __init__(self, module: str, name: str, as_name: str):\n ..." + tree = ast.parse(source) + stub = extract_init_stub_from_class("ImportCST", source, tree) + assert stub == expected + + +def test_extract_init_stub_attrs_factory_fields(tmp_path: Path) -> None: + """Fields using attrs factory= keyword should appear as optional (= ...) in the stub.""" + source = """ +import attrs + +@attrs.define +class ClassCST: + name: str = attrs.field() + methods: list = attrs.field(factory=list) + imports: set = attrs.field(factory=set) + + def compute(self) -> int: + return len(self.methods) +""" + expected = "class ClassCST:\n def __init__(self, name: str, methods: list = ..., imports: set = ...):\n ..." + tree = ast.parse(source) + stub = extract_init_stub_from_class("ClassCST", source, tree) + assert stub == expected + + +def test_extract_init_stub_attrs_init_disabled(tmp_path: Path) -> None: + """When @attrs.define(init=False) but with explicit __init__, the explicit body is returned.""" + source = """ +import attrs + +@attrs.define(init=False) +class NoAutoInit: + x: int = attrs.field() + + def __init__(self, x: int): + self.x = x * 2 + + def get(self) -> int: + return self.x +""" + expected = "class NoAutoInit:\n def __init__(self, x: int):\n self.x = x * 2" + tree = ast.parse(source) + stub = extract_init_stub_from_class("NoAutoInit", source, tree) + assert stub == expected + + def test_enrich_testgen_context_third_party_uses_stubs(tmp_path: Path) -> None: """Third-party classes should produce compact __init__ stubs, not full class source.""" # Use a real third-party package (pydantic) so jedi can actually resolve it diff --git a/tests/test_instrument_codeflash_capture.py b/tests/test_instrument_codeflash_capture.py index 543d50855..8a7694821 100644 --- a/tests/test_instrument_codeflash_capture.py +++ b/tests/test_instrument_codeflash_capture.py @@ -499,6 +499,125 @@ class MyTuple(typing.NamedTuple): test_path.unlink(missing_ok=True) +def test_attrs_define_no_init_skipped(): + """@attrs.define classes have auto-generated __init__; synthesizing super().__init__() breaks + because attrs.define(slots=True) creates a new class whose instances fail the __class__ cell + check. Instrumentation must skip them.""" + original_code = """ +import attrs +from attrs.validators import instance_of + +@attrs.define +class MyAttrsClass: + x: int = attrs.field(validator=[instance_of(int)]) + y: str = attrs.field(default="hello") + + def compute(self): + return self.x +""" + expected = """import attrs +from attrs.validators import instance_of + + +@attrs.define +class MyAttrsClass: + x: int = attrs.field(validator=[instance_of(int)]) + y: str = attrs.field(default='hello') + + def compute(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="compute", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrsClass")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + +def test_attrs_define_frozen_no_init_skipped(): + """@attrs.define(frozen=True) should also be skipped.""" + original_code = """ +import attrs + +@attrs.define(frozen=True) +class FrozenPoint: + x: float = attrs.field() + y: float = attrs.field() + + def distance(self): + return (self.x ** 2 + self.y ** 2) ** 0.5 +""" + expected = """import attrs + + +@attrs.define(frozen=True) +class FrozenPoint: + x: float = attrs.field() + y: float = attrs.field() + + def distance(self): + return (self.x ** 2 + self.y ** 2) ** 0.5 +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="distance", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + +def test_attr_s_no_init_skipped(): + """@attr.s classes should also be skipped.""" + original_code = """ +import attr + +@attr.s +class MyAttrClass: + x: int = attr.ib() + + def display(self): + return self.x +""" + expected = """import attr + + +@attr.s +class MyAttrClass: + x: int = attr.ib() + + def display(self): + return self.x +""" + test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve() + test_path.write_text(original_code) + + function = FunctionToOptimize( + function_name="display", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="MyAttrClass")] + ) + + try: + instrument_codeflash_capture(function, {}, test_path.parent) + modified_code = test_path.read_text() + assert modified_code.strip() == expected.strip() + finally: + test_path.unlink(missing_ok=True) + + def test_dataclass_with_explicit_init_still_instrumented(): """A dataclass that defines its own __init__ should still be instrumented normally.""" original_code = """