209 lines
8.1 KiB
Python
209 lines
8.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import libcst as cst
|
|
import sentry_sdk
|
|
from ninja import NinjaAPI, Schema
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
from pydantic import ValidationError
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data
|
|
from aiservice.llm import CODE_REPAIR_MODEL, calculate_llm_cost, call_llm
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from optimizer.models import OptimizedCandidateSource
|
|
|
|
from .code_repair_context import ( # noqa: TC001 (don't move CodeRepairRequestSchema to type-checking because it's the schema definition)
|
|
CodeRepairContext,
|
|
CodeRepairContextData,
|
|
CodeRepairRequestSchema,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from django.http import HttpRequest as Request
|
|
from openai.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
)
|
|
|
|
from aiservice.llm import LLM
|
|
|
|
code_repair_api = NinjaAPI(urls_namespace="code_repair")
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
|
|
|
|
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
|
|
|
|
|
|
async def code_repair( # noqa: D417
|
|
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
|
|
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
|
|
"""Repair the given candidate to match the behaviour of the original code.
|
|
|
|
Parameters
|
|
----------
|
|
:param user_id:
|
|
:param optimization_id
|
|
:param optimize_model: LLM for getting the code_repairs
|
|
:param ctx: the repair context, has the data property which includes
|
|
- original code
|
|
- optimized code
|
|
- behaviour test diffs
|
|
|
|
Returns
|
|
-------
|
|
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
|
|
|
|
"""
|
|
system_prompt = ctx.get_system_prompt()
|
|
user_prompt = ctx.get_user_prompt()
|
|
|
|
new_op_id = str(uuid.uuid4())
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
messages: list[
|
|
ChatCompletionSystemMessageParam
|
|
| ChatCompletionUserMessageParam
|
|
| ChatCompletionAssistantMessageParam
|
|
| ChatCompletionToolMessageParam
|
|
| ChatCompletionFunctionMessageParam
|
|
] = [system_message, user_message]
|
|
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
|
try:
|
|
output = await call_llm(model_name=optimize_model.name, model_type=optimize_model.model_type, messages=messages)
|
|
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
|
except Exception as e:
|
|
logging.exception("Claude Code Generation error in code_repair")
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
|
return CodeRepairErrorResponseSchema(error=str(e))
|
|
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
|
|
if output.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"code_repair-usage",
|
|
properties={
|
|
"model": optimize_model.name,
|
|
"usage": {"input_tokens": output.usage.input_tokens, "output_tokens": output.usage.output_tokens},
|
|
},
|
|
)
|
|
|
|
# Regex doesn't work yet in extracting everything else other than the search replace block
|
|
explanation = output.content
|
|
|
|
repaired_optimization = ""
|
|
try:
|
|
diff_patches = ctx.extract_diff_patches_from_llm_res(output.content)
|
|
repaired_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
|
|
except (ValueError, ValidationError) as exc:
|
|
sentry_sdk.capture_exception(exc)
|
|
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
|
|
debug_log_sensitive_data(f"Traceback: {exc}")
|
|
repaired_optimization = ""
|
|
|
|
if not ctx.is_valid(repaired_optimization):
|
|
repaired_optimization = ""
|
|
|
|
return CodeRepairIntermediateResponseItemschema(
|
|
optimization_id=new_op_id,
|
|
parent_id=optimization_id,
|
|
source_code=repaired_optimization,
|
|
llm_cost=llm_cost,
|
|
explanation=explanation,
|
|
)
|
|
|
|
|
|
class CodeRepairErrorResponseSchema(Schema):
|
|
error: str
|
|
|
|
|
|
class CodeRepairIntermediateResponseItemschema(Schema):
|
|
# the key will be the optimization id and the value will be the actual refined code
|
|
optimization_id: str
|
|
parent_id: str
|
|
source_code: str
|
|
llm_cost: float
|
|
explanation: str
|
|
|
|
|
|
class CodeRepairResponseItemschema(Schema):
|
|
# the key will be the optimization id and the value will be the actual refined code
|
|
optimization_id: str
|
|
parent_id: str
|
|
source_code: str
|
|
explanation: str
|
|
|
|
|
|
@code_repair_api.post(
|
|
"/",
|
|
response={
|
|
200: CodeRepairResponseItemschema,
|
|
400: CodeRepairErrorResponseSchema,
|
|
500: CodeRepairErrorResponseSchema,
|
|
},
|
|
)
|
|
async def repair(
|
|
request: Request, data: CodeRepairRequestSchema
|
|
) -> tuple[int, CodeRepairResponseItemschema | CodeRepairErrorResponseSchema]:
|
|
ph(request.user, "aiservice-code_repair-called")
|
|
ctx_data = CodeRepairContextData(
|
|
original_source_code=data.original_source_code,
|
|
modified_source_code=data.modified_source_code,
|
|
test_diffs=data.test_diffs,
|
|
)
|
|
ctx = CodeRepairContext(ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT)
|
|
trace_id = data.trace_id
|
|
if not validate_trace_id(trace_id):
|
|
return 400, CodeRepairErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
|
|
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
|
|
total_llm_cost = 0.0
|
|
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
|
|
return 500, code_repair_data
|
|
total_llm_cost += code_repair_data.llm_cost
|
|
try:
|
|
ctx.validate_python_module()
|
|
except cst.ParserSyntaxError as e:
|
|
# log exception with sentry
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
|
|
debug_log_sensitive_data(f"Traceback: {e}")
|
|
return 500, CodeRepairErrorResponseSchema(error=str(e))
|
|
except (ValueError, ValidationError) as exc:
|
|
# Another one bites the Pydantic validation dust
|
|
sentry_sdk.capture_exception(exc)
|
|
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
|
|
debug_log_sensitive_data(f"Traceback: {exc}")
|
|
return 500, CodeRepairErrorResponseSchema(error=str(exc))
|
|
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
|
|
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
|
|
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
|
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
|
optimizations_origin={
|
|
code_repair_data.optimization_id: {
|
|
"source": OptimizedCandidateSource.REPAIR,
|
|
"parent": code_repair_data.parent_id,
|
|
}
|
|
},
|
|
)
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
|
return 200, CodeRepairResponseItemschema(
|
|
source_code=code_repair_data.source_code,
|
|
optimization_id=code_repair_data.optimization_id,
|
|
parent_id=code_repair_data.parent_id,
|
|
explanation=code_repair_data.explanation,
|
|
)
|