38 lines
1.6 KiB
Python
38 lines
1.6 KiB
Python
import os
|
|
|
|
from dotenv import load_dotenv
|
|
from mistralai.client import MistralClient
|
|
from mistralai.models.chat_completion import ChatMessage
|
|
from openai import OpenAI
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def optimize_python_code(source_code: str, n: int = 1) -> tuple[list[str | None], list[str | list[str]]]:
|
|
system_message = {
|
|
"role": "system",
|
|
"content": (
|
|
"You are a computer programmer who writes really fast programs and is an expert in optimizing "
|
|
"runtime and memory requirements of a program by rewriting them. You don't rename function names "
|
|
"or change function signatures. The function return value should be exactly the same as before."
|
|
),
|
|
}
|
|
user_message = {
|
|
"role": "user",
|
|
"content": f"Rewrite this python program to run faster.\n```python\n{source_code}\n```",
|
|
}
|
|
|
|
openai_messages = [system_message, user_message]
|
|
model_openai = "gpt-4-1106-preview"
|
|
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), max_retries=3)
|
|
openai_response = openai_client.chat.completions.create(model=model_openai, messages=openai_messages, n=n)
|
|
|
|
mistral_messages = [
|
|
ChatMessage(role=system_message["role"], content=system_message["content"]),
|
|
ChatMessage(role=user_message["role"], content=user_message["content"]),
|
|
]
|
|
model_mistral = "mistral-large-latest"
|
|
mistral_client = MistralClient(api_key=os.environ.get("MISTRAL_API_KEY"), max_retries=3)
|
|
mistral_response = mistral_client.chat(model=model_mistral, messages=mistral_messages)
|
|
|
|
return [c.message.content for c in openai_response.choices], [c.message.content for c in mistral_response.choices]
|