feat: improve testgen review & repair quality (#2473)
## Summary - Pass coverage details (unexecuted lines, threshold) to review and repair prompts so the LLM can identify low-coverage tests - Accept previous repair errors in the repair endpoint and include them in the prompt for retry cycles - Parallelize per-test review LLM calls with `asyncio.TaskGroup` - Conditionally include codeflash env var context (`CODEFLASH_TRACER_DISABLE`, etc.) in repair prompts when the function under test references them ## Test plan - [x] Tested locally with codeflash CLI against `Tracer.__enter__` — review, repair, and retry cycles all work - [x] Coverage details and previous errors appear correctly in prompts - [x] Review parallelization reduces latency from sequential ~60s per test to concurrent
This commit is contained in:
parent
14c0b3acca
commit
434fb7df77
6 changed files with 128 additions and 17 deletions
|
|
@ -14,6 +14,23 @@ class TestSourceWithFailures(Schema):
|
|||
failure_messages: dict[str, str] = {} # function_name → error message from test run
|
||||
|
||||
|
||||
class CoverageFunctionDetail(Schema):
|
||||
name: str
|
||||
coverage: float
|
||||
executed_lines: list[int] = []
|
||||
unexecuted_lines: list[int] = []
|
||||
executed_branches: list[list[int]] = []
|
||||
unexecuted_branches: list[list[int]] = []
|
||||
|
||||
|
||||
class CoverageDetails(Schema):
|
||||
coverage_percentage: float
|
||||
threshold_percentage: float
|
||||
function_start_line: int = 1
|
||||
main_function: CoverageFunctionDetail
|
||||
dependent_function: CoverageFunctionDetail | None = None
|
||||
|
||||
|
||||
class TestgenReviewSchema(Schema):
|
||||
tests: list[TestSourceWithFailures]
|
||||
function_source_code: str
|
||||
|
|
@ -23,6 +40,7 @@ class TestgenReviewSchema(Schema):
|
|||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
coverage_summary: str = ""
|
||||
coverage_details: CoverageDetails | None = None
|
||||
|
||||
|
||||
class FunctionVerdict(Schema):
|
||||
|
|
@ -67,6 +85,8 @@ class TestRepairSchema(Schema):
|
|||
language: str = "python"
|
||||
codeflash_version: str | None = None
|
||||
call_sequence: int | None = None
|
||||
coverage_details: CoverageDetails | None = None
|
||||
previous_repair_errors: dict[str, str] = {}
|
||||
|
||||
|
||||
class TestRepairResponseSchema(Schema):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,14 @@ You are a test repair assistant for an AI code optimizer.
|
|||
|
||||
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
|
||||
|
||||
{% if has_codeflash_env_vars %}
|
||||
Tests run in a controlled environment with these environment variables set:
|
||||
- `CODEFLASH_TRACER_DISABLE=1` — disables the Codeflash profiler during test execution
|
||||
- `CODEFLASH_LOOP_INDEX` — benchmarking loop iteration index
|
||||
- `CODEFLASH_TEST_ITERATION` — test iteration counter
|
||||
|
||||
The function under test reads some of these variables. Repaired tests must account for their presence — use `monkeypatch.delenv()` or `monkeypatch.setenv()` to control them when testing code paths that depend on these variables.
|
||||
{% endif %}
|
||||
Constraints for repaired functions:
|
||||
- Use diverse, realistic inputs that reflect production usage patterns
|
||||
- Only use the function's public API — do not access internal implementation details
|
||||
|
|
|
|||
|
|
@ -15,5 +15,22 @@
|
|||
- `{{ fn.function_name }}`: {{ fn.reason }}
|
||||
{% endfor %}
|
||||
</functions_to_repair>
|
||||
{% if previous_repair_errors %}
|
||||
|
||||
<previous_repair_errors>
|
||||
A previous repair attempt for these functions was tried but the rewritten tests still failed at runtime. Use these error messages to understand what went wrong and avoid the same mistakes:
|
||||
{% for fn_name, error_msg in previous_repair_errors.items() -%}
|
||||
- `{{ fn_name }}`: {{ error_msg }}
|
||||
{% endfor %}
|
||||
</previous_repair_errors>
|
||||
{% endif %}
|
||||
{% if coverage_context %}
|
||||
|
||||
<coverage_details>
|
||||
{{ coverage_context }}
|
||||
|
||||
When rewriting flagged functions, ensure they exercise the unexecuted lines listed above. Tests that only trigger early-return or disabled code paths do not provide useful coverage.
|
||||
</coverage_details>
|
||||
{% endif %}
|
||||
|
||||
Rewrite only the functions listed above. Return the complete file.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,12 @@
|
|||
{{ test_source }}
|
||||
```
|
||||
</generated_test>
|
||||
{% if coverage_summary %}
|
||||
{% if coverage_context %}
|
||||
|
||||
<coverage_details>
|
||||
{{ coverage_context }}
|
||||
</coverage_details>
|
||||
{% elif coverage_summary %}
|
||||
|
||||
<coverage>
|
||||
Test coverage of {{ function_name }}: {{ coverage_summary }}
|
||||
|
|
|
|||
|
|
@ -47,12 +47,24 @@ async def testgen_repair(
|
|||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
|
||||
|
||||
try:
|
||||
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
|
||||
from core.shared.testgen_review.review import _build_coverage_context # noqa: PLC0415
|
||||
|
||||
coverage_context = (
|
||||
_build_coverage_context(data.coverage_details, data.function_source_code) if data.coverage_details else ""
|
||||
)
|
||||
|
||||
_codeflash_env_vars = ("CODEFLASH_TRACER_DISABLE", "CODEFLASH_LOOP_INDEX", "CODEFLASH_TEST_ITERATION")
|
||||
has_codeflash_env_vars = any(var in data.function_source_code for var in _codeflash_env_vars)
|
||||
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render(
|
||||
has_codeflash_env_vars=has_codeflash_env_vars
|
||||
)
|
||||
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
|
||||
function_name=data.function_to_optimize.function_name,
|
||||
function_source_code=data.function_source_code,
|
||||
test_source=data.test_source,
|
||||
functions_to_repair=data.functions_to_repair,
|
||||
coverage_context=coverage_context,
|
||||
previous_repair_errors=data.previous_repair_errors,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from aiservice.common.markdown_utils import extract_code_block_with_context
|
|||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import (
|
||||
CoverageDetails,
|
||||
FunctionVerdict,
|
||||
TestgenReviewErrorSchema,
|
||||
TestgenReviewResponseSchema,
|
||||
|
|
@ -55,21 +56,30 @@ async def testgen_review(
|
|||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
|
||||
|
||||
try:
|
||||
reviews: list[TestReview] = []
|
||||
for test_entry in data.tests:
|
||||
review = await _review_single_test(
|
||||
test_source=test_entry.test_source,
|
||||
test_index=test_entry.test_index,
|
||||
failed_test_functions=test_entry.failed_test_functions,
|
||||
failure_messages=test_entry.failure_messages,
|
||||
function_source_code=data.function_source_code,
|
||||
function_name=data.function_name,
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
call_sequence=data.call_sequence,
|
||||
coverage_summary=data.coverage_summary,
|
||||
)
|
||||
reviews.append(review)
|
||||
coverage_context = (
|
||||
_build_coverage_context(data.coverage_details, data.function_source_code) if data.coverage_details else ""
|
||||
)
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tasks = [
|
||||
tg.create_task(
|
||||
_review_single_test(
|
||||
test_source=test_entry.test_source,
|
||||
test_index=test_entry.test_index,
|
||||
failed_test_functions=test_entry.failed_test_functions,
|
||||
failure_messages=test_entry.failure_messages,
|
||||
function_source_code=data.function_source_code,
|
||||
function_name=data.function_name,
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
call_sequence=data.call_sequence,
|
||||
coverage_summary=data.coverage_summary,
|
||||
coverage_context=coverage_context,
|
||||
)
|
||||
)
|
||||
for test_entry in data.tests
|
||||
]
|
||||
reviews = [task.result() for task in tasks]
|
||||
|
||||
await asyncio.to_thread(
|
||||
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
|
||||
|
|
@ -93,6 +103,7 @@ async def _review_single_test(
|
|||
user_id: str,
|
||||
call_sequence: int | None,
|
||||
coverage_summary: str = "",
|
||||
coverage_context: str = "",
|
||||
) -> TestReview:
|
||||
# Functions that failed behavioral tests are definitively bad — include error details in reason
|
||||
failed_verdicts = [
|
||||
|
|
@ -126,6 +137,7 @@ async def _review_single_test(
|
|||
test_source=test_source,
|
||||
failed_note=failed_note,
|
||||
coverage_summary=coverage_summary,
|
||||
coverage_context=coverage_context,
|
||||
)
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [
|
||||
|
|
@ -172,3 +184,40 @@ def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
|
|||
except (json.JSONDecodeError, AttributeError):
|
||||
logging.warning("testgen_review: failed to parse AI response JSON")
|
||||
return []
|
||||
|
||||
|
||||
def _build_coverage_context(details: CoverageDetails, function_source_code: str = "") -> str:
|
||||
source_lines: dict[int, str] = {}
|
||||
if function_source_code:
|
||||
for i, line in enumerate(function_source_code.splitlines(), start=1):
|
||||
source_lines[i] = line
|
||||
|
||||
lines = [f"Coverage: {details.coverage_percentage:.1f}% (threshold: {details.threshold_percentage:.0f}%)"]
|
||||
below_threshold = details.coverage_percentage < details.threshold_percentage
|
||||
if below_threshold:
|
||||
lines.append(
|
||||
"⚠ Coverage is BELOW the required threshold — tests that only hit trivial/early-return paths should be flagged for repair."
|
||||
)
|
||||
|
||||
mc = details.main_function
|
||||
lines.append(f"\n{mc.name}: {mc.coverage:.1f}% coverage")
|
||||
if mc.unexecuted_lines and source_lines:
|
||||
lines.append(" Unexecuted lines:")
|
||||
for ln in mc.unexecuted_lines:
|
||||
code = source_lines.get(ln, "")
|
||||
lines.append(f" L{ln}: {code}")
|
||||
elif mc.unexecuted_lines:
|
||||
lines.append(f" Unexecuted lines: {mc.unexecuted_lines}")
|
||||
|
||||
if details.dependent_function:
|
||||
dc = details.dependent_function
|
||||
lines.append(f"\n{dc.name}: {dc.coverage:.1f}% coverage")
|
||||
if dc.unexecuted_lines and source_lines:
|
||||
lines.append(" Unexecuted lines:")
|
||||
for ln in dc.unexecuted_lines:
|
||||
code = source_lines.get(ln, "")
|
||||
lines.append(f" L{ln}: {code}")
|
||||
elif dc.unexecuted_lines:
|
||||
lines.append(f" Unexecuted lines: {dc.unexecuted_lines}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
|
|
|||
Loading…
Reference in a new issue