codeflash-internal/django/aiservice/ranker/ranker.py

194 lines
9.4 KiB
Python
Raw Normal View History

2025-09-03 23:26:09 +00:00
from __future__ import annotations
import re
from typing import TYPE_CHECKING
2025-10-23 04:20:20 +00:00
import sentry_sdk
2025-09-03 23:26:09 +00:00
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import create_llm_client, debug_log_sensitive_data, llm_clients
2025-11-04 08:27:58 +00:00
from aiservice.models.aimodels import LLM, RANKING_MODEL, calculate_llm_cost
2025-09-03 23:26:09 +00:00
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
2025-09-03 23:26:09 +00:00
if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
2025-09-03 23:26:09 +00:00
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
2025-09-04 23:10:35 +00:00
# from google import genai
# from pydantic import BaseModel
#
# class RankerResponseSchema(BaseModel):
# ranking: list[int]
# explanation: str
ranker_api = NinjaAPI(urls_namespace="ranker")
2025-09-03 23:26:09 +00:00
rank_regex_pattern = re.compile(r"<rank>(.*)<\/rank>", re.DOTALL | re.IGNORECASE)
explain_regex_pattern = re.compile(r"<explain>(.*)<\/explain>", re.DOTALL | re.IGNORECASE)
SYSTEM_PROMPT = """You are an expert code reviewer who understands why programs run fast.
2025-09-10 00:05:47 +00:00
You are provided with a list of optimization candidates with their code diff with respect to the baseline code and speedup ratio information. Your task is to rank the candidates in decreasing order of their viability as a pull request. Your goal is to improve the probability of acceptance of the optimization by an expert engineer.
2025-09-03 23:26:09 +00:00
2025-11-04 05:21:00 +00:00
You are also provided with the following information.
- python_version - The version of python the code would be executed on.
- function_references - Python markdown blocks with filename and references of some functions which call the function being optimized. The filenames and/or references could indicate if the function being optimized is in a hot path. The reference could have the function being called from a place that is important, for example in a loop, which means the effect of optimization might be important.
2025-09-10 00:05:47 +00:00
Rules to follow while ranking optimization candidates -
- Prefer optimizations with higher speedup ratios. If the higher speedups happen due to strange hacks or micro-optimizations or something an expert won't write then prefer it less.
- Prefer optimizations which contain precise diffs unless the speedup provided is very high. Larger pull requests are typically harder to accept then more precise smaller pull requests.
2025-11-04 08:18:14 +00:00
- Introduction of the `global` and `nonlocal` keywords in optimizations is **HIGHLY DISCOURAGED** as it reduces code clarity and maintainability, introduces hidden dependencies, can cause subtle bugs and breaks modularity. **DO NOT** prefer such optimizations.
2025-11-04 08:27:58 +00:00
- If the only optimizations are micro-optimizations like inlining a function call, or localizing variables or methods (not being used in a loop), especially with python_version older than 3.11, do not prefer the optimizations.
2025-09-10 00:05:47 +00:00
- The optimization candidate should not impact the code readability unless the speedup provided is very high.
Sometimes, these criteria maybe in conflict with each other. In such cases you have to remember that the goal is acceptance of the pull request, so make a judgement on what optimization candidate would be most likely to be accepted.
Please provide your response in the following format:
2025-09-03 23:26:09 +00:00
<rank>
Comma separated list of candidate indices in decreasing order of their viability as a pull request candidate.
</rank>
<explain>
A brief explanation of why the particular ranking was made.
</explain>
"""
USER_PROMPT = """Here is a numbered list of optimization candidates' code diffs and their speedup ratios.
2025-09-03 23:26:09 +00:00
2025-09-05 00:30:48 +00:00
{ranking_context}
2025-11-04 05:21:00 +00:00
2025-11-04 08:31:02 +00:00
Here is the python version
2025-11-04 05:21:00 +00:00
{python_version}
Here are the function references
2025-11-04 06:30:44 +00:00
{function_references}
2025-09-03 23:26:09 +00:00
"""
2025-11-04 05:21:00 +00:00
async def rank_optimizations( # noqa: D417
2025-11-04 06:30:44 +00:00
user_id: str, data: RankInputSchema, rank_model: LLM = RANKING_MODEL
2025-09-03 23:26:09 +00:00
) -> RankResponseSchema | RankErrorResponseSchema:
"""Optimize the given python code for performance using the Claude 4 model.
Parameters
----------
- speedups list[str]: list of speedups of optimized candidates.
- diffs list[str]: list of diffs of optimized candidates.
- python_version (tuple[int, int, int]): The python version to use. Default is (3,12,9).
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
optimized code and the second is the explanation.
2025-09-10 00:05:47 +00:00
:param optimization_ids:
2025-09-03 23:26:09 +00:00
"""
2025-09-05 00:30:48 +00:00
debug_log_sensitive_data(f"Generating a ranking for {user_id}")
# TODO add logging instead of print(optimization_ids)
ranking_context = ""
2025-11-04 06:30:44 +00:00
for i, (diff, speedup) in enumerate(zip(data.diffs, data.speedups, strict=False)):
ranking_context += f"{i + 1}. Diff:\n```diff\n{diff}\n```\nSpeedup: {speedup:.3f}\n"
2025-09-03 23:26:09 +00:00
2025-11-04 06:30:44 +00:00
user_prompt = USER_PROMPT.format(
ranking_context=ranking_context,
python_version=data.python_version or "Not available",
function_references=data.function_references or "Not available",
)
2025-09-03 23:26:09 +00:00
system_message = ChatCompletionSystemMessageParam(role="system", content=SYSTEM_PROMPT)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
debug_log_sensitive_data(f"{SYSTEM_PROMPT}{user_prompt}")
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
llm_client = llm_clients[rank_model.model_type]
2025-10-28 08:21:23 +00:00
try:
output = await llm_client.with_options(max_retries=2).chat.completions.create(
2025-10-28 08:21:23 +00:00
model=rank_model.name, messages=messages, n=1
)
2025-11-04 06:30:44 +00:00
await update_optimization_cost(trace_id=data.trace_id, cost=calculate_llm_cost(output, rank_model))
except Exception as e: # noqa: BLE001
2025-10-28 08:21:23 +00:00
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
sentry_sdk.capture_exception(e)
return RankErrorResponseSchema(error=str(e))
2025-09-03 23:26:09 +00:00
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
if output.usage is not None:
ph(
user_id,
"aiservice-optimize-openai-usage",
2025-09-04 23:10:35 +00:00
properties={"model": rank_model.name, "n": 1, "usage": output.usage.model_dump_json()},
2025-09-03 23:26:09 +00:00
)
2025-09-05 00:30:48 +00:00
# parse xml tag for explanation, ranking
try:
explanation_match = re.search(explain_regex_pattern, output.choices[0].message.content)
explanation = explanation_match.group(1)
except: # noqa: E722
2025-09-05 00:30:48 +00:00
# TODO add logging instead of print("No explanation found")
explanation = ""
# still doing stuff instead of returning coz ranking is important
if explanation == "":
2025-09-05 00:30:48 +00:00
# TODO add logging instead of print("No explanation found")
pass
# still doing stuff instead of returning coz ranking is important
try:
ranking_match = re.search(rank_regex_pattern, output.choices[0].message.content)
2025-09-05 18:51:38 +00:00
# TODO better parsing, could be only comma separated, need to handle all edge cases
ranking = list(map(int, ranking_match.group(1).strip().split(",")))
except: # noqa: E722
2025-09-05 00:30:48 +00:00
# TODO add logging instead of print("No ranking found")
return RankErrorResponseSchema(error="No ranking found")
if sorted(ranking) != list(range(1, len(data.diffs) + 1)):
2025-09-05 00:30:48 +00:00
# TODO need to handle all edge cases
# TODO add logging instead of print("Invalid ranking")
return RankErrorResponseSchema(error="No ranking found")
return RankResponseSchema(ranking=ranking, explanation=explanation)
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankInputSchema(Schema):
2025-09-03 23:26:09 +00:00
trace_id: str
2025-09-05 02:04:32 +00:00
speedups: list[float]
2025-09-05 00:30:48 +00:00
diffs: list[str]
optimization_ids: list[str] # which diff corresponded to which opt candidate
2025-11-04 06:30:44 +00:00
python_version: str | None = None
2025-11-04 05:01:22 +00:00
function_references: str | None = None
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankResponseSchema(Schema):
ranking: list[int]
2025-09-03 23:26:09 +00:00
explanation: str
2025-09-04 23:10:35 +00:00
class RankErrorResponseSchema(Schema):
2025-09-03 23:26:09 +00:00
error: str
@ranker_api.post("/", response={200: RankResponseSchema, 400: RankErrorResponseSchema, 500: RankErrorResponseSchema})
async def rank(request, data: RankInputSchema) -> tuple[int, RankResponseSchema | RankErrorResponseSchema]:
2025-09-05 00:30:48 +00:00
ph(request.user, "aiservice-rank-called")
2025-09-03 23:26:09 +00:00
if not validate_trace_id(data.trace_id):
2025-09-04 23:10:35 +00:00
return 400, RankErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
2025-11-04 06:30:44 +00:00
ranking_response = await rank_optimizations(request.user, data)
2025-09-05 00:30:48 +00:00
if isinstance(ranking_response, RankErrorResponseSchema):
ph(request.user, "Invalid Ranking, fallback to default")
debug_log_sensitive_data("No valid ranking was generated")
return 500, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
ph(request.user, "ranking generated", properties={"ranking": ranking_response})
ranking_with_0_idx = [x - 1 for x in ranking_response.ranking]
2025-09-05 18:51:38 +00:00
if hasattr(request, "should_log_features") and request.should_log_features:
ranked_opt_ids = [data.optimization_ids[i] for i in ranking_with_0_idx]
2025-10-14 06:02:30 +00:00
await log_features(
trace_id=data.trace_id,
user_id=request.user,
ranking={"ranking": ranked_opt_ids, "explanation": ranking_response.explanation},
)
response = RankResponseSchema(explanation=ranking_response.explanation, ranking=ranking_with_0_idx)
2025-09-03 23:26:09 +00:00
return 200, response