710 lines
29 KiB
Python
710 lines
29 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sentry_sdk
|
|
from ninja import NinjaAPI
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import parse_python_version, validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
|
from aiservice.llm import OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
|
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from optimizer.config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
|
|
from optimizer.context_utils.context_helpers import (
|
|
group_code,
|
|
is_multi_context_js,
|
|
is_multi_context_ts,
|
|
split_markdown_code,
|
|
)
|
|
from optimizer.context_utils.optimizer_context import (
|
|
BaseOptimizerContext,
|
|
OptimizeErrorResponseSchema,
|
|
OptimizeResponseItemSchema,
|
|
OptimizeResponseSchema,
|
|
)
|
|
from optimizer.diff_patches_utils.diff import DiffMethod
|
|
from optimizer.models import OptimizedCandidateSource, OptimizeSchemaLP
|
|
|
|
if TYPE_CHECKING:
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from aiservice.llm import LLM
|
|
|
|
|
|
optimize_line_profiler_api = NinjaAPI(urls_namespace="optimize-line-profiler")
|
|
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
|
|
USER_PROMPT = (current_dir / "user_prompt.md").read_text()
|
|
JIT_INSTRUCTIONS = (current_dir / "jit_instructions.md").read_text()
|
|
|
|
# JavaScript/TypeScript prompts
|
|
JS_SYSTEM_PROMPT = (current_dir / "prompts" / "javascript" / "system_prompt.md").read_text()
|
|
|
|
# Pattern to extract code blocks from JavaScript LLM response (single file, no file path)
|
|
JS_CODE_PATTERN = re.compile(r"```(?:javascript|js|typescript|ts)\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
|
|
|
|
# Pattern to extract code blocks with file paths (multi-file context)
|
|
JS_CODE_WITH_PATH_PATTERN = re.compile(
|
|
r"```(?:javascript|js|typescript|ts):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL
|
|
)
|
|
|
|
# Line profiler context prompt for JavaScript
|
|
JS_LINE_PROF_CONTEXT = """
|
|
Here are the results of the line profiling of the JavaScript/TypeScript code you will be optimizing.
|
|
The profiling data shows:
|
|
- Line numbers with execution counts (hits)
|
|
- Time spent on each line (in milliseconds)
|
|
- Percentage of total time per line
|
|
|
|
Use this data to identify performance bottlenecks and focus your optimization on the hottest code paths.
|
|
|
|
{line_profiler_results}
|
|
"""
|
|
|
|
|
|
def extract_js_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
|
|
"""Extract JavaScript code and explanation from LLM response.
|
|
|
|
Args:
|
|
content: The raw LLM response content
|
|
is_multi_file: Whether to expect multi-file format
|
|
|
|
Returns:
|
|
Tuple of (code, explanation) where code is a string for single file
|
|
or dict[str, str] for multi-file
|
|
|
|
"""
|
|
if is_multi_file:
|
|
# Extract all code blocks with file paths
|
|
matches = JS_CODE_WITH_PATH_PATTERN.findall(content)
|
|
if matches:
|
|
file_to_code: dict[str, str] = {}
|
|
first_match_pos = content.find("```")
|
|
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
|
|
|
for file_path, code in matches:
|
|
file_to_code[file_path.strip()] = code.strip()
|
|
|
|
return file_to_code, explanation
|
|
|
|
# Fall back to single file extraction
|
|
return extract_js_code_and_explanation(content, is_multi_file=False)
|
|
|
|
# Single file extraction
|
|
match = JS_CODE_PATTERN.search(content)
|
|
if match:
|
|
code = match.group(1).strip()
|
|
explanation_end = match.start()
|
|
explanation = content[:explanation_end].strip()
|
|
return code, explanation
|
|
|
|
return "", content
|
|
|
|
|
|
def normalize_js_code(code: str) -> str:
|
|
"""Normalize JavaScript code for comparison."""
|
|
# Remove single-line comments
|
|
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
|
# Remove multi-line comments
|
|
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
|
# Normalize whitespace
|
|
code = " ".join(code.split())
|
|
return code
|
|
|
|
|
|
async def optimize_python_code_line_profiler_single(
|
|
user_id: str,
|
|
trace_id: str,
|
|
line_profiler_results: str,
|
|
ctx: BaseOptimizerContext,
|
|
dependency_code: str | None = None,
|
|
optimize_model: LLM = OPTIMIZE_MODEL,
|
|
python_version: tuple[int, int, int] = (3, 12, 9),
|
|
call_sequence: int | None = None,
|
|
is_numerical_code: bool | None = None,
|
|
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
|
|
"""Optimize the given python code for performance using LLMs."""
|
|
logging.info("/optimize: Optimizing python code line profile.")
|
|
debug_log_sensitive_data(f"Optimizing python code for user {user_id}:\n{ctx.source_code}")
|
|
|
|
python_version_str = ".".join(str(x) for x in python_version)
|
|
|
|
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
|
|
# next optimization iteration
|
|
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
|
|
# function doing and then ask it to describe how to speed it up and then generate optimization
|
|
system_prompt = ctx.get_system_prompt(python_version_str=python_version_str)
|
|
user_prompt = ctx.get_user_prompt(dependency_code or "", line_profiler_results)
|
|
if is_numerical_code:
|
|
system_prompt += f"\n{JIT_INSTRUCTIONS}\n"
|
|
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
|
|
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
|
|
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
|
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
|
# TODO: Verify if the context window length is within the model capability
|
|
|
|
obs_context: dict = {}
|
|
if call_sequence is not None:
|
|
obs_context["call_sequence"] = call_sequence
|
|
|
|
try:
|
|
output = await call_llm(
|
|
llm=optimize_model,
|
|
messages=messages,
|
|
call_type="line_profiler",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version=python_version_str,
|
|
context=obs_context,
|
|
)
|
|
except Exception as e:
|
|
logging.exception("OpenAI Code Generation error in optimizer-line-profiler")
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
|
return None, None, optimize_model.name
|
|
|
|
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
|
|
|
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
|
|
|
if output.raw_response.usage is not None:
|
|
await asyncio.to_thread(
|
|
ph,
|
|
user_id,
|
|
"aiservice-optimize-line-profiler-openai-usage",
|
|
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json()},
|
|
)
|
|
|
|
ctx.extract_code_and_explanation_from_llm_res(output.content)
|
|
res = ctx.parse_and_generate_candidate_schema()
|
|
if res is not None and ctx.is_valid_code():
|
|
return res, llm_cost, optimize_model.name
|
|
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
|
|
async def optimize_python_code_line_profiler(
|
|
user_id: str,
|
|
trace_id: str,
|
|
line_profiler_results: str,
|
|
ctx: BaseOptimizerContext,
|
|
original_source_code: str,
|
|
dependency_code: str | None = None,
|
|
n_candidates: int = 0,
|
|
python_version: tuple[int, int, int] = (3, 12, 9),
|
|
is_numerical_code: bool | None = None,
|
|
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict], dict[str, str]]:
|
|
"""Run parallel line profiler optimizations with multiple models.
|
|
|
|
Returns:
|
|
tuple containing:
|
|
- list of optimization results
|
|
- total LLM cost
|
|
- dict of raw code/explanations keyed by optimization_id
|
|
- dict mapping optimization_id to model name
|
|
|
|
"""
|
|
if n_candidates == 0:
|
|
return [], 0.0, {}, {}
|
|
model_distribution = get_model_distribution(n_candidates, MAX_OPTIMIZER_LP_CALLS)
|
|
|
|
# Create tasks for each model call
|
|
tasks: list[
|
|
tuple[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]], BaseOptimizerContext]
|
|
] = []
|
|
call_sequence = 1
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
for model, num_calls in model_distribution:
|
|
for _ in range(num_calls):
|
|
# Each call needs its own context instance to avoid shared state issues
|
|
# Use original_source_code (markdown format) to preserve file path info
|
|
task_ctx = BaseOptimizerContext.get_dynamic_context(
|
|
ctx.base_system_prompt, ctx.base_user_prompt, original_source_code
|
|
)
|
|
task = tg.create_task(
|
|
optimize_python_code_line_profiler_single(
|
|
user_id=user_id,
|
|
trace_id=trace_id,
|
|
line_profiler_results=line_profiler_results,
|
|
ctx=task_ctx,
|
|
dependency_code=dependency_code,
|
|
optimize_model=model,
|
|
python_version=python_version,
|
|
call_sequence=call_sequence,
|
|
is_numerical_code=is_numerical_code,
|
|
)
|
|
)
|
|
tasks.append((task, task_ctx))
|
|
call_sequence += 1
|
|
|
|
# Collect results
|
|
optimization_results: list[OptimizeResponseItemSchema] = []
|
|
total_cost = 0.0
|
|
code_and_explanations: dict[str, dict] = {}
|
|
optimization_models: dict[str, str] = {}
|
|
|
|
for task, task_ctx in tasks:
|
|
result, cost, model_name = task.result()
|
|
if cost:
|
|
total_cost += cost
|
|
if result is not None:
|
|
optimization_results.append(result)
|
|
optimization_models[result.optimization_id] = model_name
|
|
# Collect raw code/explanations for logging
|
|
for op_id, cei in task_ctx.code_and_explanation_before_post_processing.items():
|
|
code_and_explanations[op_id] = {"code": cei.code, "explanation": cei.explanation}
|
|
|
|
return optimization_results, total_cost, code_and_explanations, optimization_models
|
|
|
|
|
|
# ============================================================================
|
|
# JavaScript/TypeScript Line Profiler Optimization
|
|
# ============================================================================
|
|
|
|
|
|
async def optimize_javascript_code_line_profiler_single(
|
|
user_id: str,
|
|
trace_id: str,
|
|
source_code: str,
|
|
line_profiler_results: str,
|
|
dependency_code: str | None = None,
|
|
optimize_model: LLM = OPTIMIZE_MODEL,
|
|
language_version: str = "ES2022",
|
|
language: str = "javascript",
|
|
call_sequence: int | None = None,
|
|
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
|
|
"""Optimize JavaScript/TypeScript code using LLMs with line profiler guidance."""
|
|
lang_name = "TypeScript" if language == "typescript" else "JavaScript"
|
|
code_block_tag = "typescript" if language == "typescript" else "javascript"
|
|
logging.info(f"/optimize-line-profiler: Optimizing {lang_name} code.")
|
|
debug_log_sensitive_data(f"Optimizing {lang_name} code for user {user_id}:\n{source_code}")
|
|
|
|
# Check if source code is multi-file format
|
|
is_multi_file = is_multi_context_ts(source_code) if language == "typescript" else is_multi_context_js(source_code)
|
|
original_file_to_code: dict[str, str] = {}
|
|
|
|
if is_multi_file:
|
|
original_file_to_code = split_markdown_code(source_code, language)
|
|
logging.info(
|
|
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
|
)
|
|
|
|
# Format system prompt with language version
|
|
system_prompt = JS_SYSTEM_PROMPT.format(language_version=language_version)
|
|
|
|
# Build user prompt with line profiler results
|
|
if is_multi_file:
|
|
# For multi-file, identify the first file as the target and others as helper context
|
|
file_paths = list(original_file_to_code.keys())
|
|
target_file = file_paths[0] if file_paths else "main file"
|
|
helper_files = file_paths[1:] if len(file_paths) > 1 else []
|
|
|
|
# Build multi-file instructions
|
|
helper_notice = ""
|
|
if helper_files:
|
|
helper_list = ", ".join(f"`{f}`" for f in helper_files)
|
|
helper_notice = f"""
|
|
HELPER FILES: {helper_list}
|
|
These files contain helper functions that the target function uses. You may optimize these as well if needed.
|
|
"""
|
|
|
|
multi_file_instructions = f"""
|
|
The code is provided in a multi-file format. Each file is wrapped in a code block with its path.
|
|
|
|
TARGET FILE: `{target_file}`
|
|
{helper_notice}
|
|
Output the optimized code for each file that you modify. Wrap each file's code in:
|
|
```{code_block_tag}:<file_path>
|
|
<optimized code>
|
|
```
|
|
|
|
You MUST output the target file. You may also output helper files if you optimize them.
|
|
"""
|
|
system_prompt = system_prompt + "\n" + multi_file_instructions
|
|
|
|
user_prompt = f"""Optimize the following {lang_name} code for better performance.
|
|
|
|
{JS_LINE_PROF_CONTEXT.format(line_profiler_results=line_profiler_results)}
|
|
|
|
Here is the code to optimize:
|
|
{source_code}
|
|
"""
|
|
else:
|
|
user_prompt = f"""Optimize the following {lang_name} code for better performance.
|
|
|
|
{JS_LINE_PROF_CONTEXT.format(line_profiler_results=line_profiler_results)}
|
|
|
|
Here is the code to optimize:
|
|
```{code_block_tag}
|
|
{source_code}
|
|
```
|
|
"""
|
|
|
|
if dependency_code:
|
|
user_prompt = f"Dependencies (read-only):\n```{code_block_tag}\n{dependency_code}\n```\n\n{user_prompt}"
|
|
|
|
obs_context: dict = {}
|
|
if call_sequence is not None:
|
|
obs_context["call_sequence"] = call_sequence
|
|
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
|
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
|
]
|
|
|
|
try:
|
|
output = await call_llm(
|
|
llm=optimize_model,
|
|
messages=messages,
|
|
call_type="line_profiler",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version=language_version, # Reusing python_version field for language version
|
|
context=obs_context,
|
|
)
|
|
except Exception as e:
|
|
logging.exception(f"LLM Code Generation error in {lang_name} line profiler optimizer")
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for source:\n{source_code}")
|
|
return None, None, optimize_model.name
|
|
|
|
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
|
|
|
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
|
|
|
if output.raw_response.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-line-profiler-openai-usage",
|
|
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": language},
|
|
)
|
|
|
|
# Extract code and explanation from response
|
|
extracted_code, explanation = extract_js_code_and_explanation(output.content, is_multi_file=is_multi_file)
|
|
|
|
if not extracted_code:
|
|
sentry_sdk.capture_message(f"No code block found in {lang_name} line profiler optimization response")
|
|
debug_log_sensitive_data(f"No code found in response for source:\n{source_code}")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
optimization_id = str(uuid.uuid4())
|
|
|
|
if is_multi_file and isinstance(extracted_code, dict):
|
|
# Handle multi-file response
|
|
# LLM can optimize both target and helper files
|
|
merged_file_to_code: dict[str, str] = {}
|
|
has_changes = False
|
|
|
|
for file_path, original_code in original_file_to_code.items():
|
|
if file_path in extracted_code:
|
|
new_code = extracted_code[file_path]
|
|
|
|
# Validate the new code
|
|
if language == "typescript":
|
|
is_valid, error = validate_typescript_syntax(new_code)
|
|
else:
|
|
is_valid, error = validate_javascript_syntax(new_code)
|
|
|
|
if not is_valid:
|
|
sentry_sdk.capture_message(f"Invalid {lang_name} generated for {file_path}: {error}")
|
|
debug_log_sensitive_data(f"Invalid code generated for {file_path}:\n{new_code}\nError: {error}")
|
|
# Keep original code for this file
|
|
merged_file_to_code[file_path] = original_code
|
|
else:
|
|
merged_file_to_code[file_path] = new_code
|
|
if normalize_js_code(new_code) != normalize_js_code(original_code):
|
|
has_changes = True
|
|
else:
|
|
# File not in response, keep original
|
|
merged_file_to_code[file_path] = original_code
|
|
|
|
if not has_changes:
|
|
debug_log_sensitive_data("Generated code identical to original (multi-file)")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Format as multi-file markdown
|
|
wrapped_code = group_code(merged_file_to_code, language=code_block_tag)
|
|
|
|
result = OptimizeResponseItemSchema(
|
|
source_code=wrapped_code, explanation=explanation, optimization_id=optimization_id
|
|
)
|
|
return result, llm_cost, optimize_model.name
|
|
|
|
# Single file handling
|
|
optimized_code = extracted_code if isinstance(extracted_code, str) else ""
|
|
|
|
if not optimized_code:
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Validate the generated code
|
|
if language == "typescript":
|
|
is_valid, error = validate_typescript_syntax(optimized_code)
|
|
else:
|
|
is_valid, error = validate_javascript_syntax(optimized_code)
|
|
|
|
if not is_valid:
|
|
sentry_sdk.capture_message(f"Invalid {lang_name} generated: {error}")
|
|
debug_log_sensitive_data(f"Invalid code generated:\n{optimized_code}\nError: {error}")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Check that the code is actually different from the original
|
|
if normalize_js_code(optimized_code) == normalize_js_code(source_code):
|
|
debug_log_sensitive_data("Generated code identical to original")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Wrap code in markdown format for CLI parsing
|
|
wrapped_code = (
|
|
f"```{code_block_tag}\n{optimized_code}\n```"
|
|
if not optimized_code.endswith("\n")
|
|
else f"```{code_block_tag}\n{optimized_code}```"
|
|
)
|
|
result = OptimizeResponseItemSchema(
|
|
source_code=wrapped_code, explanation=explanation, optimization_id=optimization_id
|
|
)
|
|
|
|
return result, llm_cost, optimize_model.name
|
|
|
|
|
|
async def optimize_javascript_code_line_profiler(
|
|
user_id: str,
|
|
trace_id: str,
|
|
source_code: str,
|
|
line_profiler_results: str,
|
|
dependency_code: str | None = None,
|
|
language_version: str = "ES2022",
|
|
language: str = "javascript",
|
|
n_candidates: int = 0,
|
|
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, str]]:
|
|
"""Run parallel JavaScript line profiler optimizations with multiple models."""
|
|
if n_candidates == 0:
|
|
return [], 0.0, {}
|
|
|
|
model_distribution = get_model_distribution(n_candidates, MAX_OPTIMIZER_LP_CALLS)
|
|
tasks: list[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]]] = []
|
|
call_sequence = 1
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
for model, num_calls in model_distribution:
|
|
for _ in range(num_calls):
|
|
task = tg.create_task(
|
|
optimize_javascript_code_line_profiler_single(
|
|
user_id=user_id,
|
|
trace_id=trace_id,
|
|
source_code=source_code,
|
|
line_profiler_results=line_profiler_results,
|
|
dependency_code=dependency_code,
|
|
optimize_model=model,
|
|
language_version=language_version,
|
|
language=language,
|
|
call_sequence=call_sequence,
|
|
)
|
|
)
|
|
tasks.append(task)
|
|
call_sequence += 1
|
|
|
|
# Collect results
|
|
optimization_results: list[OptimizeResponseItemSchema] = []
|
|
total_cost = 0.0
|
|
optimization_models: dict[str, str] = {}
|
|
|
|
for task in tasks:
|
|
result, cost, model_name = task.result()
|
|
if cost:
|
|
total_cost += cost
|
|
if result is not None:
|
|
optimization_results.append(result)
|
|
optimization_models[result.optimization_id] = model_name
|
|
|
|
return optimization_results, total_cost, optimization_models
|
|
|
|
|
|
@optimize_line_profiler_api.post(
|
|
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
|
|
)
|
|
async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]: # noqa: ANN001
|
|
await asyncio.to_thread(ph, request.user, "aiservice-optimize-called")
|
|
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
|
|
SYSTEM_PROMPT, USER_PROMPT, data.source_code, DiffMethod.NO_DIFF
|
|
)
|
|
# Route based on language
|
|
language = data.language.lower() if data.language else "python"
|
|
is_javascript = language in ("javascript", "typescript")
|
|
is_python = language == "python"
|
|
|
|
if is_python and not data.python_version:
|
|
return 400, OptimizeErrorResponseSchema(error="Python version is required.")
|
|
try:
|
|
python_version: tuple[int, int, int] = parse_python_version(data.python_version)
|
|
except: # noqa: E722
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
|
|
)
|
|
try:
|
|
if is_python:
|
|
ctx.validate_and_parse_source_code(code=data.source_code, feature_version=python_version[:2])
|
|
except SyntaxError:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error="Invalid source code. It is not valid Python code. Please check syntax of your code."
|
|
)
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
if data.line_profiler_results is None:
|
|
return 400, OptimizeErrorResponseSchema(error="Line profiler results are required.")
|
|
|
|
|
|
if is_javascript:
|
|
# JavaScript/TypeScript path
|
|
lang_name = "TypeScript" if language == "typescript" else "JavaScript"
|
|
|
|
# Check if multi-file context
|
|
is_multi_file = (
|
|
is_multi_context_ts(data.source_code) if language == "typescript" else is_multi_context_js(data.source_code)
|
|
)
|
|
|
|
if is_multi_file:
|
|
# Validate each file in the multi-file context
|
|
file_to_code = split_markdown_code(data.source_code, language)
|
|
if not file_to_code:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error=f"Invalid source code format. Expected multi-file {lang_name} markdown format."
|
|
)
|
|
|
|
for file_path, code in file_to_code.items():
|
|
if language == "typescript":
|
|
is_valid, error = validate_typescript_syntax(code)
|
|
else:
|
|
is_valid, error = validate_javascript_syntax(code)
|
|
|
|
if not is_valid:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error=f"Invalid source code in {file_path}. It is not valid {lang_name}: {error}"
|
|
)
|
|
else:
|
|
# Single file validation
|
|
if language == "typescript":
|
|
is_valid, error = validate_typescript_syntax(data.source_code)
|
|
else:
|
|
is_valid, error = validate_javascript_syntax(data.source_code)
|
|
|
|
if not is_valid:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error=f"Invalid source code. It is not valid {lang_name} code. Error: {error}"
|
|
)
|
|
|
|
language_version = data.language_version or "ES2022"
|
|
|
|
(optimization_response_items, llm_cost, optimization_models) = await optimize_javascript_code_line_profiler(
|
|
user_id=request.user,
|
|
trace_id=data.trace_id,
|
|
source_code=data.source_code,
|
|
line_profiler_results=data.line_profiler_results,
|
|
dependency_code=data.dependency_code,
|
|
language_version=language_version,
|
|
language=language,
|
|
n_candidates=data.n_candidates,
|
|
)
|
|
# JavaScript path doesn't have code_and_explanations dict like Python
|
|
code_and_explanations: dict[str, dict] = {}
|
|
|
|
else:
|
|
# Python path (default)
|
|
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
|
|
SYSTEM_PROMPT, USER_PROMPT, data.source_code, DiffMethod.NO_DIFF
|
|
)
|
|
try:
|
|
python_version: tuple[int, int, int] = parse_python_version(data.python_version or "3.12.0")
|
|
except: # noqa: E722
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
|
|
)
|
|
try:
|
|
ctx.validate_and_parse_source_code(code=data.source_code, feature_version=python_version[:2])
|
|
except SyntaxError:
|
|
return 400, OptimizeErrorResponseSchema(
|
|
error="Invalid source code. It is not valid Python code. Please check syntax of your code."
|
|
)
|
|
|
|
(
|
|
optimization_response_items,
|
|
llm_cost,
|
|
code_and_explanations,
|
|
optimization_models,
|
|
) = await optimize_python_code_line_profiler(
|
|
user_id=request.user,
|
|
trace_id=data.trace_id,
|
|
line_profiler_results=data.line_profiler_results,
|
|
ctx=ctx,
|
|
original_source_code=data.source_code,
|
|
dependency_code=data.dependency_code,
|
|
n_candidates=data.n_candidates,
|
|
python_version=python_version,
|
|
is_numerical_code=data.is_numerical_code,
|
|
)
|
|
|
|
# Update total cost
|
|
await update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user)
|
|
|
|
if len(optimization_response_items) == 0:
|
|
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
|
|
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
|
logging.error(
|
|
"Could not generate any optimizations (line profiler). trace_id=%s, n_candidates=%d, source_len=%d, has_line_profiler=%s",
|
|
data.trace_id,
|
|
data.n_candidates,
|
|
len(data.source_code) if data.source_code else 0,
|
|
bool(data.line_profiler_results),
|
|
)
|
|
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
|
await asyncio.to_thread(
|
|
ph,
|
|
request.user,
|
|
"aiservice-optimize-optimizations-found",
|
|
properties={"num_optimizations": len(optimization_response_items), "language": language},
|
|
)
|
|
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
original_code=data.source_code,
|
|
dependency_code=data.dependency_code,
|
|
line_profiler_results=data.line_profiler_results,
|
|
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
|
|
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
|
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
|
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
|
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
|
optimizations_origin={
|
|
cei.optimization_id: {
|
|
"source": OptimizedCandidateSource.OPTIMIZE_LP,
|
|
"parent": None,
|
|
"model": optimization_models.get(cei.optimization_id, "unknown"),
|
|
}
|
|
for cei in optimization_response_items
|
|
},
|
|
)
|
|
|
|
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
|
|
|
def log_response() -> None:
|
|
debug_log_sensitive_data(f"Response:\n{response.model_dump_json()}")
|
|
for opt in response.optimizations:
|
|
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
|
|
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
|
|
|
|
debug_log_sensitive_data_from_callable(log_response)
|
|
await asyncio.to_thread(ph, request.user, "aiservice-optimize-successful")
|
|
return 200, response
|