fix: use greedy code extraction and retry on syntax errors in repair (#2475)

## Summary
- Switch `extract_code_block_with_context` (non-greedy `.*?`) →
`extract_code_block` (greedy `.*`) for repair code extraction — the
non-greedy regex matched the first closing fence, truncating code when
the LLM included explanatory snippets before the full file (root cause
of 82% of repair failures)
- Add `ast.parse` validation before CST parsing for fast syntax checking
- Retry the LLM once with the specific syntax error appended to the
conversation when validation fails

## Test plan
- [x] Existing tests pass
- [ ] Run end-to-end optimization to verify repairs succeed
This commit is contained in:
Kevin Turcios 2026-03-06 11:24:31 +00:00 committed by GitHub
parent 07edfaa0bd
commit 0ca3a2ab07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -8,6 +8,7 @@ Currently enabled for Python only.
from __future__ import annotations
import ast
import asyncio
import logging
from pathlib import Path
@ -18,13 +19,14 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
ChatCompletionUserMessageParam,
)
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block_with_context
from aiservice.common.markdown_utils import extract_code_block
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
@ -35,6 +37,36 @@ _prompts_dir = Path(__file__).parent / "prompts"
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
def _extract_and_validate(llm_output: str) -> tuple[str | None, Any]:
"""Extract a Python code block from LLM output and validate its syntax.
Returns (code_string, parsed_cst) on success, or (code_string, None) if
extraction succeeded but syntax is invalid, or (None, None) if no code
block was found.
"""
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
code = extract_code_block(llm_output)
if code is None:
return None, None
try:
ast.parse(code)
except SyntaxError:
return code, None
try:
return code, parse_module_to_cst(code)
except Exception: # noqa: BLE001
return code, None
def _get_syntax_error(code: str) -> str:
try:
ast.parse(code)
except SyntaxError as e:
return f"{e.msg} (line {e.lineno})"
return "Unknown syntax error"
@testgen_repair_api.post(
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
)
@ -86,27 +118,46 @@ async def testgen_repair(
)
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
logging.debug(f"testgen_repair LLM cost: {cost}")
logging.debug("testgen_repair LLM cost: %s", cost)
repair_text = response.content.strip()
if result := extract_code_block_with_context(repair_text, language="python"):
_before, repaired_code, _after = result
else:
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
# Splice only the flagged functions from the LLM output into the original test source,
# keeping all unflagged functions untouched
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
from core.languages.python.testgen.instrumentation.edit_generated_test import find_and_replace_function
from core.languages.python.testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from core.languages.python.testgen.instrumentation.edit_generated_test import ( # noqa: PLC0415
find_and_replace_function,
)
from core.languages.python.testgen.postprocessing.postprocess_pipeline import ( # noqa: PLC0415
postprocessing_testgen_pipeline,
)
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
try:
repaired_cst = parse_module_to_cst(repaired_code)
except Exception:
logging.warning("LLM returned syntactically invalid repaired code, falling back to original")
repaired_code, repaired_cst = _extract_and_validate(response.content.strip())
if repaired_cst is None:
syntax_error = _get_syntax_error(repaired_code) if repaired_code else "No code block found"
logging.info("Repair attempt produced invalid syntax, retrying: %s", syntax_error)
messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=response.content))
messages.append(
ChatCompletionUserMessageParam(
role="user",
content=f"The code you returned has a syntax error:\n```\n{syntax_error}\n```\n"
"Please return the complete corrected file in a single ```python code block.",
)
)
retry_response = await call_llm(
llm=HAIKU_MODEL,
messages=messages,
call_type="testgen_repair_retry",
trace_id=data.trace_id,
user_id=request.user,
context=obs_context,
)
cost += calculate_llm_cost(retry_response.raw_response, HAIKU_MODEL)
repaired_code, repaired_cst = _extract_and_validate(retry_response.content.strip())
if repaired_cst is None:
logging.warning("LLM returned syntactically invalid code after retry")
return 500, TestRepairErrorSchema(error="LLM returned syntactically invalid code")
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
original_cst = parse_module_to_cst(data.test_source)
# Extract repaired function nodes by name