mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
# Pull Request Checklist ## Description - [ ] **Breaking Changes**: Document any breaking changes (if applicable) - [ ] **Description of PR**: Clear and concise description of what this PR accomplishes - [ ] **Related Issues**: Link to any related issues or tickets ## Testing - [ ] **Test cases Attached**: All relevant test cases have been added/updated - [ ] **Manual Testing**: Manual testing completed for the changes ## Monitoring & Debugging - [ ] **Logging in place**: Appropriate logging has been added for debugging user issues - [ ] **Sentry will be able to catch errors**: Error handling ensures Sentry can capture and report errors - [ ] **Avoid Dev based/Prisma logging**: No development-only or Prisma-specific logging in production code ## Configuration - [ ] **Env variables newly added**: Any new environment variables are documented in .env.example file or mentioned in description --- ## Additional Notes <!-- Add any additional context, screenshots, or notes for reviewers here --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-200.ec2.internal> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Kevin Turcios <turcioskevinr@gmail.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
585 lines
23 KiB
Python
585 lines
23 KiB
Python
"""Java code optimizer module.
|
|
|
|
This module handles optimization requests for Java code.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import sentry_sdk
|
|
from ninja.errors import HttpError
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data
|
|
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
|
from authapp.auth import AuthenticatedRequest
|
|
from authapp.user import get_user_by_id
|
|
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
|
|
from core.shared.context_helpers import group_code, split_markdown_code
|
|
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
|
|
from core.shared.optimizer_models import OptimizeSchema
|
|
from core.shared.optimizer_schemas import (
|
|
OptimizeErrorResponseSchema,
|
|
OptimizeResponseItemSchema,
|
|
OptimizeResponseSchema,
|
|
)
|
|
from log_features.log_event import get_or_create_optimization_event
|
|
from log_features.log_features import log_features
|
|
|
|
if TYPE_CHECKING:
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from aiservice.validators.java_validator import validate_java_syntax
|
|
|
|
# Pattern to extract code blocks from LLM response (handles both ```java and ```java:filename.java)
|
|
JAVA_CODE_PATTERN = re.compile(r"```(?:java)(?::[^\n]*)?\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
|
|
|
|
# Pattern to extract code blocks with file paths (multi-file context)
|
|
JAVA_CODE_WITH_PATH_PATTERN = re.compile(r"```(?:java):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL)
|
|
|
|
|
|
def is_multi_context_java(source_code: str) -> bool:
|
|
"""Check if source code contains multiple Java file blocks."""
|
|
return source_code.count("```java:") >= 1
|
|
|
|
|
|
def extract_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
|
|
"""Extract code and explanation from LLM response.
|
|
|
|
Args:
|
|
content: The raw LLM response content
|
|
is_multi_file: Whether to expect multi-file format
|
|
|
|
Returns:
|
|
Tuple of (code, explanation) where code is a string for single file
|
|
or dict[str, str] for multi-file
|
|
|
|
"""
|
|
if is_multi_file:
|
|
# Extract all code blocks with file paths
|
|
matches = JAVA_CODE_WITH_PATH_PATTERN.findall(content)
|
|
if matches:
|
|
file_to_code: dict[str, str] = {}
|
|
first_match_pos = content.find("```")
|
|
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
|
|
|
for file_path, code in matches:
|
|
file_to_code[file_path.strip()] = code.strip()
|
|
|
|
return file_to_code, explanation
|
|
|
|
# Fall back to single file extraction
|
|
return extract_code_and_explanation(content, is_multi_file=False)
|
|
|
|
# Single file extraction
|
|
match = JAVA_CODE_PATTERN.search(content)
|
|
if match:
|
|
code = match.group(1).strip()
|
|
# Explanation is everything before the code block
|
|
explanation_end = match.start()
|
|
explanation = content[:explanation_end].strip()
|
|
return code, explanation
|
|
|
|
# No code block found, return empty code
|
|
return "", content
|
|
|
|
|
|
def _extract_demo_context(source_code: str) -> tuple[str, str, str, str]:
|
|
"""Extract package, class name, exception type, and extra imports from the demo source code.
|
|
|
|
Returns:
|
|
Tuple of (package_declaration, class_name, throw_statement_prefix, extra_imports)
|
|
|
|
"""
|
|
# Extract the raw code from markdown block if present
|
|
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
|
|
raw_code = code_match.group(1) if code_match else source_code
|
|
|
|
# Extract package
|
|
pkg_match = re.search(r"^\s*package\s+([\w.]+)\s*;", raw_code, re.MULTILINE)
|
|
package_decl = f"package {pkg_match.group(1)};\n" if pkg_match else ""
|
|
|
|
# Extract class name
|
|
class_match = re.search(r"\bclass\s+(\w+)", raw_code)
|
|
class_name = class_match.group(1) if class_match else "FileUtils"
|
|
|
|
# Extract exception type from the throw statement (e.g., "throw new AerospikeException(...)")
|
|
throw_match = re.search(r"throw\s+new\s+(\w+)\s*\(", raw_code)
|
|
exception_type = throw_match.group(1) if throw_match else "RuntimeException"
|
|
|
|
# Collect extra imports needed for the exception type (skip standard java/javax)
|
|
extra_imports = ""
|
|
if exception_type != "RuntimeException":
|
|
import_match = re.search(rf"^\s*import\s+([\w.]*\.{re.escape(exception_type)})\s*;", raw_code, re.MULTILINE)
|
|
if import_match:
|
|
extra_imports = f"import {import_match.group(1)};\n"
|
|
|
|
return package_decl, class_name, exception_type, extra_imports
|
|
|
|
|
|
def _build_demo_optimizations(
|
|
package_decl: str, class_name: str, exception_type: str, extra_imports: str
|
|
) -> list[dict[str, str]]:
|
|
"""Build 2 demo optimization candidates using the extracted class context.
|
|
|
|
Candidate 1 (Files.readAllBytes) is the intended winner — it benchmarks fastest.
|
|
Candidate 2 is a plausible alternative that is functionally correct but
|
|
benchmarks slightly slower, ensuring Files.readAllBytes wins the speedup critic.
|
|
"""
|
|
fmt = dict(
|
|
package_decl=package_decl, class_name=class_name, exception_type=exception_type, extra_imports=extra_imports
|
|
)
|
|
|
|
return [
|
|
# Candidate 2: FileInputStream.readAllBytes() (Java 9+)
|
|
{
|
|
"source_code": (
|
|
"{package_decl}"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.io.FileInputStream;\n"
|
|
"{extra_imports}"
|
|
"\n"
|
|
"public final class {class_name} {{\n"
|
|
" public static byte[] readFile(File file) {{\n"
|
|
" try (FileInputStream fis = new FileInputStream(file)) {{\n"
|
|
" return fis.readAllBytes();\n"
|
|
" }}\n"
|
|
" catch (Throwable e) {{\n"
|
|
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
|
|
" }}\n"
|
|
" }}\n"
|
|
"}}"
|
|
).format(**fmt),
|
|
"explanation": (
|
|
"Use FileInputStream.readAllBytes() (Java 9+) to read the entire file in one call. "
|
|
"This eliminates the manual read loop but still uses FileInputStream internally."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
# Candidate 1: Files.readAllBytes (THE WINNER)
|
|
{
|
|
"source_code": (
|
|
"{package_decl}"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.nio.file.Files;\n"
|
|
"{extra_imports}"
|
|
"\n"
|
|
"public final class {class_name} {{\n"
|
|
" public static byte[] readFile(File file) {{\n"
|
|
" try {{\n"
|
|
" return java.nio.file.Files.readAllBytes(file.toPath());\n"
|
|
" }}\n"
|
|
" catch (Throwable e) {{\n"
|
|
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
|
|
" }}\n"
|
|
" }}\n"
|
|
"}}"
|
|
).format(**fmt),
|
|
"explanation": (
|
|
"Replace manual FileInputStream read loop with java.nio.file.Files.readAllBytes(). "
|
|
"This NIO method is optimized at the JDK level for direct file-to-byte-array transfer, "
|
|
"eliminating manual buffering and loop overhead."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
]
|
|
|
|
|
|
def _build_host_equals_demo_optimizations(source_code: str) -> list[dict[str, str]]:
|
|
"""Build 5 optimization candidates for Host.equals by reordering comparisons.
|
|
|
|
Candidate 1 (port-first early return) is the intended winner — comparing the
|
|
primitive int port before the String name avoids unnecessary method dispatch.
|
|
"""
|
|
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
|
|
raw_code = code_match.group(1) if code_match else source_code
|
|
|
|
# Match: return this.name.equals(other.name) && this.port == other.port;
|
|
original_stmt = re.compile(
|
|
r"(\s*)return\s+this\.name\.equals\(other\.name\)\s*&&\s*this\.port\s*==\s*other\.port\s*;"
|
|
)
|
|
|
|
match = original_stmt.search(raw_code)
|
|
if not match:
|
|
return [
|
|
{
|
|
"source_code": raw_code,
|
|
"explanation": "No optimization applicable.",
|
|
"optimization_id": str(uuid.uuid4()),
|
|
}
|
|
]
|
|
|
|
indent = match.group(1)
|
|
inner = indent + " "
|
|
|
|
def replace_with(replacement: str) -> str:
|
|
return original_stmt.sub(replacement, raw_code)
|
|
|
|
return [
|
|
# Candidate 1 (WINNER): Port-first early return
|
|
{
|
|
"source_code": replace_with(
|
|
f"{indent}// Compare primitive port first to avoid unnecessary string equals calls.\n"
|
|
f"{indent}if (this.port != other.port) {{\n"
|
|
f"{inner}return false;\n"
|
|
f"{indent}}}\n"
|
|
f"{indent}return this.name.equals(other.name);"
|
|
),
|
|
"explanation": (
|
|
"Compare primitive port first to avoid unnecessary string equals calls. "
|
|
"Integer comparison is a single CPU instruction, while String.equals() "
|
|
"involves method dispatch and potential character-by-character comparison."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
# Candidate 2: Reordered conjunction (port first in &&)
|
|
{
|
|
"source_code": replace_with(f"{indent}return this.port == other.port && this.name.equals(other.name);"),
|
|
"explanation": (
|
|
"Reorder the conjunction to evaluate the cheaper primitive int comparison first. "
|
|
"Short-circuit evaluation skips String.equals() when ports differ."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
# Candidate 3: Port-first with Objects.equals for null safety
|
|
{
|
|
"source_code": replace_with(
|
|
f"{indent}if (this.port != other.port) {{\n"
|
|
f"{inner}return false;\n"
|
|
f"{indent}}}\n"
|
|
f"{indent}return java.util.Objects.equals(this.name, other.name);"
|
|
),
|
|
"explanation": (
|
|
"Check port first (cheap primitive comparison), then use Objects.equals() "
|
|
"for null-safe name comparison. Adds safety at slight method-call overhead."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
# Candidate 4: Ternary with port-first guard
|
|
{
|
|
"source_code": replace_with(
|
|
f"{indent}return this.port == other.port ? this.name.equals(other.name) : false;"
|
|
),
|
|
"explanation": (
|
|
"Use a ternary to short-circuit on port mismatch. "
|
|
"Evaluates the cheap int comparison first, only calling String.equals() when ports match."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
# Candidate 5: Explicit null guard + port first
|
|
{
|
|
"source_code": replace_with(
|
|
f"{indent}if (this.port != other.port) {{\n"
|
|
f"{inner}return false;\n"
|
|
f"{indent}}}\n"
|
|
f"{indent}if (this.name == null) {{\n"
|
|
f"{inner}return other.name == null;\n"
|
|
f"{indent}}}\n"
|
|
f"{indent}return this.name.equals(other.name);"
|
|
),
|
|
"explanation": (
|
|
"Guard on port first, then add explicit null handling for the name field "
|
|
"before delegating to String.equals(). Avoids potential NullPointerException."
|
|
),
|
|
"optimization_id": str(uuid.uuid4()),
|
|
},
|
|
]
|
|
|
|
|
|
async def hack_for_demo_java(source_code: str) -> OptimizeResponseSchema:
|
|
# Extract file path from markdown source (```java:path/to/File.java)
|
|
file_path_match = re.search(r"```java:([^\n]+)", source_code)
|
|
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
|
|
|
|
if is_host_equals_demo(source_code):
|
|
optimizations = _build_host_equals_demo_optimizations(source_code)
|
|
else:
|
|
# Extract class context dynamically from the source code
|
|
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
|
|
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
|
|
|
|
response_list: list[OptimizeResponseItemSchema] = [
|
|
OptimizeResponseItemSchema(
|
|
explanation=opt["explanation"],
|
|
optimization_id=opt["optimization_id"],
|
|
source_code=group_code({file_name: opt["source_code"]}, language="java"),
|
|
)
|
|
for opt in optimizations
|
|
]
|
|
await asyncio.sleep(5)
|
|
return OptimizeResponseSchema(optimizations=response_list)
|
|
|
|
|
|
async def optimize_java_code_single(
|
|
user_id: str,
|
|
source_code: str,
|
|
trace_id: str,
|
|
dependency_code: str | None = None,
|
|
optimize_model: LLM = OPTIMIZE_MODEL,
|
|
language_version: str = "17",
|
|
call_sequence: int | None = None,
|
|
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
|
|
"""Optimize Java code using LLMs.
|
|
|
|
Args:
|
|
user_id: The user ID making the request
|
|
source_code: The source code to optimize (can be multi-file markdown format)
|
|
trace_id: The trace ID for logging
|
|
dependency_code: Optional dependency code for context
|
|
optimize_model: The LLM model to use
|
|
language_version: Target Java version (e.g., "11", "17", "21")
|
|
call_sequence: Call sequence number for tracking
|
|
|
|
Returns:
|
|
Tuple of (optimization_result, llm_cost, model_name)
|
|
|
|
"""
|
|
logging.info("/optimize: Optimizing Java code.")
|
|
debug_log_sensitive_data(f"Optimizing Java code for user {user_id}:\n{source_code}")
|
|
|
|
# Check if source code is multi-file format
|
|
is_multi_file = is_multi_context_java(source_code)
|
|
original_file_to_code: dict[str, str] = {}
|
|
|
|
if is_multi_file:
|
|
original_file_to_code = split_markdown_code(source_code, "java")
|
|
logging.info(
|
|
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
|
)
|
|
|
|
# Get Java-specific prompts
|
|
system_prompt = get_system_prompt(is_async=False)
|
|
user_prompt = get_user_prompt(is_async=False)
|
|
|
|
# Format prompts with Java version
|
|
system_prompt = system_prompt.format(language_version=f"Java {language_version}")
|
|
|
|
if is_multi_file:
|
|
user_prompt = user_prompt.format(source_code=source_code)
|
|
else:
|
|
user_prompt = user_prompt.format(source_code=source_code)
|
|
|
|
if dependency_code:
|
|
user_prompt += f"\n\n**Context (read-only, do not modify):**\n{dependency_code}"
|
|
|
|
obs_context: dict[str, Any] | None = {"call_sequence": call_sequence} if call_sequence is not None else None
|
|
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
|
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
|
]
|
|
|
|
try:
|
|
output = await call_llm(
|
|
llm=optimize_model,
|
|
messages=messages,
|
|
call_type="optimization",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version="N/A", # Not applicable for Java
|
|
context=obs_context,
|
|
)
|
|
except Exception as e:
|
|
logging.exception("LLM Code Generation error in Java optimizer")
|
|
sentry_sdk.capture_exception(e)
|
|
debug_log_sensitive_data(f"Failed to generate code for Java source:\n{source_code}")
|
|
return None, None, optimize_model.name
|
|
|
|
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
|
|
|
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
|
|
|
if output.raw_response.usage is not None:
|
|
await asyncio.to_thread(
|
|
ph,
|
|
user_id,
|
|
"aiservice-optimize-openai-usage",
|
|
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": "java"},
|
|
)
|
|
|
|
# Extract code and explanation from response
|
|
code, explanation = extract_code_and_explanation(output.content, is_multi_file)
|
|
|
|
if not code:
|
|
logging.warning("No valid Java code extracted from LLM response")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Validate the code
|
|
code_to_validate = code if isinstance(code, str) else "\n".join(code.values())
|
|
is_valid, error = validate_java_syntax(code_to_validate)
|
|
if not is_valid:
|
|
logging.warning(f"Java code failed syntax validation: {error}")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Format the response
|
|
if isinstance(code, dict):
|
|
# Multi-file response
|
|
formatted_code = group_code(code, language="java")
|
|
# Single file response - try to get file name from original
|
|
elif is_multi_file and original_file_to_code:
|
|
file_name = next(iter(original_file_to_code.keys()))
|
|
formatted_code = group_code({file_name: code}, language="java")
|
|
else:
|
|
# Default file name
|
|
formatted_code = group_code({"Source.java": code}, language="java")
|
|
|
|
optimization_id = str(uuid.uuid4())
|
|
result = OptimizeResponseItemSchema(
|
|
explanation=explanation, optimization_id=optimization_id, source_code=formatted_code
|
|
)
|
|
|
|
return result, llm_cost, optimize_model.name
|
|
|
|
|
|
async def optimize_java_code(
|
|
user_id: str,
|
|
source_code: str,
|
|
trace_id: str,
|
|
dependency_code: str | None = None,
|
|
language_version: str = "17",
|
|
n_candidates: int = 5,
|
|
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
|
|
"""Run parallel optimizations with multiple models based on the distribution config.
|
|
|
|
Returns:
|
|
tuple containing:
|
|
- list of optimization results
|
|
- total LLM cost
|
|
- dict of raw code/explanations keyed by optimization_id
|
|
- dict mapping optimization_id to model name
|
|
|
|
"""
|
|
tasks: list[tuple[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]], None]] = []
|
|
call_sequence = 1
|
|
|
|
if n_candidates == 0:
|
|
return [], 0.0, {}, {}
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
for model, num_calls in get_model_distribution(n_candidates, MAX_OPTIMIZER_CALLS):
|
|
for _ in range(num_calls):
|
|
task = tg.create_task(
|
|
optimize_java_code_single(
|
|
user_id=user_id,
|
|
source_code=source_code,
|
|
trace_id=trace_id,
|
|
dependency_code=dependency_code,
|
|
optimize_model=model,
|
|
language_version=language_version,
|
|
call_sequence=call_sequence,
|
|
)
|
|
)
|
|
tasks.append((task, None))
|
|
call_sequence += 1
|
|
|
|
# Collect results
|
|
optimization_results: list[OptimizeResponseItemSchema] = []
|
|
total_cost = 0.0
|
|
code_and_explanations: dict[str, dict[str, str]] = {}
|
|
optimization_models: dict[str, str] = {}
|
|
|
|
for task, _ in tasks:
|
|
result, cost, model_name = task.result()
|
|
if cost:
|
|
total_cost += cost
|
|
if result is not None:
|
|
optimization_results.append(result)
|
|
code_and_explanations[result.optimization_id] = {
|
|
"code": result.source_code,
|
|
"explanation": result.explanation,
|
|
}
|
|
optimization_models[result.optimization_id] = model_name
|
|
|
|
return optimization_results, total_cost, code_and_explanations, optimization_models
|
|
|
|
|
|
async def optimize_java(
|
|
request: AuthenticatedRequest, data: OptimizeSchema
|
|
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
|
|
"""Optimize Java code for performance using LLMs."""
|
|
# Validate trace_id
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, OptimizeErrorResponseSchema(error="Invalid trace_id")
|
|
|
|
user_id = request.user
|
|
user = await get_user_by_id(user_id)
|
|
if user is None:
|
|
raise HttpError(401, "User not found")
|
|
|
|
# Log the event
|
|
optimization_event, _created = await get_or_create_optimization_event(
|
|
trace_id=data.trace_id, event_type="no-pr", user_id=user_id
|
|
)
|
|
if optimization_event is not None:
|
|
await asyncio.to_thread(
|
|
log_features,
|
|
data.source_code[:1000],
|
|
optimization_event,
|
|
"optimize_request",
|
|
{
|
|
"source_code_length": len(data.source_code),
|
|
"dependency_code_length": len(data.dependency_code) if data.dependency_code else 0,
|
|
"n_candidates": data.n_candidates,
|
|
"language": "java",
|
|
"language_version": data.language_version or "17",
|
|
},
|
|
)
|
|
|
|
# Determine Java version
|
|
language_version = data.language_version or "17"
|
|
|
|
# Check for demo mode
|
|
if should_hack_for_demo_java(data.source_code):
|
|
response = await hack_for_demo_java(data.source_code)
|
|
for item in response.optimizations:
|
|
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
|
|
return 200, response
|
|
|
|
# Run optimization
|
|
optimization_results, total_cost, code_and_explanations, optimization_models = await optimize_java_code(
|
|
user_id=user_id,
|
|
source_code=data.source_code,
|
|
trace_id=data.trace_id,
|
|
dependency_code=data.dependency_code,
|
|
language_version=language_version,
|
|
n_candidates=data.n_candidates,
|
|
)
|
|
|
|
# Track analytics
|
|
await asyncio.to_thread(
|
|
ph,
|
|
user_id,
|
|
"aiservice-optimize-java",
|
|
properties={
|
|
"trace_id": data.trace_id,
|
|
"n_candidates_requested": data.n_candidates,
|
|
"n_candidates_returned": len(optimization_results),
|
|
"total_cost": total_cost,
|
|
"language_version": language_version,
|
|
},
|
|
)
|
|
|
|
# Log the response
|
|
if optimization_event is not None:
|
|
await asyncio.to_thread(
|
|
log_features,
|
|
str(code_and_explanations)[:1000],
|
|
optimization_event,
|
|
"optimize_response",
|
|
{
|
|
"n_candidates": len(optimization_results),
|
|
"total_cost": total_cost,
|
|
"models": list(optimization_models.values()),
|
|
},
|
|
)
|
|
|
|
return 200, OptimizeResponseSchema(optimizations=optimization_results)
|