402 lines
20 KiB
Python
402 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
from uuid import uuid4
|
|
import sentry_sdk
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import parse_python_version, validate_trace_id
|
|
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data, create_claude_client, IS_PRODUCTION
|
|
from aiservice.models.aimodels import REFINEMENT_MODEL
|
|
from ninja import NinjaAPI, Schema
|
|
from openai import OpenAIError
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
|
|
from log_features.log_refinement import log_refinement
|
|
from optimizer.diff_patches_utils.patches_v2 import apply_patches
|
|
from pydantic import ValidationError
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.models.aimodels import LLM
|
|
from openai.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
)
|
|
|
|
|
|
refinement_api = NinjaAPI(urls_namespace="refinement")
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
SYSTEM_PROMPT = """You are an expert software engineer who writes really fast programs and is an expert in optimizing runtime and memory requirements of a program by rewriting it. You have deep expertise in modern programming best practices, and clean code principles.
|
|
|
|
You are part of a code optimization system called Codeflash. Codeflash analyzes user code to optimize it for performance. It does so by first establishing the original code baseline by generating regression tests and discovering existing test cases, which finds the runtime of the original code and the behavior. Then, Codeflash generates multiple candidate optimizations which are applied to the codebase and ran to find the new runtime and behavior. Codeflash discards any optimizations that are either slower or have different behavior, in order to find the real optimizations.
|
|
|
|
Even though the optimizations may be technically correct, we want the quality of the optimizations to be really high, so your task is to refine the code of the optimizations so that they are expert quality.
|
|
|
|
The goal of refining the quality is to make the optimizations more precise. You want to reduce the number of lines and characters different between the optimizations and the original code to deliver very similar optimizations. These precise optimizations are highly preferred by the user and makes it is easier to accept the optimizations.
|
|
|
|
The refinement process should NEVER change the behavior of the optimization and we should try to preserve the main optimizations.
|
|
|
|
You are provided the following information to succeed in the quality refinement process -
|
|
|
|
- original_source_code - This is the original code being optimized
|
|
- optimized_source_code - This is the previously found optimization source code.
|
|
- original_line_profiler_results - The results after running line_profiler on the original_source_code
|
|
- optimized_line_profiler_results - The results after running line_profiler on the optimized_source_code
|
|
- optimization_speedup_results - The runtime for the original_source_code and optimized_source_code over a series of tests
|
|
- optimization_explanation - The explanation generated by codeflash earlier for the optimization in the optimized_source_code
|
|
- read_only_dependency_code - The READ ONLY dependencies for the code provided, to help you better understand the code being provided. Do no modify the code here, it is only provided for your reference.
|
|
|
|
Rules to follow while refining the quality of the optimized code -
|
|
- Analyze the original code and the optimized code and look at the line profiler info and the explanation to understand how the optimization works
|
|
- Figure out the code difference between the original_source_code and the optimized_source_code to see what part of the optimized_source_code is not contributing to the optimization. In such a case, we want to revert that part of the optimized_source_code to the original_source_code. It is okay to revert parts of the changes that aren't faster by at least 1%.
|
|
- If there are any changes in the optimized code that make that code section slower than the original then we want to revert such a change to the original.
|
|
- Revert the new comments in the optimized_source_code that are different from the original_source_code unless the new code is complex and requires additional context.
|
|
- The the variables names for the same logical variable in the optimized_source_code is different from the original_source_code then prefer the variable name in the original_source_code.
|
|
|
|
Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
|
|
|
|
TOOL USE
|
|
|
|
You have access to a set of tools that are don't need any approval to run and is required to use to succeed with the task.
|
|
|
|
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
|
|
|
<tool_name>
|
|
<parameter1_name>value1</parameter1_name>
|
|
<parameter2_name>value2</parameter2_name>
|
|
...
|
|
</tool_name>
|
|
|
|
For example:
|
|
|
|
<replace_in_file>
|
|
<path>src/main.py</path>
|
|
<diff>
|
|
<<<<<<< SEARCH
|
|
a = 2
|
|
=======
|
|
a = 3
|
|
>>>>>>> REPLACE
|
|
</diff>
|
|
</replace_in_file>
|
|
|
|
|
|
Always adhere to this format for tool use to ensure proper parsing and execution.
|
|
|
|
# Tools
|
|
|
|
## replace_in_file
|
|
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file.
|
|
Parameters:
|
|
- path: (required) The path of the file to modify
|
|
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
|
|
```
|
|
<<<<<<< SEARCH
|
|
[exact content to find]
|
|
=======
|
|
[new content to replace with]
|
|
>>>>>>> REPLACE
|
|
```
|
|
Critical rules:
|
|
1. SEARCH content must match the associated file section to find EXACTLY:
|
|
* Match character-for-character including whitespace, indentation, line endings
|
|
* Include all comments, docstrings, etc.
|
|
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
|
|
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
|
|
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
|
|
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
|
|
3. Keep SEARCH/REPLACE blocks concise:
|
|
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
|
|
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
|
|
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
|
|
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
|
|
4. Special operations:
|
|
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
|
|
* To delete code: Use empty REPLACE section
|
|
Usage:
|
|
<replace_in_file>
|
|
<path>File path here</path>
|
|
<diff>
|
|
Search and replace blocks here
|
|
</diff>
|
|
</replace_in_file>
|
|
|
|
Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization.
|
|
"""
|
|
USER_PROMPT = """
|
|
Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization given the following information.
|
|
|
|
original_source_code
|
|
```python
|
|
{original_source_code}
|
|
```
|
|
Here is the line profiler information for the original_source_code
|
|
|
|
{original_line_profiler_results}
|
|
|
|
Here is the optimized_source_code
|
|
|
|
```python
|
|
{optimized_source_code}
|
|
```
|
|
|
|
Here is the explanation generated by Codeflash for the optimized_source_code -
|
|
|
|
{optimized_explanation}
|
|
|
|
Here is the line profiler information for the optimized_source_code
|
|
|
|
{optimized_line_profiler_results}
|
|
|
|
The original_source_code takes {original_code_runtime} to run and the optimized_source_code takes {optimized_code_runtime} to run, making it {speedup} faster.
|
|
|
|
Here is the read_only_dependency_code
|
|
|
|
{read_only_dependency_code}
|
|
"""
|
|
|
|
replace_in_file_regex = re.compile(r"<replace_in_file>\s+<path>(.*?)<\/path>\s+<diff>(.*)<\/diff>\s+<\/replace_in_file>", re.S)
|
|
|
|
async def refinement(
|
|
user_id: str,
|
|
original_source_code: str,
|
|
original_line_profiler_results: str,
|
|
original_code_runtime: str,
|
|
read_only_dependency_code: str,
|
|
optimized_source_code: str,
|
|
original_optimization_id: str, # this is the optimization id generated for the /optimize not the refinement one
|
|
optimized_line_profiler_results: str,
|
|
optimized_code_runtime: str,
|
|
speedup: str,
|
|
n: int = 1,
|
|
optimize_model: LLM = REFINEMENT_MODEL,
|
|
lsp_mode: bool = False,
|
|
optimized_explanation: str = "",
|
|
multi_file_context: bool = False, # TODO: support sending multiple files for refinement
|
|
) -> RefinementItem:
|
|
"""Optimize the given python code for performance using OpenAI's GPT-4o model.
|
|
|
|
Parameters
|
|
----------
|
|
- source_code (str): The python code to optimize.
|
|
- n (int): Number of optimization variants to generate. Default is 1.
|
|
|
|
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
|
|
optimized code and the second is the optimized_explanation.
|
|
"""
|
|
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
|
|
# next optimization iteration
|
|
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
|
|
# function doing and then ask it to describe how to speed it up and then generate optimization
|
|
print("/refinement")
|
|
SYSTEM_PROMPT.format()
|
|
user_prompt = USER_PROMPT.format(
|
|
original_source_code=original_source_code,
|
|
original_line_profiler_results=original_line_profiler_results or "[No profiler results available]",
|
|
optimized_source_code=optimized_source_code,
|
|
optimized_line_profiler_results=optimized_line_profiler_results or "[No profiler results available]",
|
|
optimized_explanation=optimized_explanation,
|
|
original_code_runtime=original_code_runtime,
|
|
optimized_code_runtime=optimized_code_runtime,
|
|
speedup=speedup,
|
|
read_only_dependency_code=read_only_dependency_code or "[No read only code present]"
|
|
)
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=SYSTEM_PROMPT)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
messages: list[
|
|
ChatCompletionSystemMessageParam
|
|
| ChatCompletionUserMessageParam
|
|
| ChatCompletionAssistantMessageParam
|
|
| ChatCompletionToolMessageParam
|
|
| ChatCompletionFunctionMessageParam
|
|
] = [system_message, user_message]
|
|
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
|
new_optimization_id = original_optimization_id[:-4] + 'refi' # to identify if the optimization is a refined one or not
|
|
async with create_claude_client() as openai_client:
|
|
try:
|
|
output = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=optimize_model.name, messages=messages, n=1
|
|
)
|
|
except OpenAIError as e:
|
|
print("refinement api: OpenAI Code Generation error ...")
|
|
print(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{original_source_code}")
|
|
return RefinementItem(
|
|
optimization_id=new_optimization_id,
|
|
original_optimization_id=original_optimization_id,
|
|
raw_refinement_response=f"[Error] OpenAI code generation error.\n: {e}",
|
|
optimized_source_code=optimized_source_code
|
|
)
|
|
|
|
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
|
|
|
|
if output.usage is not None:
|
|
ph(user_id, "refinement-usage", properties={"model": optimize_model.name, "usage": output.usage.json()})
|
|
|
|
results = [content for op in output.choices if (content := op.message.content)] # will be of size 1
|
|
match = replace_in_file_regex.search(results[0])
|
|
|
|
refined_optimization: str = optimized_source_code
|
|
search_replace_block = ""
|
|
if match:
|
|
search_replace_block = match.group(2)
|
|
path = match.group(1) # Future usage
|
|
try:
|
|
refined_optimization = apply_patches(search_replace_block, optimized_source_code)
|
|
except (ValueError, ValidationError) as exc:
|
|
sentry_sdk.capture_exception(exc)
|
|
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{optimized_source_code}")
|
|
debug_log_sensitive_data(f"Traceback: {exc}")
|
|
|
|
if refined_optimization.strip() == optimized_source_code.strip() or search_replace_block.strip() == "":
|
|
# ignore the empty diff patches refinements
|
|
refined_optimization = ""
|
|
return RefinementItem(
|
|
optimization_id=new_optimization_id,
|
|
original_optimization_id=original_optimization_id,
|
|
raw_refinement_response=results[0],
|
|
diff_patches=search_replace_block,
|
|
refined_optimization=refined_optimization,
|
|
optimized_source_code= optimized_source_code
|
|
)
|
|
|
|
|
|
class RefinementRequestSchema(Schema):
|
|
trace_id: str
|
|
optimization_id: str
|
|
function_to_optimize: str
|
|
python_version: str
|
|
experiment_metadata: dict[str, str] | None = None
|
|
codeflash_version: str | None = None
|
|
original_source_code: str
|
|
original_line_profiler_results: str = ""
|
|
read_only_dependency_code: str
|
|
optimized_source_code: str
|
|
optimized_line_profiler_results: str = ""
|
|
optimized_explanation: str
|
|
original_code_runtime: str = ""
|
|
optimized_code_runtime: str = ""
|
|
speedup: str = ""
|
|
repo_owner: str | None = None
|
|
repo_name: str | None = None
|
|
|
|
class RefinementItem(Schema):
|
|
optimization_id: str = "" # ends with refi
|
|
original_optimization_id: str = ""
|
|
raw_refinement_response: str = ""
|
|
diff_patches: str = ""
|
|
refined_optimization: str = ""
|
|
optimized_source_code: str = ""
|
|
|
|
class OptimizeErrorResponseSchema(Schema):
|
|
error: str
|
|
|
|
|
|
class RefinementResponseItemschema(Schema):
|
|
# the key will be the optimization id and the value will be the actual refined code
|
|
explanation: str
|
|
optimization_id: str
|
|
source_code: str
|
|
|
|
class Refinementschema(Schema):
|
|
# the key will be the optimization id and the value will be the actual refined code
|
|
refinements: list[RefinementResponseItemschema]
|
|
|
|
|
|
@refinement_api.post("/", response={200: Refinementschema, 400: Refinementschema, 500: Refinementschema})
|
|
async def refine(request, data: list[RefinementRequestSchema]) -> tuple[int, Refinementschema | OptimizeErrorResponseSchema]:
|
|
ph(request.user, "aiservice-refinement-called")
|
|
try:
|
|
python_version: tuple[int, int, int] = parse_python_version(data[0].python_version)
|
|
except:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
|
|
)
|
|
trace_id = data[0].trace_id
|
|
repo_owner = data[0].repo_owner
|
|
repo_name = data[0].repo_name
|
|
function_to_optimize = data[0].function_to_optimize
|
|
|
|
if not validate_trace_id(trace_id):
|
|
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
optimized_source_code_and_explanations_futures = []
|
|
for i in range(len(data)):
|
|
optimized_source_code_and_explanations_futures.append(refinement(
|
|
user_id=request.user,
|
|
original_optimization_id=data[i].optimization_id,
|
|
original_source_code=data[i].original_source_code,
|
|
original_line_profiler_results=data[i].original_line_profiler_results,
|
|
original_code_runtime=data[i].original_code_runtime,
|
|
optimized_source_code=data[i].optimized_source_code,
|
|
optimized_line_profiler_results=data[i].optimized_line_profiler_results,
|
|
optimized_explanation=data[i].optimized_explanation,
|
|
optimized_code_runtime=data[i].optimized_code_runtime,
|
|
speedup=data[i].speedup,
|
|
read_only_dependency_code=data[i].read_only_dependency_code,
|
|
))
|
|
time.sleep(0.01) # seemingly works to avoid rate limit issues, TODO remove later
|
|
refinement_data = await asyncio.gather(*optimized_source_code_and_explanations_futures)
|
|
should_log_features = hasattr(request, "should_log_features") and request.should_log_features
|
|
# TODO introduce filtering pipeline present in optimizer.py, need to think carefully what we need
|
|
# optimized_code_and_explanations: list[CodeAndExplanation] = []
|
|
# for result in results:
|
|
# match = re.match(r"(.*)```python(?:\n|\\n)(.*?)```(.*)", result, re.DOTALL | re.MULTILINE)
|
|
# if match:
|
|
# code = match.group(2)
|
|
# explanation = match.group(1) + match.group(3)
|
|
# try:
|
|
# cst_module = parse_module_to_cst(code)
|
|
# except cst.ParserSyntaxError as e:
|
|
# # log exception with sentry
|
|
# sentry_sdk.capture_exception(e)
|
|
# debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code}")
|
|
# debug_log_sensitive_data(f"Traceback: {e}")
|
|
# continue
|
|
# try:
|
|
# optimized_code_and_explanations.append(CodeAndExplanation(cst_module, explanation))
|
|
# except (ValueError, ValidationError) as exc:
|
|
# # Another one bites the Pydantic validation dust
|
|
# sentry_sdk.capture_exception(exc)
|
|
# debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code}")
|
|
# debug_log_sensitive_data(f"Traceback: {exc}")
|
|
# if len(optimized_code_and_explanations) == 0:
|
|
# ph(request.user, "aiservice-optimize-no-optimizations-found")
|
|
# debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
|
# return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
|
# ph(
|
|
# request.user,
|
|
# "aiservice-optimize-optimizations-found",
|
|
# properties={"num_optimizations": len(optimized_code_and_explanations)},
|
|
# )
|
|
# traced_optimizations = [
|
|
# CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=str(uuid.uuid4()))
|
|
# for ce in optimized_code_and_explanations
|
|
# ]
|
|
# processed_optimizations: list[CodeExplanationAndID] = optimizations_postprocessing_pipeline(
|
|
# data.source_code, traced_optimizations
|
|
# )
|
|
if should_log_features:
|
|
asyncio.create_task(log_refinement(
|
|
trace_id=trace_id,
|
|
user_id=request.user,
|
|
repo_owner=repo_owner,
|
|
repo_name=repo_name,
|
|
function_to_optimize=function_to_optimize,
|
|
refinement_responses=[r.dict() for r in refinement_data], # make it serializable
|
|
))
|
|
refined_optimizations = [RefinementResponseItemschema(explanation=data[i].optimized_explanation,optimization_id=elem.original_optimization_id,source_code=elem.refined_optimization) for i, elem in enumerate(refinement_data)]
|
|
# simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces
|
|
filtered_refined_optimizations = []
|
|
source_code_set = set()
|
|
for elem in refined_optimizations:
|
|
if (elem.source_code.strip() not in source_code_set) and elem.source_code!="":
|
|
source_code_set.add(elem.source_code.strip())
|
|
filtered_refined_optimizations.append(elem)
|
|
print(filtered_refined_optimizations)
|
|
return 200, Refinementschema(refinements=filtered_refined_optimizations)
|