Ready to review
This commit is contained in:
parent
a65569cfa0
commit
c93b80e87b
4 changed files with 166 additions and 48 deletions
|
|
@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
||||||
"disable_telemetry",
|
"disable_telemetry",
|
||||||
"disable_imports_sorting",
|
"disable_imports_sorting",
|
||||||
"git_remote",
|
"git_remote",
|
||||||
|
"override_fixtures",
|
||||||
]
|
]
|
||||||
for key in supported_keys:
|
for key in supported_keys:
|
||||||
if key in pyproject_config and (
|
if key in pyproject_config and (
|
||||||
|
|
@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
||||||
or not hasattr(args, key.replace("-", "_"))
|
or not hasattr(args, key.replace("-", "_"))
|
||||||
):
|
):
|
||||||
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
setattr(args, key.replace("-", "_"), pyproject_config[key])
|
||||||
args.override_fixtures = pyproject_config.get("override_fixtures", False)
|
|
||||||
assert args.module_root is not None, "--module-root must be specified"
|
assert args.module_root is not None, "--module-root must be specified"
|
||||||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||||
assert args.tests_root is not None, "--tests-root must be specified"
|
assert args.tests_root is not None, "--tests-root must be specified"
|
||||||
|
|
|
||||||
|
|
@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
||||||
formatter_cmds.append("disabled")
|
formatter_cmds.append("disabled")
|
||||||
check_formatter_installed(formatter_cmds, exit_on_failure=False)
|
check_formatter_installed(formatter_cmds, exit_on_failure=False)
|
||||||
codeflash_section["formatter-cmds"] = formatter_cmds
|
codeflash_section["formatter-cmds"] = formatter_cmds
|
||||||
codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide
|
|
||||||
# Add the 'codeflash' section, ensuring 'tool' section exists
|
# Add the 'codeflash' section, ensuring 'tool' section exists
|
||||||
tool_section = pyproject_data.get("tool", tomlkit.table())
|
tool_section = pyproject_data.get("tool", tomlkit.table())
|
||||||
tool_section["codeflash"] = codeflash_section
|
tool_section["codeflash"] = codeflash_section
|
||||||
pyproject_data["tool"] = tool_section
|
pyproject_data["tool"] = tool_section
|
||||||
|
|
||||||
with toml_path.open("w", encoding="utf8") as pyproject_file:
|
with toml_path.open("w", encoding="utf8") as pyproject_file:
|
||||||
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
pyproject_file.write(tomlkit.dumps(pyproject_data))
|
||||||
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,6 @@ class PytestMarkAdder(cst.CSTTransformer):
|
||||||
for import_alias in stmt.names:
|
for import_alias in stmt.names:
|
||||||
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
|
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
|
||||||
self.has_pytest_import = True
|
self.has_pytest_import = True
|
||||||
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
|
|
||||||
self.has_pytest_import = True
|
|
||||||
|
|
||||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
|
||||||
"""Add pytest import if not present."""
|
"""Add pytest import if not present."""
|
||||||
|
|
|
||||||
|
|
@ -2180,17 +2180,25 @@ def my_fixture(request):
|
||||||
setup_code()
|
setup_code()
|
||||||
yield "value"
|
yield "value"
|
||||||
cleanup_code()
|
cleanup_code()
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
from pytest import fixture
|
||||||
|
|
||||||
|
@fixture(autouse=True)
|
||||||
|
def my_fixture(request):
|
||||||
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
setup_code()
|
||||||
|
yield "value"
|
||||||
|
cleanup_code()
|
||||||
'''
|
'''
|
||||||
module = cst.parse_module(source_code)
|
module = cst.parse_module(source_code)
|
||||||
modifier = AutouseFixtureModifier()
|
modifier = AutouseFixtureModifier()
|
||||||
modified_module = module.visit(modifier)
|
modified_module = module.visit(modifier)
|
||||||
|
|
||||||
# Check that the if statement was added
|
# Check that the if statement was added
|
||||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code
|
assert modified_module.code.strip() == expected_code.strip()
|
||||||
assert "yield" in modified_module.code
|
|
||||||
assert "else:" in modified_module.code
|
|
||||||
assert "setup_code()" in modified_module.code
|
|
||||||
assert "cleanup_code()" in modified_module.code
|
|
||||||
|
|
||||||
def test_ignores_non_autouse_fixture(self):
|
def test_ignores_non_autouse_fixture(self):
|
||||||
"""Test that non-autouse fixtures are not modified."""
|
"""Test that non-autouse fixtures are not modified."""
|
||||||
|
|
@ -2241,6 +2249,23 @@ def fixture_one(request):
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def fixture_two(request):
|
def fixture_two(request):
|
||||||
yield "two"
|
yield "two"
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fixture_one(request):
|
||||||
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield "one"
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def fixture_two(request):
|
||||||
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield "two"
|
||||||
'''
|
'''
|
||||||
module = cst.parse_module(source_code)
|
module = cst.parse_module(source_code)
|
||||||
modifier = AutouseFixtureModifier()
|
modifier = AutouseFixtureModifier()
|
||||||
|
|
@ -2248,8 +2273,7 @@ def fixture_two(request):
|
||||||
|
|
||||||
# Both fixtures should be modified
|
# Both fixtures should be modified
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2
|
assert code==expected_code
|
||||||
assert code.count("else:") == 2
|
|
||||||
|
|
||||||
def test_preserves_fixture_with_complex_body(self):
|
def test_preserves_fixture_with_complex_body(self):
|
||||||
"""Test that fixtures with complex bodies are handled correctly."""
|
"""Test that fixtures with complex bodies are handled correctly."""
|
||||||
|
|
@ -2258,24 +2282,39 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def complex_fixture(request):
|
def complex_fixture(request):
|
||||||
try:
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
setup_database()
|
yield
|
||||||
configure_logging()
|
else:
|
||||||
yield get_test_client()
|
try:
|
||||||
finally:
|
setup_database()
|
||||||
cleanup_database()
|
configure_logging()
|
||||||
reset_logging()
|
yield get_test_client()
|
||||||
|
finally:
|
||||||
|
cleanup_database()
|
||||||
|
reset_logging()
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def complex_fixture(request):
|
||||||
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
setup_database()
|
||||||
|
configure_logging()
|
||||||
|
yield get_test_client()
|
||||||
|
finally:
|
||||||
|
cleanup_database()
|
||||||
|
reset_logging()
|
||||||
'''
|
'''
|
||||||
module = cst.parse_module(source_code)
|
module = cst.parse_module(source_code)
|
||||||
modifier = AutouseFixtureModifier()
|
modifier = AutouseFixtureModifier()
|
||||||
modified_module = module.visit(modifier)
|
modified_module = module.visit(modifier)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
|
assert code==expected_code
|
||||||
assert "try:" in code
|
|
||||||
assert "setup_database()" in code
|
|
||||||
assert "finally:" in code
|
|
||||||
assert "cleanup_database()" in code
|
|
||||||
|
|
||||||
|
|
||||||
class TestPytestMarkAdder:
|
class TestPytestMarkAdder:
|
||||||
|
|
@ -2284,6 +2323,12 @@ class TestPytestMarkAdder:
|
||||||
def test_adds_pytest_import_when_missing(self):
|
def test_adds_pytest_import_when_missing(self):
|
||||||
"""Test that pytest import is added when not present."""
|
"""Test that pytest import is added when not present."""
|
||||||
source_code = '''
|
source_code = '''
|
||||||
|
def test_something():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_something():
|
def test_something():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2292,14 +2337,20 @@ def test_something():
|
||||||
modified_module = module.visit(mark_adder)
|
modified_module = module.visit(mark_adder)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert "import pytest" in code
|
assert code==expected_code
|
||||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
|
||||||
|
|
||||||
def test_skips_pytest_import_when_present(self):
|
def test_skips_pytest_import_when_present(self):
|
||||||
"""Test that pytest import is not duplicated when already present."""
|
"""Test that pytest import is not duplicated when already present."""
|
||||||
source_code = '''
|
source_code = '''
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
def test_something():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_something():
|
def test_something():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2309,8 +2360,7 @@ def test_something():
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
# Should only have one import pytest line
|
# Should only have one import pytest line
|
||||||
assert code.count("import pytest") == 1
|
assert code==expected_code
|
||||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
|
||||||
|
|
||||||
def test_handles_from_pytest_import(self):
|
def test_handles_from_pytest_import(self):
|
||||||
"""Test that existing 'from pytest import ...' is recognized."""
|
"""Test that existing 'from pytest import ...' is recognized."""
|
||||||
|
|
@ -2320,15 +2370,21 @@ from pytest import fixture
|
||||||
def test_something():
|
def test_something():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
from pytest import fixture
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
|
def test_something():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
module = cst.parse_module(source_code)
|
module = cst.parse_module(source_code)
|
||||||
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
||||||
modified_module = module.visit(mark_adder)
|
modified_module = module.visit(mark_adder)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
# Should not add import pytest since pytest is already imported
|
# Should not add import pytest since pytest is already imported
|
||||||
assert "import pytest" not in code
|
assert code.strip()==expected_code.strip()
|
||||||
assert "from pytest import fixture" in code
|
|
||||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
|
||||||
|
|
||||||
def test_adds_mark_to_all_functions(self):
|
def test_adds_mark_to_all_functions(self):
|
||||||
"""Test that marks are added to all functions in the module."""
|
"""Test that marks are added to all functions in the module."""
|
||||||
|
|
@ -2341,6 +2397,21 @@ def test_first():
|
||||||
def test_second():
|
def test_second():
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
def helper_function():
|
||||||
|
return "not a test"
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
|
def test_first():
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
|
def test_second():
|
||||||
|
assert False
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def helper_function():
|
def helper_function():
|
||||||
return "not a test"
|
return "not a test"
|
||||||
'''
|
'''
|
||||||
|
|
@ -2350,7 +2421,7 @@ def helper_function():
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
# All functions should get the mark
|
# All functions should get the mark
|
||||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 3
|
assert code==expected_code
|
||||||
|
|
||||||
def test_skips_existing_mark(self):
|
def test_skips_existing_mark(self):
|
||||||
"""Test that existing marks are not duplicated."""
|
"""Test that existing marks are not duplicated."""
|
||||||
|
|
@ -2361,6 +2432,17 @@ import pytest
|
||||||
def test_already_marked():
|
def test_already_marked():
|
||||||
assert True
|
assert True
|
||||||
|
|
||||||
|
def test_needs_mark():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
|
def test_already_marked():
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_needs_mark():
|
def test_needs_mark():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2370,13 +2452,20 @@ def test_needs_mark():
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
# Should have exactly 2 marks total (one existing, one added)
|
# Should have exactly 2 marks total (one existing, one added)
|
||||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
|
assert code==expected_code
|
||||||
|
|
||||||
def test_handles_different_mark_names(self):
|
def test_handles_different_mark_names(self):
|
||||||
"""Test that different mark names work correctly."""
|
"""Test that different mark names work correctly."""
|
||||||
source_code = '''
|
source_code = '''
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
def test_something():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
def test_something():
|
def test_something():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2385,8 +2474,7 @@ def test_something():
|
||||||
modified_module = module.visit(mark_adder)
|
modified_module = module.visit(mark_adder)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert "@pytest.mark.slow" in code
|
assert code==expected_code
|
||||||
assert "codeflash_no_autouse" not in code
|
|
||||||
|
|
||||||
def test_preserves_existing_decorators(self):
|
def test_preserves_existing_decorators(self):
|
||||||
"""Test that existing decorators are preserved."""
|
"""Test that existing decorators are preserved."""
|
||||||
|
|
@ -2395,6 +2483,15 @@ import pytest
|
||||||
|
|
||||||
@pytest.mark.parametrize("value", [1, 2, 3])
|
@pytest.mark.parametrize("value", [1, 2, 3])
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
def test_with_decorators():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", [1, 2, 3])
|
||||||
|
@pytest.fixture
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_with_decorators():
|
def test_with_decorators():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2403,9 +2500,7 @@ def test_with_decorators():
|
||||||
modified_module = module.visit(mark_adder)
|
modified_module = module.visit(mark_adder)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert "@pytest.mark.parametrize" in code
|
assert code==expected_code
|
||||||
assert "@pytest.fixture" in code
|
|
||||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
|
||||||
|
|
||||||
def test_handles_call_style_existing_marks(self):
|
def test_handles_call_style_existing_marks(self):
|
||||||
"""Test recognition of existing marks in call style (with parentheses)."""
|
"""Test recognition of existing marks in call style (with parentheses)."""
|
||||||
|
|
@ -2416,6 +2511,17 @@ import pytest
|
||||||
def test_with_call_mark():
|
def test_with_call_mark():
|
||||||
assert True
|
assert True
|
||||||
|
|
||||||
|
def test_needs_mark():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse()
|
||||||
|
def test_with_call_mark():
|
||||||
|
assert True
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_needs_mark():
|
def test_needs_mark():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2425,8 +2531,7 @@ def test_needs_mark():
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
# Should recognize the existing call-style mark and not duplicate
|
# Should recognize the existing call-style mark and not duplicate
|
||||||
lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line]
|
assert code==expected_code
|
||||||
assert len(lines_with_mark) == 2 # One existing, one added
|
|
||||||
|
|
||||||
def test_empty_module(self):
|
def test_empty_module(self):
|
||||||
"""Test handling of empty module."""
|
"""Test handling of empty module."""
|
||||||
|
|
@ -2437,7 +2542,7 @@ def test_needs_mark():
|
||||||
|
|
||||||
# Should just add the import
|
# Should just add the import
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert "import pytest" in code
|
assert code =='import pytest'
|
||||||
|
|
||||||
def test_module_with_only_imports(self):
|
def test_module_with_only_imports(self):
|
||||||
"""Test handling of module with only imports."""
|
"""Test handling of module with only imports."""
|
||||||
|
|
@ -2445,16 +2550,19 @@ def test_needs_mark():
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
'''
|
'''
|
||||||
module = cst.parse_module(source_code)
|
module = cst.parse_module(source_code)
|
||||||
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
mark_adder = PytestMarkAdder("codeflash_no_autouse")
|
||||||
modified_module = module.visit(mark_adder)
|
modified_module = module.visit(mark_adder)
|
||||||
|
|
||||||
code = modified_module.code
|
code = modified_module.code
|
||||||
assert "import pytest" in code
|
assert code==expected_code
|
||||||
assert "import os" in code
|
|
||||||
assert "import sys" in code
|
|
||||||
assert "from pathlib import Path" in code
|
|
||||||
|
|
||||||
|
|
||||||
class TestIntegration:
|
class TestIntegration:
|
||||||
|
|
@ -2469,6 +2577,21 @@ import pytest
|
||||||
def my_fixture(request):
|
def my_fixture(request):
|
||||||
yield "value"
|
yield "value"
|
||||||
|
|
||||||
|
def test_something():
|
||||||
|
assert True
|
||||||
|
'''
|
||||||
|
expected_code = '''
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
|
def my_fixture(request):
|
||||||
|
if request.node.get_closest_marker("codeflash_no_autouse"):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield "value"
|
||||||
|
|
||||||
|
@pytest.mark.codeflash_no_autouse
|
||||||
def test_something():
|
def test_something():
|
||||||
assert True
|
assert True
|
||||||
'''
|
'''
|
||||||
|
|
@ -2483,7 +2606,4 @@ def test_something():
|
||||||
|
|
||||||
code = final_module.code
|
code = final_module.code
|
||||||
# Should have both modifications
|
# Should have both modifications
|
||||||
assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code
|
assert code==expected_code
|
||||||
assert "@pytest.mark.codeflash_no_autouse" in code
|
|
||||||
# Mark should be added to both functions
|
|
||||||
assert code.count("@pytest.mark.codeflash_no_autouse") == 2
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue