mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
When Azure OpenAI or Anthropic returns null/empty content (content filter, truncation, transient failure), call_openai/call_anthropic now raise LLMOutputUnparseable instead of returning an empty string that silently flows through the pipeline and produces 422 "Could not generate any optimizations." All optimizer callers catch LLMOutputUnparseable to preserve cost tracking while returning None.
303 lines
11 KiB
Python
303 lines
11 KiB
Python
"""Java code optimizer module.
|
|
|
|
This module handles optimization requests for Java code.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
from ninja.errors import HttpError
|
|
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
|
from aiservice.llm import LLMOutputUnparseable, llm_client
|
|
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
|
from authapp.auth import AuthenticatedRequest
|
|
from authapp.user import get_user_by_id
|
|
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
|
|
from core.log_features.log_event import get_or_create_optimization_event
|
|
from core.log_features.log_features import safe_log_features
|
|
from core.shared.context_helpers import extract_code_and_explanation, group_code, split_markdown_code
|
|
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
|
|
from core.shared.optimizer_models import OptimizeSchema
|
|
from core.shared.optimizer_schemas import (
|
|
OptimizeErrorResponseSchema,
|
|
OptimizeResponseItemSchema,
|
|
OptimizeResponseSchema,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from aiservice.validators.java_validator import validate_java_syntax
|
|
|
|
# Pattern to extract code blocks from LLM response (handles both ```java and ```java:filename.java)
|
|
JAVA_CODE_PATTERN = re.compile(r"```(?:java)(?::[^\n]*)?\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
|
|
|
|
# Pattern to extract code blocks with file paths (multi-file context)
|
|
JAVA_CODE_WITH_PATH_PATTERN = re.compile(r"```(?:java):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL)
|
|
|
|
|
|
def is_multi_context_java(source_code: str) -> bool:
|
|
"""Check if source code contains multiple Java file blocks."""
|
|
return source_code.count("```java:") >= 1
|
|
|
|
|
|
async def optimize_java_code_single(
|
|
user_id: str,
|
|
source_code: str,
|
|
trace_id: str,
|
|
dependency_code: str | None = None,
|
|
optimize_model: LLM = OPTIMIZE_MODEL,
|
|
language_version: str = "17",
|
|
call_sequence: int | None = None,
|
|
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
|
|
"""Optimize Java code using LLMs.
|
|
|
|
Args:
|
|
user_id: The user ID making the request
|
|
source_code: The source code to optimize (can be multi-file markdown format)
|
|
trace_id: The trace ID for logging
|
|
dependency_code: Optional dependency code for context
|
|
optimize_model: The LLM model to use
|
|
language_version: Target Java version (e.g., "11", "17", "21")
|
|
call_sequence: Call sequence number for tracking
|
|
|
|
Returns:
|
|
Tuple of (optimization_result, llm_cost, model_name)
|
|
|
|
"""
|
|
logging.info("/optimize: Optimizing Java code.")
|
|
debug_log_sensitive_data(f"Optimizing Java code for user {user_id}:\n{source_code}")
|
|
|
|
# Check if source code is multi-file format
|
|
is_multi_file = is_multi_context_java(source_code)
|
|
original_file_to_code: dict[str, str] = {}
|
|
|
|
if is_multi_file:
|
|
original_file_to_code = split_markdown_code(source_code, "java")
|
|
logging.info(
|
|
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
|
)
|
|
|
|
# Get Java-specific prompts
|
|
system_prompt = get_system_prompt(is_async=False)
|
|
user_prompt = get_user_prompt(is_async=False)
|
|
|
|
# Format prompts with Java version
|
|
system_prompt = system_prompt.format(language_version=f"Java {language_version}")
|
|
|
|
if is_multi_file:
|
|
user_prompt = user_prompt.format(source_code=source_code)
|
|
else:
|
|
user_prompt = user_prompt.format(source_code=source_code)
|
|
|
|
if dependency_code:
|
|
user_prompt += f"\n\n**Context (read-only, do not modify):**\n{dependency_code}"
|
|
|
|
obs_context: dict[str, Any] | None = {"call_sequence": call_sequence} if call_sequence is not None else None
|
|
|
|
messages: list[ChatCompletionMessageParam] = [
|
|
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
|
|
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
|
]
|
|
|
|
try:
|
|
output = await llm_client.call(
|
|
llm=optimize_model,
|
|
messages=messages,
|
|
call_type="optimization",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version="N/A", # Not applicable for Java
|
|
context=obs_context,
|
|
)
|
|
except LLMOutputUnparseable as e:
|
|
debug_log_sensitive_data(f"Empty LLM response for Java source:\n{source_code}")
|
|
return None, e.cost, optimize_model.name
|
|
except Exception:
|
|
debug_log_sensitive_data(f"Failed to generate code for Java source:\n{source_code}")
|
|
return None, None, optimize_model.name
|
|
|
|
llm_cost = output.cost
|
|
|
|
debug_log_sensitive_data_from_callable(
|
|
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
|
|
)
|
|
|
|
if output.raw_response.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-openai-usage",
|
|
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": "java"},
|
|
)
|
|
|
|
# Extract code and explanation from response
|
|
code, explanation = extract_code_and_explanation(
|
|
output.content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file
|
|
)
|
|
|
|
if not code:
|
|
logging.warning("No valid Java code extracted from LLM response")
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Validate the code
|
|
code_to_validate = code if isinstance(code, str) else "\n".join(code.values())
|
|
is_valid, error = validate_java_syntax(code_to_validate)
|
|
if not is_valid:
|
|
logging.warning("Java code failed syntax validation: %s", error)
|
|
return None, llm_cost, optimize_model.name
|
|
|
|
# Format the response
|
|
if isinstance(code, dict):
|
|
# Multi-file response
|
|
formatted_code = group_code(code, language="java")
|
|
# Single file response - try to get file name from original
|
|
elif is_multi_file and original_file_to_code:
|
|
file_name = next(iter(original_file_to_code.keys()))
|
|
formatted_code = group_code({file_name: code}, language="java")
|
|
else:
|
|
# Default file name
|
|
formatted_code = group_code({"Source.java": code}, language="java")
|
|
|
|
optimization_id = str(uuid.uuid4())
|
|
result = OptimizeResponseItemSchema(
|
|
explanation=explanation, optimization_id=optimization_id, source_code=formatted_code
|
|
)
|
|
|
|
return result, llm_cost, optimize_model.name
|
|
|
|
|
|
async def optimize_java_code(
|
|
user_id: str,
|
|
source_code: str,
|
|
trace_id: str,
|
|
dependency_code: str | None = None,
|
|
language_version: str = "17",
|
|
n_candidates: int = 5,
|
|
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
|
|
"""Run parallel optimizations with multiple models based on the distribution config.
|
|
|
|
Returns:
|
|
tuple containing:
|
|
- list of optimization results
|
|
- total LLM cost
|
|
- dict of raw code/explanations keyed by optimization_id
|
|
- dict mapping optimization_id to model name
|
|
|
|
"""
|
|
tasks: list[tuple[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]], None]] = []
|
|
call_sequence = 1
|
|
|
|
if n_candidates == 0:
|
|
return [], 0.0, {}, {}
|
|
|
|
async with asyncio.TaskGroup() as tg:
|
|
for model, num_calls in get_model_distribution(n_candidates, MAX_OPTIMIZER_CALLS):
|
|
for _ in range(num_calls):
|
|
task = tg.create_task(
|
|
optimize_java_code_single(
|
|
user_id=user_id,
|
|
source_code=source_code,
|
|
trace_id=trace_id,
|
|
dependency_code=dependency_code,
|
|
optimize_model=model,
|
|
language_version=language_version,
|
|
call_sequence=call_sequence,
|
|
)
|
|
)
|
|
tasks.append((task, None))
|
|
call_sequence += 1
|
|
|
|
# Collect results
|
|
optimization_results: list[OptimizeResponseItemSchema] = []
|
|
total_cost = 0.0
|
|
code_and_explanations: dict[str, dict[str, str]] = {}
|
|
optimization_models: dict[str, str] = {}
|
|
|
|
for task, _ in tasks:
|
|
result, cost, model_name = task.result()
|
|
if cost:
|
|
total_cost += cost
|
|
if result is not None:
|
|
optimization_results.append(result)
|
|
code_and_explanations[result.optimization_id] = {
|
|
"code": result.source_code,
|
|
"explanation": result.explanation,
|
|
}
|
|
optimization_models[result.optimization_id] = model_name
|
|
|
|
return optimization_results, total_cost, code_and_explanations, optimization_models
|
|
|
|
|
|
async def optimize_java(
|
|
request: AuthenticatedRequest, data: OptimizeSchema
|
|
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
|
|
"""Optimize Java code for performance using LLMs."""
|
|
# Validate trace_id
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, OptimizeErrorResponseSchema(error="Invalid trace_id")
|
|
|
|
user_id = request.user
|
|
async with asyncio.TaskGroup() as tg:
|
|
user_task = tg.create_task(get_user_by_id(user_id))
|
|
event_task = tg.create_task(
|
|
get_or_create_optimization_event(trace_id=data.trace_id, event_type="no-pr", user_id=user_id)
|
|
)
|
|
user = user_task.result()
|
|
if user is None:
|
|
raise HttpError(401, "User not found")
|
|
optimization_event, _created = event_task.result()
|
|
# Determine Java version
|
|
language_version = data.language_version or "17"
|
|
|
|
from core.languages.java.demo_hacks import try_demo_optimize_java # noqa: PLC0415
|
|
|
|
demo_response = await try_demo_optimize_java(data.source_code)
|
|
if demo_response is not None:
|
|
for item in demo_response.optimizations:
|
|
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
|
|
return 200, demo_response
|
|
|
|
# Run optimization
|
|
optimization_results, total_cost, code_and_explanations, _optimization_models = await optimize_java_code(
|
|
user_id=user_id,
|
|
source_code=data.source_code,
|
|
trace_id=data.trace_id,
|
|
dependency_code=data.dependency_code,
|
|
language_version=language_version,
|
|
n_candidates=data.n_candidates,
|
|
)
|
|
|
|
# Track analytics
|
|
ph(
|
|
user_id,
|
|
"aiservice-optimize-java",
|
|
properties={
|
|
"trace_id": data.trace_id,
|
|
"n_candidates_requested": data.n_candidates,
|
|
"n_candidates_returned": len(optimization_results),
|
|
"total_cost": total_cost,
|
|
"language_version": language_version,
|
|
},
|
|
)
|
|
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await safe_log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
original_code=data.source_code,
|
|
dependency_code=data.dependency_code,
|
|
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
|
|
optimizations_post={r.optimization_id: r.source_code for r in optimization_results},
|
|
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
|
)
|
|
|
|
return 200, OptimizeResponseSchema(optimizations=optimization_results)
|