codeflash-internal/django/aiservice/optimizer/refinement.py

403 lines
20 KiB
Python
Raw Normal View History

2025-07-15 05:17:51 +00:00
from __future__ import annotations
2025-07-15 05:50:29 +00:00
import asyncio
2025-07-15 05:17:51 +00:00
import re
2025-07-22 20:29:00 +00:00
import time
2025-07-15 05:17:51 +00:00
from pathlib import Path
from typing import TYPE_CHECKING
2025-07-22 20:29:00 +00:00
from uuid import uuid4
2025-07-15 05:17:51 +00:00
import sentry_sdk
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data, create_claude_client, IS_PRODUCTION
from aiservice.models.aimodels import REFINEMENT_MODEL
2025-07-15 05:17:51 +00:00
from ninja import NinjaAPI, Schema
from openai import OpenAIError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
2025-07-18 15:41:22 +00:00
from log_features.log_refinement import log_refinement
2025-07-17 15:18:09 +00:00
from optimizer.diff_patches_utils.patches_v2 import apply_patches
2025-07-15 05:17:51 +00:00
from pydantic import ValidationError
if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
refinement_api = NinjaAPI(urls_namespace="refinement")
# Get the directory of the current file
current_dir = Path(__file__).parent
2025-07-25 20:22:35 +00:00
SYSTEM_PROMPT = """You are an expert software engineer who writes really fast programs and is an expert in optimizing runtime and memory requirements of a program by rewriting it. You have deep expertise in modern programming best practices, and clean code principles.
You are part of a code optimization system called Codeflash. Codeflash analyzes user code to optimize it for performance. It does so by first establishing the original code baseline by generating regression tests and discovering existing test cases, which finds the runtime of the original code and the behavior. Then, Codeflash generates multiple candidate optimizations which are applied to the codebase and ran to find the new runtime and behavior. Codeflash discards any optimizations that are either slower or have different behavior, in order to find the real optimizations.
Even though the optimizations may be technically correct, we want the quality of the optimizations to be really high, so your task is to refine the code of the optimizations so that they are expert quality.
The goal of refining the quality is to make the optimizations more precise. You want to reduce the number of lines and characters different between the optimizations and the original code to deliver very similar optimizations. These precise optimizations are highly preferred by the user and makes it is easier to accept the optimizations.
The refinement process should NEVER change the behavior of the optimization and we should try to preserve the main optimizations.
You are provided the following information to succeed in the quality refinement process -
- original_source_code - This is the original code being optimized
- optimized_source_code - This is the previously found optimization source code.
- original_line_profiler_results - The results after running line_profiler on the original_source_code
- optimized_line_profiler_results - The results after running line_profiler on the optimized_source_code
- optimization_speedup_results - The runtime for the original_source_code and optimized_source_code over a series of tests
- optimization_explanation - The explanation generated by codeflash earlier for the optimization in the optimized_source_code
- read_only_dependency_code - The READ ONLY dependencies for the code provided, to help you better understand the code being provided. Do no modify the code here, it is only provided for your reference.
Rules to follow while refining the quality of the optimized code -
2025-07-24 20:59:30 +00:00
- Analyze the original code and the optimized code and look at the line profiler info and the explanation to understand how the optimization works
2025-07-25 20:22:35 +00:00
- Figure out the code difference between the original_source_code and the optimized_source_code to see what part of the optimized_source_code is not contributing to the optimization. In such a case, we want to revert that part of the optimized_source_code to the original_source_code. It is okay to revert parts of the changes that aren't faster by at least 1%.
- If there are any changes in the optimized code that make that code section slower than the original then we want to revert such a change to the original.
- Revert the new comments in the optimized_source_code that are different from the original_source_code unless the new code is complex and requires additional context.
- The the variables names for the same logical variable in the optimized_source_code is different from the original_source_code then prefer the variable name in the original_source_code.
Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file.
TOOL USE
You have access to a set of tools that are don't need any approval to run and is required to use to succeed with the task.
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
<tool_name>
<parameter1_name>value1</parameter1_name>
<parameter2_name>value2</parameter2_name>
...
</tool_name>
For example:
<replace_in_file>
<path>src/main.py</path>
<diff>
<<<<<<< SEARCH
a = 2
=======
a = 3
>>>>>>> REPLACE
2025-07-17 20:18:25 +00:00
</diff>
</replace_in_file>
2025-07-24 20:59:30 +00:00
Always adhere to this format for tool use to ensure proper parsing and execution.
# Tools
## replace_in_file
Description: Request to replace sections of content in an existing file using SEARCH/REPLACE blocks that define exact changes to specific parts of the file. This tool should be used when you need to make targeted changes to specific parts of a file.
Parameters:
- path: (required) The path of the file to modify
- diff: (required) One or more SEARCH/REPLACE blocks following this exact format:
```
<<<<<<< SEARCH
[exact content to find]
=======
[new content to replace with]
>>>>>>> REPLACE
```
Critical rules:
1. SEARCH content must match the associated file section to find EXACTLY:
* Match character-for-character including whitespace, indentation, line endings
* Include all comments, docstrings, etc.
2. SEARCH/REPLACE blocks will ONLY replace the first match occurrence.
* Including multiple unique SEARCH/REPLACE blocks if you need to make multiple changes.
* Include *just* enough lines in each SEARCH section to uniquely match each set of lines that need to change.
* When using multiple SEARCH/REPLACE blocks, list them in the order they appear in the file.
3. Keep SEARCH/REPLACE blocks concise:
* Break large SEARCH/REPLACE blocks into a series of smaller blocks that each change a small portion of the file.
* Include just the changing lines, and a few surrounding lines if needed for uniqueness.
* Do not include long runs of unchanging lines in SEARCH/REPLACE blocks.
* Each line must be complete. Never truncate lines mid-way through as this can cause matching failures.
4. Special operations:
* To move code: Use two SEARCH/REPLACE blocks (one to delete from original + one to insert at new location)
* To delete code: Use empty REPLACE section
Usage:
<replace_in_file>
<path>File path here</path>
<diff>
Search and replace blocks here
</diff>
</replace_in_file>
Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization.
2025-07-15 05:17:51 +00:00
"""
USER_PROMPT = """
Please edit the optimized_source_code to implement the refinement process to improve the quality of the optimization given the following information.
2025-07-25 20:22:35 +00:00
original_source_code
2025-07-15 05:17:51 +00:00
```python
{original_source_code}
```
Here is the line profiler information for the original_source_code
2025-07-25 20:22:35 +00:00
2025-07-15 05:17:51 +00:00
{original_line_profiler_results}
2025-07-25 20:22:35 +00:00
Here is the optimized_source_code
2025-07-25 20:22:35 +00:00
2025-07-15 05:17:51 +00:00
```python
{optimized_source_code}
2025-07-16 19:44:38 +00:00
```
2025-07-25 20:22:35 +00:00
Here is the explanation generated by Codeflash for the optimized_source_code -
2025-07-25 20:22:35 +00:00
{optimized_explanation}
2025-07-25 20:22:35 +00:00
Here is the line profiler information for the optimized_source_code
2025-07-25 20:22:35 +00:00
2025-07-15 05:17:51 +00:00
{optimized_line_profiler_results}
2025-07-25 20:22:35 +00:00
The original_source_code takes {original_code_runtime} to run and the optimized_source_code takes {optimized_code_runtime} to run, making it {speedup} faster.
2025-07-25 20:22:35 +00:00
Here is the read_only_dependency_code
2025-07-25 20:22:35 +00:00
{read_only_dependency_code}
2025-07-15 05:17:51 +00:00
"""
replace_in_file_regex = re.compile(r"<replace_in_file>\s+<path>(.*?)<\/path>\s+<diff>(.*)<\/diff>\s+<\/replace_in_file>", re.S)
2025-07-15 05:17:51 +00:00
async def refinement(
user_id: str,
original_source_code: str,
original_line_profiler_results: str,
original_code_runtime: str,
read_only_dependency_code: str,
2025-07-15 05:17:51 +00:00
optimized_source_code: str,
2025-07-22 20:29:00 +00:00
original_optimization_id: str, # this is the optimization id generated for the /optimize not the refinement one
2025-07-15 05:17:51 +00:00
optimized_line_profiler_results: str,
optimized_code_runtime: str,
speedup: str,
n: int = 1,
optimize_model: LLM = REFINEMENT_MODEL,
lsp_mode: bool = False,
2025-07-15 05:17:51 +00:00
optimized_explanation: str = "",
2025-07-21 21:15:12 +00:00
multi_file_context: bool = False, # TODO: support sending multiple files for refinement
2025-07-22 20:29:00 +00:00
) -> RefinementItem:
2025-07-15 05:17:51 +00:00
"""Optimize the given python code for performance using OpenAI's GPT-4o model.
Parameters
----------
- source_code (str): The python code to optimize.
- n (int): Number of optimization variants to generate. Default is 1.
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
optimized code and the second is the optimized_explanation.
"""
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
# next optimization iteration
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
# function doing and then ask it to describe how to speed it up and then generate optimization
2025-07-24 20:59:30 +00:00
print("/refinement")
SYSTEM_PROMPT.format()
2025-07-15 05:17:51 +00:00
user_prompt = USER_PROMPT.format(
original_source_code=original_source_code,
2025-07-17 16:34:50 +00:00
original_line_profiler_results=original_line_profiler_results or "[No profiler results available]",
2025-07-15 05:17:51 +00:00
optimized_source_code=optimized_source_code,
2025-07-17 16:34:50 +00:00
optimized_line_profiler_results=optimized_line_profiler_results or "[No profiler results available]",
2025-07-15 05:17:51 +00:00
optimized_explanation=optimized_explanation,
original_code_runtime=original_code_runtime,
optimized_code_runtime=optimized_code_runtime,
speedup=speedup,
read_only_dependency_code=read_only_dependency_code or "[No read only code present]"
2025-07-15 05:17:51 +00:00
)
system_message = ChatCompletionSystemMessageParam(role="system", content=SYSTEM_PROMPT)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
new_optimization_id = original_optimization_id[:-4] + 'refi' # to identify if the optimization is a refined one or not
async with create_claude_client() as openai_client:
2025-07-15 05:17:51 +00:00
try:
output = await openai_client.with_options(max_retries=2).chat.completions.create(
2025-07-15 05:50:29 +00:00
model=optimize_model.name, messages=messages, n=1
2025-07-15 05:17:51 +00:00
)
except OpenAIError as e:
2025-07-21 21:15:12 +00:00
print("refinement api: OpenAI Code Generation error ...")
2025-07-15 05:17:51 +00:00
print(e)
2025-07-24 20:59:30 +00:00
debug_log_sensitive_data(f"Failed to generate code for source:\n{original_source_code}")
2025-07-22 20:29:00 +00:00
return RefinementItem(
optimization_id=new_optimization_id,
original_optimization_id=original_optimization_id,
2025-07-21 21:15:12 +00:00
raw_refinement_response=f"[Error] OpenAI code generation error.\n: {e}",
2025-07-22 20:29:00 +00:00
optimized_source_code=optimized_source_code
2025-07-21 21:15:12 +00:00
)
2025-07-15 05:17:51 +00:00
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
if output.usage is not None:
2025-07-17 16:18:12 +00:00
ph(user_id, "refinement-usage", properties={"model": optimize_model.name, "usage": output.usage.json()})
results = [content for op in output.choices if (content := op.message.content)] # will be of size 1
match = replace_in_file_regex.search(results[0])
2025-07-17 16:34:50 +00:00
2025-07-21 21:15:12 +00:00
refined_optimization: str = optimized_source_code
search_replace_block = ""
2025-07-15 05:50:29 +00:00
if match:
search_replace_block = match.group(2)
2025-07-24 20:59:30 +00:00
path = match.group(1) # Future usage
2025-07-15 05:50:29 +00:00
try:
2025-07-17 15:18:09 +00:00
refined_optimization = apply_patches(search_replace_block, optimized_source_code)
2025-07-15 05:50:29 +00:00
except (ValueError, ValidationError) as exc:
sentry_sdk.capture_exception(exc)
2025-07-17 15:18:09 +00:00
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{optimized_source_code}")
2025-07-15 05:50:29 +00:00
debug_log_sensitive_data(f"Traceback: {exc}")
2025-07-18 22:37:34 +00:00
2025-07-22 20:29:00 +00:00
if refined_optimization.strip() == optimized_source_code.strip() or search_replace_block.strip() == "":
# ignore the empty diff patches refinements
refined_optimization = ""
return RefinementItem(
optimization_id=new_optimization_id,
original_optimization_id=original_optimization_id,
2025-07-21 21:15:12 +00:00
raw_refinement_response=results[0],
diff_patches=search_replace_block,
refined_optimization=refined_optimization,
optimized_source_code= optimized_source_code
)
2025-07-15 05:17:51 +00:00
class RefinementRequestSchema(Schema):
2025-07-15 05:17:51 +00:00
trace_id: str
2025-07-22 20:29:00 +00:00
optimization_id: str
2025-07-18 15:41:22 +00:00
function_to_optimize: str
2025-07-15 05:17:51 +00:00
python_version: str
experiment_metadata: dict[str, str] | None = None
codeflash_version: str | None = None
original_source_code: str
2025-07-17 16:34:50 +00:00
original_line_profiler_results: str = ""
read_only_dependency_code: str
2025-07-15 05:17:51 +00:00
optimized_source_code: str
2025-07-17 16:34:50 +00:00
optimized_line_profiler_results: str = ""
2025-07-15 05:17:51 +00:00
optimized_explanation: str
2025-07-17 15:18:09 +00:00
original_code_runtime: str = ""
optimized_code_runtime: str = ""
speedup: str = ""
2025-07-18 22:37:34 +00:00
repo_owner: str | None = None
repo_name: str | None = None
2025-07-15 05:17:51 +00:00
2025-07-22 20:29:00 +00:00
class RefinementItem(Schema):
optimization_id: str = "" # ends with refi
original_optimization_id: str = ""
2025-07-21 21:15:12 +00:00
raw_refinement_response: str = ""
diff_patches: str = ""
refined_optimization: str = ""
optimized_source_code: str = ""
2025-07-15 05:17:51 +00:00
class OptimizeErrorResponseSchema(Schema):
error: str
class RefinementResponseItemschema(Schema):
# the key will be the optimization id and the value will be the actual refined code
explanation: str
optimization_id: str
source_code: str
2025-07-15 05:17:51 +00:00
class Refinementschema(Schema):
2025-07-22 20:29:00 +00:00
# the key will be the optimization id and the value will be the actual refined code
refinements: list[RefinementResponseItemschema]
2025-07-15 05:17:51 +00:00
@refinement_api.post("/", response={200: Refinementschema, 400: Refinementschema, 500: Refinementschema})
2025-07-18 15:41:22 +00:00
async def refine(request, data: list[RefinementRequestSchema]) -> tuple[int, Refinementschema | OptimizeErrorResponseSchema]:
ph(request.user, "aiservice-refinement-called")
2025-07-15 05:17:51 +00:00
try:
2025-07-15 05:50:29 +00:00
python_version: tuple[int, int, int] = parse_python_version(data[0].python_version)
2025-07-15 05:17:51 +00:00
except:
return 400, OptimizeErrorResponseSchema(
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
)
2025-07-21 21:15:12 +00:00
trace_id = data[0].trace_id
repo_owner = data[0].repo_owner
repo_name = data[0].repo_name
function_to_optimize = data[0].function_to_optimize
if not validate_trace_id(trace_id):
2025-07-15 05:17:51 +00:00
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
2025-07-15 05:50:29 +00:00
optimized_source_code_and_explanations_futures = []
for i in range(len(data)):
optimized_source_code_and_explanations_futures.append(refinement(
user_id=request.user,
2025-07-22 20:29:00 +00:00
original_optimization_id=data[i].optimization_id,
2025-07-15 05:50:29 +00:00
original_source_code=data[i].original_source_code,
original_line_profiler_results=data[i].original_line_profiler_results,
original_code_runtime=data[i].original_code_runtime,
2025-07-15 05:50:29 +00:00
optimized_source_code=data[i].optimized_source_code,
optimized_line_profiler_results=data[i].optimized_line_profiler_results,
optimized_explanation=data[i].optimized_explanation,
optimized_code_runtime=data[i].optimized_code_runtime,
speedup=data[i].speedup,
2025-07-18 15:41:22 +00:00
read_only_dependency_code=data[i].read_only_dependency_code,
2025-07-15 05:50:29 +00:00
))
2025-07-24 20:59:30 +00:00
time.sleep(0.01) # seemingly works to avoid rate limit issues, TODO remove later
2025-07-21 21:15:12 +00:00
refinement_data = await asyncio.gather(*optimized_source_code_and_explanations_futures)
2025-07-22 20:29:00 +00:00
should_log_features = hasattr(request, "should_log_features") and request.should_log_features
# TODO introduce filtering pipeline present in optimizer.py, need to think carefully what we need
# optimized_code_and_explanations: list[CodeAndExplanation] = []
# for result in results:
# match = re.match(r"(.*)```python(?:\n|\\n)(.*?)```(.*)", result, re.DOTALL | re.MULTILINE)
# if match:
# code = match.group(2)
# explanation = match.group(1) + match.group(3)
# try:
# cst_module = parse_module_to_cst(code)
# except cst.ParserSyntaxError as e:
# # log exception with sentry
# sentry_sdk.capture_exception(e)
# debug_log_sensitive_data(f"ParserSyntaxError for source:\n{code}")
# debug_log_sensitive_data(f"Traceback: {e}")
# continue
# try:
# optimized_code_and_explanations.append(CodeAndExplanation(cst_module, explanation))
# except (ValueError, ValidationError) as exc:
# # Another one bites the Pydantic validation dust
# sentry_sdk.capture_exception(exc)
# debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{code}")
# debug_log_sensitive_data(f"Traceback: {exc}")
# if len(optimized_code_and_explanations) == 0:
# ph(request.user, "aiservice-optimize-no-optimizations-found")
# debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
# return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
# ph(
# request.user,
# "aiservice-optimize-optimizations-found",
# properties={"num_optimizations": len(optimized_code_and_explanations)},
# )
# traced_optimizations = [
# CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=str(uuid.uuid4()))
# for ce in optimized_code_and_explanations
# ]
# processed_optimizations: list[CodeExplanationAndID] = optimizations_postprocessing_pipeline(
# data.source_code, traced_optimizations
# )
2025-07-21 21:15:12 +00:00
if should_log_features:
asyncio.create_task(log_refinement(
trace_id=trace_id,
user_id=request.user,
repo_owner=repo_owner,
repo_name=repo_name,
function_to_optimize=function_to_optimize,
refinement_responses=[r.dict() for r in refinement_data], # make it serializable
))
refined_optimizations = [RefinementResponseItemschema(explanation=data[i].optimized_explanation,optimization_id=elem.original_optimization_id,source_code=elem.refined_optimization) for i, elem in enumerate(refinement_data)]
# simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces
filtered_refined_optimizations = []
source_code_set = set()
for elem in refined_optimizations:
if (elem.source_code.strip() not in source_code_set) and elem.source_code!="":
source_code_set.add(elem.source_code.strip())
filtered_refined_optimizations.append(elem)
print(filtered_refined_optimizations)
return 200, Refinementschema(refinements=filtered_refined_optimizations)