Merge pull request #2465 from codeflash-ai/testgen-review-repair
feat: per-function test review + repair endpoints
This commit is contained in:
commit
8d1dfd9bdb
14 changed files with 556 additions and 30 deletions
|
|
@ -30,6 +30,8 @@ from core.languages.python.optimizer.refinement import refinement_api
|
|||
from core.log_features.log_features import features_api
|
||||
from core.shared.optimizer_router import optimize_api
|
||||
from core.shared.ranker.ranker import ranker_api
|
||||
from core.shared.testgen_review.repair import testgen_repair_api
|
||||
from core.shared.testgen_review.review import testgen_review_api
|
||||
from core.shared.testgen_router import testgen_api
|
||||
from core.shared.workflow_gen.workflow_gen import workflow_gen_api
|
||||
|
||||
|
|
@ -37,6 +39,8 @@ urlpatterns = [
|
|||
path("ai/optimize", optimize_api.urls),
|
||||
path("ai/optimize-line-profiler", optimize_line_profiler_api.urls),
|
||||
path("ai/testgen", testgen_api.urls),
|
||||
path("ai/testgen_review", testgen_review_api.urls),
|
||||
path("ai/testgen_repair", testgen_repair_api.urls),
|
||||
path("ai/log_features", features_api.urls),
|
||||
path("ai/refinement", refinement_api.urls),
|
||||
path("ai/explain", explanations_api.urls),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import ast
|
|||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
|
|
@ -71,7 +71,7 @@ def build_prompt(
|
|||
is_async: bool,
|
||||
is_numerical_code: bool | None = None,
|
||||
model_type: str = "openai",
|
||||
) -> tuple[list[dict[str, str]], str, str]:
|
||||
) -> tuple[list[ChatCompletionMessageParam], str, str]:
|
||||
system_template = "generate_async_system.md.j2" if is_async else "generate_system.md.j2"
|
||||
|
||||
system_prompt = _jinja_env.get_template(system_template).render(
|
||||
|
|
@ -89,7 +89,10 @@ def build_prompt(
|
|||
|
||||
posthog_event_suffix = "async-" if is_async else ""
|
||||
error_context = "async " if is_async else ""
|
||||
execute_messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
||||
execute_messages: list[ChatCompletionMessageParam] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
return execute_messages, posthog_event_suffix, error_context
|
||||
|
||||
|
|
@ -163,7 +166,7 @@ async def generate_and_validate_test_code(
|
|||
is_async: bool = False,
|
||||
test_index: int | None = None,
|
||||
) -> tuple[str, str]:
|
||||
obs_context: dict | None = (
|
||||
obs_context: dict[str, Any] | None = (
|
||||
{
|
||||
"call_sequence": call_sequence,
|
||||
"test_index": test_index,
|
||||
|
|
@ -237,7 +240,7 @@ async def generate_regression_tests_from_function(
|
|||
call_sequence: int | None = None,
|
||||
test_index: int | None = None,
|
||||
model_type: str = "openai",
|
||||
) -> tuple[str, str | None, str | None, str]:
|
||||
) -> tuple[str, str, str, str]:
|
||||
execute_messages, posthog_event_suffix, error_context = build_prompt(
|
||||
qualified_name=qualified_name,
|
||||
source_code=source_code,
|
||||
|
|
@ -251,7 +254,7 @@ async def generate_regression_tests_from_function(
|
|||
|
||||
cost_tracker = CostTracker()
|
||||
try:
|
||||
validated_code, raw_llm_content = await generate_and_validate_test_code(
|
||||
validated_code, _raw_llm_content = await generate_and_validate_test_code(
|
||||
messages=execute_messages,
|
||||
model=execute_model,
|
||||
source_code=source_code,
|
||||
|
|
@ -270,7 +273,7 @@ async def generate_regression_tests_from_function(
|
|||
test_index=test_index,
|
||||
)
|
||||
|
||||
processed_cst = postprocessing_testgen_pipeline(
|
||||
processed_cst, display_cst = postprocessing_testgen_pipeline(
|
||||
parse_module_to_cst(validated_code),
|
||||
data.helper_function_names or [],
|
||||
data.function_to_optimize,
|
||||
|
|
@ -279,6 +282,7 @@ async def generate_regression_tests_from_function(
|
|||
)
|
||||
|
||||
generated_test_source = processed_cst.code
|
||||
raw_display_source = display_cst.code
|
||||
|
||||
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
|
||||
generated_test_source, data, python_version
|
||||
|
|
@ -317,7 +321,7 @@ async def generate_regression_tests_from_function(
|
|||
"validation_error": "No test functions found after postprocessing",
|
||||
},
|
||||
)
|
||||
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests, raw_llm_content # noqa: TRY300
|
||||
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests, raw_display_source # noqa: TRY300
|
||||
except CodeValidationError as e:
|
||||
msg = f"Failed to generate valid {error_context}test code after {cost_tracker.calls} tries. trace_id={trace_id}"
|
||||
logging.exception(msg)
|
||||
|
|
@ -380,7 +384,7 @@ async def testgen_python(
|
|||
generated_test_source,
|
||||
instrumented_behavior_tests,
|
||||
instrumented_perf_tests,
|
||||
raw_llm_content,
|
||||
raw_display_source,
|
||||
) = await generate_regression_tests_from_function(
|
||||
source_code=data.source_code_being_tested,
|
||||
qualified_name=data.function_to_optimize.qualified_name,
|
||||
|
|
@ -389,7 +393,7 @@ async def testgen_python(
|
|||
python_version=python_version,
|
||||
data=data,
|
||||
unit_test_package=data.test_framework,
|
||||
is_async=data.function_to_optimize.is_async or data.is_async,
|
||||
is_async=data.function_to_optimize.is_async or data.is_async or False,
|
||||
trace_id=data.trace_id,
|
||||
call_sequence=data.call_sequence,
|
||||
execute_model=execute_model,
|
||||
|
|
@ -403,7 +407,7 @@ async def testgen_python(
|
|||
await log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
generated_tests=[raw_llm_content],
|
||||
generated_tests=[raw_display_source],
|
||||
instrumented_generated_tests=[instrumented_behavior_tests],
|
||||
instrumented_perf_tests=[instrumented_perf_tests],
|
||||
test_index=test_index,
|
||||
|
|
@ -419,6 +423,7 @@ async def testgen_python(
|
|||
generated_tests=generated_test_source,
|
||||
instrumented_behavior_tests=instrumented_behavior_tests,
|
||||
instrumented_perf_tests=instrumented_perf_tests,
|
||||
raw_generated_tests=raw_display_source,
|
||||
)
|
||||
|
||||
except TestGenerationFailedError as e:
|
||||
|
|
|
|||
|
|
@ -31,16 +31,19 @@ def postprocessing_testgen_pipeline(
|
|||
function_to_optimize: FunctionToOptimize,
|
||||
module_path: str,
|
||||
source_code_being_tested: str,
|
||||
) -> Module:
|
||||
) -> tuple[Module, Module]:
|
||||
"""Full postprocessing pipeline for generated test code.
|
||||
|
||||
Applies all CST transformations in sequence:
|
||||
1. Clean up definitions (remove helper functions, unused definitions)
|
||||
2. Modify constructs (large loops, tensors)
|
||||
3. Remove asserts
|
||||
4. Add missing imports
|
||||
5. Replace function definition with import
|
||||
3. Add missing imports, replace function definition with import
|
||||
4. Remove asserts (for instrumentation only)
|
||||
|
||||
Returns (processed_module_with_asserts_removed, processed_module_with_asserts_kept).
|
||||
The second module is the "display" version — fully cleaned but with original asserts intact.
|
||||
"""
|
||||
source_cst: Module | dict[str, Module]
|
||||
if is_multi_context(source_code_being_tested):
|
||||
source_cst = {
|
||||
path: cst.parse_module(code) for path, code in split_markdown_code(source_code_being_tested).items()
|
||||
|
|
@ -48,24 +51,27 @@ def postprocessing_testgen_pipeline(
|
|||
else:
|
||||
source_cst = cst.parse_module(source_code_being_tested)
|
||||
|
||||
pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
|
||||
cleanup_pipeline: list[tuple[Callable[..., Module], dict[str, Any]]] = [
|
||||
(delete_top_def_nodes, {"deletable_list": helper_function_names}),
|
||||
(remove_unused_definitions_from_pytest_file, {}),
|
||||
(modify_large_loops, {}),
|
||||
(modify_tensors, {}),
|
||||
(
|
||||
remove_asserts_from_test,
|
||||
{
|
||||
"function_to_optimize": function_to_optimize,
|
||||
"helper_function_names": helper_function_names,
|
||||
"module_path": module_path,
|
||||
},
|
||||
),
|
||||
(add_missing_imports, {"source_cst": source_cst, "module_path": module_path}),
|
||||
(replace_definition_with_import, {"function": function_to_optimize, "module_path": module_path}),
|
||||
]
|
||||
|
||||
for func, kwargs in pipeline:
|
||||
for func, kwargs in cleanup_pipeline:
|
||||
module = func(module, **kwargs)
|
||||
|
||||
return module
|
||||
# Capture the display version (with asserts) before removing them
|
||||
display_module = module
|
||||
|
||||
# Remove asserts for the instrumentation version
|
||||
module = remove_asserts_from_test(
|
||||
module,
|
||||
function_to_optimize=function_to_optimize,
|
||||
helper_function_names=helper_function_names,
|
||||
module_path=module_path,
|
||||
)
|
||||
|
||||
return module, display_module
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class TestGenResponseSchema(Schema):
|
|||
generated_tests: str
|
||||
instrumented_behavior_tests: str
|
||||
instrumented_perf_tests: str
|
||||
raw_generated_tests: str | None = None
|
||||
|
||||
|
||||
class TestGenDebugInfo(Schema):
|
||||
|
|
|
|||
0
django/aiservice/core/shared/testgen_review/__init__.py
Normal file
0
django/aiservice/core/shared/testgen_review/__init__.py
Normal file
79
django/aiservice/core/shared/testgen_review/models.py
Normal file
79
django/aiservice/core/shared/testgen_review/models.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Request/response schemas for the testgen review and repair endpoints."""
|
||||
|
||||
from ninja import Schema
|
||||
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
# ── Review ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSourceWithFailures(Schema):
|
||||
test_source: str
|
||||
test_index: int
|
||||
failed_test_functions: list[str] = [] # functions that failed behavioral tests
|
||||
failure_messages: dict[str, str] = {} # function_name → error message from test run
|
||||
|
||||
|
||||
class TestgenReviewSchema(Schema):
|
||||
tests: list[TestSourceWithFailures]
|
||||
function_source_code: str
|
||||
function_name: str
|
||||
trace_id: str
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
coverage_summary: str = ""
|
||||
|
||||
|
||||
class FunctionVerdict(Schema):
|
||||
function_name: str
|
||||
verdict: str # "pass" or "repair"
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class TestReview(Schema):
|
||||
test_index: int
|
||||
functions: list[FunctionVerdict]
|
||||
|
||||
|
||||
class TestgenReviewResponseSchema(Schema):
|
||||
reviews: list[TestReview]
|
||||
|
||||
|
||||
class TestgenReviewErrorSchema(Schema):
|
||||
error: str
|
||||
|
||||
|
||||
# ── Repair ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class FunctionToRepair(Schema):
|
||||
function_name: str
|
||||
reason: str
|
||||
|
||||
|
||||
class TestRepairSchema(Schema):
|
||||
test_source: str
|
||||
functions_to_repair: list[FunctionToRepair]
|
||||
function_source_code: str
|
||||
function_to_optimize: FunctionToOptimize
|
||||
helper_function_names: list[str] | None = None
|
||||
module_path: str
|
||||
test_module_path: str
|
||||
test_framework: str
|
||||
test_timeout: int
|
||||
trace_id: str
|
||||
python_version: str | None = None
|
||||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
|
||||
|
||||
class TestRepairResponseSchema(Schema):
|
||||
generated_tests: str
|
||||
instrumented_behavior_tests: str
|
||||
instrumented_perf_tests: str
|
||||
|
||||
|
||||
class TestRepairErrorSchema(Schema):
|
||||
error: str
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
You are a test repair assistant for an AI code optimizer.
|
||||
|
||||
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
|
||||
|
||||
Constraints for repaired functions:
|
||||
- Use diverse, realistic inputs that reflect production usage patterns
|
||||
- Only use the function's public API — do not access internal implementation details
|
||||
- Exercise meaningful code paths
|
||||
- Keep the same function name so test discovery still works
|
||||
- Maintain the same testing framework conventions (pytest, unittest, etc.)
|
||||
- Do NOT add comments, docstrings, or type annotations that weren't in the original
|
||||
- The reason provided for each function describes the specific problem — fix that problem
|
||||
|
||||
Return the complete file in a single Python code block:
|
||||
```python
|
||||
# complete test file here
|
||||
```
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<test_file>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</test_file>
|
||||
|
||||
<functions_to_repair>
|
||||
{% for fn in functions_to_repair -%}
|
||||
- `{{ fn.function_name }}`: {{ fn.reason }}
|
||||
{% endfor %}
|
||||
</functions_to_repair>
|
||||
|
||||
Rewrite only the functions listed above. Return the complete file.
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
You are a test quality reviewer for CodeFlash, an AI code optimizer. CodeFlash uses generated tests to:
|
||||
|
||||
1. **Verify correctness** — run the same tests on both original and optimized code, comparing outputs to catch regressions
|
||||
2. **Benchmark performance** — time the function under test to measure speedup
|
||||
|
||||
Review each test function and flag ones that would compromise correctness verification or produce misleading benchmarks.
|
||||
|
||||
Flag a test function if it has any of these problems:
|
||||
|
||||
1. **Doesn't exercise the function**: The test never calls the target function, or only triggers trivial/no-op code paths (empty input, immediate error return). If coverage information is provided, use it to identify tests that miss the function entirely.
|
||||
|
||||
2. **Non-deterministic behavior**: The test produces different outputs across runs due to randomness, timestamps, or external state. This causes false correctness failures when comparing original vs optimized outputs.
|
||||
|
||||
3. **Internal state coupling**: The test asserts on internal implementation details (private attributes, internal data structures) rather than observable outputs. Optimized code may change internals while preserving behavior, causing false failures.
|
||||
|
||||
4. **Identical-input repetition**: Calling the function with the same arguments multiple times inflates cache hit rates and produces speedups that vanish in production.
|
||||
|
||||
5. **Unrealistic data patterns**: Only a single type or size of input, missing the diversity of real-world usage that would reveal performance regressions.
|
||||
|
||||
Only flag clear problems. If a test function looks reasonable, omit it from the output.
|
||||
|
||||
Respond with a JSON code block containing only the functions that need repair:
|
||||
|
||||
```json
|
||||
{
|
||||
"functions": [
|
||||
{"function_name": "test_example", "verdict": "repair", "reason": "One sentence explaining the specific problem"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If all functions are acceptable:
|
||||
|
||||
```json
|
||||
{"functions": []}
|
||||
```
|
||||
|
||||
Keep reasons to one sentence — they are passed directly to the repair LLM as instructions.
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
<function_under_test name="{{ function_name }}">
|
||||
```python
|
||||
{{ function_source_code }}
|
||||
```
|
||||
</function_under_test>
|
||||
|
||||
<generated_test>
|
||||
```python
|
||||
{{ test_source }}
|
||||
```
|
||||
</generated_test>
|
||||
{% if coverage_summary %}
|
||||
|
||||
<coverage>
|
||||
Test coverage of {{ function_name }}: {{ coverage_summary }}
|
||||
</coverage>
|
||||
{% endif %}
|
||||
{% if failed_note %}
|
||||
|
||||
<note>
|
||||
{{ failed_note }}
|
||||
</note>
|
||||
{% endif %}
|
||||
|
||||
Review each test function in the generated test. Return your verdict as JSON.
|
||||
158
django/aiservice/core/shared/testgen_review/repair.py
Normal file
158
django/aiservice/core/shared/testgen_review/repair.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Test repair endpoint.
|
||||
|
||||
Takes a generated test file with specific flagged functions and returns a
|
||||
repaired version with those functions rewritten.
|
||||
|
||||
Currently enabled for Python only.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
from ninja import NinjaAPI
|
||||
from ninja.errors import HttpError
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
|
||||
|
||||
testgen_repair_api = NinjaAPI(urls_namespace="testgen_repair")
|
||||
|
||||
_prompts_dir = Path(__file__).parent / "prompts"
|
||||
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
|
||||
|
||||
|
||||
@testgen_repair_api.post(
|
||||
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
|
||||
)
|
||||
async def testgen_repair(
|
||||
request: AuthenticatedRequest, data: TestRepairSchema
|
||||
) -> tuple[int, TestRepairResponseSchema | TestRepairErrorSchema]:
|
||||
if data.language != "python":
|
||||
return 400, TestRepairErrorSchema(error="Test repair is only supported for Python")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
|
||||
|
||||
try:
|
||||
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
|
||||
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
|
||||
function_name=data.function_to_optimize.function_name,
|
||||
function_source_code=data.function_source_code,
|
||||
test_source=data.test_source,
|
||||
functions_to_repair=data.functions_to_repair,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
|
||||
obs_context: dict[str, Any] = {}
|
||||
if data.call_sequence is not None:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
llm=EXECUTE_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_repair",
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, EXECUTE_MODEL)
|
||||
logging.debug(f"testgen_repair LLM cost: {cost}")
|
||||
|
||||
repair_text = response.content.strip()
|
||||
if result := extract_code_block_with_context(repair_text, language="python"):
|
||||
_before, repaired_code, _after = result
|
||||
else:
|
||||
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
|
||||
|
||||
# Splice only the flagged functions from the LLM output into the original test source,
|
||||
# keeping all unflagged functions untouched
|
||||
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
|
||||
from core.languages.python.testgen.instrumentation.edit_generated_test import find_and_replace_function
|
||||
from core.languages.python.testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
||||
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
|
||||
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
|
||||
|
||||
repaired_cst = parse_module_to_cst(repaired_code)
|
||||
original_cst = parse_module_to_cst(data.test_source)
|
||||
|
||||
# Extract repaired function nodes by name
|
||||
repaired_functions: dict[str, Any] = {}
|
||||
for node in repaired_cst.body:
|
||||
if hasattr(node, "body") and hasattr(node.body, "__iter__"):
|
||||
for child in node.body:
|
||||
if hasattr(child, "name") and hasattr(child.name, "value"):
|
||||
repaired_functions[child.name.value] = child
|
||||
if hasattr(node, "name") and hasattr(node.name, "value"):
|
||||
repaired_functions[node.name.value] = node
|
||||
|
||||
# Replace only flagged functions in the original
|
||||
merged_cst = original_cst
|
||||
for func_to_repair in data.functions_to_repair:
|
||||
if func_to_repair.function_name in repaired_functions:
|
||||
merged_cst = find_and_replace_function(
|
||||
merged_cst, func_to_repair.function_name, repaired_functions[func_to_repair.function_name]
|
||||
)
|
||||
|
||||
# Run postprocessing and instrumentation on the merged result
|
||||
testgen_data = TestGenSchema(
|
||||
source_code_being_tested=data.function_source_code,
|
||||
function_to_optimize=data.function_to_optimize,
|
||||
helper_function_names=data.helper_function_names or [],
|
||||
module_path=data.module_path,
|
||||
test_module_path=data.test_module_path,
|
||||
test_framework=data.test_framework,
|
||||
test_timeout=data.test_timeout,
|
||||
trace_id=data.trace_id,
|
||||
python_version=data.python_version,
|
||||
language=data.language,
|
||||
codeflash_version=data.codeflash_version,
|
||||
)
|
||||
|
||||
python_version = validate_request_data(testgen_data)
|
||||
|
||||
processed_cst, display_cst = postprocessing_testgen_pipeline(
|
||||
merged_cst,
|
||||
data.helper_function_names or [],
|
||||
data.function_to_optimize,
|
||||
data.module_path,
|
||||
data.function_source_code,
|
||||
)
|
||||
processed_code = processed_cst.code
|
||||
display_code = display_cst.code
|
||||
|
||||
instrumented_behavior, instrumented_perf = instrument_tests(processed_code, testgen_data, python_version)
|
||||
|
||||
if instrumented_behavior is None or instrumented_perf is None:
|
||||
return 500, TestRepairErrorSchema(error="Failed to instrument repaired tests")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-completed")
|
||||
return 200, TestRepairResponseSchema(
|
||||
generated_tests=display_code,
|
||||
instrumented_behavior_tests=instrumented_behavior,
|
||||
instrumented_perf_tests=instrumented_perf,
|
||||
)
|
||||
|
||||
except HttpError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logging.exception("Error in testgen_repair")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, TestRepairErrorSchema(error="Internal server error")
|
||||
174
django/aiservice/core/shared/testgen_review/review.py
Normal file
174
django/aiservice/core/shared/testgen_review/review.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Test quality review endpoint.
|
||||
|
||||
Reviews AI-generated test sources for unrealistic patterns (cache warm-up,
|
||||
internal state manipulation, identical inputs, etc.) and returns per-function
|
||||
verdicts of "pass" or "repair" with reasons.
|
||||
|
||||
Currently enabled for Python only — other languages return all-pass immediately.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import sentry_sdk
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
from ninja import NinjaAPI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import (
|
||||
FunctionVerdict,
|
||||
TestgenReviewErrorSchema,
|
||||
TestgenReviewResponseSchema,
|
||||
TestgenReviewSchema,
|
||||
TestReview,
|
||||
)
|
||||
|
||||
testgen_review_api = NinjaAPI(urls_namespace="testgen_review")
|
||||
|
||||
_prompts_dir = Path(__file__).parent / "prompts"
|
||||
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
|
||||
|
||||
|
||||
@testgen_review_api.post(
|
||||
"/", response={200: TestgenReviewResponseSchema, 400: TestgenReviewErrorSchema, 500: TestgenReviewErrorSchema}
|
||||
)
|
||||
async def testgen_review(
|
||||
request: AuthenticatedRequest, data: TestgenReviewSchema
|
||||
) -> tuple[int, TestgenReviewResponseSchema | TestgenReviewErrorSchema]:
|
||||
# Only enabled for Python for now
|
||||
if data.language != "python":
|
||||
return 200, TestgenReviewResponseSchema(reviews=[])
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
|
||||
|
||||
try:
|
||||
reviews: list[TestReview] = []
|
||||
for test_entry in data.tests:
|
||||
review = await _review_single_test(
|
||||
test_source=test_entry.test_source,
|
||||
test_index=test_entry.test_index,
|
||||
failed_test_functions=test_entry.failed_test_functions,
|
||||
failure_messages=test_entry.failure_messages,
|
||||
function_source_code=data.function_source_code,
|
||||
function_name=data.function_name,
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
call_sequence=data.call_sequence,
|
||||
coverage_summary=data.coverage_summary,
|
||||
)
|
||||
reviews.append(review)
|
||||
|
||||
await asyncio.to_thread(
|
||||
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
|
||||
)
|
||||
return 200, TestgenReviewResponseSchema(reviews=reviews)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Error in testgen_review")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, TestgenReviewErrorSchema(error="Internal server error")
|
||||
|
||||
|
||||
async def _review_single_test(
|
||||
test_source: str,
|
||||
test_index: int,
|
||||
failed_test_functions: list[str],
|
||||
failure_messages: dict[str, str],
|
||||
function_source_code: str,
|
||||
function_name: str,
|
||||
trace_id: str,
|
||||
user_id: str,
|
||||
call_sequence: int | None,
|
||||
coverage_summary: str = "",
|
||||
) -> TestReview:
|
||||
# Functions that failed behavioral tests are definitively bad — include error details in reason
|
||||
failed_verdicts = [
|
||||
FunctionVerdict(
|
||||
function_name=fn,
|
||||
verdict="repair",
|
||||
reason=failure_messages.get(fn, "Failed behavioral test against original code"),
|
||||
)
|
||||
for fn in failed_test_functions
|
||||
]
|
||||
|
||||
# Build prompt for AI review of passing functions — include error context so the AI can see patterns
|
||||
failed_note = ""
|
||||
if failed_test_functions:
|
||||
failed_lines = []
|
||||
for fn in failed_test_functions:
|
||||
error_msg = failure_messages.get(fn, "")
|
||||
if error_msg:
|
||||
failed_lines.append(f"- `{fn}`: {error_msg}")
|
||||
else:
|
||||
failed_lines.append(f"- `{fn}`")
|
||||
failed_note = (
|
||||
"Note: The following functions failed behavioral tests and are already flagged for repair. "
|
||||
"Do NOT include them in your review — only review the remaining functions.\n" + "\n".join(failed_lines)
|
||||
)
|
||||
|
||||
system_prompt = _jinja_env.get_template("testgen_review_system_prompt.md.j2").render()
|
||||
user_prompt = _jinja_env.get_template("testgen_review_user_prompt.md.j2").render(
|
||||
function_name=function_name,
|
||||
function_source_code=function_source_code,
|
||||
test_source=test_source,
|
||||
failed_note=failed_note,
|
||||
coverage_summary=coverage_summary,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
||||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
|
||||
obs_context: dict[str, Any] = {}
|
||||
if call_sequence is not None:
|
||||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_review",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
|
||||
logging.debug(f"testgen_review LLM cost: {cost}")
|
||||
|
||||
ai_verdicts = _parse_review_response(response.content.strip())
|
||||
return TestReview(test_index=test_index, functions=failed_verdicts + ai_verdicts)
|
||||
|
||||
|
||||
def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
|
||||
if result := extract_code_block_with_context(review_text, language="json"):
|
||||
_before, json_content, _after = result
|
||||
try:
|
||||
data = json.loads(json_content.strip())
|
||||
# Support both formats: {"functions": [...]} or [...]
|
||||
fn_list = data if isinstance(data, list) else data.get("functions", [])
|
||||
return [
|
||||
FunctionVerdict(
|
||||
function_name=fn.get("function_name", ""),
|
||||
verdict=fn.get("verdict", "pass"),
|
||||
reason=fn.get("reason", ""),
|
||||
)
|
||||
for fn in fn_list
|
||||
if fn.get("verdict") == "repair"
|
||||
]
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
logging.warning("testgen_review: failed to parse AI response JSON")
|
||||
return []
|
||||
|
|
@ -43,7 +43,7 @@ def test_list_comparison():
|
|||
)
|
||||
source_code = "def comparator(orig, new):\n return orig == new"
|
||||
|
||||
result = postprocessing_testgen_pipeline(
|
||||
result, _ = postprocessing_testgen_pipeline(
|
||||
module=cst.parse_module(generated_code),
|
||||
helper_function_names=[],
|
||||
function_to_optimize=function_to_optimize,
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def extract_input_variables(nodes):
|
|||
ending_line=None,
|
||||
)
|
||||
module_path = "test_validate_pipeline"
|
||||
result = postprocessing_testgen_pipeline(
|
||||
result, _ = postprocessing_testgen_pipeline(
|
||||
module, ["function_to_remove"], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
|
|
@ -262,7 +262,7 @@ def test_document_to_element_list_single_element():
|
|||
|
||||
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
|
||||
source_code_being_tested = group_code(source_code_blocks)
|
||||
processed_module = postprocessing_testgen_pipeline(
|
||||
processed_module, _ = postprocessing_testgen_pipeline(
|
||||
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
|
|
@ -570,7 +570,7 @@ def test_list_items_are_inferred():
|
|||
|
||||
# Step 2: postprocessing_testgen_pipeline (includes add_missing_imports and replace_definition_with_import)
|
||||
source_code_being_tested = group_code(source_code_blocks)
|
||||
processed_module = postprocessing_testgen_pipeline(
|
||||
processed_module, _ = postprocessing_testgen_pipeline(
|
||||
parse_module_to_cst(validated_code), [], function_to_optimize, module_path, source_code_being_tested
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue