mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
* fix: resolve all ruff lint errors across repo Auto-fixed 31 errors (unused imports, formatting, simplifications). Manually fixed 14 remaining: - EXE001: removed shebangs from non-executable bench scripts - C417: replaced map(lambda) with generator expression - C901/PLR0915: extracted _write_and_instrument_tests from generate_ai_tests - C901/PLR0912: extracted _parse_toml_addopts and _ini_section_name from modify_addopts - RUF001/RUF002: replaced ambiguous Unicode chars (en dash, multiplication sign) - FBT002: made boolean params keyword-only in report functions - E402: moved `import re` to top of file in security reports * fix: resolve pre-existing mypy errors across packages - _testgen.py: annotate `generated` as `str` to avoid no-any-return - _test_runner.py: use str() for TimeoutExpired stdout/stderr (bytes|str), remove unused type: ignore on proc.kill() - _candidate_eval.py: annotate `speedup` as `float` to avoid no-any-return from lazy-loaded performance_gain
396 lines
12 KiB
Python
396 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
import pytest
|
|
|
|
from codeflash_api.repair._context import (
|
|
apply_patches_to_optimized_code,
|
|
build_test_details,
|
|
build_user_prompt,
|
|
is_valid_repair,
|
|
)
|
|
from codeflash_api.repair.schemas import (
|
|
BehaviorDiff,
|
|
BehaviorDiffScope,
|
|
CodeRepairRequest,
|
|
)
|
|
|
|
# -------------------------------------------------------------------
|
|
# Schemas
|
|
# -------------------------------------------------------------------
|
|
|
|
|
|
class TestCodeRepairRequest:
|
|
"""Tests for CodeRepairRequest schema."""
|
|
|
|
def test_minimal_valid_request(self):
|
|
"""
|
|
A request with required fields only deserializes.
|
|
"""
|
|
req = CodeRepairRequest(
|
|
trace_id=str(uuid.uuid4()),
|
|
optimization_id=str(uuid.uuid4()),
|
|
original_source_code="def f(): pass",
|
|
modified_source_code="def f(): return 1",
|
|
test_diffs=[],
|
|
)
|
|
assert "python" == req.language
|
|
assert req.rerun_trace_id is None
|
|
|
|
def test_full_request(self):
|
|
"""
|
|
A request with all fields deserializes.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value="1",
|
|
candidate_value="2",
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="assert f() == 1",
|
|
)
|
|
req = CodeRepairRequest(
|
|
trace_id=str(uuid.uuid4()),
|
|
optimization_id=str(uuid.uuid4()),
|
|
original_source_code="def f(): return 1",
|
|
modified_source_code="def f(): return 2",
|
|
test_diffs=[diff],
|
|
language="python",
|
|
rerun_trace_id=str(uuid.uuid4()),
|
|
)
|
|
assert 1 == len(req.test_diffs)
|
|
assert BehaviorDiffScope.RETURN_VALUE == req.test_diffs[0].scope
|
|
|
|
|
|
class TestBehaviorDiff:
|
|
"""Tests for BehaviorDiff schema."""
|
|
|
|
def test_all_scopes(self):
|
|
"""
|
|
Every scope enum value can be used.
|
|
"""
|
|
for scope in BehaviorDiffScope:
|
|
diff = BehaviorDiff(
|
|
scope=scope,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
)
|
|
assert scope == diff.scope
|
|
|
|
def test_none_values(self):
|
|
"""
|
|
Optional fields default to None.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.DID_PASS,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
)
|
|
assert diff.original_value is None
|
|
assert diff.candidate_value is None
|
|
assert diff.test_src_code is None
|
|
assert diff.candidate_pytest_error is None
|
|
assert diff.original_pytest_error is None
|
|
|
|
|
|
# -------------------------------------------------------------------
|
|
# build_test_details
|
|
# -------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildTestDetails:
|
|
"""Tests for build_test_details."""
|
|
|
|
def test_empty_diffs(self):
|
|
"""
|
|
Empty test_diffs produces empty string.
|
|
"""
|
|
assert "" == build_test_details([])
|
|
|
|
def test_return_value_diff(self):
|
|
"""
|
|
Return value diffs show Expected/Got lines.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value="hello",
|
|
candidate_value="world",
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="assert f() == 'hello'",
|
|
)
|
|
result = build_test_details([diff])
|
|
assert "Expected: 'hello'" in result
|
|
assert "Got: 'world'" in result
|
|
assert "assert f() == 'hello'" in result
|
|
|
|
def test_did_pass_diff(self):
|
|
"""
|
|
DID_PASS diffs show pass/fail status.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.DID_PASS,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_something",
|
|
)
|
|
result = build_test_details([diff])
|
|
assert "Passed" in result
|
|
assert "Failed" in result
|
|
|
|
def test_pytest_errors_included(self):
|
|
"""
|
|
Pytest errors appear in the output when present.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_func",
|
|
candidate_pytest_error="AssertionError: 1 != 2",
|
|
original_pytest_error="",
|
|
)
|
|
result = build_test_details([diff])
|
|
assert "Pytest error (optimized)" in result
|
|
|
|
def test_multiple_diffs_same_test(self):
|
|
"""
|
|
Multiple diffs for the same test source share a header.
|
|
"""
|
|
diffs = [
|
|
BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_func",
|
|
),
|
|
BehaviorDiff(
|
|
scope=BehaviorDiffScope.STDOUT,
|
|
original_value="out",
|
|
candidate_value="err",
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_func",
|
|
),
|
|
]
|
|
result = build_test_details(diffs)
|
|
assert result.count("**Test source:**") == 1
|
|
|
|
def test_multiple_diffs_different_tests(self):
|
|
"""
|
|
Different test sources get separate sections.
|
|
"""
|
|
diffs = [
|
|
BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_a",
|
|
),
|
|
BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=3,
|
|
candidate_value=4,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_b",
|
|
),
|
|
]
|
|
result = build_test_details(diffs)
|
|
assert "---" in result
|
|
assert result.count("**Test source:**") == 2
|
|
|
|
def test_non_python_language(self):
|
|
"""
|
|
Non-python language uses 'Test error' label.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
test_src_code="test_func",
|
|
candidate_pytest_error="Error in test",
|
|
)
|
|
result = build_test_details([diff], language="javascript")
|
|
assert "Test error (optimized)" in result
|
|
|
|
def test_no_test_source(self):
|
|
"""
|
|
Missing test source shows 'Not available'.
|
|
"""
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
)
|
|
result = build_test_details([diff])
|
|
assert "Not available" in result
|
|
|
|
|
|
# -------------------------------------------------------------------
|
|
# build_user_prompt
|
|
# -------------------------------------------------------------------
|
|
|
|
|
|
class TestBuildUserPrompt:
|
|
"""Tests for build_user_prompt."""
|
|
|
|
def test_placeholders_filled(self):
|
|
"""
|
|
All placeholders in the template are filled.
|
|
"""
|
|
template = (
|
|
"Original: {original_source_code}\n"
|
|
"Modified: {modified_source_code}\n"
|
|
"Tests: {test_details}"
|
|
)
|
|
diff = BehaviorDiff(
|
|
scope=BehaviorDiffScope.RETURN_VALUE,
|
|
original_value=1,
|
|
candidate_value=2,
|
|
original_pass=True,
|
|
candidate_pass=False,
|
|
)
|
|
result = build_user_prompt(
|
|
template,
|
|
original_source_code="def f(): return 1",
|
|
modified_source_code="def f(): return 2",
|
|
test_diffs=[diff],
|
|
)
|
|
assert "def f(): return 1" in result
|
|
assert "def f(): return 2" in result
|
|
assert "Expected: 1" in result
|
|
|
|
def test_empty_diffs(self):
|
|
"""
|
|
Empty test_diffs fills test_details as empty.
|
|
"""
|
|
template = "Tests: [{test_details}]"
|
|
result = build_user_prompt(
|
|
template,
|
|
original_source_code="",
|
|
modified_source_code="",
|
|
test_diffs=[],
|
|
)
|
|
assert "Tests: []" == result
|
|
|
|
|
|
# -------------------------------------------------------------------
|
|
# apply_patches_to_optimized_code
|
|
# -------------------------------------------------------------------
|
|
|
|
|
|
class TestApplyPatches:
|
|
"""Tests for apply_patches_to_optimized_code."""
|
|
|
|
def test_simple_patch(self):
|
|
"""
|
|
A simple SEARCH/REPLACE block patches the code.
|
|
"""
|
|
modified = "def f():\n return 2\n"
|
|
llm_response = (
|
|
"<replace_in_file>\n"
|
|
"<path>file.py</path>\n"
|
|
"<diff>\n"
|
|
"<<<<<<< SEARCH\n"
|
|
" return 2\n"
|
|
"=======\n"
|
|
" return 1\n"
|
|
">>>>>>> REPLACE\n"
|
|
"</diff>\n"
|
|
"</replace_in_file>"
|
|
)
|
|
result = apply_patches_to_optimized_code(llm_response, modified)
|
|
assert "return 1" in result
|
|
|
|
def test_no_patches_returns_original(self):
|
|
"""
|
|
When LLM response has no patches, original code is returned.
|
|
"""
|
|
modified = "def f(): return 2"
|
|
result = apply_patches_to_optimized_code("No patches here", modified)
|
|
assert "def f(): return 2" in result
|
|
|
|
def test_markdown_input(self):
|
|
"""
|
|
Markdown-wrapped source code is handled correctly.
|
|
"""
|
|
modified = "```python:main.py\ndef f():\n return 2\n```"
|
|
llm_response = (
|
|
"<replace_in_file>\n"
|
|
"<path>main.py</path>\n"
|
|
"<diff>\n"
|
|
"<<<<<<< SEARCH\n"
|
|
" return 2\n"
|
|
"=======\n"
|
|
" return 1\n"
|
|
">>>>>>> REPLACE\n"
|
|
"</diff>\n"
|
|
"</replace_in_file>"
|
|
)
|
|
result = apply_patches_to_optimized_code(llm_response, modified)
|
|
assert "return 1" in result
|
|
|
|
|
|
# -------------------------------------------------------------------
|
|
# is_valid_repair
|
|
# -------------------------------------------------------------------
|
|
|
|
|
|
class TestIsValidRepair:
|
|
"""Tests for is_valid_repair."""
|
|
|
|
def test_valid_python(self):
|
|
"""
|
|
Valid Python code passes validation.
|
|
"""
|
|
assert is_valid_repair("def f():\n return 1\n", "")
|
|
|
|
def test_invalid_python(self):
|
|
"""
|
|
Syntax error code fails validation.
|
|
"""
|
|
assert not is_valid_repair("def f(\n", "")
|
|
|
|
def test_empty_code(self):
|
|
"""
|
|
Empty code fails validation.
|
|
"""
|
|
assert not is_valid_repair("", "")
|
|
assert not is_valid_repair(" ", "")
|
|
|
|
def test_markdown_valid(self):
|
|
"""
|
|
Valid markdown-wrapped code passes.
|
|
"""
|
|
code = "```python:main.py\ndef f():\n return 1\n```"
|
|
assert is_valid_repair(code, code)
|
|
|
|
def test_markdown_structure_changed(self):
|
|
"""
|
|
Changed file paths between old and new fail validation.
|
|
"""
|
|
old = "```python:main.py\ndef f(): pass\n```"
|
|
new = "```python:other.py\ndef f(): pass\n```"
|
|
assert not is_valid_repair(new, old)
|
|
|
|
def test_non_python_skips_cst(self):
|
|
"""
|
|
Non-python language skips libcst validation.
|
|
"""
|
|
assert is_valid_repair(
|
|
"function f() { return 1; }",
|
|
"",
|
|
language="javascript",
|
|
)
|