Port all P1 endpoints from the Django aiservice to FastAPI: - repair: 2-attempt LLM retry, SearchAndReplaceDiff patch application - refinement: parallel LLM calls via asyncio.gather, single/multi-file context dispatch, XML explanation extraction, deduplication - adaptive: single LLM call with previous candidate history - explain: conditional throughput/concurrency/acceptance prompt sections, XML <explain> tag extraction - review: 4-dimension scoring, JSON code block extraction, 2-attempt retry - ranking: 4-dimension weighted scoring, JSON extraction with 3 fallbacks (direct parse, markdown block, brace matching), legacy XML fallback - jit: reuses optimize pipeline with JIT-specific prompts - workflow: 3-tier regex YAML extraction, LLM-generated CI steps - testgen: stub returning 501 (language-specific logic deferred) - log_features: trace_id validation, DB write stubbed Also adds: - Task-specific model assignments in llm/_models.py - XML tag extraction utility in languages/python/_xml.py - All 11 routers registered in _app.py 348 tests passing, all lint clean.
459 lines
15 KiB
Python
459 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
import libcst as cst
|
|
import pytest
|
|
|
|
from codeflash_api.refinement._context import (
|
|
BaseRefinerContext,
|
|
MultiRefinerContext,
|
|
RefinementContextData,
|
|
SingleRefinerContext,
|
|
)
|
|
from codeflash_api.refinement.schemas import (
|
|
RefinementErrorResponse,
|
|
RefinementRequestItem,
|
|
RefinementResponse,
|
|
RefinementResponseItem,
|
|
)
|
|
|
|
|
|
def _make_ctx_data(**overrides: object) -> RefinementContextData:
|
|
"""
|
|
Build a RefinementContextData with sensible defaults.
|
|
"""
|
|
defaults = {
|
|
"original_source_code": "def f():\n return 1\n",
|
|
"original_line_profiler_results": None,
|
|
"original_code_runtime": "100ms",
|
|
"read_only_dependency_code": None,
|
|
"optimized_source_code": "def f():\n return 2\n",
|
|
"optimized_line_profiler_results": None,
|
|
"optimized_code_runtime": "50ms",
|
|
"speedup": "2x",
|
|
"optimized_explanation": "Changed return value",
|
|
"python_version": "3.12.0",
|
|
"function_references": None,
|
|
"language": "python",
|
|
"language_version": None,
|
|
}
|
|
defaults.update(overrides)
|
|
return RefinementContextData(**defaults)
|
|
|
|
|
|
# ── Schemas ─────────────────────────────────────────────────────
|
|
|
|
|
|
class TestRefinementRequestItem:
|
|
"""Tests for RefinementRequestItem schema."""
|
|
|
|
def test_minimal_fields(self) -> None:
|
|
"""
|
|
Only the required fields are needed.
|
|
"""
|
|
item = RefinementRequestItem(
|
|
trace_id=str(uuid.uuid4()),
|
|
optimization_id=str(uuid.uuid4()),
|
|
original_source_code="def f(): pass",
|
|
optimized_source_code="def f(): return 1",
|
|
)
|
|
assert item.language == "python"
|
|
assert item.speedup == ""
|
|
|
|
def test_all_fields(self) -> None:
|
|
"""
|
|
All optional fields can be set.
|
|
"""
|
|
item = RefinementRequestItem(
|
|
trace_id=str(uuid.uuid4()),
|
|
optimization_id=str(uuid.uuid4()),
|
|
original_source_code="def f(): pass",
|
|
optimized_source_code="def f(): return 1",
|
|
original_line_profiler_results="profiler data",
|
|
read_only_dependency_code="import os",
|
|
optimized_line_profiler_results="opt profiler",
|
|
optimized_explanation="faster",
|
|
original_code_runtime="100ms",
|
|
optimized_code_runtime="50ms",
|
|
speedup="2x",
|
|
python_version="3.12.0",
|
|
function_references="called in loop",
|
|
call_sequence=1,
|
|
language="python",
|
|
language_version="3.12.0",
|
|
rerun_trace_id=str(uuid.uuid4()),
|
|
)
|
|
assert item.call_sequence == 1
|
|
assert item.language == "python"
|
|
|
|
|
|
class TestRefinementResponseItem:
|
|
"""Tests for RefinementResponseItem schema."""
|
|
|
|
def test_required_fields(self) -> None:
|
|
"""
|
|
source_code, explanation, and optimization_id are required.
|
|
"""
|
|
item = RefinementResponseItem(
|
|
source_code="def f(): return 1",
|
|
explanation="optimized",
|
|
optimization_id=str(uuid.uuid4()),
|
|
)
|
|
assert item.parent_id is None
|
|
|
|
def test_with_parent_id(self) -> None:
|
|
"""
|
|
parent_id can be set.
|
|
"""
|
|
pid = str(uuid.uuid4())
|
|
item = RefinementResponseItem(
|
|
source_code="def f(): return 1",
|
|
explanation="optimized",
|
|
optimization_id=str(uuid.uuid4()),
|
|
parent_id=pid,
|
|
)
|
|
assert pid == item.parent_id
|
|
|
|
|
|
class TestRefinementResponse:
|
|
"""Tests for RefinementResponse schema."""
|
|
|
|
def test_empty_list(self) -> None:
|
|
"""
|
|
Empty refinements list is valid.
|
|
"""
|
|
resp = RefinementResponse(refinements=[])
|
|
assert [] == resp.refinements
|
|
|
|
def test_with_items(self) -> None:
|
|
"""
|
|
Response holds multiple items.
|
|
"""
|
|
items = [
|
|
RefinementResponseItem(
|
|
source_code=f"def f(): return {i}",
|
|
explanation=f"opt {i}",
|
|
optimization_id=str(uuid.uuid4()),
|
|
)
|
|
for i in range(3)
|
|
]
|
|
resp = RefinementResponse(refinements=items)
|
|
assert 3 == len(resp.refinements)
|
|
|
|
|
|
class TestRefinementErrorResponse:
|
|
"""Tests for RefinementErrorResponse schema."""
|
|
|
|
def test_error_field(self) -> None:
|
|
"""
|
|
Error message is stored.
|
|
"""
|
|
resp = RefinementErrorResponse(error="something broke")
|
|
assert "something broke" == resp.error
|
|
|
|
|
|
# ── Context Data ────────────────────────────────────────────────
|
|
|
|
|
|
class TestRefinementContextData:
|
|
"""Tests for RefinementContextData attrs class."""
|
|
|
|
def test_frozen(self) -> None:
|
|
"""
|
|
RefinementContextData is immutable.
|
|
"""
|
|
data = _make_ctx_data()
|
|
with pytest.raises(AttributeError):
|
|
data.speedup = "3x" # type: ignore[misc]
|
|
|
|
def test_defaults(self) -> None:
|
|
"""
|
|
Optional fields default to expected values.
|
|
"""
|
|
data = RefinementContextData(
|
|
original_source_code="def f(): pass",
|
|
original_line_profiler_results=None,
|
|
original_code_runtime="100ms",
|
|
read_only_dependency_code=None,
|
|
optimized_source_code="def f(): return 1",
|
|
optimized_line_profiler_results=None,
|
|
optimized_code_runtime="50ms",
|
|
speedup="2x",
|
|
)
|
|
assert "" == data.optimized_explanation
|
|
assert "python" == data.language
|
|
assert data.language_version is None
|
|
|
|
|
|
# ── Context Factory ─────────────────────────────────────────────
|
|
|
|
|
|
class TestBaseRefinerContextFactory:
|
|
"""Tests for BaseRefinerContext.get_dynamic_context."""
|
|
|
|
def test_single_file_returns_single_context(self) -> None:
|
|
"""
|
|
Plain Python source yields a SingleRefinerContext.
|
|
"""
|
|
data = _make_ctx_data(optimized_source_code="def f():\n return 2\n")
|
|
ctx = BaseRefinerContext.get_dynamic_context(
|
|
ctx_data=data,
|
|
base_system_prompt="system",
|
|
base_user_prompt="user",
|
|
)
|
|
assert isinstance(ctx, SingleRefinerContext)
|
|
|
|
def test_multi_file_returns_multi_context(self) -> None:
|
|
"""
|
|
Markdown-fenced multi-file source yields MultiRefinerContext.
|
|
"""
|
|
multi_src = (
|
|
"```python:foo.py\ndef f():\n return 2\n```\n"
|
|
"```python:bar.py\ndef g():\n return 3\n```"
|
|
)
|
|
data = _make_ctx_data(optimized_source_code=multi_src)
|
|
ctx = BaseRefinerContext.get_dynamic_context(
|
|
ctx_data=data,
|
|
base_system_prompt="system",
|
|
base_user_prompt="user",
|
|
)
|
|
assert isinstance(ctx, MultiRefinerContext)
|
|
|
|
|
|
# ── User Prompt ─────────────────────────────────────────────────
|
|
|
|
|
|
class TestSingleRefinerContextPrompt:
|
|
"""Tests for SingleRefinerContext user prompt formatting."""
|
|
|
|
def test_user_prompt_contains_source(self) -> None:
|
|
"""
|
|
The user prompt includes the original source code.
|
|
"""
|
|
data = _make_ctx_data(original_source_code="def hello(): pass")
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt=(
|
|
"{original_source_code} "
|
|
"{optimized_source_code} "
|
|
"{original_line_profiler_results} "
|
|
"{optimized_line_profiler_results} "
|
|
"{optimized_explanation} "
|
|
"{original_code_runtime} "
|
|
"{optimized_code_runtime} "
|
|
"{speedup} "
|
|
"{read_only_dependency_code} "
|
|
"{python_version} "
|
|
"{function_references}"
|
|
),
|
|
)
|
|
prompt = ctx.get_user_prompt()
|
|
assert "def hello(): pass" in prompt
|
|
|
|
def test_none_profiler_results_default(self) -> None:
|
|
"""
|
|
None profiler results show a placeholder.
|
|
"""
|
|
data = _make_ctx_data(
|
|
original_line_profiler_results=None,
|
|
optimized_line_profiler_results=None,
|
|
)
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt=(
|
|
"{original_line_profiler_results} "
|
|
"{optimized_line_profiler_results} "
|
|
"{original_source_code} "
|
|
"{optimized_source_code} "
|
|
"{optimized_explanation} "
|
|
"{original_code_runtime} "
|
|
"{optimized_code_runtime} "
|
|
"{speedup} "
|
|
"{read_only_dependency_code} "
|
|
"{python_version} "
|
|
"{function_references}"
|
|
),
|
|
)
|
|
prompt = ctx.get_user_prompt()
|
|
assert "[No profiler results available]" in prompt
|
|
|
|
|
|
# ── Validation ──────────────────────────────────────────────────
|
|
|
|
|
|
class TestSingleRefinerContextValidation:
|
|
"""Tests for SingleRefinerContext validation methods."""
|
|
|
|
def test_valid_refinement_different_code(self) -> None:
|
|
"""
|
|
Code that differs from the original and parses is valid.
|
|
"""
|
|
data = _make_ctx_data(optimized_source_code="def f():\n return 1\n")
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
assert ctx.is_valid_refinement("def f():\n return 42\n")
|
|
|
|
def test_invalid_refinement_empty(self) -> None:
|
|
"""
|
|
Empty code is invalid.
|
|
"""
|
|
data = _make_ctx_data()
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
assert not ctx.is_valid_refinement("")
|
|
|
|
def test_invalid_refinement_same_as_optimized(self) -> None:
|
|
"""
|
|
Code identical to the optimized source is invalid.
|
|
"""
|
|
src = "def f():\n return 2\n"
|
|
data = _make_ctx_data(optimized_source_code=src)
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
assert not ctx.is_valid_refinement(src)
|
|
|
|
def test_invalid_refinement_syntax_error(self) -> None:
|
|
"""
|
|
Code with syntax errors is invalid.
|
|
"""
|
|
data = _make_ctx_data()
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
assert not ctx.is_valid_refinement("def f(\n")
|
|
|
|
def test_validate_code_syntax_valid(self) -> None:
|
|
"""
|
|
Valid Python code passes syntax validation.
|
|
"""
|
|
data = _make_ctx_data()
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
ctx.validate_code_syntax("def f():\n return 1\n")
|
|
|
|
def test_validate_code_syntax_invalid(self) -> None:
|
|
"""
|
|
Invalid Python code raises an exception.
|
|
"""
|
|
data = _make_ctx_data()
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
with pytest.raises(cst.ParserSyntaxError):
|
|
ctx.validate_code_syntax("def f(\n")
|
|
|
|
|
|
# ── Patch Application ───────────────────────────────────────────
|
|
|
|
|
|
class TestSingleRefinerContextPatches:
|
|
"""Tests for SingleRefinerContext.apply_patches_to_optimized_code."""
|
|
|
|
def test_apply_simple_patch(self) -> None:
|
|
"""
|
|
A SEARCH/REPLACE patch is applied to the optimized code.
|
|
"""
|
|
data = _make_ctx_data(optimized_source_code="def f():\n return 2\n")
|
|
ctx = SingleRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
llm_output = (
|
|
"<replace_in_file>\n"
|
|
"<path>file</path>\n"
|
|
"<diff>\n"
|
|
"<<<<<<< SEARCH\n"
|
|
" return 2\n"
|
|
"=======\n"
|
|
" return 42\n"
|
|
">>>>>>> REPLACE\n"
|
|
"</diff>\n"
|
|
"</replace_in_file>"
|
|
)
|
|
result = ctx.apply_patches_to_optimized_code(llm_output)
|
|
assert "return 42" in result
|
|
assert "return 2" not in result
|
|
|
|
|
|
class TestMultiRefinerContextValidation:
|
|
"""Tests for MultiRefinerContext validation."""
|
|
|
|
def test_valid_multi_file_refinement(self) -> None:
|
|
"""
|
|
Multi-file code with same structure and valid syntax is valid.
|
|
"""
|
|
original = (
|
|
"```python:foo.py\ndef f():\n return 1\n```\n"
|
|
"```python:bar.py\ndef g():\n return 2\n```"
|
|
)
|
|
data = _make_ctx_data(optimized_source_code=original)
|
|
ctx = MultiRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
new_code = (
|
|
"```python:foo.py\ndef f():\n return 42\n```\n"
|
|
"```python:bar.py\ndef g():\n return 2\n```"
|
|
)
|
|
assert ctx.is_valid_refinement(new_code)
|
|
|
|
def test_invalid_structure_change(self) -> None:
|
|
"""
|
|
Changing the file structure makes the refinement invalid.
|
|
"""
|
|
original = (
|
|
"```python:foo.py\ndef f():\n return 1\n```\n"
|
|
"```python:bar.py\ndef g():\n return 2\n```"
|
|
)
|
|
data = _make_ctx_data(optimized_source_code=original)
|
|
ctx = MultiRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
new_code = (
|
|
"```python:foo.py\ndef f():\n return 42\n```\n"
|
|
"```python:baz.py\ndef h():\n return 3\n```"
|
|
)
|
|
assert not ctx.is_valid_refinement(new_code)
|
|
|
|
def test_invalid_syntax_in_one_file(self) -> None:
|
|
"""
|
|
Syntax error in any file makes the refinement invalid.
|
|
"""
|
|
original = (
|
|
"```python:foo.py\ndef f():\n return 1\n```\n"
|
|
"```python:bar.py\ndef g():\n return 2\n```"
|
|
)
|
|
data = _make_ctx_data(optimized_source_code=original)
|
|
ctx = MultiRefinerContext(
|
|
ctx_data=data,
|
|
base_system_prompt="sys",
|
|
base_user_prompt="user",
|
|
)
|
|
new_code = (
|
|
"```python:foo.py\ndef f(:\n```\n"
|
|
"```python:bar.py\ndef g():\n return 2\n```"
|
|
)
|
|
assert not ctx.is_valid_refinement(new_code)
|