Ready to review
This commit is contained in:
parent
a65569cfa0
commit
c93b80e87b
4 changed files with 166 additions and 48 deletions
|
|
@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
"disable_telemetry",
|
||||
"disable_imports_sorting",
|
||||
"git_remote",
|
||||
"override_fixtures",
|
||||
]
|
||||
for key in supported_keys:
|
||||
if key in pyproject_config and (
|
||||
|
|
@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
or not hasattr(args, key.replace("-", "_"))
|
||||
):
|
||||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||
args.override_fixtures = pyproject_config.get("override_fixtures", False)
|
||||
assert args.module_root is not None, "--module-root must be specified"
|
||||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None, "--tests-root must be specified"
|
||||
|
|
|
|||
|
|
@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
formatter_cmds.append("disabled")
|
||||
check_formatter_installed(formatter_cmds, exit_on_failure=False)
|
||||
codeflash_section["formatter-cmds"] = formatter_cmds
|
||||
codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide
|
||||
# Add the 'codeflash' section, ensuring 'tool' section exists
|
||||
tool_section = pyproject_data.get("tool", tomlkit.table())
|
||||
tool_section["codeflash"] = codeflash_section
|
||||
pyproject_data["tool"] = tool_section
|
||||
|
||||
with toml_path.open("w", encoding="utf8") as pyproject_file:
|
||||
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
||||
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ class PytestMarkAdder(cst.CSTTransformer):
|
|||
for import_alias in stmt.names:
|
||||
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
|
||||
self.has_pytest_import = True
|
||||
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
|
||||
self.has_pytest_import = True
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||
"""Add pytest import if not present."""
|
||||
|
|
|
|||
|
|
@ -2180,17 +2180,25 @@ def my_fixture(request):
|
|||
setup_code()
|
||||
yield "value"
|
||||
cleanup_code()
|
||||
'''
|
||||
expected_code = '''
|
||||
from pytest import fixture
|
||||
|
||||
@fixture(autouse=True)
|
||||
def my_fixture(request):
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
setup_code()
|
||||
yield "value"
|
||||
cleanup_code()
|
||||
'''
|
||||
module = cst.parse_module(source_code)
|
||||
modifier = AutouseFixtureModifier()
|
||||
modified_module = module.visit(modifier)
|
||||
|
||||
# Check that the if statement was added
|
||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code
|
||||
assert "yield" in modified_module.code
|
||||
assert "else:" in modified_module.code
|
||||
assert "setup_code()" in modified_module.code
|
||||
assert "cleanup_code()" in modified_module.code
|
||||
assert modified_module.code.strip() == expected_code.strip()
|
||||
|
||||
def test_ignores_non_autouse_fixture(self):
|
||||
"""Test that non-autouse fixtures are not modified."""
|
||||
|
|
@ -2241,6 +2249,23 @@ def fixture_one(request):
|
|||
@pytest.fixture(autouse=True)
|
||||
def fixture_two(request):
|
||||
yield "two"
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fixture_one(request):
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
yield "one"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def fixture_two(request):
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
yield "two"
|
||||
'''
|
||||
module = cst.parse_module(source_code)
|
||||
modifier = AutouseFixtureModifier()
|
||||
|
|
@ -2248,8 +2273,7 @@ def fixture_two(request):
|
|||
|
||||
# Both fixtures should be modified
|
||||
code = modified_module.code
|
||||
assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2
|
||||
assert code.count("else:") == 2
|
||||
assert code==expected_code
|
||||
|
||||
def test_preserves_fixture_with_complex_body(self):
|
||||
"""Test that fixtures with complex bodies are handled correctly."""
|
||||
|
|
@ -2258,24 +2282,39 @@ import pytest
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def complex_fixture(request):
|
||||
try:
|
||||
setup_database()
|
||||
configure_logging()
|
||||
yield get_test_client()
|
||||
finally:
|
||||
cleanup_database()
|
||||
reset_logging()
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
try:
|
||||
setup_database()
|
||||
configure_logging()
|
||||
yield get_test_client()
|
||||
finally:
|
||||
cleanup_database()
|
||||
reset_logging()
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def complex_fixture(request):
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
try:
|
||||
setup_database()
|
||||
configure_logging()
|
||||
yield get_test_client()
|
||||
finally:
|
||||
cleanup_database()
|
||||
reset_logging()
|
||||
'''
|
||||
module = cst.parse_module(source_code)
|
||||
modifier = AutouseFixtureModifier()
|
||||
modified_module = module.visit(modifier)
|
||||
|
||||
code = modified_module.code
|
||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
|
||||
assert "try:" in code
|
||||
assert "setup_database()" in code
|
||||
assert "finally:" in code
|
||||
assert "cleanup_database()" in code
|
||||
assert code==expected_code
|
||||
|
||||
|
||||
class TestPytestMarkAdder:
|
||||
|
|
@ -2284,6 +2323,12 @@ class TestPytestMarkAdder:
|
|||
def test_adds_pytest_import_when_missing(self):
|
||||
"""Test that pytest import is added when not present."""
|
||||
source_code = '''
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2292,14 +2337,20 @@ def test_something():
|
|||
modified_module = module.visit(mark_adder)
|
||||
|
||||
code = modified_module.code
|
||||
assert "import pytest" in code
|
||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
||||
assert code==expected_code
|
||||
|
||||
def test_skips_pytest_import_when_present(self):
|
||||
"""Test that pytest import is not duplicated when already present."""
|
||||
source_code = '''
|
||||
import pytest
|
||||
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2309,8 +2360,7 @@ def test_something():
|
|||
|
||||
code = modified_module.code
|
||||
# Should only have one import pytest line
|
||||
assert code.count("import pytest") == 1
|
||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
||||
assert code==expected_code
|
||||
|
||||
def test_handles_from_pytest_import(self):
|
||||
"""Test that existing 'from pytest import ...' is recognized."""
|
||||
|
|
@ -2320,15 +2370,21 @@ from pytest import fixture
|
|||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
module = cst.parse_module(source_code)
|
||||
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
||||
modified_module = module.visit(mark_adder)
|
||||
|
||||
code = modified_module.code
|
||||
# Should not add import pytest since pytest is already imported
|
||||
assert "import pytest" not in code
|
||||
assert "from pytest import fixture" in code
|
||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
||||
assert code.strip()==expected_code.strip()
|
||||
|
||||
def test_adds_mark_to_all_functions(self):
|
||||
"""Test that marks are added to all functions in the module."""
|
||||
|
|
@ -2341,6 +2397,21 @@ def test_first():
|
|||
def test_second():
|
||||
assert False
|
||||
|
||||
def helper_function():
|
||||
return "not a test"
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_first():
|
||||
assert True
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_second():
|
||||
assert False
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def helper_function():
|
||||
return "not a test"
|
||||
'''
|
||||
|
|
@ -2350,7 +2421,7 @@ def helper_function():
|
|||
|
||||
code = modified_module.code
|
||||
# All functions should get the mark
|
||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 3
|
||||
assert code==expected_code
|
||||
|
||||
def test_skips_existing_mark(self):
|
||||
"""Test that existing marks are not duplicated."""
|
||||
|
|
@ -2361,6 +2432,17 @@ import pytest
|
|||
def test_already_marked():
|
||||
assert True
|
||||
|
||||
def test_needs_mark():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_already_marked():
|
||||
assert True
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_needs_mark():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2370,13 +2452,20 @@ def test_needs_mark():
|
|||
|
||||
code = modified_module.code
|
||||
# Should have exactly 2 marks total (one existing, one added)
|
||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
|
||||
assert code==expected_code
|
||||
|
||||
def test_handles_different_mark_names(self):
|
||||
"""Test that different mark names work correctly."""
|
||||
source_code = '''
|
||||
import pytest
|
||||
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2385,8 +2474,7 @@ def test_something():
|
|||
modified_module = module.visit(mark_adder)
|
||||
|
||||
code = modified_module.code
|
||||
assert "@pytest.mark.slow" in code
|
||||
assert "codeflash_no_autouse" not in code
|
||||
assert code==expected_code
|
||||
|
||||
def test_preserves_existing_decorators(self):
|
||||
"""Test that existing decorators are preserved."""
|
||||
|
|
@ -2395,6 +2483,15 @@ import pytest
|
|||
|
||||
@pytest.mark.parametrize("value", [1, 2, 3])
|
||||
@pytest.fixture
|
||||
def test_with_decorators():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.parametrize("value", [1, 2, 3])
|
||||
@pytest.fixture
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_with_decorators():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2403,9 +2500,7 @@ def test_with_decorators():
|
|||
modified_module = module.visit(mark_adder)
|
||||
|
||||
code = modified_module.code
|
||||
assert "@pytest.mark.parametrize" in code
|
||||
assert "@pytest.fixture" in code
|
||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
||||
assert code==expected_code
|
||||
|
||||
def test_handles_call_style_existing_marks(self):
|
||||
"""Test recognition of existing marks in call style (with parentheses)."""
|
||||
|
|
@ -2416,6 +2511,17 @@ import pytest
|
|||
def test_with_call_mark():
|
||||
assert True
|
||||
|
||||
def test_needs_mark():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.mark.codeflash_no_autouse()
|
||||
def test_with_call_mark():
|
||||
assert True
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_needs_mark():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2425,8 +2531,7 @@ def test_needs_mark():
|
|||
|
||||
code = modified_module.code
|
||||
# Should recognize the existing call-style mark and not duplicate
|
||||
lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line]
|
||||
assert len(lines_with_mark) == 2 # One existing, one added
|
||||
assert code==expected_code
|
||||
|
||||
def test_empty_module(self):
|
||||
"""Test handling of empty module."""
|
||||
|
|
@ -2437,7 +2542,7 @@ def test_needs_mark():
|
|||
|
||||
# Should just add the import
|
||||
code = modified_module.code
|
||||
assert "import pytest" in code
|
||||
assert code =='import pytest'
|
||||
|
||||
def test_module_with_only_imports(self):
|
||||
"""Test handling of module with only imports."""
|
||||
|
|
@ -2445,16 +2550,19 @@ def test_needs_mark():
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
'''
|
||||
module = cst.parse_module(source_code)
|
||||
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
||||
modified_module = module.visit(mark_adder)
|
||||
|
||||
code = modified_module.code
|
||||
assert "import pytest" in code
|
||||
assert "import os" in code
|
||||
assert "import sys" in code
|
||||
assert "from pathlib import Path" in code
|
||||
assert code==expected_code
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
|
|
@ -2469,6 +2577,21 @@ import pytest
|
|||
def my_fixture(request):
|
||||
yield "value"
|
||||
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
expected_code = '''
|
||||
import pytest
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def my_fixture(request):
|
||||
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||
yield
|
||||
else:
|
||||
yield "value"
|
||||
|
||||
@pytest.mark.codeflash_no_autouse
|
||||
def test_something():
|
||||
assert True
|
||||
'''
|
||||
|
|
@ -2483,7 +2606,4 @@ def test_something():
|
|||
|
||||
code = final_module.code
|
||||
# Should have both modifications
|
||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
|
||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
||||
# Mark should be added to both functions
|
||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
|
||||
assert code==expected_code
|
||||
|
|
|
|||
Loading…
Reference in a new issue