codeflash-internal/django/aiservice/code_repair/code_repair.py
Kevin Turcios 273edff3ab unify
2025-12-22 23:51:05 -05:00

209 lines
8.1 KiB
Python

from __future__ import annotations
import logging
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import CODE_REPAIR_MODEL, calculate_llm_cost, call_llm
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from optimizer.models import OptimizedCandidateSource
from .code_repair_context import ( # noqa: TC001 (don't move CodeRepairRequestSchema to type-checking because it's the schema definition)
CodeRepairContext,
CodeRepairContextData,
CodeRepairRequestSchema,
)
if TYPE_CHECKING:
from django.http import HttpRequest as Request
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
from aiservice.llm import LLM
code_repair_api = NinjaAPI(urls_namespace="code_repair")
# Get the directory of the current file
current_dir = Path(__file__).parent
SYSTEM_PROMPT = (current_dir / "CODE_REPAIR_SYSTEM_PROMPT.md").read_text()
USER_PROMPT = (current_dir / "CODE_REPAIR_USER_PROMPT.md").read_text()
async def code_repair( # noqa: D417
user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM = CODE_REPAIR_MODEL
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
"""Repair the given candidate to match the behaviour of the original code.
Parameters
----------
:param user_id:
:param optimization_id
:param optimize_model: LLM for getting the code_repairs
:param ctx: the repair context, has the data property which includes
- original code
- optimized code
- behaviour test diffs
Returns
-------
CodeRepairIntermediateResponseItemschema or CodeRepairErrorResponseSchema
"""
system_prompt = ctx.get_system_prompt()
user_prompt = ctx.get_user_prompt()
new_op_id = str(uuid.uuid4())
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
try:
output = await call_llm(model_name=optimize_model.name, model_type=optimize_model.model_type, messages=messages)
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
except Exception as e:
logging.exception("Claude Code Generation error in code_repair")
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
return CodeRepairErrorResponseSchema(error=str(e))
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
if output.usage is not None:
ph(
user_id,
"code_repair-usage",
properties={
"model": optimize_model.name,
"usage": {"input_tokens": output.usage.input_tokens, "output_tokens": output.usage.output_tokens},
},
)
# Regex doesn't work yet in extracting everything else other than the search replace block
explanation = output.content
repaired_optimization = ""
try:
diff_patches = ctx.extract_diff_patches_from_llm_res(output.content)
repaired_optimization = ctx.apply_patches_to_optimized_code(diff_patches)
except (ValueError, ValidationError) as exc:
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.data.modified_source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
repaired_optimization = ""
if not ctx.is_valid(repaired_optimization):
repaired_optimization = ""
return CodeRepairIntermediateResponseItemschema(
optimization_id=new_op_id,
parent_id=optimization_id,
source_code=repaired_optimization,
llm_cost=llm_cost,
explanation=explanation,
)
class CodeRepairErrorResponseSchema(Schema):
error: str
class CodeRepairIntermediateResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code
optimization_id: str
parent_id: str
source_code: str
llm_cost: float
explanation: str
class CodeRepairResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code
optimization_id: str
parent_id: str
source_code: str
explanation: str
@code_repair_api.post(
"/",
response={
200: CodeRepairResponseItemschema,
400: CodeRepairErrorResponseSchema,
500: CodeRepairErrorResponseSchema,
},
)
async def repair(
request: Request, data: CodeRepairRequestSchema
) -> tuple[int, CodeRepairResponseItemschema | CodeRepairErrorResponseSchema]:
ph(request.user, "aiservice-code_repair-called")
ctx_data = CodeRepairContextData(
original_source_code=data.original_source_code,
modified_source_code=data.modified_source_code,
test_diffs=data.test_diffs,
)
ctx = CodeRepairContext(ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT)
trace_id = data.trace_id
if not validate_trace_id(trace_id):
return 400, CodeRepairErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
code_repair_data = await code_repair(user_id=request.user, optimization_id=data.optimization_id, ctx=ctx)
total_llm_cost = 0.0
if isinstance(code_repair_data, CodeRepairErrorResponseSchema):
return 500, code_repair_data
total_llm_cost += code_repair_data.llm_cost
try:
ctx.validate_python_module()
except cst.ParserSyntaxError as e:
# log exception with sentry
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {e}")
return 500, CodeRepairErrorResponseSchema(error=str(e))
except (ValueError, ValidationError) as exc:
# Another one bites the Pydantic validation dust
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code_repair_data.source_code}")
debug_log_sensitive_data(f"Traceback: {exc}")
return 500, CodeRepairErrorResponseSchema(error=str(exc))
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
optimizations_origin={
code_repair_data.optimization_id: {
"source": OptimizedCandidateSource.REPAIR,
"parent": code_repair_data.parent_id,
}
},
)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return 200, CodeRepairResponseItemschema(
source_code=code_repair_data.source_code,
optimization_id=code_repair_data.optimization_id,
parent_id=code_repair_data.parent_id,
explanation=code_repair_data.explanation,
)