codeflash-internal/django/aiservice/optimizer/optimizer.py
2023-12-29 18:37:49 -08:00

86 lines
3.2 KiB
Python

import os
import re
from typing import List, Tuple
from dotenv import load_dotenv
from ninja import NinjaAPI, Schema
from openai import AsyncOpenAI, APIError
from authapp.auth import AuthBearer
optimize_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize")
if os.environ.get("ENVIRONMENT") != "PRODUCTION":
load_dotenv()
openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
async def optimize_python_code(source_code: str, n: int = 1) -> List[Tuple[str, str]]:
"""
Optimize the given python code for performance using OpenAI's GPT-4 model.
Parameters:
- source_code (str): The python code to optimize.
- n (int): Number of optimization variants to generate. Default is 1.
Returns:
- List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the optimized code and the second is the explanation.
"""
print(f"/optimize: Optimizing python code:\n{source_code}")
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the next optimization iteration
# TODO: Experiment with iterative chain-of-thought generation. ask what is the function doing and then ask it to describe how to speed it up and then generate optimization
system_message = {
"role": "system",
"content": (
"You are a computer programmer who writes really fast programs and is an expert in optimizing "
"runtime and memory requirements of a program by rewriting them. You don't rename function names "
"or change function signatures. The function return value should be exactly the same as before."
),
}
user_message = {
"role": "user",
"content": f"Rewrite this python program to run faster.\n```python\n{source_code}\n```",
}
messages = [system_message, user_message]
# TODO: Verify if the context window length is within the model capability
try:
output = await openai_client.with_options(max_retries=3).chat.completions.create(
model="gpt-4", messages=messages, n=n
)
except APIError as e:
print("OpenAI Code Generation error, retrying...")
print(e)
return [("", "")]
results = [op.message.content for op in output.choices]
return_val = []
for result in results:
match = re.match(r"(.*)```python\n(.*?)```(.*)", result, re.S | re.M)
if match:
code = match.group(2)
explanation = match.group(1) + match.group(3)
return_val.append((code, explanation))
else:
return_val.append(("", ""))
return return_val
class OptimizeSchema(Schema):
source_code: str
class OptimizeResponseSchema(Schema):
source_code: str
explanation: str
@optimize_api.post("/")
async def optimize(request, data: OptimizeSchema):
optimizations = await optimize_python_code(data.source_code, n=10)
if len(optimizations) == 0 or optimizations[0][0] == "":
return 500, {"detail": "Internal Server Error"}
return [
OptimizeResponseSchema(source_code=code, explanation=explanation)
for code, explanation in optimizations
]