codeflash-agent/packages/codeflash-api/tests/test_refinement.py
Kevin Turcios 935c6f229e Add remaining endpoints: repair, refinement, adaptive, explain, review, ranking, jit, workflow, testgen, log_features
Port all P1 endpoints from the Django aiservice to FastAPI:

- repair: 2-attempt LLM retry, SearchAndReplaceDiff patch application
- refinement: parallel LLM calls via asyncio.gather, single/multi-file
  context dispatch, XML explanation extraction, deduplication
- adaptive: single LLM call with previous candidate history
- explain: conditional throughput/concurrency/acceptance prompt sections,
  XML <explain> tag extraction
- review: 4-dimension scoring, JSON code block extraction, 2-attempt retry
- ranking: 4-dimension weighted scoring, JSON extraction with 3 fallbacks
  (direct parse, markdown block, brace matching), legacy XML fallback
- jit: reuses optimize pipeline with JIT-specific prompts
- workflow: 3-tier regex YAML extraction, LLM-generated CI steps
- testgen: stub returning 501 (language-specific logic deferred)
- log_features: trace_id validation, DB write stubbed

Also adds:
- Task-specific model assignments in llm/_models.py
- XML tag extraction utility in languages/python/_xml.py
- All 11 routers registered in _app.py

348 tests passing, all lint clean.
2026-04-21 22:36:31 -05:00

459 lines
15 KiB
Python

from __future__ import annotations
import uuid
import libcst as cst
import pytest
from codeflash_api.refinement._context import (
BaseRefinerContext,
MultiRefinerContext,
RefinementContextData,
SingleRefinerContext,
)
from codeflash_api.refinement.schemas import (
RefinementErrorResponse,
RefinementRequestItem,
RefinementResponse,
RefinementResponseItem,
)
def _make_ctx_data(**overrides: object) -> RefinementContextData:
"""
Build a RefinementContextData with sensible defaults.
"""
defaults = {
"original_source_code": "def f():\n return 1\n",
"original_line_profiler_results": None,
"original_code_runtime": "100ms",
"read_only_dependency_code": None,
"optimized_source_code": "def f():\n return 2\n",
"optimized_line_profiler_results": None,
"optimized_code_runtime": "50ms",
"speedup": "2x",
"optimized_explanation": "Changed return value",
"python_version": "3.12.0",
"function_references": None,
"language": "python",
"language_version": None,
}
defaults.update(overrides)
return RefinementContextData(**defaults)
# ── Schemas ─────────────────────────────────────────────────────
class TestRefinementRequestItem:
"""Tests for RefinementRequestItem schema."""
def test_minimal_fields(self) -> None:
"""
Only the required fields are needed.
"""
item = RefinementRequestItem(
trace_id=str(uuid.uuid4()),
optimization_id=str(uuid.uuid4()),
original_source_code="def f(): pass",
optimized_source_code="def f(): return 1",
)
assert item.language == "python"
assert item.speedup == ""
def test_all_fields(self) -> None:
"""
All optional fields can be set.
"""
item = RefinementRequestItem(
trace_id=str(uuid.uuid4()),
optimization_id=str(uuid.uuid4()),
original_source_code="def f(): pass",
optimized_source_code="def f(): return 1",
original_line_profiler_results="profiler data",
read_only_dependency_code="import os",
optimized_line_profiler_results="opt profiler",
optimized_explanation="faster",
original_code_runtime="100ms",
optimized_code_runtime="50ms",
speedup="2x",
python_version="3.12.0",
function_references="called in loop",
call_sequence=1,
language="python",
language_version="3.12.0",
rerun_trace_id=str(uuid.uuid4()),
)
assert item.call_sequence == 1
assert item.language == "python"
class TestRefinementResponseItem:
"""Tests for RefinementResponseItem schema."""
def test_required_fields(self) -> None:
"""
source_code, explanation, and optimization_id are required.
"""
item = RefinementResponseItem(
source_code="def f(): return 1",
explanation="optimized",
optimization_id=str(uuid.uuid4()),
)
assert item.parent_id is None
def test_with_parent_id(self) -> None:
"""
parent_id can be set.
"""
pid = str(uuid.uuid4())
item = RefinementResponseItem(
source_code="def f(): return 1",
explanation="optimized",
optimization_id=str(uuid.uuid4()),
parent_id=pid,
)
assert pid == item.parent_id
class TestRefinementResponse:
"""Tests for RefinementResponse schema."""
def test_empty_list(self) -> None:
"""
Empty refinements list is valid.
"""
resp = RefinementResponse(refinements=[])
assert [] == resp.refinements
def test_with_items(self) -> None:
"""
Response holds multiple items.
"""
items = [
RefinementResponseItem(
source_code=f"def f(): return {i}",
explanation=f"opt {i}",
optimization_id=str(uuid.uuid4()),
)
for i in range(3)
]
resp = RefinementResponse(refinements=items)
assert 3 == len(resp.refinements)
class TestRefinementErrorResponse:
"""Tests for RefinementErrorResponse schema."""
def test_error_field(self) -> None:
"""
Error message is stored.
"""
resp = RefinementErrorResponse(error="something broke")
assert "something broke" == resp.error
# ── Context Data ────────────────────────────────────────────────
class TestRefinementContextData:
"""Tests for RefinementContextData attrs class."""
def test_frozen(self) -> None:
"""
RefinementContextData is immutable.
"""
data = _make_ctx_data()
with pytest.raises(AttributeError):
data.speedup = "3x" # type: ignore[misc]
def test_defaults(self) -> None:
"""
Optional fields default to expected values.
"""
data = RefinementContextData(
original_source_code="def f(): pass",
original_line_profiler_results=None,
original_code_runtime="100ms",
read_only_dependency_code=None,
optimized_source_code="def f(): return 1",
optimized_line_profiler_results=None,
optimized_code_runtime="50ms",
speedup="2x",
)
assert "" == data.optimized_explanation
assert "python" == data.language
assert data.language_version is None
# ── Context Factory ─────────────────────────────────────────────
class TestBaseRefinerContextFactory:
"""Tests for BaseRefinerContext.get_dynamic_context."""
def test_single_file_returns_single_context(self) -> None:
"""
Plain Python source yields a SingleRefinerContext.
"""
data = _make_ctx_data(optimized_source_code="def f():\n return 2\n")
ctx = BaseRefinerContext.get_dynamic_context(
ctx_data=data,
base_system_prompt="system",
base_user_prompt="user",
)
assert isinstance(ctx, SingleRefinerContext)
def test_multi_file_returns_multi_context(self) -> None:
"""
Markdown-fenced multi-file source yields MultiRefinerContext.
"""
multi_src = (
"```python:foo.py\ndef f():\n return 2\n```\n"
"```python:bar.py\ndef g():\n return 3\n```"
)
data = _make_ctx_data(optimized_source_code=multi_src)
ctx = BaseRefinerContext.get_dynamic_context(
ctx_data=data,
base_system_prompt="system",
base_user_prompt="user",
)
assert isinstance(ctx, MultiRefinerContext)
# ── User Prompt ─────────────────────────────────────────────────
class TestSingleRefinerContextPrompt:
"""Tests for SingleRefinerContext user prompt formatting."""
def test_user_prompt_contains_source(self) -> None:
"""
The user prompt includes the original source code.
"""
data = _make_ctx_data(original_source_code="def hello(): pass")
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt=(
"{original_source_code} "
"{optimized_source_code} "
"{original_line_profiler_results} "
"{optimized_line_profiler_results} "
"{optimized_explanation} "
"{original_code_runtime} "
"{optimized_code_runtime} "
"{speedup} "
"{read_only_dependency_code} "
"{python_version} "
"{function_references}"
),
)
prompt = ctx.get_user_prompt()
assert "def hello(): pass" in prompt
def test_none_profiler_results_default(self) -> None:
"""
None profiler results show a placeholder.
"""
data = _make_ctx_data(
original_line_profiler_results=None,
optimized_line_profiler_results=None,
)
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt=(
"{original_line_profiler_results} "
"{optimized_line_profiler_results} "
"{original_source_code} "
"{optimized_source_code} "
"{optimized_explanation} "
"{original_code_runtime} "
"{optimized_code_runtime} "
"{speedup} "
"{read_only_dependency_code} "
"{python_version} "
"{function_references}"
),
)
prompt = ctx.get_user_prompt()
assert "[No profiler results available]" in prompt
# ── Validation ──────────────────────────────────────────────────
class TestSingleRefinerContextValidation:
"""Tests for SingleRefinerContext validation methods."""
def test_valid_refinement_different_code(self) -> None:
"""
Code that differs from the original and parses is valid.
"""
data = _make_ctx_data(optimized_source_code="def f():\n return 1\n")
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
assert ctx.is_valid_refinement("def f():\n return 42\n")
def test_invalid_refinement_empty(self) -> None:
"""
Empty code is invalid.
"""
data = _make_ctx_data()
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
assert not ctx.is_valid_refinement("")
def test_invalid_refinement_same_as_optimized(self) -> None:
"""
Code identical to the optimized source is invalid.
"""
src = "def f():\n return 2\n"
data = _make_ctx_data(optimized_source_code=src)
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
assert not ctx.is_valid_refinement(src)
def test_invalid_refinement_syntax_error(self) -> None:
"""
Code with syntax errors is invalid.
"""
data = _make_ctx_data()
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
assert not ctx.is_valid_refinement("def f(\n")
def test_validate_code_syntax_valid(self) -> None:
"""
Valid Python code passes syntax validation.
"""
data = _make_ctx_data()
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
ctx.validate_code_syntax("def f():\n return 1\n")
def test_validate_code_syntax_invalid(self) -> None:
"""
Invalid Python code raises an exception.
"""
data = _make_ctx_data()
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
with pytest.raises(cst.ParserSyntaxError):
ctx.validate_code_syntax("def f(\n")
# ── Patch Application ───────────────────────────────────────────
class TestSingleRefinerContextPatches:
"""Tests for SingleRefinerContext.apply_patches_to_optimized_code."""
def test_apply_simple_patch(self) -> None:
"""
A SEARCH/REPLACE patch is applied to the optimized code.
"""
data = _make_ctx_data(optimized_source_code="def f():\n return 2\n")
ctx = SingleRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
llm_output = (
"<replace_in_file>\n"
"<path>file</path>\n"
"<diff>\n"
"<<<<<<< SEARCH\n"
" return 2\n"
"=======\n"
" return 42\n"
">>>>>>> REPLACE\n"
"</diff>\n"
"</replace_in_file>"
)
result = ctx.apply_patches_to_optimized_code(llm_output)
assert "return 42" in result
assert "return 2" not in result
class TestMultiRefinerContextValidation:
"""Tests for MultiRefinerContext validation."""
def test_valid_multi_file_refinement(self) -> None:
"""
Multi-file code with same structure and valid syntax is valid.
"""
original = (
"```python:foo.py\ndef f():\n return 1\n```\n"
"```python:bar.py\ndef g():\n return 2\n```"
)
data = _make_ctx_data(optimized_source_code=original)
ctx = MultiRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
new_code = (
"```python:foo.py\ndef f():\n return 42\n```\n"
"```python:bar.py\ndef g():\n return 2\n```"
)
assert ctx.is_valid_refinement(new_code)
def test_invalid_structure_change(self) -> None:
"""
Changing the file structure makes the refinement invalid.
"""
original = (
"```python:foo.py\ndef f():\n return 1\n```\n"
"```python:bar.py\ndef g():\n return 2\n```"
)
data = _make_ctx_data(optimized_source_code=original)
ctx = MultiRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
new_code = (
"```python:foo.py\ndef f():\n return 42\n```\n"
"```python:baz.py\ndef h():\n return 3\n```"
)
assert not ctx.is_valid_refinement(new_code)
def test_invalid_syntax_in_one_file(self) -> None:
"""
Syntax error in any file makes the refinement invalid.
"""
original = (
"```python:foo.py\ndef f():\n return 1\n```\n"
"```python:bar.py\ndef g():\n return 2\n```"
)
data = _make_ctx_data(optimized_source_code=original)
ctx = MultiRefinerContext(
ctx_data=data,
base_system_prompt="sys",
base_user_prompt="user",
)
new_code = (
"```python:foo.py\ndef f(:\n```\n"
"```python:bar.py\ndef g():\n return 2\n```"
)
assert not ctx.is_valid_refinement(new_code)