317 lines
16 KiB
Python
317 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import libcst as cst
|
|
import sentry_sdk
|
|
from ninja import NinjaAPI
|
|
from ninja.errors import HttpError
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
from pydantic import ValidationError
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
|
from aiservice.env_specific import (
|
|
create_openai_client,
|
|
debug_log_sensitive_data,
|
|
debug_log_sensitive_data_from_callable,
|
|
)
|
|
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
|
from authapp.user import get_user_by_id
|
|
from log_features.log_event import log_optimization_event
|
|
from log_features.log_features import log_features_optimized
|
|
from optimizer.context_utils.context_helpers import group_code
|
|
from optimizer.context_utils.optimizer_context import (
|
|
BaseOptimizerContext,
|
|
OptimizeErrorResponseSchema,
|
|
OptimizeResponseItemSchema,
|
|
OptimizeResponseSchema,
|
|
)
|
|
from optimizer.models import OptimizeSchema # noqa: TC001
|
|
|
|
if TYPE_CHECKING:
|
|
from django.http import HttpRequest
|
|
from openai.types.chat import (
|
|
ChatCompletionAssistantMessageParam,
|
|
ChatCompletionFunctionMessageParam,
|
|
ChatCompletionToolMessageParam,
|
|
)
|
|
|
|
from aiservice.models.aimodels import LLM
|
|
|
|
|
|
optimizations_json = [
|
|
{
|
|
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
|
|
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
{
|
|
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
|
|
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
{
|
|
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
|
|
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
{
|
|
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
|
|
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
]
|
|
|
|
|
|
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
|
|
print("Hacking for demo baby!!")
|
|
response_list: list[OptimizeResponseItemSchema] = [
|
|
OptimizeResponseItemSchema(
|
|
explanation=optimization["explanation"],
|
|
optimization_id=optimization["optimization_id"],
|
|
source_code=group_code({ctx.file_name: optimization["source_code"]}),
|
|
)
|
|
for optimization in optimizations_json
|
|
]
|
|
await asyncio.sleep(5)
|
|
return OptimizeResponseSchema(optimizations=response_list)
|
|
|
|
|
|
optimize_api = NinjaAPI(urls_namespace="optimize")
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
|
|
USER_PROMPT = (current_dir / "user_prompt.md").read_text()
|
|
ASYNC_SYSTEM_PROMPT = (current_dir / "async_system_prompt.md").read_text()
|
|
ASYNC_USER_PROMPT = (current_dir / "async_user_prompt.md").read_text()
|
|
|
|
|
|
async def optimize_python_code(
|
|
user_id: str,
|
|
ctx: BaseOptimizerContext,
|
|
dependency_code: str | None = None,
|
|
n: int = 1,
|
|
optimize_model: LLM = OPTIMIZE_MODEL,
|
|
python_version: tuple[int, int, int] = (3, 12, 9),
|
|
) -> tuple[list[OptimizeResponseItemSchema], float | None]:
|
|
"""Optimize the given python code for performance using LLMs.
|
|
|
|
Parameters
|
|
----------
|
|
user_id : str
|
|
The ID of the user requesting the optimization.
|
|
ctx : BaseOptimizerContext
|
|
The optimizer context containing source code and configuration.
|
|
dependency_code : str | None, optional
|
|
Additional dependency code for context. Default is None.
|
|
n : int, optional
|
|
Number of optimization variants to generate. Default is 1.
|
|
optimize_model : LLM, optional
|
|
The LLM model to use for optimization. Default is OPTIMIZE_MODEL.
|
|
python_version : tuple[int, int, int], optional
|
|
The python version to use. Default is (3, 12, 9).
|
|
|
|
Returns
|
|
-------
|
|
tuple[list[OptimizeResponseItemSchema], float | None]
|
|
A tuple containing a list of optimization response items and the LLM cost.
|
|
|
|
"""
|
|
logging.info("/optimize: Optimizing python code.")
|
|
debug_log_sensitive_data(f"Optimizing python code for user {user_id}:\n{ctx.source_code}")
|
|
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
|
|
# next optimization iteration
|
|
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
|
|
# function doing and then ask it to describe how to speed it up and then generate optimization
|
|
python_version_str = ".".join(str(x) for x in python_version)
|
|
|
|
system_prompt = ctx.get_system_prompt(python_version_str)
|
|
user_prompt = ctx.get_user_prompt(dependency_code, None)
|
|
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
messages: list[
|
|
ChatCompletionSystemMessageParam
|
|
| ChatCompletionUserMessageParam
|
|
| ChatCompletionAssistantMessageParam
|
|
| ChatCompletionToolMessageParam
|
|
| ChatCompletionFunctionMessageParam
|
|
] = [system_message, user_message]
|
|
async with create_openai_client() as openai_client:
|
|
try:
|
|
output = await openai_client.with_options(max_retries=3).chat.completions.create(
|
|
model=optimize_model.name, messages=messages, n=n
|
|
)
|
|
except Exception as e:
|
|
logging.exception("OpenAI Code Generation error in optimizer")
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
|
return []
|
|
llm_cost = calculate_llm_cost(output, optimize_model)
|
|
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
|
|
|
|
if output.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-openai-usage",
|
|
properties={"model": optimize_model.name, "n": n, "usage": output.usage.json()},
|
|
)
|
|
results = [content for op in output.choices if (content := op.message.content)]
|
|
optimization_response_items: list[OptimizeResponseItemSchema] = []
|
|
for result in results:
|
|
ctx.extract_code_and_explanation_from_llm_res(result)
|
|
try:
|
|
res = ctx.parse_and_generate_candidate_schema()
|
|
if res is not None and ctx.is_valid_code():
|
|
optimization_response_items.append(res)
|
|
|
|
ctx.extracted_code_and_expl = None
|
|
ctx.parsed_code_and_explanation = None
|
|
except (ValueError, ValidationError, cst.ParserSyntaxError) as e:
|
|
sentry_sdk.capture_message(f"Error parsing optimization result: {e}")
|
|
debug_log_sensitive_data(f"error for source:\n{ctx.source_code}")
|
|
debug_log_sensitive_data(f"Traceback: {e}")
|
|
continue
|
|
return optimization_response_items, llm_cost
|
|
|
|
|
|
def validate_request_data(data: OptimizeSchema, ctx: BaseOptimizerContext) -> tuple[int, int, int]:
|
|
if not data.source_code:
|
|
raise HttpError(400, "Source code cannot be empty.")
|
|
if not validate_trace_id(data.trace_id):
|
|
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
|
|
|
|
try:
|
|
python_version = parse_python_version(data.python_version)
|
|
except ValueError as e:
|
|
raise HttpError(
|
|
400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
|
|
) from e
|
|
|
|
try:
|
|
ctx.validate_and_parse_source_code(data.source_code, feature_version=python_version[:2])
|
|
except SyntaxError as e:
|
|
raise HttpError(
|
|
400, "Invalid source code. It is not valid Python code. Please check syntax of your code."
|
|
) from e
|
|
|
|
return python_version
|
|
|
|
|
|
@optimize_api.post(
|
|
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
|
|
)
|
|
async def optimize(
|
|
request: HttpRequest, data: OptimizeSchema
|
|
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
|
|
system_prompt = ASYNC_SYSTEM_PROMPT if data.is_async else SYSTEM_PROMPT
|
|
user_prompt = ASYNC_USER_PROMPT if data.is_async else USER_PROMPT
|
|
|
|
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(system_prompt, user_prompt, data.source_code)
|
|
ph(request.user, "aiservice-optimize-called")
|
|
|
|
try:
|
|
python_version = validate_request_data(data, ctx)
|
|
except HttpError as e:
|
|
e.add_note(f"Optimizer request validation error: {e.status_code} {e.message}")
|
|
sentry_sdk.capture_exception(e)
|
|
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
|
|
|
|
if should_hack_for_demo(ctx.source_code):
|
|
return 200, await hack_for_demo(ctx)
|
|
|
|
try:
|
|
async with asyncio.TaskGroup() as tg:
|
|
optimize_task = tg.create_task(
|
|
optimize_python_code(
|
|
request.user,
|
|
ctx,
|
|
data.dependency_code,
|
|
n=min(data.n_candidates or 5, 5),
|
|
python_version=python_version,
|
|
)
|
|
)
|
|
user_task = None
|
|
if data.current_username is None:
|
|
user_task = tg.create_task(get_user_by_id(request.user))
|
|
except Exception as e: # noqa: BLE001
|
|
e.add_note("Error during optimization task or user retrieval.")
|
|
sentry_sdk.capture_exception(e)
|
|
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
|
|
|
optimization_response_items, llm_cost = optimize_task.result()
|
|
if user_task:
|
|
user = await user_task
|
|
if user:
|
|
data.current_username = user.github_username
|
|
|
|
if len(optimization_response_items) == 0:
|
|
ph(request.user, "aiservice-optimize-no-optimizations-found")
|
|
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
|
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
|
ph(
|
|
request.user,
|
|
"aiservice-optimize-optimizations-found",
|
|
properties={"num_optimizations": len(optimization_response_items)},
|
|
)
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
event_task = tg.create_task(
|
|
log_optimization_event(
|
|
event_type="no-pr",
|
|
user_id=request.user,
|
|
current_username=data.current_username,
|
|
repo_owner=data.repo_owner,
|
|
repo_name=data.repo_name,
|
|
trace_id=data.trace_id,
|
|
api_key_id=request.api_key_id,
|
|
metadata={
|
|
"codeflash_version": data.codeflash_version,
|
|
"num_optimizations": len(optimization_response_items),
|
|
"experiment_metadata": data.experiment_metadata,
|
|
},
|
|
llm_cost=llm_cost,
|
|
)
|
|
)
|
|
|
|
tg.create_task(
|
|
log_features_optimized(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
original_code=data.source_code,
|
|
dependency_code=data.dependency_code,
|
|
optimizations_raw={
|
|
op_id: cei.code for op_id, cei in ctx.code_and_explanation_before_post_processing.items()
|
|
},
|
|
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
|
explanations_raw={
|
|
op_id: cei.explanation for op_id, cei in ctx.code_and_explanation_before_post_processing.items()
|
|
},
|
|
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
|
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
|
request=request,
|
|
)
|
|
)
|
|
|
|
event = event_task.result()
|
|
|
|
for item in optimization_response_items:
|
|
item.optimization_event_id = str(event.id) if event else None
|
|
|
|
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
|
|
|
def log_response() -> None:
|
|
debug_log_sensitive_data(f"Response:\n{response.json()}")
|
|
for opt in response.optimizations:
|
|
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
|
|
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
|
|
|
|
debug_log_sensitive_data_from_callable(log_response)
|
|
ph(request.user, "aiservice-optimize-successful")
|
|
return 200, response
|