Ready to review

This commit is contained in:
aseembits93 2025-06-06 13:11:07 -07:00
parent a65569cfa0
commit c93b80e87b
4 changed files with 166 additions and 48 deletions

View file

@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_telemetry", "disable_telemetry",
"disable_imports_sorting", "disable_imports_sorting",
"git_remote", "git_remote",
"override_fixtures",
] ]
for key in supported_keys: for key in supported_keys:
if key in pyproject_config and ( if key in pyproject_config and (
@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
or not hasattr(args, key.replace("-", "_")) or not hasattr(args, key.replace("-", "_"))
): ):
setattr(args, key.replace("-", "_"), pyproject_config[key]) setattr(args, key.replace("-", "_"), pyproject_config[key])
args.override_fixtures = pyproject_config.get("override_fixtures", False)
assert args.module_root is not None, "--module-root must be specified" assert args.module_root is not None, "--module-root must be specified"
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified" assert args.tests_root is not None, "--tests-root must be specified"

View file

@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
formatter_cmds.append("disabled") formatter_cmds.append("disabled")
check_formatter_installed(formatter_cmds, exit_on_failure=False) check_formatter_installed(formatter_cmds, exit_on_failure=False)
codeflash_section["formatter-cmds"] = formatter_cmds codeflash_section["formatter-cmds"] = formatter_cmds
codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide
# Add the 'codeflash' section, ensuring 'tool' section exists # Add the 'codeflash' section, ensuring 'tool' section exists
tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section = pyproject_data.get("tool", tomlkit.table())
tool_section["codeflash"] = codeflash_section tool_section["codeflash"] = codeflash_section
pyproject_data["tool"] = tool_section pyproject_data["tool"] = tool_section
with toml_path.open("w", encoding="utf8") as pyproject_file: with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data)) pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"✅ Added Codeflash configuration to {toml_path}") click.echo(f"✅ Added Codeflash configuration to {toml_path}")

View file

@ -54,8 +54,6 @@ class PytestMarkAdder(cst.CSTTransformer):
for import_alias in stmt.names: for import_alias in stmt.names:
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
self.has_pytest_import = True self.has_pytest_import = True
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
self.has_pytest_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
"""Add pytest import if not present.""" """Add pytest import if not present."""

View file

@ -2180,17 +2180,25 @@ def my_fixture(request):
setup_code() setup_code()
yield "value" yield "value"
cleanup_code() cleanup_code()
'''
expected_code = '''
from pytest import fixture
@fixture(autouse=True)
def my_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
setup_code()
yield "value"
cleanup_code()
''' '''
module = cst.parse_module(source_code) module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier() modifier = AutouseFixtureModifier()
modified_module = module.visit(modifier) modified_module = module.visit(modifier)
# Check that the if statement was added # Check that the if statement was added
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code assert modified_module.code.strip() == expected_code.strip()
assert "yield" in modified_module.code
assert "else:" in modified_module.code
assert "setup_code()" in modified_module.code
assert "cleanup_code()" in modified_module.code
def test_ignores_non_autouse_fixture(self): def test_ignores_non_autouse_fixture(self):
"""Test that non-autouse fixtures are not modified.""" """Test that non-autouse fixtures are not modified."""
@ -2241,6 +2249,23 @@ def fixture_one(request):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def fixture_two(request): def fixture_two(request):
yield "two" yield "two"
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
def fixture_one(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "one"
@pytest.fixture(autouse=True)
def fixture_two(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "two"
''' '''
module = cst.parse_module(source_code) module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier() modifier = AutouseFixtureModifier()
@ -2248,8 +2273,7 @@ def fixture_two(request):
# Both fixtures should be modified # Both fixtures should be modified
code = modified_module.code code = modified_module.code
assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2 assert code==expected_code
assert code.count("else:") == 2
def test_preserves_fixture_with_complex_body(self): def test_preserves_fixture_with_complex_body(self):
"""Test that fixtures with complex bodies are handled correctly.""" """Test that fixtures with complex bodies are handled correctly."""
@ -2258,24 +2282,39 @@ import pytest
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def complex_fixture(request): def complex_fixture(request):
try: if request.node.get_closest_marker("codeflash_no_autouse"):
setup_database() yield
configure_logging() else:
yield get_test_client() try:
finally: setup_database()
cleanup_database() configure_logging()
reset_logging() yield get_test_client()
finally:
cleanup_database()
reset_logging()
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
def complex_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
try:
setup_database()
configure_logging()
yield get_test_client()
finally:
cleanup_database()
reset_logging()
''' '''
module = cst.parse_module(source_code) module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier() modifier = AutouseFixtureModifier()
modified_module = module.visit(modifier) modified_module = module.visit(modifier)
code = modified_module.code code = modified_module.code
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code assert code==expected_code
assert "try:" in code
assert "setup_database()" in code
assert "finally:" in code
assert "cleanup_database()" in code
class TestPytestMarkAdder: class TestPytestMarkAdder:
@ -2284,6 +2323,12 @@ class TestPytestMarkAdder:
def test_adds_pytest_import_when_missing(self): def test_adds_pytest_import_when_missing(self):
"""Test that pytest import is added when not present.""" """Test that pytest import is added when not present."""
source_code = ''' source_code = '''
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_something(): def test_something():
assert True assert True
''' '''
@ -2292,14 +2337,20 @@ def test_something():
modified_module = module.visit(mark_adder) modified_module = module.visit(mark_adder)
code = modified_module.code code = modified_module.code
assert "import pytest" in code assert code==expected_code
assert "@pytest.mark.codeflash_no_autouse" in code
def test_skips_pytest_import_when_present(self): def test_skips_pytest_import_when_present(self):
"""Test that pytest import is not duplicated when already present.""" """Test that pytest import is not duplicated when already present."""
source_code = ''' source_code = '''
import pytest import pytest
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_something(): def test_something():
assert True assert True
''' '''
@ -2309,8 +2360,7 @@ def test_something():
code = modified_module.code code = modified_module.code
# Should only have one import pytest line # Should only have one import pytest line
assert code.count("import pytest") == 1 assert code==expected_code
assert "@pytest.mark.codeflash_no_autouse" in code
def test_handles_from_pytest_import(self): def test_handles_from_pytest_import(self):
"""Test that existing 'from pytest import ...' is recognized.""" """Test that existing 'from pytest import ...' is recognized."""
@ -2320,15 +2370,21 @@ from pytest import fixture
def test_something(): def test_something():
assert True assert True
''' '''
expected_code = '''
import pytest
from pytest import fixture
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
module = cst.parse_module(source_code) module = cst.parse_module(source_code)
mark_adder = PytestMarkAdder("codeflash_no_autouse") mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = module.visit(mark_adder) modified_module = module.visit(mark_adder)
code = modified_module.code code = modified_module.code
# Should not add import pytest since pytest is already imported # Should not add import pytest since pytest is already imported
assert "import pytest" not in code assert code.strip()==expected_code.strip()
assert "from pytest import fixture" in code
assert "@pytest.mark.codeflash_no_autouse" in code
def test_adds_mark_to_all_functions(self): def test_adds_mark_to_all_functions(self):
"""Test that marks are added to all functions in the module.""" """Test that marks are added to all functions in the module."""
@ -2341,6 +2397,21 @@ def test_first():
def test_second(): def test_second():
assert False assert False
def helper_function():
return "not a test"
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_first():
assert True
@pytest.mark.codeflash_no_autouse
def test_second():
assert False
@pytest.mark.codeflash_no_autouse
def helper_function(): def helper_function():
return "not a test" return "not a test"
''' '''
@ -2350,7 +2421,7 @@ def helper_function():
code = modified_module.code code = modified_module.code
# All functions should get the mark # All functions should get the mark
assert code.count("@pytest.mark.codeflash_no_autouse") == 3 assert code==expected_code
def test_skips_existing_mark(self): def test_skips_existing_mark(self):
"""Test that existing marks are not duplicated.""" """Test that existing marks are not duplicated."""
@ -2361,6 +2432,17 @@ import pytest
def test_already_marked(): def test_already_marked():
assert True assert True
def test_needs_mark():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_already_marked():
assert True
@pytest.mark.codeflash_no_autouse
def test_needs_mark(): def test_needs_mark():
assert True assert True
''' '''
@ -2370,13 +2452,20 @@ def test_needs_mark():
code = modified_module.code code = modified_module.code
# Should have exactly 2 marks total (one existing, one added) # Should have exactly 2 marks total (one existing, one added)
assert code.count("@pytest.mark.codeflash_no_autouse") == 2 assert code==expected_code
def test_handles_different_mark_names(self): def test_handles_different_mark_names(self):
"""Test that different mark names work correctly.""" """Test that different mark names work correctly."""
source_code = ''' source_code = '''
import pytest import pytest
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.slow
def test_something(): def test_something():
assert True assert True
''' '''
@ -2385,8 +2474,7 @@ def test_something():
modified_module = module.visit(mark_adder) modified_module = module.visit(mark_adder)
code = modified_module.code code = modified_module.code
assert "@pytest.mark.slow" in code assert code==expected_code
assert "codeflash_no_autouse" not in code
def test_preserves_existing_decorators(self): def test_preserves_existing_decorators(self):
"""Test that existing decorators are preserved.""" """Test that existing decorators are preserved."""
@ -2395,6 +2483,15 @@ import pytest
@pytest.mark.parametrize("value", [1, 2, 3]) @pytest.mark.parametrize("value", [1, 2, 3])
@pytest.fixture @pytest.fixture
def test_with_decorators():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.parametrize("value", [1, 2, 3])
@pytest.fixture
@pytest.mark.codeflash_no_autouse
def test_with_decorators(): def test_with_decorators():
assert True assert True
''' '''
@ -2403,9 +2500,7 @@ def test_with_decorators():
modified_module = module.visit(mark_adder) modified_module = module.visit(mark_adder)
code = modified_module.code code = modified_module.code
assert "@pytest.mark.parametrize" in code assert code==expected_code
assert "@pytest.fixture" in code
assert "@pytest.mark.codeflash_no_autouse" in code
def test_handles_call_style_existing_marks(self): def test_handles_call_style_existing_marks(self):
"""Test recognition of existing marks in call style (with parentheses).""" """Test recognition of existing marks in call style (with parentheses)."""
@ -2416,6 +2511,17 @@ import pytest
def test_with_call_mark(): def test_with_call_mark():
assert True assert True
def test_needs_mark():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse()
def test_with_call_mark():
assert True
@pytest.mark.codeflash_no_autouse
def test_needs_mark(): def test_needs_mark():
assert True assert True
''' '''
@ -2425,8 +2531,7 @@ def test_needs_mark():
code = modified_module.code code = modified_module.code
# Should recognize the existing call-style mark and not duplicate # Should recognize the existing call-style mark and not duplicate
lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line] assert code==expected_code
assert len(lines_with_mark) == 2 # One existing, one added
def test_empty_module(self): def test_empty_module(self):
"""Test handling of empty module.""" """Test handling of empty module."""
@ -2437,7 +2542,7 @@ def test_needs_mark():
# Should just add the import # Should just add the import
code = modified_module.code code = modified_module.code
assert "import pytest" in code assert code =='import pytest'
def test_module_with_only_imports(self): def test_module_with_only_imports(self):
"""Test handling of module with only imports.""" """Test handling of module with only imports."""
@ -2445,16 +2550,19 @@ def test_needs_mark():
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
'''
expected_code = '''
import pytest
import os
import sys
from pathlib import Path
''' '''
module = cst.parse_module(source_code) module = cst.parse_module(source_code)
mark_adder = PytestMarkAdder("codeflash_no_autouse") mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = module.visit(mark_adder) modified_module = module.visit(mark_adder)
code = modified_module.code code = modified_module.code
assert "import pytest" in code assert code==expected_code
assert "import os" in code
assert "import sys" in code
assert "from pathlib import Path" in code
class TestIntegration: class TestIntegration:
@ -2469,6 +2577,21 @@ import pytest
def my_fixture(request): def my_fixture(request):
yield "value" yield "value"
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def my_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "value"
@pytest.mark.codeflash_no_autouse
def test_something(): def test_something():
assert True assert True
''' '''
@ -2483,7 +2606,4 @@ def test_something():
code = final_module.code code = final_module.code
# Should have both modifications # Should have both modifications
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code assert code==expected_code
assert "@pytest.mark.codeflash_no_autouse" in code
# Mark should be added to both functions
assert code.count("@pytest.mark.codeflash_no_autouse") == 2