codeflash-internal/django/aiservice/optimizer/optimizer.py
mohammed 81bc6a0bb5
refactoring
Signed-off-by: mohammed <mohammed18200118@gmail.com>
2025-08-03 12:46:25 +03:00

248 lines
10 KiB
Python

from __future__ import annotations
import ast
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from ninja import NinjaAPI, Schema
from openai import OpenAIError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import (
create_openai_client,
debug_log_sensitive_data_from_callable,
)
from aiservice.models.aimodels import OPTIMIZE_MODEL
from authapp.user import get_user_by_id
from log_features.log_event import get_repository, log_optimization_event
from log_features.log_features import log_features
from optimizer.context_utils.constants import MULTI_CONTEXT_SPLITTER_PREFIX
from optimizer.context_utils.context import *
from optimizer.context_utils.multi_context import MultiOptimizerContext
from optimizer.context_utils.single_context import SingleOptimizerContext
from optimizer.models import CodeAndExplanation, CodeExplanationAndID
from optimizer.postprocess import optimizations_postprocessing_pipeline
if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
optimize_api = NinjaAPI(urls_namespace="optimize")
# Get the directory of the current file
current_dir = Path(__file__).parent
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
USER_PROMPT = (current_dir / "user_prompt.md").read_text()
async def optimize_python_code(
user_id: str,
source_code: str,
dependency_code: str | None = None,
n: int = 1,
optimize_model: LLM = OPTIMIZE_MODEL,
python_version: tuple[int, int, int] = (3, 12, 9),
) -> list[CodeAndExplanation]:
"""Optimize the given python code for performance using OpenAI's GPT-4 model.
Parameters
----------
- source_code (str): The flattened python code to optimize (uses splitters @MULTI_CONTEXT_SPLITTER_PREFIX).
- n (int): Number of optimization variants to generate. Default is 1.
- python_version (tuple[int, int, int]): The python version to use. Default is (3,12,9).
Returns: - List[Tuple[Union[str, None], Union[str, None]]]: A list of tuples where the first element is the
optimized code and the second is the explanation.
"""
print("/optimize: Optimizing python code.")
debug_log_sensitive_data(f"Optimizing python code for user {user_id}:\n{source_code}")
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
# next optimization iteration
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
# function doing and then ask it to describe how to speed it up and then generate optimization
python_version_str = ".".join(str(x) for x in python_version)
multi_context = False # support older cli versions
if MULTI_CONTEXT_SPLITTER_PREFIX in source_code:
multi_context = True
ctx: BaseOptimizerContext = MultiOptimizerContext(SYSTEM_PROMPT, USER_PROMPT, source_code) if multi_context else\
SingleOptimizerContext(SYSTEM_PROMPT, USER_PROMPT, source_code)
system_prompt = ctx.get_system_prompt(python_version_str)
user_prompt = ctx.get_user_prompt(dependency_code, None)
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
async with create_openai_client() as openai_client:
# TODO: Verify if the context window length is within the model capability
try:
output = await openai_client.with_options(max_retries=3).chat.completions.create(
model=optimize_model.name, messages=messages, n=n
)
except OpenAIError as e:
print("OpenAI Code Generation error ...")
print(e)
debug_log_sensitive_data(f"Failed to generate code for source:\n{source_code}")
return []
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
if output.usage is not None:
ph(
user_id,
"aiservice-optimize-openai-usage",
properties={"model": optimize_model.name, "n": n, "usage": output.usage.json()},
)
results = [content for op in output.choices if (content := op.message.content)]
optimized_code_and_explanations: list[CodeAndExplanation] = []
for result in results:
ctx.extract_code_and_explanation_from_llm_res(result)
module_and_explanation = ctx.parse_code_and_explanation()
if ctx.is_valid_code():
try:
optimized_code_and_explanations.append(module_and_explanation)
except (ValueError, ValidationError) as exc:
# Another one bites the Pydantic validation dust
sentry_sdk.capture_exception(exc)
debug_log_sensitive_data(f"{type(exc).__name__} for source:\n{ctx.extracted_code_and_expl.code}")
debug_log_sensitive_data(f"Traceback: {exc}")
ctx.extracted_code_and_expl = None
ctx.parsed_code_and_explanation = None
return optimized_code_and_explanations
class OptimizeSchema(Schema):
source_code: str
dependency_code: str | None
trace_id: str
python_version: str
experiment_metadata: dict[str, str] | None = None
codeflash_version: str | None = None
current_username: str | None = None
repo_owner : str | None = None
repo_name : str | None = None
class OptimizeResponseItemSchema(Schema):
source_code: str
explanation: str
optimization_id: str
optimization_event_id: str | None = None
class OptimizeResponseSchema(Schema):
optimizations: list[OptimizeResponseItemSchema]
class OptimizeErrorResponseSchema(Schema):
error: str
@optimize_api.post(
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
)
async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
ph(request.user, "aiservice-optimize-called")
try:
python_version: tuple[int, int, int] = parse_python_version(data.python_version)
except:
return 400, OptimizeErrorResponseSchema(
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
)
try:
parsed = ast.parse(data.source_code, feature_version=python_version[:2])
compile(data.source_code, "data.source_code", "exec")
if not parsed.body:
raise SyntaxError
except SyntaxError:
return 400, OptimizeErrorResponseSchema(
error="Invalid source code. It is not valid Python code. Please check syntax of your code."
)
if not validate_trace_id(data.trace_id):
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
optimized_code_and_explanations = await optimize_python_code(
request.user, data.source_code, data.dependency_code, n=5, python_version=python_version
)
if len(optimized_code_and_explanations) == 0:
ph(request.user, "aiservice-optimize-no-optimizations-found")
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
ph(
request.user,
"aiservice-optimize-optimizations-found",
properties={"num_optimizations": len(optimized_code_and_explanations)},
)
traced_optimizations = [
CodeExplanationAndID(cst_module=ce.cst_module, explanation=ce.explanation, id=str(uuid.uuid4()))
for ce in optimized_code_and_explanations
]
processed_optimizations: list[CodeExplanationAndID] = optimizations_postprocessing_pipeline(
data.source_code, traced_optimizations
)
try:
repository = await get_repository(data.repo_owner, data.repo_name)
except Exception:
repository = None
if data.current_username is None:
user = await get_user_by_id(request.user)
data.current_username = user.github_username
event = await log_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username ,
repository_id=repository.id if repository else None,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(optimized_code_and_explanations),
"experiment_metadata": data.experiment_metadata,
}
)
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={cei.id: cei.cst_module.code for cei in traced_optimizations},
optimizations_post={cei.id: cei.cst_module.code for cei in processed_optimizations},
explanations_raw={cei.id: cei.explanation for cei in traced_optimizations},
explanations_post={cei.id: cei.explanation for cei in processed_optimizations},
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
)
response = OptimizeResponseSchema(
optimizations=[
OptimizeResponseItemSchema(
source_code=ce.cst_module.code, explanation=ce.explanation, optimization_id=ce.id, optimization_event_id=str(event.id) if event else None,
)
for ce in processed_optimizations
]
)
def log_response():
debug_log_sensitive_data(f"Response:\n{response.json()}")
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-optimize-successful")
return 200, response