codeflash/tests/test_formatter.py
Kevin Turcios 5cd94cdf64 round 1
2024-10-12 17:29:15 -05:00

197 lines
4.4 KiB
Python

import os
import tempfile
from pathlib import Path
import pytest
from codeflash.code_utils.config_parser import parse_config_file
from codeflash.code_utils.formatter import format_code, sort_imports
def test_remove_duplicate_imports():
"""Test that duplicate imports are removed when should_sort_imports is True."""
original_code = "import os\nimport os\n"
new_code = sort_imports(original_code)
assert new_code == "import os\n"
def test_remove_multiple_duplicate_imports():
"""Test that multiple duplicate imports are removed when should_sort_imports is True."""
original_code = "import sys\nimport os\nimport sys\n"
new_code = sort_imports(original_code)
assert new_code == "import os\nimport sys\n"
def test_sorting_imports():
"""Test that imports are sorted when should_sort_imports is True."""
original_code = "import sys\nimport unittest\nimport os\n"
new_code = sort_imports(original_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
def test_sort_imports_without_formatting():
"""Test that imports are sorted when formatting is disabled and should_sort_imports is True."""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(b"import sys\nimport unittest\nimport os\n")
tmp.flush()
tmp_path = Path(tmp.name)
new_code = format_code(
formatter_cmds=["disabled"],
path=tmp_path,
)
assert new_code is not None
new_code = sort_imports(new_code)
assert new_code == "import os\nimport sys\nimport unittest\n"
def test_dedup_and_sort_imports_deduplicates():
original_code = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
"""
expected = """
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
"""
actual = sort_imports(original_code)
assert actual == expected
def test_dedup_and_sort_imports_sorts_and_deduplicates():
original_code = """
import os
import sys
import json
import os
def foo():
return os.path.join(sys.path[0], 'bar')
"""
expected = """
import json
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')
"""
actual = sort_imports(original_code)
assert actual == expected
def test_formatter_cmds_non_existent():
"""Test that default formatter-cmds is used when it doesn't exist in the toml."""
config_data = """
[tool.codeflash]
module-root = "src"
tests-root = "tests"
test-framework = "pytest"
ignore-paths = []
"""
with tempfile.NamedTemporaryFile(suffix=".toml", delete=False) as tmp:
tmp.write(config_data.encode())
tmp.flush()
tmp_path = Path(tmp.name)
try:
config, _ = parse_config_file(tmp_path)
assert config["formatter_cmds"] == ["black $file"]
finally:
os.remove(tmp_path)
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["black $file"],
path=Path(tmp_path),
)
assert actual == expected
def test_formatter_black():
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile() as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["black $file"],
path=Path(tmp_path),
)
assert actual == expected
def test_formatter_ruff():
try:
import ruff
except ImportError:
pytest.skip("ruff is not installed")
original_code = b"""
import os
import sys
def foo():
return os.path.join(sys.path[0], 'bar')"""
expected = """import os
import sys
def foo():
return os.path.join(sys.path[0], "bar")
"""
with tempfile.NamedTemporaryFile(suffix=".py") as tmp:
tmp.write(original_code)
tmp.flush()
tmp_path = tmp.name
actual = format_code(
formatter_cmds=["ruff check --exit-zero --fix $file", "ruff format $file"],
path=Path(tmp_path),
)
assert actual == expected