86 lines
3.2 KiB
Python
86 lines
3.2 KiB
Python
import os
|
|
import re
|
|
from typing import List, Tuple
|
|
|
|
from dotenv import load_dotenv
|
|
from ninja import NinjaAPI, Schema
|
|
from openai import AsyncOpenAI, APIError
|
|
|
|
from authapp.auth import AuthBearer
|
|
|
|
optimize_api = NinjaAPI(auth=AuthBearer(), urls_namespace="optimize")
|
|
if os.environ.get("ENVIRONMENT") != "PRODUCTION":
|
|
load_dotenv()
|
|
|
|
openai_client = AsyncOpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
|
|
|
|
|
async def optimize_python_code(source_code: str, n: int = 1) -> List[Tuple[str, str]]:
|
|
"""
|
|
Optimize the given python code for performance using OpenAI's GPT-4 model.
|
|
|
|
Parameters:
|
|
- source_code (str): The python code to optimize.
|
|
- n (int): Number of optimization variants to generate. Default is 1.
|
|
|
|
Returns:
|
|
- List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the optimized code and the second is the explanation.
|
|
"""
|
|
print(f"/optimize: Optimizing python code:\n{source_code}")
|
|
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the next optimization iteration
|
|
# TODO: Experiment with iterative chain-of-thought generation. ask what is the function doing and then ask it to describe how to speed it up and then generate optimization
|
|
system_message = {
|
|
"role": "system",
|
|
"content": (
|
|
"You are a computer programmer who writes really fast programs and is an expert in optimizing "
|
|
"runtime and memory requirements of a program by rewriting them. You don't rename function names "
|
|
"or change function signatures. The function return value should be exactly the same as before."
|
|
),
|
|
}
|
|
user_message = {
|
|
"role": "user",
|
|
"content": f"Rewrite this python program to run faster.\n```python\n{source_code}\n```",
|
|
}
|
|
messages = [system_message, user_message]
|
|
|
|
# TODO: Verify if the context window length is within the model capability
|
|
try:
|
|
output = await openai_client.with_options(max_retries=3).chat.completions.create(
|
|
model="gpt-4", messages=messages, n=n
|
|
)
|
|
except APIError as e:
|
|
print("OpenAI Code Generation error, retrying...")
|
|
print(e)
|
|
return [("", "")]
|
|
|
|
results = [op.message.content for op in output.choices]
|
|
return_val = []
|
|
for result in results:
|
|
match = re.match(r"(.*)```python\n(.*?)```(.*)", result, re.S | re.M)
|
|
if match:
|
|
code = match.group(2)
|
|
explanation = match.group(1) + match.group(3)
|
|
return_val.append((code, explanation))
|
|
else:
|
|
return_val.append(("", ""))
|
|
return return_val
|
|
|
|
|
|
class OptimizeSchema(Schema):
|
|
source_code: str
|
|
|
|
|
|
class OptimizeResponseSchema(Schema):
|
|
source_code: str
|
|
explanation: str
|
|
|
|
|
|
@optimize_api.post("/")
|
|
async def optimize(request, data: OptimizeSchema):
|
|
optimizations = await optimize_python_code(data.source_code, n=10)
|
|
if len(optimizations) == 0 or optimizations[0][0] == "":
|
|
return 500, {"detail": "Internal Server Error"}
|
|
return [
|
|
OptimizeResponseSchema(source_code=code, explanation=explanation)
|
|
for code, explanation in optimizations
|
|
]
|