mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com> Co-authored-by: saga4 <saga4@codeflashs-MacBook-Air.local> Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com> Co-authored-by: Mohamed Ashraf <mohamedashrraf222@gmail.com> Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
193 lines
9.4 KiB
Python
193 lines
9.4 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sentry_sdk
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import validate_trace_id
|
|
from aiservice.env_specific import create_llm_client, debug_log_sensitive_data, llm_clients
|
|
from aiservice.models.aimodels import LLM, RANKING_MODEL, calculate_llm_cost
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from ninja import NinjaAPI, Schema
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.models.aimodels import LLM
|
|
from openai.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
)
|
|
|
|
# from google import genai
|
|
# from pydantic import BaseModel
|
|
#
|
|
# class RankerResponseSchema(BaseModel):
|
|
# ranking: list[int]
|
|
# explanation: str
|
|
|
|
|
|
ranker_api = NinjaAPI(urls_namespace="ranker")
|
|
rank_regex_pattern = re.compile(r"<rank>(.*)<\/rank>", re.DOTALL | re.IGNORECASE)
|
|
explain_regex_pattern = re.compile(r"<explain>(.*)<\/explain>", re.DOTALL | re.IGNORECASE)
|
|
|
|
SYSTEM_PROMPT = """You are an expert code reviewer who understands why programs run fast.
|
|
|
|
You are provided with a list of optimization candidates with their code diff with respect to the baseline code and speedup ratio information. Your task is to rank the candidates in decreasing order of their viability as a pull request. Your goal is to improve the probability of acceptance of the optimization by an expert engineer.
|
|
|
|
You are also provided with the following information.
|
|
- python_version - The version of python the code would be executed on.
|
|
- function_references - Python markdown blocks with filename and references of some functions which call the function being optimized. The filenames and/or references could indicate if the function being optimized is in a hot path. The reference could have the function being called from a place that is important, for example in a loop, which means the effect of optimization might be important.
|
|
|
|
Rules to follow while ranking optimization candidates -
|
|
- Prefer optimizations with higher speedup ratios. If the higher speedups happen due to strange hacks or micro-optimizations or something an expert won't write then prefer it less.
|
|
- Prefer optimizations which contain precise diffs unless the speedup provided is very high. Larger pull requests are typically harder to accept then more precise smaller pull requests.
|
|
- Introduction of the `global` and `nonlocal` keywords in optimizations is **HIGHLY DISCOURAGED** as it reduces code clarity and maintainability, introduces hidden dependencies, can cause subtle bugs and breaks modularity. **DO NOT** prefer such optimizations.
|
|
- If the only optimizations are micro-optimizations like inlining a function call, or localizing variables or methods (not being used in a loop), especially with python_version older than 3.11, do not prefer the optimizations.
|
|
- The optimization candidate should not impact the code readability unless the speedup provided is very high.
|
|
|
|
Sometimes, these criteria maybe in conflict with each other. In such cases you have to remember that the goal is acceptance of the pull request, so make a judgement on what optimization candidate would be most likely to be accepted.
|
|
|
|
Please provide your response in the following format:
|
|
|
|
<rank>
|
|
Comma separated list of candidate indices in decreasing order of their viability as a pull request candidate.
|
|
</rank>
|
|
<explain>
|
|
A brief explanation of why the particular ranking was made.
|
|
</explain>
|
|
"""
|
|
|
|
USER_PROMPT = """Here is a numbered list of optimization candidates' code diffs and their speedup ratios.
|
|
|
|
{ranking_context}
|
|
|
|
Here is the python version
|
|
{python_version}
|
|
|
|
Here are the function references
|
|
{function_references}
|
|
"""
|
|
|
|
|
|
async def rank_optimizations( # noqa: D417
|
|
user_id: str, data: RankInputSchema, rank_model: LLM = RANKING_MODEL
|
|
) -> RankResponseSchema | RankErrorResponseSchema:
|
|
"""Optimize the given python code for performance using the Claude 4 model.
|
|
|
|
Parameters
|
|
----------
|
|
- speedups list[str]: list of speedups of optimized candidates.
|
|
- diffs list[str]: list of diffs of optimized candidates.
|
|
- python_version (tuple[int, int, int]): The python version to use. Default is (3,12,9).
|
|
|
|
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
|
|
optimized code and the second is the explanation.
|
|
:param optimization_ids:
|
|
|
|
"""
|
|
debug_log_sensitive_data(f"Generating a ranking for {user_id}")
|
|
# TODO add logging instead of print(optimization_ids)
|
|
ranking_context = ""
|
|
for i, (diff, speedup) in enumerate(zip(data.diffs, data.speedups, strict=False)):
|
|
ranking_context += f"{i + 1}. Diff:\n```diff\n{diff}\n```\nSpeedup: {speedup:.3f}\n"
|
|
|
|
user_prompt = USER_PROMPT.format(
|
|
ranking_context=ranking_context,
|
|
python_version=data.python_version or "Not available",
|
|
function_references=data.function_references or "Not available",
|
|
)
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=SYSTEM_PROMPT)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
debug_log_sensitive_data(f"{SYSTEM_PROMPT}{user_prompt}")
|
|
messages: list[
|
|
ChatCompletionSystemMessageParam
|
|
| ChatCompletionUserMessageParam
|
|
| ChatCompletionAssistantMessageParam
|
|
| ChatCompletionToolMessageParam
|
|
| ChatCompletionFunctionMessageParam
|
|
] = [system_message, user_message]
|
|
llm_client = llm_clients[rank_model.model_type]
|
|
try:
|
|
output = await llm_client.with_options(max_retries=2).chat.completions.create(
|
|
model=rank_model.name, messages=messages, n=1
|
|
)
|
|
await update_optimization_cost(trace_id=data.trace_id, cost=calculate_llm_cost(output, rank_model))
|
|
except Exception as e: # noqa: BLE001
|
|
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
|
|
sentry_sdk.capture_exception(e)
|
|
return RankErrorResponseSchema(error=str(e))
|
|
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
|
|
if output.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-openai-usage",
|
|
properties={"model": rank_model.name, "n": 1, "usage": output.usage.model_dump_json()},
|
|
)
|
|
# parse xml tag for explanation, ranking
|
|
try:
|
|
explanation_match = re.search(explain_regex_pattern, output.choices[0].message.content)
|
|
explanation = explanation_match.group(1)
|
|
except: # noqa: E722
|
|
# TODO add logging instead of print("No explanation found")
|
|
explanation = ""
|
|
# still doing stuff instead of returning coz ranking is important
|
|
if explanation == "":
|
|
# TODO add logging instead of print("No explanation found")
|
|
pass
|
|
# still doing stuff instead of returning coz ranking is important
|
|
try:
|
|
ranking_match = re.search(rank_regex_pattern, output.choices[0].message.content)
|
|
# TODO better parsing, could be only comma separated, need to handle all edge cases
|
|
ranking = list(map(int, ranking_match.group(1).strip().split(",")))
|
|
except: # noqa: E722
|
|
# TODO add logging instead of print("No ranking found")
|
|
return RankErrorResponseSchema(error="No ranking found")
|
|
if sorted(ranking) != list(range(1, len(data.diffs) + 1)):
|
|
# TODO need to handle all edge cases
|
|
# TODO add logging instead of print("Invalid ranking")
|
|
return RankErrorResponseSchema(error="No ranking found")
|
|
return RankResponseSchema(ranking=ranking, explanation=explanation)
|
|
|
|
|
|
class RankInputSchema(Schema):
|
|
trace_id: str
|
|
speedups: list[float]
|
|
diffs: list[str]
|
|
optimization_ids: list[str] # which diff corresponded to which opt candidate
|
|
python_version: str | None = None
|
|
function_references: str | None = None
|
|
|
|
|
|
class RankResponseSchema(Schema):
|
|
ranking: list[int]
|
|
explanation: str
|
|
|
|
|
|
class RankErrorResponseSchema(Schema):
|
|
error: str
|
|
|
|
|
|
@ranker_api.post("/", response={200: RankResponseSchema, 400: RankErrorResponseSchema, 500: RankErrorResponseSchema})
|
|
async def rank(request, data: RankInputSchema) -> tuple[int, RankResponseSchema | RankErrorResponseSchema]:
|
|
ph(request.user, "aiservice-rank-called")
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, RankErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
ranking_response = await rank_optimizations(request.user, data)
|
|
if isinstance(ranking_response, RankErrorResponseSchema):
|
|
ph(request.user, "Invalid Ranking, fallback to default")
|
|
debug_log_sensitive_data("No valid ranking was generated")
|
|
return 500, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
|
|
ph(request.user, "ranking generated", properties={"ranking": ranking_response})
|
|
ranking_with_0_idx = [x - 1 for x in ranking_response.ranking]
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
ranked_opt_ids = [data.optimization_ids[i] for i in ranking_with_0_idx]
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
ranking={"ranking": ranked_opt_ids, "explanation": ranking_response.explanation},
|
|
)
|
|
response = RankResponseSchema(explanation=ranking_response.explanation, ranking=ranking_with_0_idx)
|
|
return 200, response
|