fix: address PR review — recurse _expr_name for call-style decorators, guard empty-set superset

Add tests for remove_test_functions qualified name support and
module-qualified dataclass decorator handling.
This commit is contained in:
Kevin Turcios 2026-03-07 03:03:23 -05:00
parent fc55aedee7
commit 9fd5a3d93f
2 changed files with 191 additions and 0 deletions

View file

@ -442,6 +442,35 @@ class MyTuple(NamedTuple):
test_path.unlink(missing_ok=True) test_path.unlink(missing_ok=True)
def test_module_qualified_dataclass_with_call_syntax_skipped():
"""@dataclasses.dataclass(frozen=True) — module-qualified call-style decorator — should be skipped."""
original_code = """
import dataclasses
@dataclasses.dataclass(frozen=True)
class FrozenPoint:
x: int
y: int
def magnitude(self):
return (self.x ** 2 + self.y ** 2) ** 0.5
"""
test_path = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/test_file.py").resolve()
test_path.write_text(original_code)
function = FunctionToOptimize(
function_name="magnitude", file_path=test_path, parents=[FunctionParent(type="ClassDef", name="FrozenPoint")]
)
try:
instrument_codeflash_capture(function, {}, test_path.parent)
modified_code = test_path.read_text()
assert "super().__init__" not in modified_code
assert "codeflash_capture" not in modified_code
finally:
test_path.unlink(missing_ok=True)
def test_dataclass_with_explicit_init_still_instrumented(): def test_dataclass_with_explicit_init_still_instrumented():
"""A dataclass that defines its own __init__ should still be instrumented normally.""" """A dataclass that defines its own __init__ should still be instrumented normally."""
original_code = """ original_code = """

View file

@ -0,0 +1,162 @@
from codeflash.languages.python.support import PythonSupport
def test_remove_bare_function():
src = """
def test_foo():
pass
def test_bar():
pass
def test_baz():
pass
"""
result = PythonSupport().remove_test_functions(src, ["test_bar"])
assert result == """
def test_foo():
pass
def test_baz():
pass
"""
def test_remove_qualified_method():
src = """
class TestSuite:
def test_alpha(self):
pass
def test_beta(self):
pass
def test_gamma(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestSuite.test_beta"])
assert result == """
class TestSuite:
def test_alpha(self):
pass
def test_gamma(self):
pass
"""
def test_remove_all_methods_removes_class():
src = """
class TestSuite:
def test_alpha(self):
pass
def test_beta(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["TestSuite.test_alpha", "TestSuite.test_beta"]
)
assert result == "\n"
def test_remove_all_methods_from_class_with_docstring():
src = """
class TestSuite:
\"\"\"Suite docstring.\"\"\"
def test_only(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestSuite.test_only"])
assert result == "\n"
def test_mixed_bare_and_qualified():
src = """
def test_standalone():
pass
class TestSuite:
def test_method(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["test_standalone", "TestSuite.test_method"]
)
assert result == "\n"
def test_bare_name_does_not_match_class_method():
src = """
class TestSuite:
def test_method(self):
pass
def test_method():
pass
"""
result = PythonSupport().remove_test_functions(src, ["test_method"])
assert result == """
class TestSuite:
def test_method(self):
pass
"""
def test_class_kept_when_non_test_methods_remain():
src = """
class TestSuite:
def setUp(self):
self.x = 1
def test_alpha(self):
pass
def test_beta(self):
pass
"""
result = PythonSupport().remove_test_functions(
src, ["TestSuite.test_alpha", "TestSuite.test_beta"]
)
assert result == """
class TestSuite:
def setUp(self):
self.x = 1
"""
def test_qualified_name_wrong_class_no_removal():
src = """
class TestA:
def test_method(self):
pass
class TestB:
def test_method(self):
pass
"""
result = PythonSupport().remove_test_functions(src, ["TestA.test_method"])
assert result == """
class TestB:
def test_method(self):
pass
"""
def test_no_functions_to_remove_returns_unchanged():
src = """
def test_foo():
pass
"""
result = PythonSupport().remove_test_functions(src, [])
assert result == """
def test_foo():
pass
"""
def test_invalid_syntax_returns_original():
src = "def test_foo(:\n pass"
result = PythonSupport().remove_test_functions(src, ["test_foo"])
assert result == src