Ready to review

This commit is contained in:
aseembits93 2025-06-06 13:11:07 -07:00
parent a65569cfa0
commit c93b80e87b
4 changed files with 166 additions and 48 deletions

View file

@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_telemetry",
"disable_imports_sorting",
"git_remote",
"override_fixtures",
]
for key in supported_keys:
if key in pyproject_config and (
@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
or not hasattr(args, key.replace("-", "_"))
):
setattr(args, key.replace("-", "_"), pyproject_config[key])
args.override_fixtures = pyproject_config.get("override_fixtures", False)
assert args.module_root is not None, "--module-root must be specified"
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified"

View file

@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
formatter_cmds.append("disabled")
check_formatter_installed(formatter_cmds, exit_on_failure=False)
codeflash_section["formatter-cmds"] = formatter_cmds
codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide
# Add the 'codeflash' section, ensuring 'tool' section exists
tool_section = pyproject_data.get("tool", tomlkit.table())
tool_section["codeflash"] = codeflash_section
pyproject_data["tool"] = tool_section
with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"✅ Added Codeflash configuration to {toml_path}")

View file

@ -54,8 +54,6 @@ class PytestMarkAdder(cst.CSTTransformer):
for import_alias in stmt.names:
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
self.has_pytest_import = True
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
self.has_pytest_import = True
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
"""Add pytest import if not present."""

View file

@ -2180,17 +2180,25 @@ def my_fixture(request):
setup_code()
yield "value"
cleanup_code()
'''
expected_code = '''
from pytest import fixture
@fixture(autouse=True)
def my_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
setup_code()
yield "value"
cleanup_code()
'''
module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier()
modified_module = module.visit(modifier)
# Check that the if statement was added
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code
assert "yield" in modified_module.code
assert "else:" in modified_module.code
assert "setup_code()" in modified_module.code
assert "cleanup_code()" in modified_module.code
assert modified_module.code.strip() == expected_code.strip()
def test_ignores_non_autouse_fixture(self):
"""Test that non-autouse fixtures are not modified."""
@ -2241,6 +2249,23 @@ def fixture_one(request):
@pytest.fixture(autouse=True)
def fixture_two(request):
yield "two"
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
def fixture_one(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "one"
@pytest.fixture(autouse=True)
def fixture_two(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "two"
'''
module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier()
@ -2248,8 +2273,7 @@ def fixture_two(request):
# Both fixtures should be modified
code = modified_module.code
assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2
assert code.count("else:") == 2
assert code==expected_code
def test_preserves_fixture_with_complex_body(self):
"""Test that fixtures with complex bodies are handled correctly."""
@ -2258,24 +2282,39 @@ import pytest
@pytest.fixture(autouse=True)
def complex_fixture(request):
try:
setup_database()
configure_logging()
yield get_test_client()
finally:
cleanup_database()
reset_logging()
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
try:
setup_database()
configure_logging()
yield get_test_client()
finally:
cleanup_database()
reset_logging()
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
def complex_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
try:
setup_database()
configure_logging()
yield get_test_client()
finally:
cleanup_database()
reset_logging()
'''
module = cst.parse_module(source_code)
modifier = AutouseFixtureModifier()
modified_module = module.visit(modifier)
code = modified_module.code
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
assert "try:" in code
assert "setup_database()" in code
assert "finally:" in code
assert "cleanup_database()" in code
assert code==expected_code
class TestPytestMarkAdder:
@ -2284,6 +2323,12 @@ class TestPytestMarkAdder:
def test_adds_pytest_import_when_missing(self):
"""Test that pytest import is added when not present."""
source_code = '''
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
@ -2292,14 +2337,20 @@ def test_something():
modified_module = module.visit(mark_adder)
code = modified_module.code
assert "import pytest" in code
assert "@pytest.mark.codeflash_no_autouse" in code
assert code==expected_code
def test_skips_pytest_import_when_present(self):
"""Test that pytest import is not duplicated when already present."""
source_code = '''
import pytest
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
@ -2309,8 +2360,7 @@ def test_something():
code = modified_module.code
# Should only have one import pytest line
assert code.count("import pytest") == 1
assert "@pytest.mark.codeflash_no_autouse" in code
assert code==expected_code
def test_handles_from_pytest_import(self):
"""Test that existing 'from pytest import ...' is recognized."""
@ -2320,15 +2370,21 @@ from pytest import fixture
def test_something():
assert True
'''
expected_code = '''
import pytest
from pytest import fixture
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
module = cst.parse_module(source_code)
mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = module.visit(mark_adder)
code = modified_module.code
# Should not add import pytest since pytest is already imported
assert "import pytest" not in code
assert "from pytest import fixture" in code
assert "@pytest.mark.codeflash_no_autouse" in code
assert code.strip()==expected_code.strip()
def test_adds_mark_to_all_functions(self):
"""Test that marks are added to all functions in the module."""
@ -2341,6 +2397,21 @@ def test_first():
def test_second():
assert False
def helper_function():
return "not a test"
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_first():
assert True
@pytest.mark.codeflash_no_autouse
def test_second():
assert False
@pytest.mark.codeflash_no_autouse
def helper_function():
return "not a test"
'''
@ -2350,7 +2421,7 @@ def helper_function():
code = modified_module.code
# All functions should get the mark
assert code.count("@pytest.mark.codeflash_no_autouse") == 3
assert code==expected_code
def test_skips_existing_mark(self):
"""Test that existing marks are not duplicated."""
@ -2361,6 +2432,17 @@ import pytest
def test_already_marked():
assert True
def test_needs_mark():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse
def test_already_marked():
assert True
@pytest.mark.codeflash_no_autouse
def test_needs_mark():
assert True
'''
@ -2370,13 +2452,20 @@ def test_needs_mark():
code = modified_module.code
# Should have exactly 2 marks total (one existing, one added)
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
assert code==expected_code
def test_handles_different_mark_names(self):
"""Test that different mark names work correctly."""
source_code = '''
import pytest
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.slow
def test_something():
assert True
'''
@ -2385,8 +2474,7 @@ def test_something():
modified_module = module.visit(mark_adder)
code = modified_module.code
assert "@pytest.mark.slow" in code
assert "codeflash_no_autouse" not in code
assert code==expected_code
def test_preserves_existing_decorators(self):
"""Test that existing decorators are preserved."""
@ -2395,6 +2483,15 @@ import pytest
@pytest.mark.parametrize("value", [1, 2, 3])
@pytest.fixture
def test_with_decorators():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.parametrize("value", [1, 2, 3])
@pytest.fixture
@pytest.mark.codeflash_no_autouse
def test_with_decorators():
assert True
'''
@ -2403,9 +2500,7 @@ def test_with_decorators():
modified_module = module.visit(mark_adder)
code = modified_module.code
assert "@pytest.mark.parametrize" in code
assert "@pytest.fixture" in code
assert "@pytest.mark.codeflash_no_autouse" in code
assert code==expected_code
def test_handles_call_style_existing_marks(self):
"""Test recognition of existing marks in call style (with parentheses)."""
@ -2416,6 +2511,17 @@ import pytest
def test_with_call_mark():
assert True
def test_needs_mark():
assert True
'''
expected_code = '''
import pytest
@pytest.mark.codeflash_no_autouse()
def test_with_call_mark():
assert True
@pytest.mark.codeflash_no_autouse
def test_needs_mark():
assert True
'''
@ -2425,8 +2531,7 @@ def test_needs_mark():
code = modified_module.code
# Should recognize the existing call-style mark and not duplicate
lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line]
assert len(lines_with_mark) == 2 # One existing, one added
assert code==expected_code
def test_empty_module(self):
"""Test handling of empty module."""
@ -2437,7 +2542,7 @@ def test_needs_mark():
# Should just add the import
code = modified_module.code
assert "import pytest" in code
assert code =='import pytest'
def test_module_with_only_imports(self):
"""Test handling of module with only imports."""
@ -2445,16 +2550,19 @@ def test_needs_mark():
import os
import sys
from pathlib import Path
'''
expected_code = '''
import pytest
import os
import sys
from pathlib import Path
'''
module = cst.parse_module(source_code)
mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = module.visit(mark_adder)
code = modified_module.code
assert "import pytest" in code
assert "import os" in code
assert "import sys" in code
assert "from pathlib import Path" in code
assert code==expected_code
class TestIntegration:
@ -2469,6 +2577,21 @@ import pytest
def my_fixture(request):
yield "value"
def test_something():
assert True
'''
expected_code = '''
import pytest
@pytest.fixture(autouse=True)
@pytest.mark.codeflash_no_autouse
def my_fixture(request):
if request.node.get_closest_marker("codeflash_no_autouse"):
yield
else:
yield "value"
@pytest.mark.codeflash_no_autouse
def test_something():
assert True
'''
@ -2483,7 +2606,4 @@ def test_something():
code = final_module.code
# Should have both modifications
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
assert "@pytest.mark.codeflash_no_autouse" in code
# Mark should be added to both functions
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
assert code==expected_code