codeflash-internal/django/aiservice/core/languages/java/optimizer.py
Kevin Turcios 9b3cd48048 Raise LLMOutputUnparseable on empty LLM responses instead of silently returning ""
When Azure OpenAI or Anthropic returns null/empty content (content
filter, truncation, transient failure), call_openai/call_anthropic now
raise LLMOutputUnparseable instead of returning an empty string that
silently flows through the pipeline and produces 422 "Could not
generate any optimizations." All optimizer callers catch
LLMOutputUnparseable to preserve cost tracking while returning None.
2026-04-21 05:59:07 -05:00

303 lines
11 KiB
Python

"""Java code optimizer module.
This module handles optimization requests for Java code.
"""
from __future__ import annotations
import asyncio
import logging
import re
import uuid
from typing import TYPE_CHECKING, Any
from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import LLMOutputUnparseable, llm_client
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
from authapp.auth import AuthenticatedRequest
from authapp.user import get_user_by_id
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import safe_log_features
from core.shared.context_helpers import extract_code_and_explanation, group_code, split_markdown_code
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizeSchema
from core.shared.optimizer_schemas import (
OptimizeErrorResponseSchema,
OptimizeResponseItemSchema,
OptimizeResponseSchema,
)
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
from aiservice.validators.java_validator import validate_java_syntax
# Pattern to extract code blocks from LLM response (handles both ```java and ```java:filename.java)
JAVA_CODE_PATTERN = re.compile(r"```(?:java)(?::[^\n]*)?\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
# Pattern to extract code blocks with file paths (multi-file context)
JAVA_CODE_WITH_PATH_PATTERN = re.compile(r"```(?:java):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL)
def is_multi_context_java(source_code: str) -> bool:
"""Check if source code contains multiple Java file blocks."""
return source_code.count("```java:") >= 1
async def optimize_java_code_single(
user_id: str,
source_code: str,
trace_id: str,
dependency_code: str | None = None,
optimize_model: LLM = OPTIMIZE_MODEL,
language_version: str = "17",
call_sequence: int | None = None,
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
"""Optimize Java code using LLMs.
Args:
user_id: The user ID making the request
source_code: The source code to optimize (can be multi-file markdown format)
trace_id: The trace ID for logging
dependency_code: Optional dependency code for context
optimize_model: The LLM model to use
language_version: Target Java version (e.g., "11", "17", "21")
call_sequence: Call sequence number for tracking
Returns:
Tuple of (optimization_result, llm_cost, model_name)
"""
logging.info("/optimize: Optimizing Java code.")
debug_log_sensitive_data(f"Optimizing Java code for user {user_id}:\n{source_code}")
# Check if source code is multi-file format
is_multi_file = is_multi_context_java(source_code)
original_file_to_code: dict[str, str] = {}
if is_multi_file:
original_file_to_code = split_markdown_code(source_code, "java")
logging.info(
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
)
# Get Java-specific prompts
system_prompt = get_system_prompt(is_async=False)
user_prompt = get_user_prompt(is_async=False)
# Format prompts with Java version
system_prompt = system_prompt.format(language_version=f"Java {language_version}")
if is_multi_file:
user_prompt = user_prompt.format(source_code=source_code)
else:
user_prompt = user_prompt.format(source_code=source_code)
if dependency_code:
user_prompt += f"\n\n**Context (read-only, do not modify):**\n{dependency_code}"
obs_context: dict[str, Any] | None = {"call_sequence": call_sequence} if call_sequence is not None else None
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
try:
output = await llm_client.call(
llm=optimize_model,
messages=messages,
call_type="optimization",
trace_id=trace_id,
user_id=user_id,
python_version="N/A", # Not applicable for Java
context=obs_context,
)
except LLMOutputUnparseable as e:
debug_log_sensitive_data(f"Empty LLM response for Java source:\n{source_code}")
return None, e.cost, optimize_model.name
except Exception:
debug_log_sensitive_data(f"Failed to generate code for Java source:\n{source_code}")
return None, None, optimize_model.name
llm_cost = output.cost
debug_log_sensitive_data_from_callable(
lambda: f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}"
)
if output.raw_response.usage is not None:
ph(
user_id,
"aiservice-optimize-openai-usage",
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": "java"},
)
# Extract code and explanation from response
code, explanation = extract_code_and_explanation(
output.content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file
)
if not code:
logging.warning("No valid Java code extracted from LLM response")
return None, llm_cost, optimize_model.name
# Validate the code
code_to_validate = code if isinstance(code, str) else "\n".join(code.values())
is_valid, error = validate_java_syntax(code_to_validate)
if not is_valid:
logging.warning("Java code failed syntax validation: %s", error)
return None, llm_cost, optimize_model.name
# Format the response
if isinstance(code, dict):
# Multi-file response
formatted_code = group_code(code, language="java")
# Single file response - try to get file name from original
elif is_multi_file and original_file_to_code:
file_name = next(iter(original_file_to_code.keys()))
formatted_code = group_code({file_name: code}, language="java")
else:
# Default file name
formatted_code = group_code({"Source.java": code}, language="java")
optimization_id = str(uuid.uuid4())
result = OptimizeResponseItemSchema(
explanation=explanation, optimization_id=optimization_id, source_code=formatted_code
)
return result, llm_cost, optimize_model.name
async def optimize_java_code(
user_id: str,
source_code: str,
trace_id: str,
dependency_code: str | None = None,
language_version: str = "17",
n_candidates: int = 5,
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
"""Run parallel optimizations with multiple models based on the distribution config.
Returns:
tuple containing:
- list of optimization results
- total LLM cost
- dict of raw code/explanations keyed by optimization_id
- dict mapping optimization_id to model name
"""
tasks: list[tuple[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]], None]] = []
call_sequence = 1
if n_candidates == 0:
return [], 0.0, {}, {}
async with asyncio.TaskGroup() as tg:
for model, num_calls in get_model_distribution(n_candidates, MAX_OPTIMIZER_CALLS):
for _ in range(num_calls):
task = tg.create_task(
optimize_java_code_single(
user_id=user_id,
source_code=source_code,
trace_id=trace_id,
dependency_code=dependency_code,
optimize_model=model,
language_version=language_version,
call_sequence=call_sequence,
)
)
tasks.append((task, None))
call_sequence += 1
# Collect results
optimization_results: list[OptimizeResponseItemSchema] = []
total_cost = 0.0
code_and_explanations: dict[str, dict[str, str]] = {}
optimization_models: dict[str, str] = {}
for task, _ in tasks:
result, cost, model_name = task.result()
if cost:
total_cost += cost
if result is not None:
optimization_results.append(result)
code_and_explanations[result.optimization_id] = {
"code": result.source_code,
"explanation": result.explanation,
}
optimization_models[result.optimization_id] = model_name
return optimization_results, total_cost, code_and_explanations, optimization_models
async def optimize_java(
request: AuthenticatedRequest, data: OptimizeSchema
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
"""Optimize Java code for performance using LLMs."""
# Validate trace_id
if not validate_trace_id(data.trace_id):
return 400, OptimizeErrorResponseSchema(error="Invalid trace_id")
user_id = request.user
async with asyncio.TaskGroup() as tg:
user_task = tg.create_task(get_user_by_id(user_id))
event_task = tg.create_task(
get_or_create_optimization_event(trace_id=data.trace_id, event_type="no-pr", user_id=user_id)
)
user = user_task.result()
if user is None:
raise HttpError(401, "User not found")
optimization_event, _created = event_task.result()
# Determine Java version
language_version = data.language_version or "17"
from core.languages.java.demo_hacks import try_demo_optimize_java # noqa: PLC0415
demo_response = await try_demo_optimize_java(data.source_code)
if demo_response is not None:
for item in demo_response.optimizations:
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
return 200, demo_response
# Run optimization
optimization_results, total_cost, code_and_explanations, _optimization_models = await optimize_java_code(
user_id=user_id,
source_code=data.source_code,
trace_id=data.trace_id,
dependency_code=data.dependency_code,
language_version=language_version,
n_candidates=data.n_candidates,
)
# Track analytics
ph(
user_id,
"aiservice-optimize-java",
properties={
"trace_id": data.trace_id,
"n_candidates_requested": data.n_candidates,
"n_candidates_returned": len(optimization_results),
"total_cost": total_cost,
"language_version": language_version,
},
)
if hasattr(request, "should_log_features") and request.should_log_features:
await safe_log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
optimizations_post={r.optimization_id: r.source_code for r in optimization_results},
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
)
return 200, OptimizeResponseSchema(optimizations=optimization_results)