Add testgen review and repair endpoints

Port /ai/testgen_review and /ai/testgen_repair from Django reference.
Review: parallel LLM calls per test source, auto-flags behavioral
failures, parses JSON verdicts. Repair: Jinja2 prompt templates,
syntax-error retry loop, Python code extraction and validation.

Schemas: TestgenReviewRequest/Response, TestRepairRequest/Response,
CoverageDetails, FunctionVerdict, TestSourceWithFailures.

23 tests covering: coverage context building, verdict parsing,
syntax error detection, endpoint success/error/retry/language paths,
and the model validator for python_version resolution.
This commit is contained in:
Kevin Turcios 2026-04-22 20:35:39 -05:00
parent 1d70d65914
commit 6abcc8daa3
9 changed files with 1252 additions and 3 deletions

View file

@ -96,6 +96,9 @@ def _register_routes(app: FastAPI) -> None:
from codeflash_api.review._router import (
router as review_router,
)
from codeflash_api.testgen._review_router import (
router as testgen_review_router,
)
from codeflash_api.testgen._router import (
router as testgen_router,
)
@ -107,6 +110,7 @@ def _register_routes(app: FastAPI) -> None:
app.include_router(adaptive_router)
app.include_router(explain_router)
app.include_router(testgen_router)
app.include_router(testgen_review_router)
app.include_router(jit_router)
app.include_router(ranking_router)
app.include_router(refinement_router)

View file

@ -0,0 +1,419 @@
"""POST /ai/testgen_review and POST /ai/testgen_repair endpoints."""
from __future__ import annotations
import ast
import asyncio
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any
import jinja2
from fastapi import APIRouter, Depends, HTTPException, Request
from codeflash_api.auth._deps import (
check_rate_limit,
require_auth,
track_usage,
)
from codeflash_api.auth.models import AuthenticatedUser # noqa: TC001
from codeflash_api.languages.python._markdown import (
extract_code_block_with_context,
)
from codeflash_api.llm._models import ANTHROPIC_CLAUDE_HAIKU_4_5
from codeflash_api.optimize._context import validate_trace_id
from codeflash_api.testgen.schemas import (
CoverageDetails,
FunctionVerdict,
TestGenErrorResponse,
TestgenReviewRequest,
TestgenReviewResponse,
TestRepairRequest,
TestRepairResponse,
TestReview,
TestSourceWithFailures,
)
if TYPE_CHECKING:
from codeflash_api.llm._client import LLMClient
log = logging.getLogger(__name__)
router = APIRouter()
_PROMPTS_DIR = Path(__file__).parent / "prompts"
_JINJA_ENV = jinja2.Environment( # noqa: S701
loader=jinja2.FileSystemLoader(str(_PROMPTS_DIR)),
keep_trailing_newline=True,
)
REVIEW_SYSTEM_TPL = _JINJA_ENV.get_template("review_system.md.j2")
REVIEW_USER_TPL = _JINJA_ENV.get_template("review_user.md.j2")
REPAIR_SYSTEM_TPL = _JINJA_ENV.get_template("repair_system.md.j2")
REPAIR_USER_TPL = _JINJA_ENV.get_template("repair_user.md.j2")
REVIEW_MODEL = ANTHROPIC_CLAUDE_HAIKU_4_5
REPAIR_MODEL = ANTHROPIC_CLAUDE_HAIKU_4_5
_MAX_REPAIR_RETRIES = 2
def _build_coverage_context(
details: CoverageDetails | None,
) -> str:
"""
Format coverage details for prompt injection.
"""
if details is None:
return ""
parts: list[str] = []
mf = details.main_function
parts.append(
f"Function `{mf.name}`: {mf.coverage:.0f}% coverage"
)
if mf.unexecuted_lines:
parts.append(
f" Unexecuted lines: {mf.unexecuted_lines}"
)
if mf.unexecuted_branches:
parts.append(
f" Unexecuted branches: {mf.unexecuted_branches}"
)
if details.dependent_function:
df = details.dependent_function
parts.append(
f"Dependent function `{df.name}`:"
f" {df.coverage:.0f}% coverage"
)
if df.unexecuted_lines:
parts.append(
f" Unexecuted lines: {df.unexecuted_lines}"
)
return "\n".join(parts)
def _parse_review_verdicts(
content: str,
) -> list[FunctionVerdict]:
"""
Parse function verdicts from LLM JSON response.
"""
result = extract_code_block_with_context(
content, language="json"
)
json_str = result[1] if result else content
try:
data = json.loads(json_str)
except (json.JSONDecodeError, ValueError):
log.warning("Failed to parse review JSON: %s", content[:200])
return []
raw_functions: list[dict[str, Any]]
if isinstance(data, dict):
raw_functions = data.get("functions", [])
elif isinstance(data, list):
raw_functions = data
else:
return []
verdicts: list[FunctionVerdict] = []
for fn in raw_functions:
if not isinstance(fn, dict):
continue
verdict = fn.get("verdict", "")
if verdict == "repair":
verdicts.append(
FunctionVerdict(
function_name=fn.get("function_name", ""),
verdict="repair",
reason=fn.get("reason", ""),
)
)
return verdicts
async def _review_single_test(
llm_client: LLMClient,
test: TestSourceWithFailures,
function_name: str,
function_source_code: str,
trace_id: str,
coverage_context: str,
coverage_summary: str,
) -> TestReview:
"""
Review a single test source and return verdicts.
"""
failed_verdicts: list[FunctionVerdict] = [
FunctionVerdict(
function_name=fn,
verdict="repair",
reason=test.failure_messages.get(
fn, "Failed behavioral test"
),
)
for fn in test.failed_test_functions
]
failed_note = ""
if test.failed_test_functions:
failed_names = ", ".join(
f"`{fn}`" for fn in test.failed_test_functions
)
failed_note = (
f"The following functions failed behavioral"
f" tests and are already marked for repair:"
f" {failed_names}. Do NOT re-review them."
)
system_prompt = REVIEW_SYSTEM_TPL.render()
user_prompt = REVIEW_USER_TPL.render(
function_name=function_name,
function_source_code=function_source_code,
test_source=test.test_source,
coverage_context=coverage_context,
coverage_summary=coverage_summary,
failed_note=failed_note,
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
try:
output = await llm_client.call(
llm=REVIEW_MODEL,
messages=messages,
call_type="testgen_review",
trace_id=trace_id,
)
ai_verdicts = _parse_review_verdicts(output.content)
except Exception:
log.exception(
"Review LLM call failed: trace_id=%s test_index=%d",
trace_id,
test.test_index,
)
ai_verdicts = []
already_failed = {v.function_name for v in failed_verdicts}
combined = failed_verdicts + [
v
for v in ai_verdicts
if v.function_name not in already_failed
]
return TestReview(
test_index=test.test_index,
functions=combined,
)
@router.post(
"/ai/testgen_review",
response_model=TestgenReviewResponse,
responses={
400: {"model": TestGenErrorResponse},
422: {"model": TestGenErrorResponse},
},
)
async def testgen_review(
request: Request,
data: TestgenReviewRequest,
user: Annotated[AuthenticatedUser, Depends(require_auth)],
_rate: Annotated[None, Depends(check_rate_limit)],
_usage: Annotated[None, Depends(track_usage)],
) -> TestgenReviewResponse:
"""
Review generated tests for quality issues.
"""
if not validate_trace_id(data.trace_id):
raise HTTPException(
status_code=400,
detail="Invalid trace ID. Please provide a valid UUIDv4.",
)
if data.language != "python":
return TestgenReviewResponse(reviews=[])
llm_client: LLMClient = request.app.state.llm_client
coverage_context = _build_coverage_context(
data.coverage_details,
)
coros = [
_review_single_test(
llm_client,
test,
data.function_name,
data.function_source_code,
data.trace_id,
coverage_context,
data.coverage_summary,
)
for test in data.tests
]
reviews = await asyncio.gather(*coros)
return TestgenReviewResponse(reviews=list(reviews))
# ── Repair ──────────────────────────────────────────────────────────
def _extract_python_code(content: str) -> str | None:
"""
Extract a Python code block from LLM output.
"""
result = extract_code_block_with_context(
content, language="python"
)
if result is None:
return None
return result[1]
def _get_syntax_error(code: str) -> str:
"""
Return the syntax error message, or empty string if valid.
"""
try:
ast.parse(code)
except SyntaxError as exc:
return f"Line {exc.lineno}: {exc.msg}"
return ""
@router.post(
"/ai/testgen_repair",
response_model=TestRepairResponse,
responses={
400: {"model": TestGenErrorResponse},
422: {"model": TestGenErrorResponse},
500: {"model": TestGenErrorResponse},
},
)
async def testgen_repair(
request: Request,
data: TestRepairRequest,
user: Annotated[AuthenticatedUser, Depends(require_auth)],
_rate: Annotated[None, Depends(check_rate_limit)],
_usage: Annotated[None, Depends(track_usage)],
) -> TestRepairResponse:
"""
Repair flagged test functions.
"""
if not validate_trace_id(data.trace_id):
raise HTTPException(
status_code=400,
detail="Invalid trace ID. Please provide a valid UUIDv4.",
)
if data.language != "python":
raise HTTPException(
status_code=400,
detail="Only Python is currently supported.",
)
llm_client: LLMClient = request.app.state.llm_client
coverage_context = _build_coverage_context(
data.coverage_details,
)
system_prompt = REPAIR_SYSTEM_TPL.render()
user_prompt = REPAIR_USER_TPL.render(
function_name=data.function_name,
function_source_code=data.function_source_code,
test_source=data.test_source,
functions_to_repair=data.functions_to_repair,
previous_repair_errors=data.previous_repair_errors,
coverage_context=coverage_context,
)
messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
repaired_code: str | None = None
for attempt in range(_MAX_REPAIR_RETRIES):
try:
output = await llm_client.call(
llm=REPAIR_MODEL,
messages=messages,
call_type=(
"testgen_repair"
if attempt == 0
else "testgen_repair_retry"
),
trace_id=data.trace_id,
)
except Exception:
log.exception(
"Repair LLM call failed: trace_id=%s attempt=%d",
data.trace_id,
attempt + 1,
)
raise HTTPException(
status_code=422,
detail="Could not generate repair. Please try again.",
) from None
repaired_code = _extract_python_code(output.content)
if repaired_code is None:
log.warning(
"No code block in repair response:"
" trace_id=%s attempt=%d",
data.trace_id,
attempt + 1,
)
continue
syntax_err = _get_syntax_error(repaired_code)
if not syntax_err:
break
log.warning(
"Syntax error in repair: trace_id=%s"
" attempt=%d error=%s",
data.trace_id,
attempt + 1,
syntax_err,
)
messages.append(
{"role": "assistant", "content": output.content}
)
messages.append(
{
"role": "user",
"content": (
f"Syntax error:\n```\n{syntax_err}\n```\n"
"Please return the complete corrected file"
" in a single Python code block."
),
}
)
repaired_code = None
if repaired_code is None:
raise HTTPException(
status_code=422,
detail=(
"Could not generate a valid repair"
" after retries. Please try again."
),
)
return TestRepairResponse(
generated_tests=repaired_code,
instrumented_behavior_tests=repaired_code,
instrumented_perf_tests=repaired_code,
)

View file

@ -0,0 +1,18 @@
You are a test repair assistant for an AI code optimizer.
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
Constraints for repaired functions:
- Use diverse, realistic inputs that reflect production usage patterns
- Only use the function's public API — do not access internal implementation details
- Exercise meaningful code paths
- Keep the same function name so test discovery still works
- Maintain the same testing framework conventions (pytest, unittest, etc.)
- Do NOT add comments, docstrings, or type annotations that weren't in the original
- The reason provided for each function describes the specific problem — fix that problem
- All string literals must be syntactically valid Python — use triple quotes (`"""..."""`) or escaped newlines (`\n`) for multiline strings. NEVER split a string literal across multiple lines without triple quotes.
Return the complete file in a single Python code block:
```python
# complete test file here
```

View file

@ -0,0 +1,36 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<test_file>
```python
{{ test_source }}
```
</test_file>
<functions_to_repair>
{% for fn in functions_to_repair -%}
- `{{ fn.function_name }}`: {{ fn.reason }}
{% endfor %}
</functions_to_repair>
{% if previous_repair_errors %}
<previous_repair_errors>
A previous repair attempt for these functions was tried but the rewritten tests still failed at runtime. Use these error messages to understand what went wrong and avoid the same mistakes:
{% for fn_name, error_msg in previous_repair_errors.items() -%}
- `{{ fn_name }}`: {{ error_msg }}
{% endfor %}
</previous_repair_errors>
{% endif %}
{% if coverage_context %}
<coverage_details>
{{ coverage_context }}
When rewriting flagged functions, ensure they exercise the unexecuted lines listed above. Tests that only trigger early-return or disabled code paths do not provide useful coverage.
</coverage_details>
{% endif %}
Rewrite only the functions listed above. Return the complete file.

View file

@ -0,0 +1,38 @@
You are a test quality reviewer for CodeFlash, an AI code optimizer. CodeFlash uses generated tests to:
1. **Verify correctness** — run the same tests on both original and optimized code, comparing outputs to catch regressions
2. **Benchmark performance** — time the function under test to measure speedup
Review each test function and flag ones that would compromise correctness verification or produce misleading benchmarks.
Flag a test function if it has any of these problems:
1. **Doesn't exercise the function**: The test never calls the target function, or only triggers trivial/no-op code paths (empty input, immediate error return). If coverage information is provided, use it to identify tests that miss the function entirely.
2. **Non-deterministic behavior**: The test produces different outputs across runs due to randomness, timestamps, or external state. This causes false correctness failures when comparing original vs optimized outputs.
3. **Internal state coupling**: The test asserts on internal implementation details (private attributes, internal data structures) rather than observable outputs. Optimized code may change internals while preserving behavior, causing false failures.
4. **Identical-input repetition**: Calling the function with the same arguments multiple times inflates cache hit rates and produces speedups that vanish in production.
5. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage that would reveal performance regressions.
Only flag clear problems. If a test function looks reasonable, omit it from the output.
Respond with a JSON code block containing only the functions that need repair:
```json
{
"functions": [
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
]
}
```
If all functions are acceptable:
```json
{"functions": []}
```
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.

View file

@ -0,0 +1,30 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<generated_test>
```python
{{ test_source }}
```
</generated_test>
{% if coverage_context %}
<coverage_details>
{{ coverage_context }}
</coverage_details>
{% elif coverage_summary %}
<coverage>
Test coverage of {{ function_name }}: {{ coverage_summary }}
</coverage>
{% endif %}
{% if failed_note %}
<note>
{{ failed_note }}
</note>
{% endif %}
Review each test function in the generated test. Return your verdict as JSON.

View file

@ -1,8 +1,12 @@
"""Request and response schemas for POST /ai/testgen."""
"""Request and response schemas for testgen endpoints."""
from __future__ import annotations
from pydantic import BaseModel
from typing import Self
from pydantic import BaseModel, model_validator
# ── Testgen ─────────────────────────────────────────────────────────
class TestGenRequest(BaseModel):
@ -41,8 +45,149 @@ class TestGenResponse(BaseModel):
class TestGenErrorResponse(BaseModel):
"""
Error response from POST /ai/testgen.
Error response from testgen endpoints.
"""
error: str
trace_id: str | None = None
# ── Review ──────────────────────────────────────────────────────────
class TestSourceWithFailures(BaseModel):
"""
A test source file with optional behavioral test failures.
"""
test_source: str
test_index: int
failed_test_functions: list[str] = []
failure_messages: dict[str, str] = {}
class CoverageFunctionDetail(BaseModel):
"""
Per-function coverage detail.
"""
name: str
coverage: float
executed_lines: list[int] = []
unexecuted_lines: list[int] = []
executed_branches: list[list[int]] = []
unexecuted_branches: list[list[int]] = []
class CoverageDetails(BaseModel):
"""
Coverage info for the function under test.
"""
coverage_percentage: float
threshold_percentage: float
function_start_line: int = 1
main_function: CoverageFunctionDetail
dependent_function: CoverageFunctionDetail | None = None
class TestgenReviewRequest(BaseModel):
"""
Request body for POST /ai/testgen_review.
"""
tests: list[TestSourceWithFailures]
function_source_code: str
function_name: str
trace_id: str
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
coverage_summary: str = ""
coverage_details: CoverageDetails | None = None
class FunctionVerdict(BaseModel):
"""
Verdict for a single test function.
"""
function_name: str
verdict: str
reason: str = ""
class TestReview(BaseModel):
"""
Review result for one test source.
"""
test_index: int
functions: list[FunctionVerdict]
class TestgenReviewResponse(BaseModel):
"""
Response from POST /ai/testgen_review.
"""
reviews: list[TestReview]
# ── Repair ──────────────────────────────────────────────────────────
class FunctionToRepair(BaseModel):
"""
A single function that needs repair with the reason.
"""
function_name: str
reason: str
class TestRepairRequest(BaseModel):
"""
Request body for POST /ai/testgen_repair.
"""
test_source: str
functions_to_repair: list[FunctionToRepair]
function_source_code: str
function_name: str
module_path: str
test_module_path: str
test_framework: str
test_timeout: int
trace_id: str
python_version: str | None = None
language_version: str | None = None
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
coverage_details: CoverageDetails | None = None
previous_repair_errors: dict[str, str] = {}
module_source_code: str = ""
@model_validator(mode="after")
def resolve_python_version(self) -> Self:
"""
Resolve python_version from language_version.
"""
if (
self.python_version is None
and self.language_version is not None
and self.language == "python"
):
self.python_version = self.language_version
return self
class TestRepairResponse(BaseModel):
"""
Response from POST /ai/testgen_repair.
"""
generated_tests: str
instrumented_behavior_tests: str
instrumented_perf_tests: str

View file

@ -0,0 +1,556 @@
"""Tests for testgen review and repair endpoints."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock
import attrs
import pytest
from codeflash_api.testgen._review_router import (
_build_coverage_context,
_get_syntax_error,
_parse_review_verdicts,
)
from codeflash_api.testgen.schemas import (
CoverageDetails,
CoverageFunctionDetail,
FunctionVerdict,
TestgenReviewRequest,
TestgenReviewResponse,
TestRepairRequest,
TestRepairResponse,
TestReview,
TestSourceWithFailures,
)
class TestBuildCoverageContext:
"""Tests for _build_coverage_context."""
def test_none_returns_empty(self) -> None:
"""
None coverage details yields empty string.
"""
assert "" == _build_coverage_context(None)
def test_formats_main_function(self) -> None:
"""
Main function coverage is included with unexecuted lines.
"""
details = CoverageDetails(
coverage_percentage=80.0,
threshold_percentage=90.0,
main_function=CoverageFunctionDetail(
name="compute",
coverage=80.0,
unexecuted_lines=[10, 15, 20],
),
)
result = _build_coverage_context(details)
assert "compute" in result
assert "80%" in result
assert "[10, 15, 20]" in result
def test_includes_dependent_function(self) -> None:
"""
Dependent function coverage is appended when present.
"""
details = CoverageDetails(
coverage_percentage=75.0,
threshold_percentage=90.0,
main_function=CoverageFunctionDetail(
name="main_fn",
coverage=90.0,
),
dependent_function=CoverageFunctionDetail(
name="helper_fn",
coverage=60.0,
unexecuted_lines=[5, 6],
),
)
result = _build_coverage_context(details)
assert "helper_fn" in result
assert "60%" in result
class TestParseReviewVerdicts:
"""Tests for _parse_review_verdicts."""
def test_parses_json_object(self) -> None:
"""
Standard JSON object with functions array is parsed.
"""
content = json.dumps(
{
"functions": [
{
"function_name": "test_foo",
"verdict": "repair",
"reason": "bad test",
}
]
}
)
verdicts = _parse_review_verdicts(content)
assert 1 == len(verdicts)
assert "test_foo" == verdicts[0].function_name
assert "repair" == verdicts[0].verdict
def test_parses_json_in_code_block(self) -> None:
"""
JSON inside a markdown code block is extracted.
"""
content = (
'Some text\n```json\n{"functions": ['
'{"function_name": "test_bar",'
' "verdict": "repair", "reason": "x"}'
"]}\n```\nMore text"
)
verdicts = _parse_review_verdicts(content)
assert 1 == len(verdicts)
assert "test_bar" == verdicts[0].function_name
def test_filters_non_repair_verdicts(self) -> None:
"""
Only repair verdicts are returned.
"""
content = json.dumps(
{
"functions": [
{
"function_name": "test_ok",
"verdict": "pass",
"reason": "",
},
{
"function_name": "test_bad",
"verdict": "repair",
"reason": "broken",
},
]
}
)
verdicts = _parse_review_verdicts(content)
assert 1 == len(verdicts)
assert "test_bad" == verdicts[0].function_name
def test_empty_functions_returns_empty(self) -> None:
"""
Empty functions list yields no verdicts.
"""
content = '{"functions": []}'
assert [] == _parse_review_verdicts(content)
def test_invalid_json_returns_empty(self) -> None:
"""
Unparseable JSON returns empty list.
"""
assert [] == _parse_review_verdicts("not json at all")
def test_parses_array_directly(self) -> None:
"""
A raw JSON array is accepted as the functions list.
"""
content = json.dumps(
[
{
"function_name": "test_x",
"verdict": "repair",
"reason": "r",
}
]
)
verdicts = _parse_review_verdicts(content)
assert 1 == len(verdicts)
class TestGetSyntaxError:
"""Tests for _get_syntax_error."""
def test_valid_code(self) -> None:
"""
Valid Python returns empty string.
"""
assert "" == _get_syntax_error("x = 1\n")
def test_syntax_error(self) -> None:
"""
Invalid Python returns error description.
"""
result = _get_syntax_error("def f(\n")
assert "Line" in result
class TestTestgenReviewEndpoint:
"""Integration tests for POST /ai/testgen_review."""
async def test_success(self, client, mock_llm_client) -> None: # type: ignore[no-untyped-def]
"""
Successful review returns verdicts.
"""
mock_llm_client.call = AsyncMock(
return_value=MagicMock(
content='{"functions": []}',
cost=0.01,
),
)
resp = await client.post(
"/ai/testgen_review",
json={
"tests": [
{
"test_source": "def test_a(): pass",
"test_index": 0,
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 200 == resp.status_code
data = resp.json()
assert 1 == len(data["reviews"])
assert 0 == data["reviews"][0]["test_index"]
async def test_invalid_trace_id(self, client) -> None: # type: ignore[no-untyped-def]
"""
Invalid trace_id returns 400.
"""
resp = await client.post(
"/ai/testgen_review",
json={
"tests": [],
"function_source_code": "def f(): pass",
"function_name": "f",
"trace_id": "bad-id",
},
)
assert 400 == resp.status_code
async def test_non_python_returns_empty(
self, client
) -> None: # type: ignore[no-untyped-def]
"""
Non-Python language returns empty reviews without LLM call.
"""
resp = await client.post(
"/ai/testgen_review",
json={
"tests": [
{
"test_source": "test code",
"test_index": 0,
}
],
"function_source_code": "fn f() {}",
"function_name": "f",
"trace_id": "12345678-1234-4000-8000-000000000000",
"language": "rust",
},
)
assert 200 == resp.status_code
assert [] == resp.json()["reviews"]
async def test_failed_tests_marked_for_repair(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
Tests with behavioral failures are auto-flagged for repair.
"""
mock_llm_client.call = AsyncMock(
return_value=MagicMock(
content='{"functions": []}',
cost=0.01,
),
)
resp = await client.post(
"/ai/testgen_review",
json={
"tests": [
{
"test_source": "def test_a(): assert False",
"test_index": 0,
"failed_test_functions": ["test_a"],
"failure_messages": {
"test_a": "AssertionError"
},
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 200 == resp.status_code
reviews = resp.json()["reviews"]
assert 1 == len(reviews)
fns = reviews[0]["functions"]
assert 1 == len(fns)
assert "test_a" == fns[0]["function_name"]
assert "repair" == fns[0]["verdict"]
class TestTestgenRepairEndpoint:
"""Integration tests for POST /ai/testgen_repair."""
async def test_success(self, client, mock_llm_client) -> None: # type: ignore[no-untyped-def]
"""
Successful repair returns repaired code.
"""
repaired = "def test_a(): assert True"
mock_llm_client.call = AsyncMock(
return_value=MagicMock(
content=f"```python\n{repaired}\n```",
cost=0.01,
),
)
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "def test_a(): assert False",
"functions_to_repair": [
{
"function_name": "test_a",
"reason": "Always fails",
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"module_path": "src/mod.py",
"test_module_path": "tests/test_mod.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 200 == resp.status_code
data = resp.json()
assert repaired in data["generated_tests"]
async def test_invalid_trace_id(self, client) -> None: # type: ignore[no-untyped-def]
"""
Invalid trace_id returns 400.
"""
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "code",
"functions_to_repair": [],
"function_source_code": "def f(): pass",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "bad",
},
)
assert 400 == resp.status_code
async def test_non_python_returns_400(self, client) -> None: # type: ignore[no-untyped-def]
"""
Non-Python language returns 400.
"""
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "code",
"functions_to_repair": [],
"function_source_code": "fn f() {}",
"function_name": "f",
"module_path": "m.rs",
"test_module_path": "t.rs",
"test_framework": "cargo",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
"language": "rust",
},
)
assert 400 == resp.status_code
async def test_syntax_error_triggers_retry(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
Syntax error in first response triggers a retry.
"""
bad_code = "def test_a(:\n pass"
good_code = "def test_a():\n pass"
mock_llm_client.call = AsyncMock(
side_effect=[
MagicMock(
content=f"```python\n{bad_code}\n```",
cost=0.01,
),
MagicMock(
content=f"```python\n{good_code}\n```",
cost=0.01,
),
],
)
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "def test_a(): assert False",
"functions_to_repair": [
{
"function_name": "test_a",
"reason": "broken",
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 200 == resp.status_code
assert good_code in resp.json()["generated_tests"]
assert 2 == mock_llm_client.call.await_count
async def test_all_retries_fail_returns_422(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
If all retries produce syntax errors, return 422.
"""
bad_code = "def test_a(:\n pass"
mock_llm_client.call = AsyncMock(
return_value=MagicMock(
content=f"```python\n{bad_code}\n```",
cost=0.01,
),
)
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "def test_a(): pass",
"functions_to_repair": [
{
"function_name": "test_a",
"reason": "x",
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 422 == resp.status_code
async def test_llm_failure_returns_422(
self, client, mock_llm_client
) -> None: # type: ignore[no-untyped-def]
"""
LLM exception returns 422.
"""
mock_llm_client.call = AsyncMock(
side_effect=RuntimeError("LLM down"),
)
resp = await client.post(
"/ai/testgen_repair",
json={
"test_source": "def test_a(): pass",
"functions_to_repair": [
{
"function_name": "test_a",
"reason": "x",
}
],
"function_source_code": "def f(): return 1",
"function_name": "f",
"module_path": "m.py",
"test_module_path": "t.py",
"test_framework": "pytest",
"test_timeout": 60,
"trace_id": "12345678-1234-4000-8000-000000000000",
},
)
assert 422 == resp.status_code
class TestTestRepairRequestValidator:
"""Tests for TestRepairRequest model validator."""
def test_resolves_python_version_from_language_version(
self,
) -> None:
"""
python_version is set from language_version when not provided.
"""
req = TestRepairRequest(
test_source="code",
functions_to_repair=[],
function_source_code="def f(): pass",
function_name="f",
module_path="m.py",
test_module_path="t.py",
test_framework="pytest",
test_timeout=60,
trace_id="abc",
language_version="3.12.1",
)
assert "3.12.1" == req.python_version
def test_does_not_override_explicit_python_version(
self,
) -> None:
"""
Explicit python_version is not overridden by language_version.
"""
req = TestRepairRequest(
test_source="code",
functions_to_repair=[],
function_source_code="def f(): pass",
function_name="f",
module_path="m.py",
test_module_path="t.py",
test_framework="pytest",
test_timeout=60,
trace_id="abc",
python_version="3.11.0",
language_version="3.12.1",
)
assert "3.11.0" == req.python_version

View file

@ -107,6 +107,9 @@ ignore = [
"packages/codeflash-api/src/codeflash_api/observability/_recording.py" = [
"PLR0913", # recording functions faithfully match Django signatures
]
"packages/codeflash-api/src/codeflash_api/testgen/_review_router.py" = [
"PLR0913", # _review_single_test needs all context params for prompt assembly
]
"packages/codeflash-api/src/codeflash_api/optimize/_context.py" = [
"PLR2004", # magic values in faithfully ported version parsing and humanize_ns
]