codeflash-agent/packages/codeflash-api/tests/test_repair.py
Kevin Turcios 3ee9c22c8e
fix: resolve all ruff lint errors across repo (#38)
* fix: resolve all ruff lint errors across repo

Auto-fixed 31 errors (unused imports, formatting, simplifications).
Manually fixed 14 remaining:
- EXE001: removed shebangs from non-executable bench scripts
- C417: replaced map(lambda) with generator expression
- C901/PLR0915: extracted _write_and_instrument_tests from generate_ai_tests
- C901/PLR0912: extracted _parse_toml_addopts and _ini_section_name from modify_addopts
- RUF001/RUF002: replaced ambiguous Unicode chars (en dash, multiplication sign)
- FBT002: made boolean params keyword-only in report functions
- E402: moved `import re` to top of file in security reports

* fix: resolve pre-existing mypy errors across packages

- _testgen.py: annotate `generated` as `str` to avoid no-any-return
- _test_runner.py: use str() for TimeoutExpired stdout/stderr (bytes|str),
  remove unused type: ignore on proc.kill()
- _candidate_eval.py: annotate `speedup` as `float` to avoid no-any-return
  from lazy-loaded performance_gain
2026-04-23 10:22:42 -05:00

396 lines
12 KiB
Python

from __future__ import annotations
import uuid
import pytest
from codeflash_api.repair._context import (
apply_patches_to_optimized_code,
build_test_details,
build_user_prompt,
is_valid_repair,
)
from codeflash_api.repair.schemas import (
BehaviorDiff,
BehaviorDiffScope,
CodeRepairRequest,
)
# -------------------------------------------------------------------
# Schemas
# -------------------------------------------------------------------
class TestCodeRepairRequest:
"""Tests for CodeRepairRequest schema."""
def test_minimal_valid_request(self):
"""
A request with required fields only deserializes.
"""
req = CodeRepairRequest(
trace_id=str(uuid.uuid4()),
optimization_id=str(uuid.uuid4()),
original_source_code="def f(): pass",
modified_source_code="def f(): return 1",
test_diffs=[],
)
assert "python" == req.language
assert req.rerun_trace_id is None
def test_full_request(self):
"""
A request with all fields deserializes.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value="1",
candidate_value="2",
original_pass=True,
candidate_pass=False,
test_src_code="assert f() == 1",
)
req = CodeRepairRequest(
trace_id=str(uuid.uuid4()),
optimization_id=str(uuid.uuid4()),
original_source_code="def f(): return 1",
modified_source_code="def f(): return 2",
test_diffs=[diff],
language="python",
rerun_trace_id=str(uuid.uuid4()),
)
assert 1 == len(req.test_diffs)
assert BehaviorDiffScope.RETURN_VALUE == req.test_diffs[0].scope
class TestBehaviorDiff:
"""Tests for BehaviorDiff schema."""
def test_all_scopes(self):
"""
Every scope enum value can be used.
"""
for scope in BehaviorDiffScope:
diff = BehaviorDiff(
scope=scope,
original_pass=True,
candidate_pass=False,
)
assert scope == diff.scope
def test_none_values(self):
"""
Optional fields default to None.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.DID_PASS,
original_pass=True,
candidate_pass=False,
)
assert diff.original_value is None
assert diff.candidate_value is None
assert diff.test_src_code is None
assert diff.candidate_pytest_error is None
assert diff.original_pytest_error is None
# -------------------------------------------------------------------
# build_test_details
# -------------------------------------------------------------------
class TestBuildTestDetails:
"""Tests for build_test_details."""
def test_empty_diffs(self):
"""
Empty test_diffs produces empty string.
"""
assert "" == build_test_details([])
def test_return_value_diff(self):
"""
Return value diffs show Expected/Got lines.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value="hello",
candidate_value="world",
original_pass=True,
candidate_pass=False,
test_src_code="assert f() == 'hello'",
)
result = build_test_details([diff])
assert "Expected: 'hello'" in result
assert "Got: 'world'" in result
assert "assert f() == 'hello'" in result
def test_did_pass_diff(self):
"""
DID_PASS diffs show pass/fail status.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.DID_PASS,
original_pass=True,
candidate_pass=False,
test_src_code="test_something",
)
result = build_test_details([diff])
assert "Passed" in result
assert "Failed" in result
def test_pytest_errors_included(self):
"""
Pytest errors appear in the output when present.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
test_src_code="test_func",
candidate_pytest_error="AssertionError: 1 != 2",
original_pytest_error="",
)
result = build_test_details([diff])
assert "Pytest error (optimized)" in result
def test_multiple_diffs_same_test(self):
"""
Multiple diffs for the same test source share a header.
"""
diffs = [
BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
test_src_code="test_func",
),
BehaviorDiff(
scope=BehaviorDiffScope.STDOUT,
original_value="out",
candidate_value="err",
original_pass=True,
candidate_pass=False,
test_src_code="test_func",
),
]
result = build_test_details(diffs)
assert result.count("**Test source:**") == 1
def test_multiple_diffs_different_tests(self):
"""
Different test sources get separate sections.
"""
diffs = [
BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
test_src_code="test_a",
),
BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=3,
candidate_value=4,
original_pass=True,
candidate_pass=False,
test_src_code="test_b",
),
]
result = build_test_details(diffs)
assert "---" in result
assert result.count("**Test source:**") == 2
def test_non_python_language(self):
"""
Non-python language uses 'Test error' label.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
test_src_code="test_func",
candidate_pytest_error="Error in test",
)
result = build_test_details([diff], language="javascript")
assert "Test error (optimized)" in result
def test_no_test_source(self):
"""
Missing test source shows 'Not available'.
"""
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
)
result = build_test_details([diff])
assert "Not available" in result
# -------------------------------------------------------------------
# build_user_prompt
# -------------------------------------------------------------------
class TestBuildUserPrompt:
"""Tests for build_user_prompt."""
def test_placeholders_filled(self):
"""
All placeholders in the template are filled.
"""
template = (
"Original: {original_source_code}\n"
"Modified: {modified_source_code}\n"
"Tests: {test_details}"
)
diff = BehaviorDiff(
scope=BehaviorDiffScope.RETURN_VALUE,
original_value=1,
candidate_value=2,
original_pass=True,
candidate_pass=False,
)
result = build_user_prompt(
template,
original_source_code="def f(): return 1",
modified_source_code="def f(): return 2",
test_diffs=[diff],
)
assert "def f(): return 1" in result
assert "def f(): return 2" in result
assert "Expected: 1" in result
def test_empty_diffs(self):
"""
Empty test_diffs fills test_details as empty.
"""
template = "Tests: [{test_details}]"
result = build_user_prompt(
template,
original_source_code="",
modified_source_code="",
test_diffs=[],
)
assert "Tests: []" == result
# -------------------------------------------------------------------
# apply_patches_to_optimized_code
# -------------------------------------------------------------------
class TestApplyPatches:
"""Tests for apply_patches_to_optimized_code."""
def test_simple_patch(self):
"""
A simple SEARCH/REPLACE block patches the code.
"""
modified = "def f():\n return 2\n"
llm_response = (
"<replace_in_file>\n"
"<path>file.py</path>\n"
"<diff>\n"
"<<<<<<< SEARCH\n"
" return 2\n"
"=======\n"
" return 1\n"
">>>>>>> REPLACE\n"
"</diff>\n"
"</replace_in_file>"
)
result = apply_patches_to_optimized_code(llm_response, modified)
assert "return 1" in result
def test_no_patches_returns_original(self):
"""
When LLM response has no patches, original code is returned.
"""
modified = "def f(): return 2"
result = apply_patches_to_optimized_code("No patches here", modified)
assert "def f(): return 2" in result
def test_markdown_input(self):
"""
Markdown-wrapped source code is handled correctly.
"""
modified = "```python:main.py\ndef f():\n return 2\n```"
llm_response = (
"<replace_in_file>\n"
"<path>main.py</path>\n"
"<diff>\n"
"<<<<<<< SEARCH\n"
" return 2\n"
"=======\n"
" return 1\n"
">>>>>>> REPLACE\n"
"</diff>\n"
"</replace_in_file>"
)
result = apply_patches_to_optimized_code(llm_response, modified)
assert "return 1" in result
# -------------------------------------------------------------------
# is_valid_repair
# -------------------------------------------------------------------
class TestIsValidRepair:
"""Tests for is_valid_repair."""
def test_valid_python(self):
"""
Valid Python code passes validation.
"""
assert is_valid_repair("def f():\n return 1\n", "")
def test_invalid_python(self):
"""
Syntax error code fails validation.
"""
assert not is_valid_repair("def f(\n", "")
def test_empty_code(self):
"""
Empty code fails validation.
"""
assert not is_valid_repair("", "")
assert not is_valid_repair(" ", "")
def test_markdown_valid(self):
"""
Valid markdown-wrapped code passes.
"""
code = "```python:main.py\ndef f():\n return 1\n```"
assert is_valid_repair(code, code)
def test_markdown_structure_changed(self):
"""
Changed file paths between old and new fail validation.
"""
old = "```python:main.py\ndef f(): pass\n```"
new = "```python:other.py\ndef f(): pass\n```"
assert not is_valid_repair(new, old)
def test_non_python_skips_cst(self):
"""
Non-python language skips libcst validation.
"""
assert is_valid_repair(
"function f() { return 1; }",
"",
language="javascript",
)