2025-09-03 23:26:09 +00:00
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from aiservice . analytics . posthog import ph
from aiservice . common_utils import validate_trace_id
2025-09-11 19:46:26 +00:00
from aiservice . env_specific import create_openai_client , debug_log_sensitive_data
2025-09-04 23:10:35 +00:00
from aiservice . models . aimodels import RANKING_MODEL , calculate_llm_cost
2025-09-03 23:26:09 +00:00
from log_features . log_event import update_optimization_cost
from log_features . log_features import log_features
2025-09-11 19:46:26 +00:00
from ninja import NinjaAPI , Schema
from openai import OpenAIError
from openai . types . chat import ChatCompletionSystemMessageParam , ChatCompletionUserMessageParam
2025-09-03 23:26:09 +00:00
if TYPE_CHECKING :
2025-09-11 19:46:26 +00:00
from aiservice . models . aimodels import LLM
2025-09-03 23:26:09 +00:00
from openai . types . chat import (
ChatCompletionAssistantMessageParam ,
ChatCompletionFunctionMessageParam ,
ChatCompletionToolMessageParam ,
)
2025-09-04 23:10:35 +00:00
# from google import genai
# from pydantic import BaseModel
#
# class RankerResponseSchema(BaseModel):
# ranking: list[int]
# explanation: str
2025-09-11 19:46:26 +00:00
ranker_api = NinjaAPI ( urls_namespace = " ranker " )
2025-09-03 23:26:09 +00:00
rank_regex_pattern = re . compile ( r " <rank>(.*)< \ /rank> " , re . DOTALL | re . IGNORECASE )
explain_regex_pattern = re . compile ( r " <explain>(.*)< \ /explain> " , re . DOTALL | re . IGNORECASE )
SYSTEM_PROMPT = """ You are an expert code reviewer who understands why programs run fast.
2025-09-10 00:05:47 +00:00
You are provided with a list of optimization candidates with their code diff with respect to the baseline code and speedup ratio information . Your task is to rank the candidates in decreasing order of their viability as a pull request . Your goal is to improve the probability of acceptance of the optimization by an expert engineer .
2025-09-03 23:26:09 +00:00
2025-09-10 00:05:47 +00:00
Rules to follow while ranking optimization candidates -
- Prefer optimizations with higher speedup ratios . If the higher speedups happen due to strange hacks or micro - optimizations or something an expert won ' t write then prefer it less.
- Prefer optimizations which contain precise diffs unless the speedup provided is very high . Larger pull requests are typically harder to accept then more precise smaller pull requests .
- The optimization candidate should not impact the code readability unless the speedup provided is very high .
Sometimes , these criteria maybe in conflict with each other . In such cases you have to remember that the goal is acceptance of the pull request , so make a judgement on what optimization candidate would be most likely to be accepted .
Please provide your response in the following format :
2025-09-03 23:26:09 +00:00
< rank >
Comma separated list of candidate indices in decreasing order of their viability as a pull request candidate .
< / rank >
< explain >
A brief explanation of why the particular ranking was made .
< / explain >
"""
2025-09-10 00:05:47 +00:00
USER_PROMPT = """ Here is a numbered list of optimization candidates ' and their speedup ratios.
2025-09-03 23:26:09 +00:00
2025-09-05 00:30:48 +00:00
{ ranking_context }
2025-09-03 23:26:09 +00:00
"""
async def rank_optimizations (
user_id : str ,
2025-09-05 02:04:32 +00:00
speedups : list [ float ] ,
2025-09-03 23:26:09 +00:00
diffs : list [ str ] ,
2025-09-05 00:30:48 +00:00
optimization_ids : list [ str ] ,
python_version : str = " 3.12.9 " ,
2025-09-03 23:26:09 +00:00
rank_model : LLM = RANKING_MODEL ,
trace_id : str = " " ,
) - > RankResponseSchema | RankErrorResponseSchema :
""" Optimize the given python code for performance using the Claude 4 model.
Parameters
- - - - - - - - - -
- speedups list [ str ] : list of speedups of optimized candidates .
- diffs list [ str ] : list of diffs of optimized candidates .
- python_version ( tuple [ int , int , int ] ) : The python version to use . Default is ( 3 , 12 , 9 ) .
Returns : - List [ Tuple [ Union [ str , None ] , Union [ str , None ] ] ] : A list of tuples where the first element is the
optimized code and the second is the explanation .
2025-09-10 00:05:47 +00:00
: param optimization_ids :
2025-09-03 23:26:09 +00:00
"""
2025-09-05 00:30:48 +00:00
debug_log_sensitive_data ( f " Generating a ranking for { user_id } " )
# TODO add logging instead of print(optimization_ids)
SYSTEM_PROMPT . format ( python_version = python_version )
ranking_context = " "
2025-09-11 19:46:26 +00:00
for i , ( diff , speedup ) in enumerate ( zip ( diffs , speedups , strict = False ) ) :
ranking_context + = f " { i + 1 } . Diff: \n ```diff \n { diff } \n ``` \n Speedup: { speedup : .3f } \n "
2025-09-03 23:26:09 +00:00
2025-09-11 19:46:26 +00:00
user_prompt = USER_PROMPT . format ( ranking_context = ranking_context )
2025-09-03 23:26:09 +00:00
system_message = ChatCompletionSystemMessageParam ( role = " system " , content = SYSTEM_PROMPT )
user_message = ChatCompletionUserMessageParam ( role = " user " , content = user_prompt )
debug_log_sensitive_data ( f " { SYSTEM_PROMPT } { user_prompt } " )
messages : list [
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [ system_message , user_message ]
2025-09-05 02:01:01 +00:00
async with create_openai_client ( ) as claude_client :
2025-09-03 23:26:09 +00:00
try :
output = await claude_client . with_options ( max_retries = 2 ) . chat . completions . create (
2025-09-04 23:10:35 +00:00
model = rank_model . name , messages = messages , n = 1
2025-09-03 23:26:09 +00:00
)
2025-09-11 19:46:26 +00:00
await update_optimization_cost ( trace_id = trace_id , cost = calculate_llm_cost ( output , rank_model ) )
2025-09-03 23:26:09 +00:00
except OpenAIError as e :
debug_log_sensitive_data ( f " Failed to generate new explanation, Error message: { e } " )
2025-09-04 23:10:35 +00:00
return RankErrorResponseSchema ( error = str ( e ) )
2025-09-03 23:26:09 +00:00
debug_log_sensitive_data ( f " AIClient optimization response: \n { output } " )
if output . usage is not None :
ph (
user_id ,
" aiservice-optimize-openai-usage " ,
2025-09-04 23:10:35 +00:00
properties = { " model " : rank_model . name , " n " : 1 , " usage " : output . usage . model_dump_json ( ) } ,
2025-09-03 23:26:09 +00:00
)
2025-09-05 00:30:48 +00:00
# parse xml tag for explanation, ranking
try :
explanation_match = re . search ( explain_regex_pattern , output . choices [ 0 ] . message . content )
explanation = explanation_match . group ( 1 )
except :
# TODO add logging instead of print("No explanation found")
explanation = " "
# still doing stuff instead of returning coz ranking is important
2025-09-11 19:46:26 +00:00
if explanation == " " :
2025-09-05 00:30:48 +00:00
# TODO add logging instead of print("No explanation found")
pass
# still doing stuff instead of returning coz ranking is important
try :
ranking_match = re . search ( rank_regex_pattern , output . choices [ 0 ] . message . content )
2025-09-05 18:51:38 +00:00
# TODO better parsing, could be only comma separated, need to handle all edge cases
2025-09-11 19:46:26 +00:00
ranking = list ( map ( int , ranking_match . group ( 1 ) . strip ( ) . split ( " , " ) ) )
2025-09-05 00:30:48 +00:00
except :
# TODO add logging instead of print("No ranking found")
return RankErrorResponseSchema ( error = " No ranking found " )
2025-09-11 19:46:26 +00:00
if not sorted ( ranking ) == list ( range ( 1 , len ( diffs ) + 1 ) ) :
2025-09-05 00:30:48 +00:00
# TODO need to handle all edge cases
# TODO add logging instead of print("Invalid ranking")
return RankErrorResponseSchema ( error = " No ranking found " )
2025-09-11 19:46:26 +00:00
return RankResponseSchema ( ranking = ranking , explanation = explanation )
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankInputSchema ( Schema ) :
2025-09-03 23:26:09 +00:00
trace_id : str
2025-09-05 02:04:32 +00:00
speedups : list [ float ]
2025-09-05 00:30:48 +00:00
diffs : list [ str ]
2025-09-11 19:46:26 +00:00
optimization_ids : list [ str ] # which diff corresponded to which opt candidate
2025-09-05 00:30:48 +00:00
python_version : str
2025-09-03 23:26:09 +00:00
2025-09-04 23:10:35 +00:00
class RankResponseSchema ( Schema ) :
ranking : list [ int ]
2025-09-03 23:26:09 +00:00
explanation : str
2025-09-04 23:10:35 +00:00
class RankErrorResponseSchema ( Schema ) :
2025-09-03 23:26:09 +00:00
error : str
2025-09-11 19:46:26 +00:00
@ranker_api.post ( " / " , response = { 200 : RankResponseSchema , 400 : RankErrorResponseSchema , 500 : RankErrorResponseSchema } )
async def rank ( request , data : RankInputSchema ) - > tuple [ int , RankResponseSchema | RankErrorResponseSchema ] :
2025-09-05 00:30:48 +00:00
ph ( request . user , " aiservice-rank-called " )
2025-09-03 23:26:09 +00:00
if not validate_trace_id ( data . trace_id ) :
2025-09-04 23:10:35 +00:00
return 400 , RankErrorResponseSchema ( error = " Invalid trace ID. Please provide a valid UUIDv4. " )
2025-09-05 00:30:48 +00:00
ranking_response = await rank_optimizations (
2025-09-03 23:26:09 +00:00
request . user ,
2025-09-05 00:30:48 +00:00
data . speedups ,
data . diffs ,
data . optimization_ids ,
python_version = data . python_version ,
2025-09-03 23:26:09 +00:00
trace_id = data . trace_id ,
)
2025-09-05 00:30:48 +00:00
if isinstance ( ranking_response , RankErrorResponseSchema ) :
ph ( request . user , " Invalid Ranking, fallback to default " )
debug_log_sensitive_data ( " No valid ranking was generated " )
return 500 , RankErrorResponseSchema ( error = " Error generating ranking. Internal server error. " )
ph ( request . user , " ranking generated " , properties = { " ranking " : ranking_response } )
2025-09-16 03:22:15 +00:00
ranking_with_0_idx = [ x - 1 for x in ranking_response . ranking ]
2025-09-05 18:51:38 +00:00
if hasattr ( request , " should_log_features " ) and request . should_log_features :
2025-09-12 18:13:48 +00:00
ranked_opt_ids = [ data . optimization_ids [ i ] for i in ranking_with_0_idx ]
2025-09-05 18:51:38 +00:00
await log_features ( trace_id = data . trace_id , user_id = request . user , ranking = { ' ranking ' : ranked_opt_ids , ' explanation ' : ranking_response . explanation } )
2025-09-12 18:13:48 +00:00
response = RankResponseSchema ( explanation = ranking_response . explanation , ranking = ranking_with_0_idx )
2025-09-03 23:26:09 +00:00
return 200 , response