fix: use greedy code extraction and retry on syntax errors in repair (#2475)
## Summary - Switch `extract_code_block_with_context` (non-greedy `.*?`) → `extract_code_block` (greedy `.*`) for repair code extraction — the non-greedy regex matched the first closing fence, truncating code when the LLM included explanatory snippets before the full file (root cause of 82% of repair failures) - Add `ast.parse` validation before CST parsing for fast syntax checking - Retry the LLM once with the specific syntax error appended to the conversation when validation fails ## Test plan - [x] Existing tests pass - [ ] Run end-to-end optimization to verify repairs succeed
This commit is contained in:
parent
07edfaa0bd
commit
0ca3a2ab07
1 changed files with 68 additions and 17 deletions
|
|
@ -8,6 +8,7 @@ Currently enabled for Python only.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
|
@ -18,13 +19,14 @@ from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
|||
from ninja import NinjaAPI
|
||||
from ninja.errors import HttpError
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
)
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.common.markdown_utils import extract_code_block
|
||||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
|
||||
|
|
@ -35,6 +37,36 @@ _prompts_dir = Path(__file__).parent / "prompts"
|
|||
_jinja_env = Environment(loader=FileSystemLoader(_prompts_dir), keep_trailing_newline=True, undefined=StrictUndefined) # noqa: S701
|
||||
|
||||
|
||||
def _extract_and_validate(llm_output: str) -> tuple[str | None, Any]:
|
||||
"""Extract a Python code block from LLM output and validate its syntax.
|
||||
|
||||
Returns (code_string, parsed_cst) on success, or (code_string, None) if
|
||||
extraction succeeded but syntax is invalid, or (None, None) if no code
|
||||
block was found.
|
||||
"""
|
||||
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
|
||||
|
||||
code = extract_code_block(llm_output)
|
||||
if code is None:
|
||||
return None, None
|
||||
try:
|
||||
ast.parse(code)
|
||||
except SyntaxError:
|
||||
return code, None
|
||||
try:
|
||||
return code, parse_module_to_cst(code)
|
||||
except Exception: # noqa: BLE001
|
||||
return code, None
|
||||
|
||||
|
||||
def _get_syntax_error(code: str) -> str:
|
||||
try:
|
||||
ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
return f"{e.msg} (line {e.lineno})"
|
||||
return "Unknown syntax error"
|
||||
|
||||
|
||||
@testgen_repair_api.post(
|
||||
"/", response={200: TestRepairResponseSchema, 400: TestRepairErrorSchema, 500: TestRepairErrorSchema}
|
||||
)
|
||||
|
|
@ -86,27 +118,46 @@ async def testgen_repair(
|
|||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
|
||||
logging.debug(f"testgen_repair LLM cost: {cost}")
|
||||
logging.debug("testgen_repair LLM cost: %s", cost)
|
||||
|
||||
repair_text = response.content.strip()
|
||||
if result := extract_code_block_with_context(repair_text, language="python"):
|
||||
_before, repaired_code, _after = result
|
||||
else:
|
||||
return 500, TestRepairErrorSchema(error="Could not extract repaired code from LLM response")
|
||||
|
||||
# Splice only the flagged functions from the LLM output into the original test source,
|
||||
# keeping all unflagged functions untouched
|
||||
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
|
||||
from core.languages.python.testgen.instrumentation.edit_generated_test import find_and_replace_function
|
||||
from core.languages.python.testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
||||
from core.languages.python.testgen.instrumentation.edit_generated_test import ( # noqa: PLC0415
|
||||
find_and_replace_function,
|
||||
)
|
||||
from core.languages.python.testgen.postprocessing.postprocess_pipeline import ( # noqa: PLC0415
|
||||
postprocessing_testgen_pipeline,
|
||||
)
|
||||
from core.languages.python.testgen.validate import instrument_tests, validate_request_data # noqa: PLC0415
|
||||
from core.shared.testgen_models import TestGenSchema # noqa: PLC0415
|
||||
|
||||
try:
|
||||
repaired_cst = parse_module_to_cst(repaired_code)
|
||||
except Exception:
|
||||
logging.warning("LLM returned syntactically invalid repaired code, falling back to original")
|
||||
repaired_code, repaired_cst = _extract_and_validate(response.content.strip())
|
||||
|
||||
if repaired_cst is None:
|
||||
syntax_error = _get_syntax_error(repaired_code) if repaired_code else "No code block found"
|
||||
logging.info("Repair attempt produced invalid syntax, retrying: %s", syntax_error)
|
||||
messages.append(ChatCompletionAssistantMessageParam(role="assistant", content=response.content))
|
||||
messages.append(
|
||||
ChatCompletionUserMessageParam(
|
||||
role="user",
|
||||
content=f"The code you returned has a syntax error:\n```\n{syntax_error}\n```\n"
|
||||
"Please return the complete corrected file in a single ```python code block.",
|
||||
)
|
||||
)
|
||||
retry_response = await call_llm(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_repair_retry",
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
context=obs_context,
|
||||
)
|
||||
cost += calculate_llm_cost(retry_response.raw_response, HAIKU_MODEL)
|
||||
repaired_code, repaired_cst = _extract_and_validate(retry_response.content.strip())
|
||||
|
||||
if repaired_cst is None:
|
||||
logging.warning("LLM returned syntactically invalid code after retry")
|
||||
return 500, TestRepairErrorSchema(error="LLM returned syntactically invalid code")
|
||||
from core.languages.python.cst_utils import parse_module_to_cst # noqa: PLC0415
|
||||
|
||||
original_cst = parse_module_to_cst(data.test_source)
|
||||
|
||||
# Extract repaired function nodes by name
|
||||
|
|
|
|||
Loading…
Reference in a new issue