mirror of
https://github.com/codeflash-ai/codeflash-agent.git
synced 2026-05-04 18:25:19 +00:00
Add testgen review and repair endpoints
Port /ai/testgen_review and /ai/testgen_repair from Django reference. Review: parallel LLM calls per test source, auto-flags behavioral failures, parses JSON verdicts. Repair: Jinja2 prompt templates, syntax-error retry loop, Python code extraction and validation. Schemas: TestgenReviewRequest/Response, TestRepairRequest/Response, CoverageDetails, FunctionVerdict, TestSourceWithFailures. 23 tests covering: coverage context building, verdict parsing, syntax error detection, endpoint success/error/retry/language paths, and the model validator for python_version resolution.
This commit is contained in:
parent
1d70d65914
commit
6abcc8daa3
9 changed files with 1252 additions and 3 deletions
|
|
@ -96,6 +96,9 @@ def _register_routes(app: FastAPI) -> None:
|
|||
from codeflash_api.review._router import (
|
||||
router as review_router,
|
||||
)
|
||||
from codeflash_api.testgen._review_router import (
|
||||
router as testgen_review_router,
|
||||
)
|
||||
from codeflash_api.testgen._router import (
|
||||
router as testgen_router,
|
||||
)
|
||||
|
|
@ -107,6 +110,7 @@ def _register_routes(app: FastAPI) -> None:
|
|||
app.include_router(adaptive_router)
|
||||
app.include_router(explain_router)
|
||||
app.include_router(testgen_router)
|
||||
app.include_router(testgen_review_router)
|
||||
app.include_router(jit_router)
|
||||
app.include_router(ranking_router)
|
||||
app.include_router(refinement_router)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,419 @@
|
|||
"""POST /ai/testgen_review and POST /ai/testgen_repair endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Annotated, Any
|
||||
|
||||
import jinja2
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
from codeflash_api.auth._deps import (
|
||||
check_rate_limit,
|
||||
require_auth,
|
||||
track_usage,
|
||||
)
|
||||
from codeflash_api.auth.models import AuthenticatedUser # noqa: TC001
|
||||
from codeflash_api.languages.python._markdown import (
|
||||
extract_code_block_with_context,
|
||||
)
|
||||
from codeflash_api.llm._models import ANTHROPIC_CLAUDE_HAIKU_4_5
|
||||
from codeflash_api.optimize._context import validate_trace_id
|
||||
from codeflash_api.testgen.schemas import (
|
||||
CoverageDetails,
|
||||
FunctionVerdict,
|
||||
TestGenErrorResponse,
|
||||
TestgenReviewRequest,
|
||||
TestgenReviewResponse,
|
||||
TestRepairRequest,
|
||||
TestRepairResponse,
|
||||
TestReview,
|
||||
TestSourceWithFailures,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_api.llm._client import LLMClient
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_PROMPTS_DIR = Path(__file__).parent / "prompts"
|
||||
_JINJA_ENV = jinja2.Environment( # noqa: S701
|
||||
loader=jinja2.FileSystemLoader(str(_PROMPTS_DIR)),
|
||||
keep_trailing_newline=True,
|
||||
)
|
||||
|
||||
REVIEW_SYSTEM_TPL = _JINJA_ENV.get_template("review_system.md.j2")
|
||||
REVIEW_USER_TPL = _JINJA_ENV.get_template("review_user.md.j2")
|
||||
REPAIR_SYSTEM_TPL = _JINJA_ENV.get_template("repair_system.md.j2")
|
||||
REPAIR_USER_TPL = _JINJA_ENV.get_template("repair_user.md.j2")
|
||||
|
||||
REVIEW_MODEL = ANTHROPIC_CLAUDE_HAIKU_4_5
|
||||
REPAIR_MODEL = ANTHROPIC_CLAUDE_HAIKU_4_5
|
||||
|
||||
_MAX_REPAIR_RETRIES = 2
|
||||
|
||||
|
||||
def _build_coverage_context(
|
||||
details: CoverageDetails | None,
|
||||
) -> str:
|
||||
"""
|
||||
Format coverage details for prompt injection.
|
||||
"""
|
||||
if details is None:
|
||||
return ""
|
||||
|
||||
parts: list[str] = []
|
||||
mf = details.main_function
|
||||
parts.append(
|
||||
f"Function `{mf.name}`: {mf.coverage:.0f}% coverage"
|
||||
)
|
||||
if mf.unexecuted_lines:
|
||||
parts.append(
|
||||
f" Unexecuted lines: {mf.unexecuted_lines}"
|
||||
)
|
||||
if mf.unexecuted_branches:
|
||||
parts.append(
|
||||
f" Unexecuted branches: {mf.unexecuted_branches}"
|
||||
)
|
||||
|
||||
if details.dependent_function:
|
||||
df = details.dependent_function
|
||||
parts.append(
|
||||
f"Dependent function `{df.name}`:"
|
||||
f" {df.coverage:.0f}% coverage"
|
||||
)
|
||||
if df.unexecuted_lines:
|
||||
parts.append(
|
||||
f" Unexecuted lines: {df.unexecuted_lines}"
|
||||
)
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _parse_review_verdicts(
|
||||
content: str,
|
||||
) -> list[FunctionVerdict]:
|
||||
"""
|
||||
Parse function verdicts from LLM JSON response.
|
||||
"""
|
||||
result = extract_code_block_with_context(
|
||||
content, language="json"
|
||||
)
|
||||
json_str = result[1] if result else content
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
log.warning("Failed to parse review JSON: %s", content[:200])
|
||||
return []
|
||||
|
||||
raw_functions: list[dict[str, Any]]
|
||||
if isinstance(data, dict):
|
||||
raw_functions = data.get("functions", [])
|
||||
elif isinstance(data, list):
|
||||
raw_functions = data
|
||||
else:
|
||||
return []
|
||||
|
||||
verdicts: list[FunctionVerdict] = []
|
||||
for fn in raw_functions:
|
||||
if not isinstance(fn, dict):
|
||||
continue
|
||||
verdict = fn.get("verdict", "")
|
||||
if verdict == "repair":
|
||||
verdicts.append(
|
||||
FunctionVerdict(
|
||||
function_name=fn.get("function_name", ""),
|
||||
verdict="repair",
|
||||
reason=fn.get("reason", ""),
|
||||
)
|
||||
)
|
||||
return verdicts
|
||||
|
||||
|
||||
async def _review_single_test(
|
||||
llm_client: LLMClient,
|
||||
test: TestSourceWithFailures,
|
||||
function_name: str,
|
||||
function_source_code: str,
|
||||
trace_id: str,
|
||||
coverage_context: str,
|
||||
coverage_summary: str,
|
||||
) -> TestReview:
|
||||
"""
|
||||
Review a single test source and return verdicts.
|
||||
"""
|
||||
failed_verdicts: list[FunctionVerdict] = [
|
||||
FunctionVerdict(
|
||||
function_name=fn,
|
||||
verdict="repair",
|
||||
reason=test.failure_messages.get(
|
||||
fn, "Failed behavioral test"
|
||||
),
|
||||
)
|
||||
for fn in test.failed_test_functions
|
||||
]
|
||||
|
||||
failed_note = ""
|
||||
if test.failed_test_functions:
|
||||
failed_names = ", ".join(
|
||||
f"`{fn}`" for fn in test.failed_test_functions
|
||||
)
|
||||
failed_note = (
|
||||
f"The following functions failed behavioral"
|
||||
f" tests and are already marked for repair:"
|
||||
f" {failed_names}. Do NOT re-review them."
|
||||
)
|
||||
|
||||
system_prompt = REVIEW_SYSTEM_TPL.render()
|
||||
user_prompt = REVIEW_USER_TPL.render(
|
||||
function_name=function_name,
|
||||
function_source_code=function_source_code,
|
||||
test_source=test.test_source,
|
||||
coverage_context=coverage_context,
|
||||
coverage_summary=coverage_summary,
|
||||
failed_note=failed_note,
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
try:
|
||||
output = await llm_client.call(
|
||||
llm=REVIEW_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_review",
|
||||
trace_id=trace_id,
|
||||
)
|
||||
ai_verdicts = _parse_review_verdicts(output.content)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Review LLM call failed: trace_id=%s test_index=%d",
|
||||
trace_id,
|
||||
test.test_index,
|
||||
)
|
||||
ai_verdicts = []
|
||||
|
||||
already_failed = {v.function_name for v in failed_verdicts}
|
||||
combined = failed_verdicts + [
|
||||
v
|
||||
for v in ai_verdicts
|
||||
if v.function_name not in already_failed
|
||||
]
|
||||
|
||||
return TestReview(
|
||||
test_index=test.test_index,
|
||||
functions=combined,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/testgen_review",
|
||||
response_model=TestgenReviewResponse,
|
||||
responses={
|
||||
400: {"model": TestGenErrorResponse},
|
||||
422: {"model": TestGenErrorResponse},
|
||||
},
|
||||
)
|
||||
async def testgen_review(
|
||||
request: Request,
|
||||
data: TestgenReviewRequest,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_auth)],
|
||||
_rate: Annotated[None, Depends(check_rate_limit)],
|
||||
_usage: Annotated[None, Depends(track_usage)],
|
||||
) -> TestgenReviewResponse:
|
||||
"""
|
||||
Review generated tests for quality issues.
|
||||
"""
|
||||
if not validate_trace_id(data.trace_id):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid trace ID. Please provide a valid UUIDv4.",
|
||||
)
|
||||
|
||||
if data.language != "python":
|
||||
return TestgenReviewResponse(reviews=[])
|
||||
|
||||
llm_client: LLMClient = request.app.state.llm_client
|
||||
coverage_context = _build_coverage_context(
|
||||
data.coverage_details,
|
||||
)
|
||||
|
||||
coros = [
|
||||
_review_single_test(
|
||||
llm_client,
|
||||
test,
|
||||
data.function_name,
|
||||
data.function_source_code,
|
||||
data.trace_id,
|
||||
coverage_context,
|
||||
data.coverage_summary,
|
||||
)
|
||||
for test in data.tests
|
||||
]
|
||||
|
||||
reviews = await asyncio.gather(*coros)
|
||||
|
||||
return TestgenReviewResponse(reviews=list(reviews))
|
||||
|
||||
|
||||
# ── Repair ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_python_code(content: str) -> str | None:
|
||||
"""
|
||||
Extract a Python code block from LLM output.
|
||||
"""
|
||||
result = extract_code_block_with_context(
|
||||
content, language="python"
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
return result[1]
|
||||
|
||||
|
||||
def _get_syntax_error(code: str) -> str:
|
||||
"""
|
||||
Return the syntax error message, or empty string if valid.
|
||||
"""
|
||||
try:
|
||||
ast.parse(code)
|
||||
except SyntaxError as exc:
|
||||
return f"Line {exc.lineno}: {exc.msg}"
|
||||
return ""
|
||||
|
||||
|
||||
@router.post(
|
||||
"/ai/testgen_repair",
|
||||
response_model=TestRepairResponse,
|
||||
responses={
|
||||
400: {"model": TestGenErrorResponse},
|
||||
422: {"model": TestGenErrorResponse},
|
||||
500: {"model": TestGenErrorResponse},
|
||||
},
|
||||
)
|
||||
async def testgen_repair(
|
||||
request: Request,
|
||||
data: TestRepairRequest,
|
||||
user: Annotated[AuthenticatedUser, Depends(require_auth)],
|
||||
_rate: Annotated[None, Depends(check_rate_limit)],
|
||||
_usage: Annotated[None, Depends(track_usage)],
|
||||
) -> TestRepairResponse:
|
||||
"""
|
||||
Repair flagged test functions.
|
||||
"""
|
||||
if not validate_trace_id(data.trace_id):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid trace ID. Please provide a valid UUIDv4.",
|
||||
)
|
||||
|
||||
if data.language != "python":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only Python is currently supported.",
|
||||
)
|
||||
|
||||
llm_client: LLMClient = request.app.state.llm_client
|
||||
coverage_context = _build_coverage_context(
|
||||
data.coverage_details,
|
||||
)
|
||||
|
||||
system_prompt = REPAIR_SYSTEM_TPL.render()
|
||||
user_prompt = REPAIR_USER_TPL.render(
|
||||
function_name=data.function_name,
|
||||
function_source_code=data.function_source_code,
|
||||
test_source=data.test_source,
|
||||
functions_to_repair=data.functions_to_repair,
|
||||
previous_repair_errors=data.previous_repair_errors,
|
||||
coverage_context=coverage_context,
|
||||
)
|
||||
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
repaired_code: str | None = None
|
||||
|
||||
for attempt in range(_MAX_REPAIR_RETRIES):
|
||||
try:
|
||||
output = await llm_client.call(
|
||||
llm=REPAIR_MODEL,
|
||||
messages=messages,
|
||||
call_type=(
|
||||
"testgen_repair"
|
||||
if attempt == 0
|
||||
else "testgen_repair_retry"
|
||||
),
|
||||
trace_id=data.trace_id,
|
||||
)
|
||||
except Exception:
|
||||
log.exception(
|
||||
"Repair LLM call failed: trace_id=%s attempt=%d",
|
||||
data.trace_id,
|
||||
attempt + 1,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="Could not generate repair. Please try again.",
|
||||
) from None
|
||||
|
||||
repaired_code = _extract_python_code(output.content)
|
||||
if repaired_code is None:
|
||||
log.warning(
|
||||
"No code block in repair response:"
|
||||
" trace_id=%s attempt=%d",
|
||||
data.trace_id,
|
||||
attempt + 1,
|
||||
)
|
||||
continue
|
||||
|
||||
syntax_err = _get_syntax_error(repaired_code)
|
||||
if not syntax_err:
|
||||
break
|
||||
|
||||
log.warning(
|
||||
"Syntax error in repair: trace_id=%s"
|
||||
" attempt=%d error=%s",
|
||||
data.trace_id,
|
||||
attempt + 1,
|
||||
syntax_err,
|
||||
)
|
||||
messages.append(
|
||||
{"role": "assistant", "content": output.content}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Syntax error:\n```\n{syntax_err}\n```\n"
|
||||
"Please return the complete corrected file"
|
||||
" in a single Python code block."
|
||||
),
|
||||
}
|
||||
)
|
||||
repaired_code = None
|
||||
|
||||
if repaired_code is None:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=(
|
||||
"Could not generate a valid repair"
|
||||
" after retries. Please try again."
|
||||
),
|
||||
)
|
||||
|
||||
return TestRepairResponse(
|
||||
generated_tests=repaired_code,
|
||||
instrumented_behavior_tests=repaired_code,
|
||||
instrumented_perf_tests=repaired_code,
|
||||
)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
You are a test repair assistant for an AI code optimizer.
|
||||
|
||||
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
|
||||
|
||||
Constraints for repaired functions:
|
||||
- Use diverse, realistic inputs that reflect production usage patterns
|
||||
- Only use the function's public API — do not access internal implementation details
|
||||
- Exercise meaningful code paths
|
||||
- Keep the same function name so test discovery still works
|
||||
- Maintain the same testing framework conventions (pytest, unittest, etc.)
|
||||
- Do NOT add comments, docstrings, or type annotations that weren't in the original
|
||||
- The reason provided for each function describes the specific problem — fix that problem
|
||||
- All string literals must be syntactically valid Python — use triple quotes (`"""..."""`) or escaped newlines (`\n`) for multiline strings. NEVER split a string literal across multiple lines without triple quotes.
|
||||
|
||||
Return the complete file in a single Python code block:
|
||||
```python
|
||||
# complete test file here
|
||||
```
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<test_file>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</test_file>
|
||||
|
||||
<functions_to_repair>
|
||||
{% for fn in functions_to_repair -%}
|
||||
- `{{ fn.function_name }}`: {{ fn.reason }}
|
||||
{% endfor %}
|
||||
</functions_to_repair>
|
||||
{% if previous_repair_errors %}
|
||||
|
||||
<previous_repair_errors>
|
||||
A previous repair attempt for these functions was tried but the rewritten tests still failed at runtime. Use these error messages to understand what went wrong and avoid the same mistakes:
|
||||
{% for fn_name, error_msg in previous_repair_errors.items() -%}
|
||||
- `{{ fn_name }}`: {{ error_msg }}
|
||||
{% endfor %}
|
||||
</previous_repair_errors>
|
||||
{% endif %}
|
||||
{% if coverage_context %}
|
||||
|
||||
<coverage_details>
|
||||
{{ coverage_context }}
|
||||
|
||||
When rewriting flagged functions, ensure they exercise the unexecuted lines listed above. Tests that only trigger early-return or disabled code paths do not provide useful coverage.
|
||||
</coverage_details>
|
||||
{% endif %}
|
||||
|
||||
Rewrite only the functions listed above. Return the complete file.
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
You are a test quality reviewer for CodeFlash, an AI code optimizer. CodeFlash uses generated tests to:
|
||||
|
||||
1. **Verify correctness** — run the same tests on both original and optimized code, comparing outputs to catch regressions
|
||||
2. **Benchmark performance** — time the function under test to measure speedup
|
||||
|
||||
Review each test function and flag ones that would compromise correctness verification or produce misleading benchmarks.
|
||||
|
||||
Flag a test function if it has any of these problems:
|
||||
|
||||
1. **Doesn't exercise the function**: The test never calls the target function, or only triggers trivial/no-op code paths (empty input, immediate error return). If coverage information is provided, use it to identify tests that miss the function entirely.
|
||||
|
||||
2. **Non-deterministic behavior**: The test produces different outputs across runs due to randomness, timestamps, or external state. This causes false correctness failures when comparing original vs optimized outputs.
|
||||
|
||||
3. **Internal state coupling**: The test asserts on internal implementation details (private attributes, internal data structures) rather than observable outputs. Optimized code may change internals while preserving behavior, causing false failures.
|
||||
|
||||
4. **Identical-input repetition**: Calling the function with the same arguments multiple times inflates cache hit rates and produces speedups that vanish in production.
|
||||
|
||||
5. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage that would reveal performance regressions.
|
||||
|
||||
Only flag clear problems. If a test function looks reasonable, omit it from the output.
|
||||
|
||||
Respond with a JSON code block containing only the functions that need repair:
|
||||
|
||||
```json
|
||||
{
|
||||
"functions": [
|
||||
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If all functions are acceptable:
|
||||
|
||||
```json
|
||||
{"functions": []}
|
||||
```
|
||||
|
||||
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<generated_test>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</generated_test>
|
||||
{% if coverage_context %}
|
||||
|
||||
<coverage_details>
|
||||
{{ coverage_context }}
|
||||
</coverage_details>
|
||||
{% elif coverage_summary %}
|
||||
|
||||
<coverage>
|
||||
Test coverage of {{ function_name }}: {{ coverage_summary }}
|
||||
</coverage>
|
||||
{% endif %}
|
||||
{% if failed_note %}
|
||||
|
||||
<note>
|
||||
{{ failed_note }}
|
||||
</note>
|
||||
{% endif %}
|
||||
|
||||
Review each test function in the generated test. Return your verdict as JSON.
|
||||
|
|
@ -1,8 +1,12 @@
|
|||
"""Request and response schemas for POST /ai/testgen."""
|
||||
"""Request and response schemas for testgen endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
# ── Testgen ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGenRequest(BaseModel):
|
||||
|
|
@ -41,8 +45,149 @@ class TestGenResponse(BaseModel):
|
|||
|
||||
class TestGenErrorResponse(BaseModel):
|
||||
"""
|
||||
Error response from POST /ai/testgen.
|
||||
Error response from testgen endpoints.
|
||||
"""
|
||||
|
||||
error: str
|
||||
trace_id: str | None = None
|
||||
|
||||
|
||||
# ── Review ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSourceWithFailures(BaseModel):
|
||||
"""
|
||||
A test source file with optional behavioral test failures.
|
||||
"""
|
||||
|
||||
test_source: str
|
||||
test_index: int
|
||||
failed_test_functions: list[str] = []
|
||||
failure_messages: dict[str, str] = {}
|
||||
|
||||
|
||||
class CoverageFunctionDetail(BaseModel):
|
||||
"""
|
||||
Per-function coverage detail.
|
||||
"""
|
||||
|
||||
name: str
|
||||
coverage: float
|
||||
executed_lines: list[int] = []
|
||||
unexecuted_lines: list[int] = []
|
||||
executed_branches: list[list[int]] = []
|
||||
unexecuted_branches: list[list[int]] = []
|
||||
|
||||
|
||||
class CoverageDetails(BaseModel):
|
||||
"""
|
||||
Coverage info for the function under test.
|
||||
"""
|
||||
|
||||
coverage_percentage: float
|
||||
threshold_percentage: float
|
||||
function_start_line: int = 1
|
||||
main_function: CoverageFunctionDetail
|
||||
dependent_function: CoverageFunctionDetail | None = None
|
||||
|
||||
|
||||
class TestgenReviewRequest(BaseModel):
|
||||
"""
|
||||
Request body for POST /ai/testgen_review.
|
||||
"""
|
||||
|
||||
tests: list[TestSourceWithFailures]
|
||||
function_source_code: str
|
||||
function_name: str
|
||||
trace_id: str
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
coverage_summary: str = ""
|
||||
coverage_details: CoverageDetails | None = None
|
||||
|
||||
|
||||
class FunctionVerdict(BaseModel):
|
||||
"""
|
||||
Verdict for a single test function.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
verdict: str
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class TestReview(BaseModel):
|
||||
"""
|
||||
Review result for one test source.
|
||||
"""
|
||||
|
||||
test_index: int
|
||||
functions: list[FunctionVerdict]
|
||||
|
||||
|
||||
class TestgenReviewResponse(BaseModel):
|
||||
"""
|
||||
Response from POST /ai/testgen_review.
|
||||
"""
|
||||
|
||||
reviews: list[TestReview]
|
||||
|
||||
|
||||
# ── Repair ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class FunctionToRepair(BaseModel):
|
||||
"""
|
||||
A single function that needs repair with the reason.
|
||||
"""
|
||||
|
||||
function_name: str
|
||||
reason: str
|
||||
|
||||
|
||||
class TestRepairRequest(BaseModel):
|
||||
"""
|
||||
Request body for POST /ai/testgen_repair.
|
||||
"""
|
||||
|
||||
test_source: str
|
||||
functions_to_repair: list[FunctionToRepair]
|
||||
function_source_code: str
|
||||
function_name: str
|
||||
module_path: str
|
||||
test_module_path: str
|
||||
test_framework: str
|
||||
test_timeout: int
|
||||
trace_id: str
|
||||
python_version: str | None = None
|
||||
language_version: str | None = None
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
coverage_details: CoverageDetails | None = None
|
||||
previous_repair_errors: dict[str, str] = {}
|
||||
module_source_code: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def resolve_python_version(self) -> Self:
|
||||
"""
|
||||
Resolve python_version from language_version.
|
||||
"""
|
||||
if (
|
||||
self.python_version is None
|
||||
and self.language_version is not None
|
||||
and self.language == "python"
|
||||
):
|
||||
self.python_version = self.language_version
|
||||
return self
|
||||
|
||||
|
||||
class TestRepairResponse(BaseModel):
|
||||
"""
|
||||
Response from POST /ai/testgen_repair.
|
||||
"""
|
||||
|
||||
generated_tests: str
|
||||
instrumented_behavior_tests: str
|
||||
instrumented_perf_tests: str
|
||||
|
|
|
|||
556
packages/codeflash-api/tests/test_testgen_review.py
Normal file
556
packages/codeflash-api/tests/test_testgen_review.py
Normal file
|
|
@ -0,0 +1,556 @@
|
|||
"""Tests for testgen review and repair endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from codeflash_api.testgen._review_router import (
|
||||
_build_coverage_context,
|
||||
_get_syntax_error,
|
||||
_parse_review_verdicts,
|
||||
)
|
||||
from codeflash_api.testgen.schemas import (
|
||||
CoverageDetails,
|
||||
CoverageFunctionDetail,
|
||||
FunctionVerdict,
|
||||
TestgenReviewRequest,
|
||||
TestgenReviewResponse,
|
||||
TestRepairRequest,
|
||||
TestRepairResponse,
|
||||
TestReview,
|
||||
TestSourceWithFailures,
|
||||
)
|
||||
|
||||
|
||||
class TestBuildCoverageContext:
|
||||
"""Tests for _build_coverage_context."""
|
||||
|
||||
def test_none_returns_empty(self) -> None:
|
||||
"""
|
||||
None coverage details yields empty string.
|
||||
"""
|
||||
assert "" == _build_coverage_context(None)
|
||||
|
||||
def test_formats_main_function(self) -> None:
|
||||
"""
|
||||
Main function coverage is included with unexecuted lines.
|
||||
"""
|
||||
details = CoverageDetails(
|
||||
coverage_percentage=80.0,
|
||||
threshold_percentage=90.0,
|
||||
main_function=CoverageFunctionDetail(
|
||||
name="compute",
|
||||
coverage=80.0,
|
||||
unexecuted_lines=[10, 15, 20],
|
||||
),
|
||||
)
|
||||
|
||||
result = _build_coverage_context(details)
|
||||
|
||||
assert "compute" in result
|
||||
assert "80%" in result
|
||||
assert "[10, 15, 20]" in result
|
||||
|
||||
def test_includes_dependent_function(self) -> None:
|
||||
"""
|
||||
Dependent function coverage is appended when present.
|
||||
"""
|
||||
details = CoverageDetails(
|
||||
coverage_percentage=75.0,
|
||||
threshold_percentage=90.0,
|
||||
main_function=CoverageFunctionDetail(
|
||||
name="main_fn",
|
||||
coverage=90.0,
|
||||
),
|
||||
dependent_function=CoverageFunctionDetail(
|
||||
name="helper_fn",
|
||||
coverage=60.0,
|
||||
unexecuted_lines=[5, 6],
|
||||
),
|
||||
)
|
||||
|
||||
result = _build_coverage_context(details)
|
||||
|
||||
assert "helper_fn" in result
|
||||
assert "60%" in result
|
||||
|
||||
|
||||
class TestParseReviewVerdicts:
|
||||
"""Tests for _parse_review_verdicts."""
|
||||
|
||||
def test_parses_json_object(self) -> None:
|
||||
"""
|
||||
Standard JSON object with functions array is parsed.
|
||||
"""
|
||||
content = json.dumps(
|
||||
{
|
||||
"functions": [
|
||||
{
|
||||
"function_name": "test_foo",
|
||||
"verdict": "repair",
|
||||
"reason": "bad test",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
verdicts = _parse_review_verdicts(content)
|
||||
|
||||
assert 1 == len(verdicts)
|
||||
assert "test_foo" == verdicts[0].function_name
|
||||
assert "repair" == verdicts[0].verdict
|
||||
|
||||
def test_parses_json_in_code_block(self) -> None:
|
||||
"""
|
||||
JSON inside a markdown code block is extracted.
|
||||
"""
|
||||
content = (
|
||||
'Some text\n```json\n{"functions": ['
|
||||
'{"function_name": "test_bar",'
|
||||
' "verdict": "repair", "reason": "x"}'
|
||||
"]}\n```\nMore text"
|
||||
)
|
||||
|
||||
verdicts = _parse_review_verdicts(content)
|
||||
|
||||
assert 1 == len(verdicts)
|
||||
assert "test_bar" == verdicts[0].function_name
|
||||
|
||||
def test_filters_non_repair_verdicts(self) -> None:
|
||||
"""
|
||||
Only repair verdicts are returned.
|
||||
"""
|
||||
content = json.dumps(
|
||||
{
|
||||
"functions": [
|
||||
{
|
||||
"function_name": "test_ok",
|
||||
"verdict": "pass",
|
||||
"reason": "",
|
||||
},
|
||||
{
|
||||
"function_name": "test_bad",
|
||||
"verdict": "repair",
|
||||
"reason": "broken",
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
verdicts = _parse_review_verdicts(content)
|
||||
|
||||
assert 1 == len(verdicts)
|
||||
assert "test_bad" == verdicts[0].function_name
|
||||
|
||||
def test_empty_functions_returns_empty(self) -> None:
|
||||
"""
|
||||
Empty functions list yields no verdicts.
|
||||
"""
|
||||
content = '{"functions": []}'
|
||||
|
||||
assert [] == _parse_review_verdicts(content)
|
||||
|
||||
def test_invalid_json_returns_empty(self) -> None:
|
||||
"""
|
||||
Unparseable JSON returns empty list.
|
||||
"""
|
||||
assert [] == _parse_review_verdicts("not json at all")
|
||||
|
||||
def test_parses_array_directly(self) -> None:
|
||||
"""
|
||||
A raw JSON array is accepted as the functions list.
|
||||
"""
|
||||
content = json.dumps(
|
||||
[
|
||||
{
|
||||
"function_name": "test_x",
|
||||
"verdict": "repair",
|
||||
"reason": "r",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
verdicts = _parse_review_verdicts(content)
|
||||
|
||||
assert 1 == len(verdicts)
|
||||
|
||||
|
||||
class TestGetSyntaxError:
|
||||
"""Tests for _get_syntax_error."""
|
||||
|
||||
def test_valid_code(self) -> None:
|
||||
"""
|
||||
Valid Python returns empty string.
|
||||
"""
|
||||
assert "" == _get_syntax_error("x = 1\n")
|
||||
|
||||
def test_syntax_error(self) -> None:
|
||||
"""
|
||||
Invalid Python returns error description.
|
||||
"""
|
||||
result = _get_syntax_error("def f(\n")
|
||||
|
||||
assert "Line" in result
|
||||
|
||||
|
||||
class TestTestgenReviewEndpoint:
|
||||
"""Integration tests for POST /ai/testgen_review."""
|
||||
|
||||
async def test_success(self, client, mock_llm_client) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Successful review returns verdicts.
|
||||
"""
|
||||
mock_llm_client.call = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content='{"functions": []}',
|
||||
cost=0.01,
|
||||
),
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_review",
|
||||
json={
|
||||
"tests": [
|
||||
{
|
||||
"test_source": "def test_a(): pass",
|
||||
"test_index": 0,
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 200 == resp.status_code
|
||||
data = resp.json()
|
||||
assert 1 == len(data["reviews"])
|
||||
assert 0 == data["reviews"][0]["test_index"]
|
||||
|
||||
async def test_invalid_trace_id(self, client) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Invalid trace_id returns 400.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/ai/testgen_review",
|
||||
json={
|
||||
"tests": [],
|
||||
"function_source_code": "def f(): pass",
|
||||
"function_name": "f",
|
||||
"trace_id": "bad-id",
|
||||
},
|
||||
)
|
||||
|
||||
assert 400 == resp.status_code
|
||||
|
||||
async def test_non_python_returns_empty(
|
||||
self, client
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Non-Python language returns empty reviews without LLM call.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/ai/testgen_review",
|
||||
json={
|
||||
"tests": [
|
||||
{
|
||||
"test_source": "test code",
|
||||
"test_index": 0,
|
||||
}
|
||||
],
|
||||
"function_source_code": "fn f() {}",
|
||||
"function_name": "f",
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
"language": "rust",
|
||||
},
|
||||
)
|
||||
|
||||
assert 200 == resp.status_code
|
||||
assert [] == resp.json()["reviews"]
|
||||
|
||||
async def test_failed_tests_marked_for_repair(
|
||||
self, client, mock_llm_client
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Tests with behavioral failures are auto-flagged for repair.
|
||||
"""
|
||||
mock_llm_client.call = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content='{"functions": []}',
|
||||
cost=0.01,
|
||||
),
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_review",
|
||||
json={
|
||||
"tests": [
|
||||
{
|
||||
"test_source": "def test_a(): assert False",
|
||||
"test_index": 0,
|
||||
"failed_test_functions": ["test_a"],
|
||||
"failure_messages": {
|
||||
"test_a": "AssertionError"
|
||||
},
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 200 == resp.status_code
|
||||
reviews = resp.json()["reviews"]
|
||||
assert 1 == len(reviews)
|
||||
fns = reviews[0]["functions"]
|
||||
assert 1 == len(fns)
|
||||
assert "test_a" == fns[0]["function_name"]
|
||||
assert "repair" == fns[0]["verdict"]
|
||||
|
||||
|
||||
class TestTestgenRepairEndpoint:
|
||||
"""Integration tests for POST /ai/testgen_repair."""
|
||||
|
||||
async def test_success(self, client, mock_llm_client) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Successful repair returns repaired code.
|
||||
"""
|
||||
repaired = "def test_a(): assert True"
|
||||
mock_llm_client.call = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content=f"```python\n{repaired}\n```",
|
||||
cost=0.01,
|
||||
),
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "def test_a(): assert False",
|
||||
"functions_to_repair": [
|
||||
{
|
||||
"function_name": "test_a",
|
||||
"reason": "Always fails",
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"module_path": "src/mod.py",
|
||||
"test_module_path": "tests/test_mod.py",
|
||||
"test_framework": "pytest",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 200 == resp.status_code
|
||||
data = resp.json()
|
||||
assert repaired in data["generated_tests"]
|
||||
|
||||
async def test_invalid_trace_id(self, client) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Invalid trace_id returns 400.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "code",
|
||||
"functions_to_repair": [],
|
||||
"function_source_code": "def f(): pass",
|
||||
"function_name": "f",
|
||||
"module_path": "m.py",
|
||||
"test_module_path": "t.py",
|
||||
"test_framework": "pytest",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "bad",
|
||||
},
|
||||
)
|
||||
|
||||
assert 400 == resp.status_code
|
||||
|
||||
async def test_non_python_returns_400(self, client) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Non-Python language returns 400.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "code",
|
||||
"functions_to_repair": [],
|
||||
"function_source_code": "fn f() {}",
|
||||
"function_name": "f",
|
||||
"module_path": "m.rs",
|
||||
"test_module_path": "t.rs",
|
||||
"test_framework": "cargo",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
"language": "rust",
|
||||
},
|
||||
)
|
||||
|
||||
assert 400 == resp.status_code
|
||||
|
||||
async def test_syntax_error_triggers_retry(
|
||||
self, client, mock_llm_client
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
Syntax error in first response triggers a retry.
|
||||
"""
|
||||
bad_code = "def test_a(:\n pass"
|
||||
good_code = "def test_a():\n pass"
|
||||
mock_llm_client.call = AsyncMock(
|
||||
side_effect=[
|
||||
MagicMock(
|
||||
content=f"```python\n{bad_code}\n```",
|
||||
cost=0.01,
|
||||
),
|
||||
MagicMock(
|
||||
content=f"```python\n{good_code}\n```",
|
||||
cost=0.01,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "def test_a(): assert False",
|
||||
"functions_to_repair": [
|
||||
{
|
||||
"function_name": "test_a",
|
||||
"reason": "broken",
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"module_path": "m.py",
|
||||
"test_module_path": "t.py",
|
||||
"test_framework": "pytest",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 200 == resp.status_code
|
||||
assert good_code in resp.json()["generated_tests"]
|
||||
assert 2 == mock_llm_client.call.await_count
|
||||
|
||||
async def test_all_retries_fail_returns_422(
|
||||
self, client, mock_llm_client
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
If all retries produce syntax errors, return 422.
|
||||
"""
|
||||
bad_code = "def test_a(:\n pass"
|
||||
mock_llm_client.call = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
content=f"```python\n{bad_code}\n```",
|
||||
cost=0.01,
|
||||
),
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "def test_a(): pass",
|
||||
"functions_to_repair": [
|
||||
{
|
||||
"function_name": "test_a",
|
||||
"reason": "x",
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"module_path": "m.py",
|
||||
"test_module_path": "t.py",
|
||||
"test_framework": "pytest",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 422 == resp.status_code
|
||||
|
||||
async def test_llm_failure_returns_422(
|
||||
self, client, mock_llm_client
|
||||
) -> None: # type: ignore[no-untyped-def]
|
||||
"""
|
||||
LLM exception returns 422.
|
||||
"""
|
||||
mock_llm_client.call = AsyncMock(
|
||||
side_effect=RuntimeError("LLM down"),
|
||||
)
|
||||
|
||||
resp = await client.post(
|
||||
"/ai/testgen_repair",
|
||||
json={
|
||||
"test_source": "def test_a(): pass",
|
||||
"functions_to_repair": [
|
||||
{
|
||||
"function_name": "test_a",
|
||||
"reason": "x",
|
||||
}
|
||||
],
|
||||
"function_source_code": "def f(): return 1",
|
||||
"function_name": "f",
|
||||
"module_path": "m.py",
|
||||
"test_module_path": "t.py",
|
||||
"test_framework": "pytest",
|
||||
"test_timeout": 60,
|
||||
"trace_id": "12345678-1234-4000-8000-000000000000",
|
||||
},
|
||||
)
|
||||
|
||||
assert 422 == resp.status_code
|
||||
|
||||
|
||||
class TestTestRepairRequestValidator:
|
||||
"""Tests for TestRepairRequest model validator."""
|
||||
|
||||
def test_resolves_python_version_from_language_version(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
python_version is set from language_version when not provided.
|
||||
"""
|
||||
req = TestRepairRequest(
|
||||
test_source="code",
|
||||
functions_to_repair=[],
|
||||
function_source_code="def f(): pass",
|
||||
function_name="f",
|
||||
module_path="m.py",
|
||||
test_module_path="t.py",
|
||||
test_framework="pytest",
|
||||
test_timeout=60,
|
||||
trace_id="abc",
|
||||
language_version="3.12.1",
|
||||
)
|
||||
|
||||
assert "3.12.1" == req.python_version
|
||||
|
||||
def test_does_not_override_explicit_python_version(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Explicit python_version is not overridden by language_version.
|
||||
"""
|
||||
req = TestRepairRequest(
|
||||
test_source="code",
|
||||
functions_to_repair=[],
|
||||
function_source_code="def f(): pass",
|
||||
function_name="f",
|
||||
module_path="m.py",
|
||||
test_module_path="t.py",
|
||||
test_framework="pytest",
|
||||
test_timeout=60,
|
||||
trace_id="abc",
|
||||
python_version="3.11.0",
|
||||
language_version="3.12.1",
|
||||
)
|
||||
|
||||
assert "3.11.0" == req.python_version
|
||||
|
|
@ -107,6 +107,9 @@ ignore = [
|
|||
"packages/codeflash-api/src/codeflash_api/observability/_recording.py" = [
|
||||
"PLR0913", # recording functions faithfully match Django signatures
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/testgen/_review_router.py" = [
|
||||
"PLR0913", # _review_single_test needs all context params for prompt assembly
|
||||
]
|
||||
"packages/codeflash-api/src/codeflash_api/optimize/_context.py" = [
|
||||
"PLR2004", # magic values in faithfully ported version parsing and humanize_ns
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue