From c93b80e87bacf6fa8407c7cc56f0ada3692e2ee2 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 13:11:07 -0700 Subject: [PATCH] Ready to review --- codeflash/cli_cmds/cli.py | 2 +- codeflash/cli_cmds/cmd_init.py | 2 +- codeflash/code_utils/code_replacer.py | 2 - tests/test_code_replacement.py | 208 ++++++++++++++++++++------ 4 files changed, 166 insertions(+), 48 deletions(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index d677deed9..5edff57a0 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: "disable_telemetry", "disable_imports_sorting", "git_remote", + "override_fixtures", ] for key in supported_keys: if key in pyproject_config and ( @@ -130,7 +131,6 @@ def process_pyproject_config(args: Namespace) -> Namespace: or not hasattr(args, key.replace("-", "_")) ): setattr(args, key.replace("-", "_"), pyproject_config[key]) - args.override_fixtures = pyproject_config.get("override_fixtures", False) assert args.module_root is not None, "--module-root must be specified" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" assert args.tests_root is not None, "--tests-root must be specified" diff --git a/codeflash/cli_cmds/cmd_init.py b/codeflash/cli_cmds/cmd_init.py index 059a9abe5..bfe600fa4 100644 --- a/codeflash/cli_cmds/cmd_init.py +++ b/codeflash/cli_cmds/cmd_init.py @@ -723,11 +723,11 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None: formatter_cmds.append("disabled") check_formatter_installed(formatter_cmds, exit_on_failure=False) codeflash_section["formatter-cmds"] = formatter_cmds - codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide # Add the 'codeflash' section, ensuring 'tool' section exists tool_section = pyproject_data.get("tool", tomlkit.table()) tool_section["codeflash"] = codeflash_section pyproject_data["tool"] = tool_section + with toml_path.open("w", encoding="utf8") as pyproject_file: pyproject_file.write(tomlkit.dumps(pyproject_data)) click.echo(f"✅ Added Codeflash configuration to {toml_path}") diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ce4dcc9d9..932053fc6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -54,8 +54,6 @@ class PytestMarkAdder(cst.CSTTransformer): for import_alias in stmt.names: if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": self.has_pytest_import = True - elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest": - self.has_pytest_import = True def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 """Add pytest import if not present.""" diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 0d1940798..e848e4525 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -2180,17 +2180,25 @@ def my_fixture(request): setup_code() yield "value" cleanup_code() +''' + expected_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + setup_code() + yield "value" + cleanup_code() ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() modified_module = module.visit(modifier) # Check that the if statement was added - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code - assert "yield" in modified_module.code - assert "else:" in modified_module.code - assert "setup_code()" in modified_module.code - assert "cleanup_code()" in modified_module.code + assert modified_module.code.strip() == expected_code.strip() def test_ignores_non_autouse_fixture(self): """Test that non-autouse fixtures are not modified.""" @@ -2241,6 +2249,23 @@ def fixture_one(request): @pytest.fixture(autouse=True) def fixture_two(request): yield "two" +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "two" ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() @@ -2248,8 +2273,7 @@ def fixture_two(request): # Both fixtures should be modified code = modified_module.code - assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2 - assert code.count("else:") == 2 + assert code==expected_code def test_preserves_fixture_with_complex_body(self): """Test that fixtures with complex bodies are handled correctly.""" @@ -2258,24 +2282,39 @@ import pytest @pytest.fixture(autouse=True) def complex_fixture(request): - try: - setup_database() - configure_logging() - yield get_test_client() - finally: - cleanup_database() - reset_logging() + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() ''' module = cst.parse_module(source_code) modifier = AutouseFixtureModifier() modified_module = module.visit(modifier) code = modified_module.code - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code - assert "try:" in code - assert "setup_database()" in code - assert "finally:" in code - assert "cleanup_database()" in code + assert code==expected_code class TestPytestMarkAdder: @@ -2284,6 +2323,12 @@ class TestPytestMarkAdder: def test_adds_pytest_import_when_missing(self): """Test that pytest import is added when not present.""" source_code = ''' +def test_something(): + assert True +''' + expected_code = ''' +import pytest +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2292,14 +2337,20 @@ def test_something(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "import pytest" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_skips_pytest_import_when_present(self): """Test that pytest import is not duplicated when already present.""" source_code = ''' import pytest +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2309,8 +2360,7 @@ def test_something(): code = modified_module.code # Should only have one import pytest line - assert code.count("import pytest") == 1 - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_handles_from_pytest_import(self): """Test that existing 'from pytest import ...' is recognized.""" @@ -2320,15 +2370,21 @@ from pytest import fixture def test_something(): assert True ''' + expected_code = ''' +import pytest +from pytest import fixture + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True + ''' module = cst.parse_module(source_code) mark_adder = PytestMarkAdder("codeflash_no_autouse") modified_module = module.visit(mark_adder) code = modified_module.code # Should not add import pytest since pytest is already imported - assert "import pytest" not in code - assert "from pytest import fixture" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code.strip()==expected_code.strip() def test_adds_mark_to_all_functions(self): """Test that marks are added to all functions in the module.""" @@ -2341,6 +2397,21 @@ def test_first(): def test_second(): assert False +def helper_function(): + return "not a test" +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_first(): + assert True + +@pytest.mark.codeflash_no_autouse +def test_second(): + assert False + +@pytest.mark.codeflash_no_autouse def helper_function(): return "not a test" ''' @@ -2350,7 +2421,7 @@ def helper_function(): code = modified_module.code # All functions should get the mark - assert code.count("@pytest.mark.codeflash_no_autouse") == 3 + assert code==expected_code def test_skips_existing_mark(self): """Test that existing marks are not duplicated.""" @@ -2361,6 +2432,17 @@ import pytest def test_already_marked(): assert True +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +@pytest.mark.codeflash_no_autouse def test_needs_mark(): assert True ''' @@ -2370,13 +2452,20 @@ def test_needs_mark(): code = modified_module.code # Should have exactly 2 marks total (one existing, one added) - assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + assert code==expected_code def test_handles_different_mark_names(self): """Test that different mark names work correctly.""" source_code = ''' import pytest +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.slow def test_something(): assert True ''' @@ -2385,8 +2474,7 @@ def test_something(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "@pytest.mark.slow" in code - assert "codeflash_no_autouse" not in code + assert code==expected_code def test_preserves_existing_decorators(self): """Test that existing decorators are preserved.""" @@ -2395,6 +2483,15 @@ import pytest @pytest.mark.parametrize("value", [1, 2, 3]) @pytest.fixture +def test_with_decorators(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +@pytest.mark.codeflash_no_autouse def test_with_decorators(): assert True ''' @@ -2403,9 +2500,7 @@ def test_with_decorators(): modified_module = module.visit(mark_adder) code = modified_module.code - assert "@pytest.mark.parametrize" in code - assert "@pytest.fixture" in code - assert "@pytest.mark.codeflash_no_autouse" in code + assert code==expected_code def test_handles_call_style_existing_marks(self): """Test recognition of existing marks in call style (with parentheses).""" @@ -2416,6 +2511,17 @@ import pytest def test_with_call_mark(): assert True +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +@pytest.mark.codeflash_no_autouse def test_needs_mark(): assert True ''' @@ -2425,8 +2531,7 @@ def test_needs_mark(): code = modified_module.code # Should recognize the existing call-style mark and not duplicate - lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line] - assert len(lines_with_mark) == 2 # One existing, one added + assert code==expected_code def test_empty_module(self): """Test handling of empty module.""" @@ -2437,7 +2542,7 @@ def test_needs_mark(): # Should just add the import code = modified_module.code - assert "import pytest" in code + assert code =='import pytest' def test_module_with_only_imports(self): """Test handling of module with only imports.""" @@ -2445,16 +2550,19 @@ def test_needs_mark(): import os import sys from pathlib import Path +''' + expected_code = ''' +import pytest +import os +import sys +from pathlib import Path ''' module = cst.parse_module(source_code) mark_adder = PytestMarkAdder("codeflash_no_autouse") modified_module = module.visit(mark_adder) code = modified_module.code - assert "import pytest" in code - assert "import os" in code - assert "import sys" in code - assert "from pathlib import Path" in code + assert code==expected_code class TestIntegration: @@ -2469,6 +2577,21 @@ import pytest def my_fixture(request): yield "value" +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "value" + +@pytest.mark.codeflash_no_autouse def test_something(): assert True ''' @@ -2483,7 +2606,4 @@ def test_something(): code = final_module.code # Should have both modifications - assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code - assert "@pytest.mark.codeflash_no_autouse" in code - # Mark should be added to both functions - assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + assert code==expected_code