feat: per-function test review + repair endpoints
Add POST /ai/testgen_review and POST /ai/testgen_repair endpoints. Review accepts per-test data with pre-flagged behavioral failures, AI reviews passing functions for unrealistic patterns, returns per-function verdicts. Repair takes flagged functions, LLM rewrites them, re-instruments, returns repaired test source. Python-only gate.
This commit is contained in:
parent
9d6799a87f
commit
87ab144d40
9 changed files with 445 additions and 0 deletions
|
|
@ -30,6 +30,8 @@ from core.languages.python.optimizer.refinement import refinement_api
|
|||
from core.log_features.log_features import features_api
|
||||
from core.shared.optimizer_router import optimize_api
|
||||
from core.shared.ranker.ranker import ranker_api
|
||||
from core.shared.testgen_review.repair import testgen_repair_api
|
||||
from core.shared.testgen_review.review import testgen_review_api
|
||||
from core.shared.testgen_router import testgen_api
|
||||
from core.shared.workflow_gen.workflow_gen import workflow_gen_api
|
||||
|
||||
|
|
@ -37,6 +39,8 @@ urlpatterns = [
|
|||
path("ai/optimize", optimize_api.urls),
|
||||
path("ai/optimize-line-profiler", optimize_line_profiler_api.urls),
|
||||
path("ai/testgen", testgen_api.urls),
|
||||
path("ai/testgen_review", testgen_review_api.urls),
|
||||
path("ai/testgen_repair", testgen_repair_api.urls),
|
||||
path("ai/log_features", features_api.urls),
|
||||
path("ai/refinement", refinement_api.urls),
|
||||
path("ai/explain", explanations_api.urls),
|
||||
|
|
|
|||
0
django/aiservice/core/shared/testgen_review/__init__.py
Normal file
0
django/aiservice/core/shared/testgen_review/__init__.py
Normal file
78
django/aiservice/core/shared/testgen_review/models.py
Normal file
78
django/aiservice/core/shared/testgen_review/models.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""Request/response schemas for the testgen review and repair endpoints."""
|
||||
|
||||
from ninja import Schema
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
# ── Review ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSourceWithFailures(Schema):
|
||||
test_source: str
|
||||
test_index: int
|
||||
failed_test_functions: list[str] = [] # functions that failed behavioral tests
|
||||
|
||||
|
||||
class TestgenReviewSchema(Schema):
|
||||
tests: list[TestSourceWithFailures]
|
||||
function_source_code: str
|
||||
function_name: str
|
||||
trace_id: str
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
|
||||
|
||||
class FunctionVerdict(Schema):
|
||||
function_name: str
|
||||
verdict: str # "pass" or "repair"
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class TestReview(Schema):
|
||||
test_index: int
|
||||
functions: list[FunctionVerdict]
|
||||
|
||||
|
||||
class TestgenReviewResponseSchema(Schema):
|
||||
reviews: list[TestReview]
|
||||
|
||||
|
||||
class TestgenReviewErrorSchema(Schema):
|
||||
error: str
|
||||
|
||||
|
||||
# ── Repair ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class FunctionToRepair(Schema):
|
||||
function_name: str
|
||||
reason: str
|
||||
|
||||
|
||||
class TestRepairSchema(Schema):
|
||||
test_source: str
|
||||
functions_to_repair: list[FunctionToRepair]
|
||||
function_source_code: str
|
||||
function_to_optimize: FunctionToOptimize
|
||||
helper_function_names: list[str] | None = None
|
||||
module_path: str
|
||||
test_module_path: str
|
||||
test_framework: str
|
||||
test_timeout: int
|
||||
trace_id: str
|
||||
python_version: str | None = None
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
|
||||
|
||||
class TestRepairResponseSchema(Schema):
|
||||
generated_tests: str
|
||||
instrumented_behavior_tests: str
|
||||
instrumented_perf_tests: str
|
||||
|
||||
|
||||
class TestRepairErrorSchema(Schema):
|
||||
error: str
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
You are a test repair assistant for an AI code optimizer.
|
||||
|
||||
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
|
||||
|
||||
Constraints for repaired functions:
|
||||
- Use diverse, realistic inputs that reflect production usage patterns
|
||||
- Only use the function's public API — do not access internal implementation details
|
||||
- Exercise meaningful code paths
|
||||
- Keep the same function name so test discovery still works
|
||||
- Maintain the same testing framework conventions (pytest, unittest, etc.)
|
||||
- Do NOT add comments, docstrings, or type annotations that weren't in the original
|
||||
- The reason provided for each function describes the specific problem — fix that problem
|
||||
|
||||
Return the complete file in a single Python code block:
|
||||
```python
|
||||
# complete test file here
|
||||
```
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<test_file>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</test_file>
|
||||
|
||||
<functions_to_repair>
|
||||
{% for fn in functions_to_repair -%}
|
||||
- `{{ fn.function_name }}`: {{ fn.reason }}
|
||||
{% endfor %}
|
||||
</functions_to_repair>
|
||||
|
||||
Rewrite only the functions listed above. Return the complete file.
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
You are a test quality reviewer for an AI code optimizer. You review generated test functions and flag ones that would produce misleading performance benchmarks.
|
||||
|
||||
Flag a test function if it has any of these problems:
|
||||
|
||||
1. **Identical-input repetition**: Calling the function with the same arguments multiple times. This inflates cache hit rates and produces fake speedups that vanish in production.
|
||||
|
||||
2. **Internal state manipulation**: Directly accessing internal implementation details instead of using the function's public API. This tests implementation details, not behavior.
|
||||
|
||||
3. **Trivial / no-op inputs**: Inputs that cause the function to return immediately without exercising meaningful code paths.
|
||||
|
||||
4. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage.
|
||||
|
||||
Only flag clear problems. If a test function looks reasonable, omit it from the output.
|
||||
|
||||
Respond with a JSON code block containing only the functions that need repair:
|
||||
|
||||
```json
|
||||
{
|
||||
"functions": [
|
||||
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If all functions are acceptable:
|
||||
|
||||
```json
|
||||
{"functions": []}
|
||||
```
|
||||
|
||||
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<generated_test>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</generated_test>
|
||||
{% if failed_note %}
|
||||
|
||||
<note>
|
||||
{{ failed_note }}
|
||||
</note>
|
||||
{% endif %}
|
||||
|
||||
Review each test function in the generated test. Return your verdict as JSON.
|
||||
118
django/aiservice/core/shared/testgen_review/repair.py
Normal file
118
django/aiservice/core/shared/testgen_review/repair.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Test repair endpoint.
|
||||
|
||||
Takes a generated test file with specific flagged functions and returns a
|
||||
repaired version with those functions rewritten.
|
||||
|
||||
Currently enabled for Python only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
from ninja import NinjaAPI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
|
||||
|
||||
testgen_repair_api = NinjaAPI(urls_namespace="testgen_repair")
|
||||
|
||||
_prompts_dir = Path(__file__).parent / "prompts"
|
||||
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
|
||||
|
||||
|
||||
@testgen_repair_api.post(
|
||||
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
|
||||
)
|
||||
async def testgen_repair(
|
||||
request: AuthenticatedRequest, data: TestRepairSchema
|
||||
) -> tuple[int, TestRepairResponseSchema | TestRepairErrorSchema]:
|
||||
if data.language != "python":
|
||||
return 400, TestRepairErrorSchema(error="Test repair is only supported for Python")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
|
||||
|
||||
try:
|
||||
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
|
||||
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
|
||||
function_name=data.function_to_optimize.function_name,
|
||||
function_source_code=data.function_source_code,
|
||||
test_source=data.test_source,
|
||||
functions_to_repair=data.functions_to_repair,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
|
||||
obs_context: dict[str, Any] = {}
|
||||
if data.call_sequence is not None:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
llm=EXECUTE_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_repair",
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, EXECUTE_MODEL)
|
||||
logging.debug(f"testgen_repair LLM cost: {cost}")
|
||||
|
||||
repair_text = response.content.strip()
|
||||
if result := extract_code_block_with_context(repair_text, language="python"):
|
||||
_before, repaired_code, _after = result
|
||||
else:
|
||||
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
|
||||
|
||||
# Re-instrument the repaired tests
|
||||
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
|
||||
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
|
||||
|
||||
testgen_data = TestGenSchema(
|
||||
source_code_being_tested=data.function_source_code,
|
||||
function_to_optimize=data.function_to_optimize,
|
||||
helper_function_names=data.helper_function_names,
|
||||
module_path=data.module_path,
|
||||
test_module_path=data.test_module_path,
|
||||
test_framework=data.test_framework,
|
||||
test_timeout=data.test_timeout,
|
||||
trace_id=data.trace_id,
|
||||
python_version=data.python_version,
|
||||
language=data.language,
|
||||
codeflash_version=data.codeflash_version,
|
||||
)
|
||||
|
||||
python_version, _ctx = validate_request_data(testgen_data)
|
||||
instrumented_behavior, instrumented_perf = instrument_tests(repaired_code, testgen_data, python_version)
|
||||
|
||||
if instrumented_behavior is None or instrumented_perf is None:
|
||||
return 500, TestRepairErrorSchema(error="Failed to instrument repaired tests")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-completed")
|
||||
return 200, TestRepairResponseSchema(
|
||||
generated_tests=repaired_code,
|
||||
instrumented_behavior_tests=instrumented_behavior,
|
||||
instrumented_perf_tests=instrumented_perf,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error in testgen_repair")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, TestRepairErrorSchema(error="Internal server error")
|
||||
159
django/aiservice/core/shared/testgen_review/review.py
Normal file
159
django/aiservice/core/shared/testgen_review/review.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
"""Test quality review endpoint.
|
||||
|
||||
Reviews AI-generated test sources for unrealistic patterns (cache warm-up,
|
||||
internal state manipulation, identical inputs, etc.) and returns per-function
|
||||
verdicts of "pass" or "repair" with reasons.
|
||||
|
||||
Currently enabled for Python only — other languages return all-pass immediately.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
from ninja import NinjaAPI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import (
|
||||
FunctionVerdict,
|
||||
TestgenReviewErrorSchema,
|
||||
TestgenReviewResponseSchema,
|
||||
TestgenReviewSchema,
|
||||
TestReview,
|
||||
)
|
||||
|
||||
testgen_review_api = NinjaAPI(urls_namespace="testgen_review")
|
||||
|
||||
_prompts_dir = Path(__file__).parent / "prompts"
|
||||
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
|
||||
|
||||
|
||||
@testgen_review_api.post(
|
||||
"/", response={200: TestgenReviewResponseSchema, 400: TestgenReviewErrorSchema, 500: TestgenReviewErrorSchema}
|
||||
)
|
||||
async def testgen_review(
|
||||
request: AuthenticatedRequest, data: TestgenReviewSchema
|
||||
) -> tuple[int, TestgenReviewResponseSchema | TestgenReviewErrorSchema]:
|
||||
# Only enabled for Python for now
|
||||
if data.language != "python":
|
||||
return 200, TestgenReviewResponseSchema(reviews=[])
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
|
||||
|
||||
try:
|
||||
reviews: list[TestReview] = []
|
||||
for test_entry in data.tests:
|
||||
review = await _review_single_test(
|
||||
test_source=test_entry.test_source,
|
||||
test_index=test_entry.test_index,
|
||||
failed_test_functions=test_entry.failed_test_functions,
|
||||
function_source_code=data.function_source_code,
|
||||
function_name=data.function_name,
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
call_sequence=data.call_sequence,
|
||||
)
|
||||
reviews.append(review)
|
||||
|
||||
await asyncio.to_thread(
|
||||
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
|
||||
)
|
||||
return 200, TestgenReviewResponseSchema(reviews=reviews)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error in testgen_review")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, TestgenReviewErrorSchema(error="Internal server error")
|
||||
|
||||
|
||||
async def _review_single_test(
|
||||
test_source: str,
|
||||
test_index: int,
|
||||
failed_test_functions: list[str],
|
||||
function_source_code: str,
|
||||
function_name: str,
|
||||
trace_id: str,
|
||||
user_id: str,
|
||||
call_sequence: int | None,
|
||||
) -> TestReview:
|
||||
# Functions that failed behavioral tests are definitively bad — no AI review needed
|
||||
failed_verdicts = [
|
||||
FunctionVerdict(function_name=fn, verdict="repair", reason="Failed behavioral test against original code")
|
||||
for fn in failed_test_functions
|
||||
]
|
||||
|
||||
# Build prompt for AI review of passing functions
|
||||
failed_note = ""
|
||||
if failed_test_functions:
|
||||
failed_list = ", ".join(f"`{fn}`" for fn in failed_test_functions)
|
||||
failed_note = (
|
||||
f"Note: The following functions already failed behavioral tests and will be repaired separately: "
|
||||
f"{failed_list}. Do NOT include them in your review — only review the remaining functions."
|
||||
)
|
||||
|
||||
system_prompt = _jinja_env.get_template("testgen_review_system_prompt.md.j2").render()
|
||||
user_prompt = _jinja_env.get_template("testgen_review_user_prompt.md.j2").render(
|
||||
function_name=function_name,
|
||||
function_source_code=function_source_code,
|
||||
test_source=test_source,
|
||||
failed_note=failed_note,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
|
||||
obs_context: dict[str, Any] = {}
|
||||
if call_sequence is not None:
|
||||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_review",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
|
||||
logging.debug(f"testgen_review LLM cost: {cost}")
|
||||
|
||||
ai_verdicts = _parse_review_response(response.content.strip())
|
||||
return TestReview(test_index=test_index, functions=failed_verdicts + ai_verdicts)
|
||||
|
||||
|
||||
def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
|
||||
if result := extract_code_block_with_context(review_text, language="json"):
|
||||
_before, json_content, _after = result
|
||||
try:
|
||||
data = json.loads(json_content.strip())
|
||||
# Support both formats: {"functions": [...]} or [...]
|
||||
fn_list = data if isinstance(data, list) else data.get("functions", [])
|
||||
return [
|
||||
FunctionVerdict(
|
||||
function_name=fn.get("function_name", ""),
|
||||
verdict=fn.get("verdict", "pass"),
|
||||
reason=fn.get("reason", ""),
|
||||
)
|
||||
for fn in fn_list
|
||||
if fn.get("verdict") == "repair"
|
||||
]
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
logging.warning("testgen_review: failed to parse AI response JSON")
|
||||
return []
|
||||
Loading…
Reference in a new issue