Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com> Co-authored-by: saga4 <saga4@codeflashs-MacBook-Air.local> Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com> Co-authored-by: Mohamed Ashraf <mohamedashrraf222@gmail.com> Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
291 lines
12 KiB
Python
291 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import re
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sentry_sdk
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import validate_trace_id
|
|
from aiservice.env_specific import create_llm_client, debug_log_sensitive_data, llm_clients
|
|
from aiservice.models.aimodels import EXPLANATIONS_MODEL, LLM, calculate_llm_cost
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from ninja import NinjaAPI, Schema
|
|
from openai import OpenAIError
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
from packaging import version
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.models.aimodels import LLM
|
|
from openai.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
)
|
|
|
|
explanations_api = NinjaAPI(urls_namespace="explanations")
|
|
explain_regex_pattern = re.compile(r"<explain>(.*)<\/explain>", re.DOTALL | re.IGNORECASE)
|
|
|
|
SYSTEM_PROMPT = """You are an expert software engineer who understands why programs run fast. You have deep expertise in data structures and algorithms.
|
|
|
|
Your goal is to explain why a piece of code is more performant than a baseline code, to make it easier for a developer to accept and merge the optimized version.
|
|
|
|
You are provided the following information to succeed in the explanation process -
|
|
|
|
- original_source_code: The baseline implementation of the code being optimized
|
|
- original_line_profiler_results - The results after running line_profiler on the original_source_code
|
|
- original_code_runtime - The runtime for the original_source_code
|
|
- optimized_source_code - This is the suggested optimized version of the original_source_code that you should explain.
|
|
- optimized_line_profiler_results - The results after running line_profiler on the optimized_source_code
|
|
- optimized_code_runtime - The runtime for the optimized_source_code
|
|
- speedup - The relative gain in runtime for the optimized_source_code
|
|
- annotated_tests - The regression tests that were run to test for performance and correctness, with runtime results annotated next to the respective test case.
|
|
- read_only_dependency_code - The READ ONLY dependencies for the code provided, to help you better understand the code being provided.
|
|
- original_explanation - The original explanation generated for the optimized_source_code. Note that the original_explanation may be out of sync as some of the micro-optimizations and irrelevant changes might have been reverted in the optimized_source_code.
|
|
- python_version - The version of python the code would be executed on.
|
|
- function_references - Python markdown blocks with filename and references of some functions which call the function being optimized. The filenames and/or references could indicate if the function being optimized is in a hot path. The reference could have the function being called from a place that is important, for example in a loop, which means the effect of optimization might be important.
|
|
|
|
Keep explanations **developer-focused and concise**. Focus on:
|
|
- **What** specific optimizations were applied.
|
|
- **Key changes** that affect behavior or dependencies.
|
|
- **Why** the specific optimization leads to a speedup based on your knowledge of performance in Python code.
|
|
- **How** the optimization could potentially impact existing workloads based on function_references which can help determine whether the function being optimized is called in a hot path or not, and if the context where the function is called may benefit from the optimization.
|
|
- What kind of test cases are the specific optimizations good for based on the annotated_tests results.
|
|
- Avoid mentioning obvious preservation details (file structure, imports, signatures) unless they were specifically modified.
|
|
|
|
Please provide your explanation in the following format:
|
|
|
|
<explain>
|
|
Your *Brief* explanation of why and how optimized_source_code is faster than original_source_code.
|
|
</explain>
|
|
"""
|
|
|
|
BASE_USER_PROMPT = """The original_source_code is as follows
|
|
|
|
<original_source_code>
|
|
```python
|
|
{original_source_code}
|
|
```
|
|
</original_source_code>
|
|
|
|
The optimized_source_code is as follows
|
|
|
|
<optimized_source_code>
|
|
```python
|
|
{optimized_source_code}
|
|
```
|
|
</optimized_source_code>
|
|
|
|
Here is the line profiler information for the original_source_code
|
|
|
|
<original_line_profiler_results>
|
|
{original_line_profiler_results}
|
|
</original_line_profiler_results>
|
|
|
|
Here is the line profiler information for the optimized_source_code
|
|
|
|
<optimized_line_profiler_results>
|
|
{optimized_line_profiler_results}
|
|
</optimized_line_profiler_results>
|
|
|
|
Here is the original_code_runtime
|
|
<original_code_runtime>
|
|
{original_code_runtime}
|
|
</original_code_runtime>
|
|
|
|
Here is the optimized_code_runtime
|
|
<optimized_code_runtime>
|
|
{optimized_code_runtime}
|
|
</optimized_code_runtime>
|
|
|
|
Here is the speedup
|
|
<speedup>
|
|
{speedup}
|
|
</speedup>
|
|
|
|
Here is the test function code with runtime results annotated next to the respective test case.
|
|
|
|
<annotated_tests>
|
|
{annotated_tests}
|
|
</annotated_tests>
|
|
|
|
Here is the read_only_dependency_code
|
|
|
|
<read_only_dependency_code>
|
|
{read_only_dependency_code}
|
|
</read_only_dependency_code>
|
|
|
|
Here is the original_explanation
|
|
<original_explanation>
|
|
{original_explanation}
|
|
</original_explanation>
|
|
|
|
Here is the python_version
|
|
<python_version>
|
|
{python_version}
|
|
</python_version>
|
|
|
|
Here is the function_references
|
|
<function_references>
|
|
{function_references}
|
|
</function_references>
|
|
|
|
"""
|
|
|
|
THROUGHPUT_PROMPT_SECTION = """Here is the original_throughput (operations per second)
|
|
<original_throughput>
|
|
{original_throughput}
|
|
</original_throughput>
|
|
|
|
Here is the optimized_throughput (operations per second)
|
|
<optimized_throughput>
|
|
{optimized_throughput}
|
|
</optimized_throughput>
|
|
|
|
Here is the throughput_improvement
|
|
<throughput_improvement>
|
|
{throughput_improvement}
|
|
</throughput_improvement>
|
|
|
|
"""
|
|
|
|
THROUGHPUT_SYSTEM_SECTION = """Additional throughput data is provided:
|
|
- original_throughput - The throughput (operations per second) for the original_source_code
|
|
- optimized_throughput - The throughput (operations per second) for the optimized_source_code
|
|
- throughput_improvement - The percentage improvement in throughput
|
|
|
|
When explaining optimizations:
|
|
- **Throughput improvements** - explain how the optimization affects the rate of operations/processing.
|
|
- When both runtime and throughput data are provided (for async functions), explain both metrics and how they relate to the optimization.
|
|
"""
|
|
|
|
|
|
async def explain_optimizations( # noqa: D417
|
|
user_id: str, data: ExplanationsSchema, explanations_model: LLM = EXPLANATIONS_MODEL
|
|
) -> ExplanationsResponseSchema | ExplanationsErrorResponseSchema:
|
|
"""Optimize the given python code for performance using the Claude 4 model.
|
|
|
|
Parameters
|
|
----------
|
|
- source_code (str): The python code to optimize.
|
|
- n (int): Number of optimization variants to generate. Default is 1.
|
|
- python_version (tuple[int, int, int]): The python version to use. Default is (3,12,9).
|
|
|
|
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
|
|
optimized code and the second is the explanation.
|
|
:param explanations_model:
|
|
|
|
"""
|
|
debug_log_sensitive_data(f"Generating an explanation for {user_id}:\n{data.optimized_code}")
|
|
if version.parse(data.codeflash_version) <= version.parse("0.18.2") and data.annotated_tests:
|
|
data.annotated_tests = f"```python\n{data.annotated_tests}\n```"
|
|
user_prompt = BASE_USER_PROMPT.format(
|
|
original_source_code=data.source_code,
|
|
original_line_profiler_results=data.original_line_profiler_results or "[No profiler results available]",
|
|
optimized_source_code=data.optimized_code,
|
|
optimized_line_profiler_results=data.optimized_line_profiler_results or "[No profiler results available]",
|
|
original_code_runtime=data.original_code_runtime,
|
|
optimized_code_runtime=data.optimized_code_runtime,
|
|
speedup=data.speedup,
|
|
annotated_tests=data.annotated_tests,
|
|
read_only_dependency_code=data.dependency_code or "[No read only code present]",
|
|
original_explanation=data.original_explanation,
|
|
python_version=data.python_version or "Not Available",
|
|
function_references=data.function_references or "Not Available",
|
|
)
|
|
|
|
system_prompt = SYSTEM_PROMPT
|
|
if data.original_throughput is not None and data.optimized_throughput is not None:
|
|
user_prompt += THROUGHPUT_PROMPT_SECTION.format(
|
|
original_throughput=data.original_throughput,
|
|
optimized_throughput=data.optimized_throughput,
|
|
throughput_improvement=data.throughput_improvement or "[Unable to calculate throughput improvement]",
|
|
)
|
|
system_prompt += "\n" + THROUGHPUT_SYSTEM_SECTION
|
|
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
debug_log_sensitive_data(f"{system_prompt}{user_prompt}")
|
|
messages: list[
|
|
ChatCompletionSystemMessageParam
|
|
| ChatCompletionUserMessageParam
|
|
| ChatCompletionAssistantMessageParam
|
|
| ChatCompletionToolMessageParam
|
|
| ChatCompletionFunctionMessageParam
|
|
] = [system_message, user_message]
|
|
llm_client = llm_clients[explanations_model.model_type]
|
|
try:
|
|
output = await llm_client.with_options(max_retries=2).chat.completions.create(
|
|
model=explanations_model.name, messages=messages, n=1
|
|
)
|
|
await update_optimization_cost(trace_id=data.trace_id, cost=calculate_llm_cost(output, explanations_model))
|
|
except OpenAIError as e:
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
|
|
return ExplanationsErrorResponseSchema(error=str(e))
|
|
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
|
|
if output.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-openai-usage",
|
|
properties={"model": explanations_model.name, "n": 1, "usage": output.usage.json()},
|
|
)
|
|
return ExplanationsResponseSchema(explanation=output.choices[0].message.content)
|
|
|
|
|
|
class ExplanationsSchema(Schema):
|
|
trace_id: str
|
|
source_code: str
|
|
optimized_code: str
|
|
original_line_profiler_results: str
|
|
optimized_line_profiler_results: str
|
|
original_code_runtime: str
|
|
optimized_code_runtime: str
|
|
speedup: str
|
|
annotated_tests: str
|
|
dependency_code: str | None
|
|
optimization_id: str
|
|
original_explanation: str
|
|
original_throughput: str | None = None
|
|
optimized_throughput: str | None = None
|
|
throughput_improvement: str | None = None
|
|
python_version: str | None = None
|
|
function_references: str | None = None
|
|
codeflash_version: str = "0.18.2"
|
|
|
|
|
|
class ExplanationsResponseSchema(Schema):
|
|
explanation: str
|
|
|
|
|
|
class ExplanationsErrorResponseSchema(Schema):
|
|
error: str
|
|
|
|
|
|
@explanations_api.post(
|
|
"/",
|
|
response={
|
|
200: ExplanationsResponseSchema,
|
|
400: ExplanationsErrorResponseSchema,
|
|
500: ExplanationsErrorResponseSchema,
|
|
},
|
|
)
|
|
async def explain(
|
|
request, data: ExplanationsSchema
|
|
) -> tuple[int, ExplanationsResponseSchema | ExplanationsErrorResponseSchema]:
|
|
ph(request.user, "aiservice-explain-called")
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, ExplanationsErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
explanation_response = await explain_optimizations(request.user, data)
|
|
if isinstance(explanation_response, ExplanationsErrorResponseSchema):
|
|
ph(request.user, "Explanation not generated, revert to old explanation")
|
|
debug_log_sensitive_data("No explanation was generated")
|
|
return 500, ExplanationsErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
|
ph(request.user, "explanation generated", properties={"explanation": explanation_response})
|
|
# parse xml tag for explanation
|
|
match = re.search(r"<explain>(.*)<\/explain>", explanation_response.explanation, re.DOTALL | re.IGNORECASE)
|
|
explanation = match.group(1)
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(trace_id=data.trace_id, user_id=request.user, final_explanation=explanation)
|
|
response = ExplanationsResponseSchema(explanation=explanation)
|
|
return 200, response
|