codeflash-internal/django/aiservice/ranker/ranker.py

186 lines
8.3 KiB
Python
Raw Normal View History

2025-09-03 23:26:09 +00:00
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
2025-09-04 23:10:35 +00:00
from aiservice.models.aimodels import RANKING_MODEL, calculate_llm_cost
2025-09-03 23:26:09 +00:00
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai import OpenAIError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
2025-09-03 23:26:09 +00:00
if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
2025-09-03 23:26:09 +00:00
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
2025-09-04 23:10:35 +00:00
# from google import genai
# from pydantic import BaseModel
#
# class RankerResponseSchema(BaseModel):
# ranking: list[int]
# explanation: str
ranker_api = NinjaAPI(urls_namespace="ranker")
2025-09-03 23:26:09 +00:00
rank_regex_pattern = re.compile(r"<rank>(.*)<\/rank>", re.DOTALL | re.IGNORECASE)
explain_regex_pattern = re.compile(r"<explain>(.*)<\/explain>", re.DOTALL | re.IGNORECASE)
SYSTEM_PROMPT = """You are an expert code reviewer who understands why programs run fast.
2025-09-10 00:05:47 +00:00
You are provided with a list of optimization candidates with their code diff with respect to the baseline code and speedup ratio information. Your task is to rank the candidates in decreasing order of their viability as a pull request. Your goal is to improve the probability of acceptance of the optimization by an expert engineer.
2025-09-03 23:26:09 +00:00
2025-09-10 00:05:47 +00:00
Rules to follow while ranking optimization candidates -
- Prefer optimizations with higher speedup ratios. If the higher speedups happen due to strange hacks or micro-optimizations or something an expert won't write then prefer it less.
- Prefer optimizations which contain precise diffs unless the speedup provided is very high. Larger pull requests are typically harder to accept then more precise smaller pull requests.
- The optimization candidate should not impact the code readability unless the speedup provided is very high.
Sometimes, these criteria maybe in conflict with each other. In such cases you have to remember that the goal is acceptance of the pull request, so make a judgement on what optimization candidate would be most likely to be accepted.
Please provide your response in the following format:
2025-09-03 23:26:09 +00:00
<rank>
Comma separated list of candidate indices in decreasing order of their viability as a pull request candidate.
</rank>
<explain>
A brief explanation of why the particular ranking was made.
</explain>
"""
2025-09-10 00:05:47 +00:00
USER_PROMPT = """Here is a numbered list of optimization candidates' and their speedup ratios.
2025-09-03 23:26:09 +00:00
2025-09-05 00:30:48 +00:00
{ranking_context}
2025-09-03 23:26:09 +00:00
"""
async def rank_optimizations(
user_id: str,
2025-09-05 02:04:32 +00:00
speedups: list[float],
2025-09-03 23:26:09 +00:00
diffs: list[str],
2025-09-05 00:30:48 +00:00
optimization_ids: list[str],
python_version: str = "3.12.9",
2025-09-03 23:26:09 +00:00
rank_model: LLM = RANKING_MODEL,
trace_id: str = "",
) -> RankResponseSchema | RankErrorResponseSchema:
"""Optimize the given python code for performance using the Claude 4 model.
Parameters
----------
- speedups list[str]: list of speedups of optimized candidates.
- diffs list[str]: list of diffs of optimized candidates.
- python_version (tuple[int, int, int]): The python version to use. Default is (3,12,9).
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
optimized code and the second is the explanation.
2025-09-10 00:05:47 +00:00
:param optimization_ids:
2025-09-03 23:26:09 +00:00
"""
2025-09-05 00:30:48 +00:00
debug_log_sensitive_data(f"Generating a ranking for {user_id}")
# TODO add logging instead of print(optimization_ids)
SYSTEM_PROMPT.format(python_version=python_version)
ranking_context = ""
for i, (diff, speedup) in enumerate(zip(diffs, speedups, strict=False)):
ranking_context += f"{i + 1}. Diff:\n```diff\n{diff}\n```\nSpeedup: {speedup:.3f}\n"
2025-09-03 23:26:09 +00:00
user_prompt = USER_PROMPT.format(ranking_context=ranking_context)
2025-09-03 23:26:09 +00:00
system_message = ChatCompletionSystemMessageParam(role="system", content=SYSTEM_PROMPT)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
debug_log_sensitive_data(f"{SYSTEM_PROMPT}{user_prompt}")
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
2025-09-05 02:01:01 +00:00
async with create_openai_client() as claude_client:
2025-09-03 23:26:09 +00:00
try:
output = await claude_client.with_options(max_retries=2).chat.completions.create(
2025-09-04 23:10:35 +00:00
model=rank_model.name, messages=messages, n=1
2025-09-03 23:26:09 +00:00
)
await update_optimization_cost(trace_id=trace_id, cost=calculate_llm_cost(output, rank_model))
2025-09-03 23:26:09 +00:00
except OpenAIError as e:
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
2025-09-04 23:10:35 +00:00
return RankErrorResponseSchema(error=str(e))
2025-09-03 23:26:09 +00:00
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
if output.usage is not None:
ph(
user_id,
"aiservice-optimize-openai-usage",
2025-09-04 23:10:35 +00:00
properties={"model": rank_model.name, "n": 1, "usage": output.usage.model_dump_json()},
2025-09-03 23:26:09 +00:00
)
2025-09-05 00:30:48 +00:00
# parse xml tag for explanation, ranking
try:
explanation_match = re.search(explain_regex_pattern, output.choices[0].message.content)
explanation = explanation_match.group(1)
except:
# TODO add logging instead of print("No explanation found")
explanation = ""
# still doing stuff instead of returning coz ranking is important
if explanation == "":
2025-09-05 00:30:48 +00:00
# TODO add logging instead of print("No explanation found")
pass
# still doing stuff instead of returning coz ranking is important
try:
ranking_match = re.search(rank_regex_pattern, output.choices[0].message.content)
2025-09-05 18:51:38 +00:00
# TODO better parsing, could be only comma separated, need to handle all edge cases
ranking = list(map(int, ranking_match.group(1).strip().split(",")))
2025-09-05 00:30:48 +00:00
except:
# TODO add logging instead of print("No ranking found")
return RankErrorResponseSchema(error="No ranking found")
if not sorted(ranking) == list(range(1, len(diffs) + 1)):
2025-09-05 00:30:48 +00:00
# TODO need to handle all edge cases
# TODO add logging instead of print("Invalid ranking")
return RankErrorResponseSchema(error="No ranking found")
return RankResponseSchema(ranking=ranking, explanation=explanation)
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankInputSchema(Schema):
2025-09-03 23:26:09 +00:00
trace_id: str
2025-09-05 02:04:32 +00:00
speedups: list[float]
2025-09-05 00:30:48 +00:00
diffs: list[str]
optimization_ids: list[str] # which diff corresponded to which opt candidate
2025-09-05 00:30:48 +00:00
python_version: str
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankResponseSchema(Schema):
ranking: list[int]
2025-09-03 23:26:09 +00:00
explanation: str
2025-09-04 23:10:35 +00:00
class RankErrorResponseSchema(Schema):
2025-09-03 23:26:09 +00:00
error: str
@ranker_api.post("/", response={200: RankResponseSchema, 400: RankErrorResponseSchema, 500: RankErrorResponseSchema})
async def rank(request, data: RankInputSchema) -> tuple[int, RankResponseSchema | RankErrorResponseSchema]:
2025-09-05 00:30:48 +00:00
ph(request.user, "aiservice-rank-called")
2025-09-03 23:26:09 +00:00
if not validate_trace_id(data.trace_id):
2025-09-04 23:10:35 +00:00
return 400, RankErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
2025-09-05 00:30:48 +00:00
ranking_response = await rank_optimizations(
2025-09-03 23:26:09 +00:00
request.user,
2025-09-05 00:30:48 +00:00
data.speedups,
data.diffs,
data.optimization_ids,
python_version=data.python_version,
2025-09-03 23:26:09 +00:00
trace_id=data.trace_id,
)
2025-09-05 00:30:48 +00:00
if isinstance(ranking_response, RankErrorResponseSchema):
ph(request.user, "Invalid Ranking, fallback to default")
debug_log_sensitive_data("No valid ranking was generated")
return 500, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
ph(request.user, "ranking generated", properties={"ranking": ranking_response})
ranking_with_0_idx = [x - 1 for x in ranking_response.ranking]
2025-09-05 18:51:38 +00:00
if hasattr(request, "should_log_features") and request.should_log_features:
ranked_opt_ids = [data.optimization_ids[i] for i in ranking_with_0_idx]
2025-09-05 18:51:38 +00:00
await log_features(trace_id=data.trace_id, user_id=request.user, ranking={'ranking':ranked_opt_ids,'explanation':ranking_response.explanation})
response = RankResponseSchema(explanation=ranking_response.explanation, ranking=ranking_with_0_idx)
2025-09-03 23:26:09 +00:00
return 200, response