feat: improve testgen review & repair quality (#2473)

## Summary

- Pass coverage details (unexecuted lines, threshold) to review and
repair prompts so the LLM can identify low-coverage tests
- Accept previous repair errors in the repair endpoint and include them
in the prompt for retry cycles
- Parallelize per-test review LLM calls with `asyncio.TaskGroup`
- Conditionally include codeflash env var context
(`CODEFLASH_TRACER_DISABLE`, etc.) in repair prompts when the function
under test references them

## Test plan

- [x] Tested locally with codeflash CLI against `Tracer.__enter__` —
review, repair, and retry cycles all work
- [x] Coverage details and previous errors appear correctly in prompts
- [x] Review parallelization reduces latency from sequential ~60s per
test to concurrent
This commit is contained in:
Kevin Turcios 2026-03-06 10:23:55 +00:00 committed by GitHub
parent 14c0b3acca
commit 434fb7df77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 128 additions and 17 deletions

View file

@ -14,6 +14,23 @@ class TestSourceWithFailures(Schema):
failure_messages: dict[str, str] = {} # function_name → error message from test run
class CoverageFunctionDetail(Schema):
name: str
coverage: float
executed_lines: list[int] = []
unexecuted_lines: list[int] = []
executed_branches: list[list[int]] = []
unexecuted_branches: list[list[int]] = []
class CoverageDetails(Schema):
coverage_percentage: float
threshold_percentage: float
function_start_line: int = 1
main_function: CoverageFunctionDetail
dependent_function: CoverageFunctionDetail | None = None
class TestgenReviewSchema(Schema):
tests: list[TestSourceWithFailures]
function_source_code: str
@ -23,6 +40,7 @@ class TestgenReviewSchema(Schema):
codeflash_version: str | None = None
call_sequence: int | None = None
coverage_summary: str = ""
coverage_details: CoverageDetails | None = None
class FunctionVerdict(Schema):
@ -67,6 +85,8 @@ class TestRepairSchema(Schema):
language: str = "python"
codeflash_version: str | None = None
call_sequence: int | None = None
coverage_details: CoverageDetails | None = None
previous_repair_errors: dict[str, str] = {}
class TestRepairResponseSchema(Schema):

View file

@ -2,6 +2,14 @@ You are a test repair assistant for an AI code optimizer.
You will receive a test file with specific test functions flagged as problematic, along with the reason each was flagged. Rewrite ONLY the flagged functions. All other code in the file must remain exactly as-is — do not modify imports, class definitions, helper functions, or any unflagged test functions.
{% if has_codeflash_env_vars %}
Tests run in a controlled environment with these environment variables set:
- `CODEFLASH_TRACER_DISABLE=1` — disables the Codeflash profiler during test execution
- `CODEFLASH_LOOP_INDEX` — benchmarking loop iteration index
- `CODEFLASH_TEST_ITERATION` — test iteration counter
The function under test reads some of these variables. Repaired tests must account for their presence — use `monkeypatch.delenv()` or `monkeypatch.setenv()` to control them when testing code paths that depend on these variables.
{% endif %}
Constraints for repaired functions:
- Use diverse, realistic inputs that reflect production usage patterns
- Only use the function's public API — do not access internal implementation details

View file

@ -15,5 +15,22 @@
- `{{ fn.function_name }}`: {{ fn.reason }}
{% endfor %}
</functions_to_repair>
{% if previous_repair_errors %}
<previous_repair_errors>
A previous repair attempt for these functions was tried but the rewritten tests still failed at runtime. Use these error messages to understand what went wrong and avoid the same mistakes:
{% for fn_name, error_msg in previous_repair_errors.items() -%}
- `{{ fn_name }}`: {{ error_msg }}
{% endfor %}
</previous_repair_errors>
{% endif %}
{% if coverage_context %}
<coverage_details>
{{ coverage_context }}
When rewriting flagged functions, ensure they exercise the unexecuted lines listed above. Tests that only trigger early-return or disabled code paths do not provide useful coverage.
</coverage_details>
{% endif %}
Rewrite only the functions listed above. Return the complete file.

View file

@ -9,7 +9,12 @@
{{ test_source }}
```
</generated_test>
{% if coverage_summary %}
{% if coverage_context %}
<coverage_details>
{{ coverage_context }}
</coverage_details>
{% elif coverage_summary %}
<coverage>
Test coverage of {{ function_name }}: {{ coverage_summary }}

View file

@ -47,12 +47,24 @@ async def testgen_repair(
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
try:
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render()
from core.shared.testgen_review.review import _build_coverage_context # noqa: PLC0415
coverage_context = (
_build_coverage_context(data.coverage_details, data.function_source_code) if data.coverage_details else ""
)
_codeflash_env_vars = ("CODEFLASH_TRACER_DISABLE", "CODEFLASH_LOOP_INDEX", "CODEFLASH_TEST_ITERATION")
has_codeflash_env_vars = any(var in data.function_source_code for var in _codeflash_env_vars)
system_prompt = _jinja_env.get_template("testgen_repair_system_prompt.md.j2").render(
has_codeflash_env_vars=has_codeflash_env_vars
)
user_prompt = _jinja_env.get_template("testgen_repair_user_prompt.md.j2").render(
function_name=data.function_to_optimize.function_name,
function_source_code=data.function_source_code,
test_source=data.test_source,
functions_to_repair=data.functions_to_repair,
coverage_context=coverage_context,
previous_repair_errors=data.previous_repair_errors,
)
messages: list[ChatCompletionMessageParam] = [

View file

@ -29,6 +29,7 @@ from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import (
CoverageDetails,
FunctionVerdict,
TestgenReviewErrorSchema,
TestgenReviewResponseSchema,
@ -55,21 +56,30 @@ async def testgen_review(
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
try:
reviews: list[TestReview] = []
for test_entry in data.tests:
review = await _review_single_test(
test_source=test_entry.test_source,
test_index=test_entry.test_index,
failed_test_functions=test_entry.failed_test_functions,
failure_messages=test_entry.failure_messages,
function_source_code=data.function_source_code,
function_name=data.function_name,
trace_id=data.trace_id,
user_id=request.user,
call_sequence=data.call_sequence,
coverage_summary=data.coverage_summary,
)
reviews.append(review)
coverage_context = (
_build_coverage_context(data.coverage_details, data.function_source_code) if data.coverage_details else ""
)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(
_review_single_test(
test_source=test_entry.test_source,
test_index=test_entry.test_index,
failed_test_functions=test_entry.failed_test_functions,
failure_messages=test_entry.failure_messages,
function_source_code=data.function_source_code,
function_name=data.function_name,
trace_id=data.trace_id,
user_id=request.user,
call_sequence=data.call_sequence,
coverage_summary=data.coverage_summary,
coverage_context=coverage_context,
)
)
for test_entry in data.tests
]
reviews = [task.result() for task in tasks]
await asyncio.to_thread(
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
@ -93,6 +103,7 @@ async def _review_single_test(
user_id: str,
call_sequence: int | None,
coverage_summary: str = "",
coverage_context: str = "",
) -> TestReview:
# Functions that failed behavioral tests are definitively bad — include error details in reason
failed_verdicts = [
@ -126,6 +137,7 @@ async def _review_single_test(
test_source=test_source,
failed_note=failed_note,
coverage_summary=coverage_summary,
coverage_context=coverage_context,
)
messages: list[ChatCompletionMessageParam] = [
@ -172,3 +184,40 @@ def _parse_review_response(review_text: str) -> list[FunctionVerdict]:
except (json.JSONDecodeError, AttributeError):
logging.warning("testgen_review: failed to parse AI response JSON")
return []
def _build_coverage_context(details: CoverageDetails, function_source_code: str = "") -> str:
source_lines: dict[int, str] = {}
if function_source_code:
for i, line in enumerate(function_source_code.splitlines(), start=1):
source_lines[i] = line
lines = [f"Coverage: {details.coverage_percentage:.1f}% (threshold: {details.threshold_percentage:.0f}%)"]
below_threshold = details.coverage_percentage < details.threshold_percentage
if below_threshold:
lines.append(
"⚠ Coverage is BELOW the required threshold — tests that only hit trivial/early-return paths should be flagged for repair."
)
mc = details.main_function
lines.append(f"\n{mc.name}: {mc.coverage:.1f}% coverage")
if mc.unexecuted_lines and source_lines:
lines.append(" Unexecuted lines:")
for ln in mc.unexecuted_lines:
code = source_lines.get(ln, "")
lines.append(f" L{ln}: {code}")
elif mc.unexecuted_lines:
lines.append(f" Unexecuted lines: {mc.unexecuted_lines}")
if details.dependent_function:
dc = details.dependent_function
lines.append(f"\n{dc.name}: {dc.coverage:.1f}% coverage")
if dc.unexecuted_lines and source_lines:
lines.append(" Unexecuted lines:")
for ln in dc.unexecuted_lines:
code = source_lines.get(ln, "")
lines.append(f" L{ln}: {code}")
elif dc.unexecuted_lines:
lines.append(f" Unexecuted lines: {dc.unexecuted_lines}")
return "\n".join(lines)