codeflash-internal/experiments/optimizer_mistral.py

39 lines
1.6 KiB
Python
Raw Permalink Normal View History

2024-02-28 18:44:12 +00:00
import os
from dotenv import load_dotenv
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage
from openai import OpenAI
load_dotenv()
def optimize_python_code(source_code: str, n: int = 1) -> tuple[list[str | None], list[str | list[str]]]:
2024-02-28 18:44:12 +00:00
system_message = {
"role": "system",
"content": (
"You are a computer programmer who writes really fast programs and is an expert in optimizing "
"runtime and memory requirements of a program by rewriting them. You don't rename function names "
"or change function signatures. The function return value should be exactly the same as before."
),
}
user_message = {
"role": "user",
"content": f"Rewrite this python program to run faster.\n```python\n{source_code}\n```",
}
openai_messages = [system_message, user_message]
model_openai = "gpt-4-1106-preview"
openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"), max_retries=3)
openai_response = openai_client.chat.completions.create(model=model_openai, messages=openai_messages, n=n)
2024-02-28 18:44:12 +00:00
mistral_messages = [
ChatMessage(role=system_message["role"], content=system_message["content"]),
ChatMessage(role=user_message["role"], content=user_message["content"]),
]
model_mistral = "mistral-large-latest"
mistral_client = MistralClient(api_key=os.environ.get("MISTRAL_API_KEY"), max_retries=3)
mistral_response = mistral_client.chat(model=model_mistral, messages=mistral_messages)
2024-02-28 18:44:12 +00:00
return [c.message.content for c in openai_response.choices], [c.message.content for c in mistral_response.choices]