feat: per-function test review + repair endpoints

Add POST /ai/testgen_review and POST /ai/testgen_repair endpoints.
Review accepts per-test data with pre-flagged behavioral failures, AI
reviews passing functions for unrealistic patterns, returns per-function
verdicts. Repair takes flagged functions, LLM rewrites them,
re-instruments, returns repaired test source. Python-only gate.
This commit is contained in:
Kevin Turcios 2026-03-02 06:28:56 -05:00
parent 9d6799a87f
commit 87ab144d40
9 changed files with 445 additions and 0 deletions

View file

@ -30,6 +30,8 @@ from core.languages.python.optimizer.refinement import refinement_api
from core.log_features.log_features import features_api
from core.shared.optimizer_router import optimize_api
from core.shared.ranker.ranker import ranker_api
from core.shared.testgen_review.repair import testgen_repair_api
from core.shared.testgen_review.review import testgen_review_api
from core.shared.testgen_router import testgen_api
from core.shared.workflow_gen.workflow_gen import workflow_gen_api
@ -37,6 +39,8 @@ urlpatterns = [
path("ai/optimize", optimize_api.urls),
path("ai/optimize-line-profiler", optimize_line_profiler_api.urls),
path("ai/testgen", testgen_api.urls),
path("ai/testgen_review", testgen_review_api.urls),
path("ai/testgen_repair", testgen_repair_api.urls),
path("ai/log_features", features_api.urls),
path("ai/refinement", refinement_api.urls),
path("ai/explain", explanations_api.urls),

View file

@ -0,0 +1,78 @@
"""Request/response schemas for the testgen review and repair endpoints."""
from ninja import Schema
from aiservice.models.functions_to_optimize import FunctionToOptimize
# ── Review ──────────────────────────────────────────────────────────
class TestSourceWithFailures(Schema):
test_source: str
test_index: int
failed_test_functions: list[str] = [] # functions that failed behavioral tests
class TestgenReviewSchema(Schema):
tests: list[TestSourceWithFailures]
function_source_code: str
function_name: str
trace_id: str
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
class FunctionVerdict(Schema):
function_name: str
verdict: str # "pass" or "repair"
reason: str = ""
class TestReview(Schema):
test_index: int
functions: list[FunctionVerdict]
class TestgenReviewResponseSchema(Schema):
reviews: list[TestReview]
class TestgenReviewErrorSchema(Schema):
error: str
# ── Repair ──────────────────────────────────────────────────────────
class FunctionToRepair(Schema):
function_name: str
reason: str
class TestRepairSchema(Schema):
test_source: str
functions_to_repair: list[FunctionToRepair]
function_source_code: str
function_to_optimize: FunctionToOptimize
helper_function_names: list[str] | None = None
module_path: str
test_module_path: str
test_framework: str
test_timeout: int
trace_id: str
python_version: str | None = None
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
class TestRepairResponseSchema(Schema):
generated_tests: str
instrumented_behavior_tests: str
instrumented_perf_tests: str
class TestRepairErrorSchema(Schema):
error: str

View file

@ -0,0 +1,17 @@
You are a test repair assistant for an AI code optimizer.
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
Constraints for repaired functions:
- Use diverse, realistic inputs that reflect production usage patterns
- Only use the function's public API — do not access internal implementation details
- Exercise meaningful code paths
- Keep the same function name so test discovery still works
- Maintain the same testing framework conventions (pytest, unittest, etc.)
- Do NOT add comments, docstrings, or type annotations that weren't in the original
- The reason provided for each function describes the specific problem — fix that problem
Return the complete file in a single Python code block:
```python
# complete test file here
```

View file

@ -0,0 +1,19 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<test_file>
```python
{{ test_source }}
```
</test_file>
<functions_to_repair>
{% for fn in functions_to_repair -%}
- `{{ fn.function_name }}`: {{ fn.reason }}
{% endfor %}
</functions_to_repair>
Rewrite only the functions listed above. Return the complete file.

View file

@ -0,0 +1,31 @@
You are a test quality reviewer for an AI code optimizer. You review generated test functions and flag ones that would produce misleading performance benchmarks.
Flag a test function if it has any of these problems:
1. **Identical-input repetition**: Calling the function with the same arguments multiple times. This inflates cache hit rates and produces fake speedups that vanish in production.
2. **Internal state manipulation**: Directly accessing internal implementation details instead of using the function's public API. This tests implementation details, not behavior.
3. **Trivial / no-op inputs**: Inputs that cause the function to return immediately without exercising meaningful code paths.
4. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage.
Only flag clear problems. If a test function looks reasonable, omit it from the output.
Respond with a JSON code block containing only the functions that need repair:
```json
{
"functions": [
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
]
}
```
If all functions are acceptable:
```json
{"functions": []}
```
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.

View file

@ -0,0 +1,19 @@
<function_under_test name="{{ function_name }}">
```python
{{ function_source_code }}
```
</function_under_test>
<generated_test>
```python
{{ test_source }}
```
</generated_test>
{% if failed_note %}
<note>
{{ failed_note }}
</note>
{% endif %}
Review each test function in the generated test. Return your verdict as JSON.

View file

@ -0,0 +1,118 @@
"""Test repair endpoint.
Takes a generated test file with specific flagged functions and returns a
repaired version with those functions rewritten.
Currently enabled for Python only.
"""
from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
testgen_repair_api = NinjaAPI(urls_namespace="testgen_repair")
_prompts_dir = Path(__file__).parent / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
@testgen_repair_api.post(
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
)
async def testgen_repair(
request: AuthenticatedRequest, data: TestRepairSchema
) -> tuple[int, TestRepairResponseSchema | TestRepairErrorSchema]:
if data.language != "python":
return 400, TestRepairErrorSchema(error="Test repair is only supported for Python")
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
try:
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
function_name=data.function_to_optimize.function_name,
function_source_code=data.function_source_code,
test_source=data.test_source,
functions_to_repair=data.functions_to_repair,
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
obs_context: dict[str, Any] = {}
if data.call_sequence is not None:
obs_context["call_sequence"] = data.call_sequence
response = await call_llm(
llm=EXECUTE_MODEL,
messages=messages,
call_type="testgen_repair",
trace_id=data.trace_id,
user_id=request.user,
context=obs_context,
)
cost = calculate_llm_cost(response.raw_response, EXECUTE_MODEL)
logging.debug(f"testgen_repair LLM cost: {cost}")
repair_text = response.content.strip()
if result := extract_code_block_with_context(repair_text, language="python"):
_before, repaired_code, _after = result
else:
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
# Re-instrument the repaired tests
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
testgen_data = TestGenSchema(
source_code_being_tested=data.function_source_code,
function_to_optimize=data.function_to_optimize,
helper_function_names=data.helper_function_names,
module_path=data.module_path,
test_module_path=data.test_module_path,
test_framework=data.test_framework,
test_timeout=data.test_timeout,
trace_id=data.trace_id,
python_version=data.python_version,
language=data.language,
codeflash_version=data.codeflash_version,
)
python_version, _ctx = validate_request_data(testgen_data)
instrumented_behavior, instrumented_perf = instrument_tests(repaired_code, testgen_data, python_version)
if instrumented_behavior is None or instrumented_perf is None:
return 500, TestRepairErrorSchema(error="Failed to instrument repaired tests")
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-completed")
return 200, TestRepairResponseSchema(
generated_tests=repaired_code,
instrumented_behavior_tests=instrumented_behavior,
instrumented_perf_tests=instrumented_perf,
)
except Exception as e:
logging.exception("Error in testgen_repair")
sentry_sdk.capture_exception(e)
return 500, TestRepairErrorSchema(error="Internal server error")

View file

@ -0,0 +1,159 @@
"""Test quality review endpoint.
Reviews AI-generated test sources for unrealistic patterns (cache warm-up,
internal state manipulation, identical inputs, etc.) and returns per-function
verdicts of "pass" or "repair" with reasons.
Currently enabled for Python only other languages return all-pass immediately.
"""
from __future__ import annotations
import asyncio
import json
import logging
from pathlib import Path
from typing import Any
import sentry_sdk
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import (
FunctionVerdict,
TestgenReviewErrorSchema,
TestgenReviewResponseSchema,
TestgenReviewSchema,
TestReview,
)
testgen_review_api = NinjaAPI(urls_namespace="testgen_review")
_prompts_dir = Path(__file__).parent / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
@testgen_review_api.post(
"/", response={200: TestgenReviewResponseSchema, 400: TestgenReviewErrorSchema, 500: TestgenReviewErrorSchema}
)
async def testgen_review(
request: AuthenticatedRequest, data: TestgenReviewSchema
) -> tuple[int, TestgenReviewResponseSchema | TestgenReviewErrorSchema]:
# Only enabled for Python for now
if data.language != "python":
return 200, TestgenReviewResponseSchema(reviews=[])
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
try:
reviews: list[TestReview] = []
for test_entry in data.tests:
review = await _review_single_test(
test_source=test_entry.test_source,
test_index=test_entry.test_index,
failed_test_functions=test_entry.failed_test_functions,
function_source_code=data.function_source_code,
function_name=data.function_name,
trace_id=data.trace_id,
user_id=request.user,
call_sequence=data.call_sequence,
)
reviews.append(review)
await asyncio.to_thread(
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
)
return 200, TestgenReviewResponseSchema(reviews=reviews)
except Exception as e:
logging.exception("Error in testgen_review")
sentry_sdk.capture_exception(e)
return 500, TestgenReviewErrorSchema(error="Internal server error")
async def _review_single_test(
test_source: str,
test_index: int,
failed_test_functions: list[str],
function_source_code: str,
function_name: str,
trace_id: str,
user_id: str,
call_sequence: int | None,
) -> TestReview:
# Functions that failed behavioral tests are definitively bad — no AI review needed
failed_verdicts = [
FunctionVerdict(function_name=fn, verdict="repair", reason="Failed behavioral test against original code")
for fn in failed_test_functions
]
# Build prompt for AI review of passing functions
failed_note = ""
if failed_test_functions:
failed_list = ", ".join(f"`{fn}`" for fn in failed_test_functions)
failed_note = (
f"Note: The following functions already failed behavioral tests and will be repaired separately: "
f"{failed_list}. Do NOT include them in your review — only review the remaining functions."
)
system_prompt = _jinja_env.get_template("testgen_review_system_prompt.md.j2").render()
user_prompt = _jinja_env.get_template("testgen_review_user_prompt.md.j2").render(
function_name=function_name,
function_source_code=function_source_code,
test_source=test_source,
failed_note=failed_note,
)
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
obs_context: dict[str, Any] = {}
if call_sequence is not None:
obs_context["call_sequence"] = call_sequence
response = await call_llm(
llm=HAIKU_MODEL,
messages=messages,
call_type="testgen_review",
trace_id=trace_id,
user_id=user_id,
context=obs_context,
)
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
logging.debug(f"testgen_review LLM cost: {cost}")
ai_verdicts = _parse_review_response(response.content.strip())
return TestReview(test_index=test_index, functions=failed_verdicts + ai_verdicts)
def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
if result := extract_code_block_with_context(review_text, language="json"):
_before, json_content, _after = result
try:
data = json.loads(json_content.strip())
# Support both formats: {"functions": [...]} or [...]
fn_list = data if isinstance(data, list) else data.get("functions", [])
return [
FunctionVerdict(
function_name=fn.get("function_name", ""),
verdict=fn.get("verdict", "pass"),
reason=fn.get("reason", ""),
)
for fn in fn_list
if fn.get("verdict") == "repair"
]
except (json.JSONDecodeError, AttributeError):
logging.warning("testgen_review: failed to parse AI response JSON")
return []