From a65569cfa04002b15d9fa690dbd4fba51b9e64b4 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Fri, 6 Jun 2025 12:30:30 -0700 Subject: [PATCH] tests --- tests/test_code_replacement.py | 371 ++++++++++++++++++++++++++++++++- 1 file changed, 363 insertions(+), 8 deletions(-) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 2e8c2f6fd..0d1940798 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,5 +1,6 @@ from __future__ import annotations - +import libcst as cst +from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder import dataclasses import os from collections import defaultdict @@ -1139,8 +1140,8 @@ class TestResults(BaseModel): ) assert ( - new_code - == """from __future__ import annotations + new_code + == """from __future__ import annotations import sys from codeflash.verification.comparator import comparator from enum import Enum @@ -1345,8 +1346,8 @@ def cosine_similarity_top_k( project_root_path=Path(__file__).parent.parent.resolve(), ) assert ( - new_code - == '''import numpy as np + new_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1404,8 +1405,8 @@ def cosine_similarity_top_k( ) assert ( - new_helper_code - == '''import numpy as np + new_helper_code + == '''import numpy as np from pydantic.dataclasses import dataclass from typing import List, Optional, Tuple, Union @dataclass(config=dict(arbitrary_types_allowed=True)) @@ -1662,6 +1663,7 @@ print("Hello world") ) assert new_code == original_code + def test_global_reassignment() -> None: original_code = """a=1 print("Hello world") @@ -2131,4 +2133,357 @@ a = 6 ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) - assert new_code.rstrip() == expected_code.rstrip() \ No newline at end of file + assert new_code.rstrip() == expected_code.rstrip() + + +class TestAutouseFixtureModifier: + """Test cases for AutouseFixtureModifier class.""" + + def test_modifies_autouse_fixture_with_pytest_decorator(self): + """Test that autouse fixture with @pytest.fixture is modified correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + print("setup") + yield + print("teardown") +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + print("setup") + yield + print("teardown") +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Parse expected to normalize formatting + expected_module = cst.parse_module(expected_code) + assert modified_module.code.strip() == expected_module.code.strip() + + def test_modifies_autouse_fixture_with_fixture_decorator(self): + """Test that autouse fixture with @fixture is modified correctly.""" + source_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + setup_code() + yield "value" + cleanup_code() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Check that the if statement was added + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in modified_module.code + assert "yield" in modified_module.code + assert "else:" in modified_module.code + assert "setup_code()" in modified_module.code + assert "cleanup_code()" in modified_module.code + + def test_ignores_non_autouse_fixture(self): + """Test that non-autouse fixtures are not modified.""" + source_code = ''' +import pytest + +@pytest.fixture +def my_fixture(request): + return "test_value" + +@pytest.fixture(scope="session") +def session_fixture(): + return "session_value" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_ignores_regular_functions(self): + """Test that regular functions are not modified.""" + source_code = ''' +def regular_function(): + return "not a fixture" + +@some_other_decorator +def decorated_function(): + return "also not a fixture" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_handles_multiple_autouse_fixtures(self): + """Test that multiple autouse fixtures in the same file are all modified.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + yield "two" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Both fixtures should be modified + code = modified_module.code + assert code.count('if request.node.get_closest_marker("codeflash_no_autouse"):') == 2 + assert code.count("else:") == 2 + + def test_preserves_fixture_with_complex_body(self): + """Test that fixtures with complex bodies are handled correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + code = modified_module.code + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code + assert "try:" in code + assert "setup_database()" in code + assert "finally:" in code + assert "cleanup_database()" in code + + +class TestPytestMarkAdder: + """Test cases for PytestMarkAdder class.""" + + def test_adds_pytest_import_when_missing(self): + """Test that pytest import is added when not present.""" + source_code = ''' +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "import pytest" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_skips_pytest_import_when_present(self): + """Test that pytest import is not duplicated when already present.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should only have one import pytest line + assert code.count("import pytest") == 1 + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_handles_from_pytest_import(self): + """Test that existing 'from pytest import ...' is recognized.""" + source_code = ''' +from pytest import fixture + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should not add import pytest since pytest is already imported + assert "import pytest" not in code + assert "from pytest import fixture" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_adds_mark_to_all_functions(self): + """Test that marks are added to all functions in the module.""" + source_code = ''' +import pytest + +def test_first(): + assert True + +def test_second(): + assert False + +def helper_function(): + return "not a test" +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # All functions should get the mark + assert code.count("@pytest.mark.codeflash_no_autouse") == 3 + + def test_skips_existing_mark(self): + """Test that existing marks are not duplicated.""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should have exactly 2 marks total (one existing, one added) + assert code.count("@pytest.mark.codeflash_no_autouse") == 2 + + def test_handles_different_mark_names(self): + """Test that different mark names work correctly.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("slow") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "@pytest.mark.slow" in code + assert "codeflash_no_autouse" not in code + + def test_preserves_existing_decorators(self): + """Test that existing decorators are preserved.""" + source_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +def test_with_decorators(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "@pytest.mark.parametrize" in code + assert "@pytest.fixture" in code + assert "@pytest.mark.codeflash_no_autouse" in code + + def test_handles_call_style_existing_marks(self): + """Test recognition of existing marks in call style (with parentheses).""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should recognize the existing call-style mark and not duplicate + lines_with_mark = [line for line in code.split('\n') if 'codeflash_no_autouse' in line] + assert len(lines_with_mark) == 2 # One existing, one added + + def test_empty_module(self): + """Test handling of empty module.""" + source_code = '' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + # Should just add the import + code = modified_module.code + assert "import pytest" in code + + def test_module_with_only_imports(self): + """Test handling of module with only imports.""" + source_code = ''' +import os +import sys +from pathlib import Path +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert "import pytest" in code + assert "import os" in code + assert "import sys" in code + assert "from pathlib import Path" in code + + +class TestIntegration: + """Integration tests for both transformers working together.""" + + def test_both_transformers_together(self): + """Test that both transformers can work on the same code.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + yield "value" + +def test_something(): + assert True +''' + # First apply AutouseFixtureModifier + module = cst.parse_module(source_code) + autouse_modifier = AutouseFixtureModifier() + modified_module = module.visit(autouse_modifier) + + # Then apply PytestMarkAdder + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + code = final_module.code + # Should have both modifications + assert 'if request.node.get_closest_marker("codeflash_no_autouse"):' in code + assert "@pytest.mark.codeflash_no_autouse" in code + # Mark should be added to both functions + assert code.count("@pytest.mark.codeflash_no_autouse") == 2