codeflash-internal/django/aiservice/core/languages/java/optimizer.py

586 lines
23 KiB
Python
Raw Normal View History

codeflash-omni-java (#2335) # Pull Request Checklist ## Description - [ ] **Breaking Changes**: Document any breaking changes (if applicable) - [ ] **Description of PR**: Clear and concise description of what this PR accomplishes - [ ] **Related Issues**: Link to any related issues or tickets ## Testing - [ ] **Test cases Attached**: All relevant test cases have been added/updated - [ ] **Manual Testing**: Manual testing completed for the changes ## Monitoring & Debugging - [ ] **Logging in place**: Appropriate logging has been added for debugging user issues - [ ] **Sentry will be able to catch errors**: Error handling ensures Sentry can capture and report errors - [ ] **Avoid Dev based/Prisma logging**: No development-only or Prisma-specific logging in production code ## Configuration - [ ] **Env variables newly added**: Any new environment variables are documented in .env.example file or mentioned in description --- ## Additional Notes <!-- Add any additional context, screenshots, or notes for reviewers here --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-200.ec2.internal> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Kevin Turcios <turcioskevinr@gmail.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
2026-02-13 17:56:55 +00:00
"""Java code optimizer module.
This module handles optimization requests for Java code.
"""
from __future__ import annotations
import asyncio
import logging
import re
import uuid
from typing import TYPE_CHECKING, Any
import sentry_sdk
from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from authapp.user import get_user_by_id
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
from core.shared.context_helpers import group_code, split_markdown_code
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizeSchema
from core.shared.optimizer_schemas import (
OptimizeErrorResponseSchema,
OptimizeResponseItemSchema,
OptimizeResponseSchema,
)
from log_features.log_event import get_or_create_optimization_event
from log_features.log_features import log_features
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
from aiservice.validators.java_validator import validate_java_syntax
# Pattern to extract code blocks from LLM response (handles both ```java and ```java:filename.java)
JAVA_CODE_PATTERN = re.compile(r"```(?:java)(?::[^\n]*)?\s*\n(.*?)```", re.MULTILINE | re.DOTALL)
# Pattern to extract code blocks with file paths (multi-file context)
JAVA_CODE_WITH_PATH_PATTERN = re.compile(r"```(?:java):([^\n]+)\n(.*?)```", re.MULTILINE | re.DOTALL)
def is_multi_context_java(source_code: str) -> bool:
"""Check if source code contains multiple Java file blocks."""
return source_code.count("```java:") >= 1
def extract_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
"""Extract code and explanation from LLM response.
Args:
content: The raw LLM response content
is_multi_file: Whether to expect multi-file format
Returns:
Tuple of (code, explanation) where code is a string for single file
or dict[str, str] for multi-file
"""
if is_multi_file:
# Extract all code blocks with file paths
matches = JAVA_CODE_WITH_PATH_PATTERN.findall(content)
if matches:
file_to_code: dict[str, str] = {}
first_match_pos = content.find("```")
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
for file_path, code in matches:
file_to_code[file_path.strip()] = code.strip()
return file_to_code, explanation
# Fall back to single file extraction
return extract_code_and_explanation(content, is_multi_file=False)
# Single file extraction
match = JAVA_CODE_PATTERN.search(content)
if match:
code = match.group(1).strip()
# Explanation is everything before the code block
explanation_end = match.start()
explanation = content[:explanation_end].strip()
return code, explanation
# No code block found, return empty code
return "", content
def _extract_demo_context(source_code: str) -> tuple[str, str, str, str]:
"""Extract package, class name, exception type, and extra imports from the demo source code.
Returns:
Tuple of (package_declaration, class_name, throw_statement_prefix, extra_imports)
"""
# Extract the raw code from markdown block if present
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
raw_code = code_match.group(1) if code_match else source_code
# Extract package
pkg_match = re.search(r"^\s*package\s+([\w.]+)\s*;", raw_code, re.MULTILINE)
package_decl = f"package {pkg_match.group(1)};\n" if pkg_match else ""
# Extract class name
class_match = re.search(r"\bclass\s+(\w+)", raw_code)
class_name = class_match.group(1) if class_match else "FileUtils"
# Extract exception type from the throw statement (e.g., "throw new AerospikeException(...)")
throw_match = re.search(r"throw\s+new\s+(\w+)\s*\(", raw_code)
exception_type = throw_match.group(1) if throw_match else "RuntimeException"
# Collect extra imports needed for the exception type (skip standard java/javax)
extra_imports = ""
if exception_type != "RuntimeException":
import_match = re.search(rf"^\s*import\s+([\w.]*\.{re.escape(exception_type)})\s*;", raw_code, re.MULTILINE)
if import_match:
extra_imports = f"import {import_match.group(1)};\n"
return package_decl, class_name, exception_type, extra_imports
def _build_demo_optimizations(
package_decl: str, class_name: str, exception_type: str, extra_imports: str
) -> list[dict[str, str]]:
"""Build 2 demo optimization candidates using the extracted class context.
Candidate 1 (Files.readAllBytes) is the intended winner it benchmarks fastest.
Candidate 2 is a plausible alternative that is functionally correct but
benchmarks slightly slower, ensuring Files.readAllBytes wins the speedup critic.
"""
fmt = dict(
package_decl=package_decl, class_name=class_name, exception_type=exception_type, extra_imports=extra_imports
)
return [
# Candidate 2: FileInputStream.readAllBytes() (Java 9+)
{
"source_code": (
"{package_decl}"
"\n"
"import java.io.File;\n"
"import java.io.FileInputStream;\n"
"{extra_imports}"
"\n"
"public final class {class_name} {{\n"
" public static byte[] readFile(File file) {{\n"
" try (FileInputStream fis = new FileInputStream(file)) {{\n"
" return fis.readAllBytes();\n"
" }}\n"
" catch (Throwable e) {{\n"
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
" }}\n"
" }}\n"
"}}"
).format(**fmt),
"explanation": (
"Use FileInputStream.readAllBytes() (Java 9+) to read the entire file in one call. "
"This eliminates the manual read loop but still uses FileInputStream internally."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 1: Files.readAllBytes (THE WINNER)
{
"source_code": (
"{package_decl}"
"\n"
"import java.io.File;\n"
"import java.nio.file.Files;\n"
"{extra_imports}"
"\n"
"public final class {class_name} {{\n"
" public static byte[] readFile(File file) {{\n"
" try {{\n"
" return java.nio.file.Files.readAllBytes(file.toPath());\n"
" }}\n"
" catch (Throwable e) {{\n"
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
" }}\n"
" }}\n"
"}}"
).format(**fmt),
"explanation": (
"Replace manual FileInputStream read loop with java.nio.file.Files.readAllBytes(). "
"This NIO method is optimized at the JDK level for direct file-to-byte-array transfer, "
"eliminating manual buffering and loop overhead."
),
"optimization_id": str(uuid.uuid4()),
},
]
def _build_host_equals_demo_optimizations(source_code: str) -> list[dict[str, str]]:
"""Build 5 optimization candidates for Host.equals by reordering comparisons.
Candidate 1 (port-first early return) is the intended winner comparing the
primitive int port before the String name avoids unnecessary method dispatch.
"""
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
raw_code = code_match.group(1) if code_match else source_code
# Match: return this.name.equals(other.name) && this.port == other.port;
original_stmt = re.compile(
r"(\s*)return\s+this\.name\.equals\(other\.name\)\s*&&\s*this\.port\s*==\s*other\.port\s*;"
)
match = original_stmt.search(raw_code)
if not match:
return [
{
"source_code": raw_code,
"explanation": "No optimization applicable.",
"optimization_id": str(uuid.uuid4()),
}
]
indent = match.group(1)
inner = indent + " "
def replace_with(replacement: str) -> str:
return original_stmt.sub(replacement, raw_code)
return [
# Candidate 1 (WINNER): Port-first early return
{
"source_code": replace_with(
f"{indent}// Compare primitive port first to avoid unnecessary string equals calls.\n"
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}return this.name.equals(other.name);"
),
"explanation": (
"Compare primitive port first to avoid unnecessary string equals calls. "
"Integer comparison is a single CPU instruction, while String.equals() "
"involves method dispatch and potential character-by-character comparison."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 2: Reordered conjunction (port first in &&)
{
"source_code": replace_with(f"{indent}return this.port == other.port && this.name.equals(other.name);"),
"explanation": (
"Reorder the conjunction to evaluate the cheaper primitive int comparison first. "
"Short-circuit evaluation skips String.equals() when ports differ."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 3: Port-first with Objects.equals for null safety
{
"source_code": replace_with(
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}return java.util.Objects.equals(this.name, other.name);"
),
"explanation": (
"Check port first (cheap primitive comparison), then use Objects.equals() "
"for null-safe name comparison. Adds safety at slight method-call overhead."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 4: Ternary with port-first guard
{
"source_code": replace_with(
f"{indent}return this.port == other.port ? this.name.equals(other.name) : false;"
),
"explanation": (
"Use a ternary to short-circuit on port mismatch. "
"Evaluates the cheap int comparison first, only calling String.equals() when ports match."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 5: Explicit null guard + port first
{
"source_code": replace_with(
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}if (this.name == null) {{\n"
f"{inner}return other.name == null;\n"
f"{indent}}}\n"
f"{indent}return this.name.equals(other.name);"
),
"explanation": (
"Guard on port first, then add explicit null handling for the name field "
"before delegating to String.equals(). Avoids potential NullPointerException."
),
"optimization_id": str(uuid.uuid4()),
},
]
async def hack_for_demo_java(source_code: str) -> OptimizeResponseSchema:
# Extract file path from markdown source (```java:path/to/File.java)
file_path_match = re.search(r"```java:([^\n]+)", source_code)
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
if is_host_equals_demo(source_code):
optimizations = _build_host_equals_demo_optimizations(source_code)
else:
# Extract class context dynamically from the source code
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
source_code=group_code({file_name: opt["source_code"]}, language="java"),
)
for opt in optimizations
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def optimize_java_code_single(
user_id: str,
source_code: str,
trace_id: str,
dependency_code: str | None = None,
optimize_model: LLM = OPTIMIZE_MODEL,
language_version: str = "17",
call_sequence: int | None = None,
) -> tuple[OptimizeResponseItemSchema | None, float | None, str]:
"""Optimize Java code using LLMs.
Args:
user_id: The user ID making the request
source_code: The source code to optimize (can be multi-file markdown format)
trace_id: The trace ID for logging
dependency_code: Optional dependency code for context
optimize_model: The LLM model to use
language_version: Target Java version (e.g., "11", "17", "21")
call_sequence: Call sequence number for tracking
Returns:
Tuple of (optimization_result, llm_cost, model_name)
"""
logging.info("/optimize: Optimizing Java code.")
debug_log_sensitive_data(f"Optimizing Java code for user {user_id}:\n{source_code}")
# Check if source code is multi-file format
is_multi_file = is_multi_context_java(source_code)
original_file_to_code: dict[str, str] = {}
if is_multi_file:
original_file_to_code = split_markdown_code(source_code, "java")
logging.info(
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
)
# Get Java-specific prompts
system_prompt = get_system_prompt(is_async=False)
user_prompt = get_user_prompt(is_async=False)
# Format prompts with Java version
system_prompt = system_prompt.format(language_version=f"Java {language_version}")
if is_multi_file:
user_prompt = user_prompt.format(source_code=source_code)
else:
user_prompt = user_prompt.format(source_code=source_code)
if dependency_code:
user_prompt += f"\n\n**Context (read-only, do not modify):**\n{dependency_code}"
obs_context: dict[str, Any] | None = {"call_sequence": call_sequence} if call_sequence is not None else None
messages: list[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(role="system", content=system_prompt),
ChatCompletionUserMessageParam(role="user", content=user_prompt),
]
try:
output = await call_llm(
llm=optimize_model,
messages=messages,
call_type="optimization",
trace_id=trace_id,
user_id=user_id,
python_version="N/A", # Not applicable for Java
context=obs_context,
)
except Exception as e:
logging.exception("LLM Code Generation error in Java optimizer")
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"Failed to generate code for Java source:\n{source_code}")
return None, None, optimize_model.name
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
if output.raw_response.usage is not None:
await asyncio.to_thread(
ph,
user_id,
"aiservice-optimize-openai-usage",
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": "java"},
)
# Extract code and explanation from response
code, explanation = extract_code_and_explanation(output.content, is_multi_file)
if not code:
logging.warning("No valid Java code extracted from LLM response")
return None, llm_cost, optimize_model.name
# Validate the code
code_to_validate = code if isinstance(code, str) else "\n".join(code.values())
is_valid, error = validate_java_syntax(code_to_validate)
if not is_valid:
logging.warning(f"Java code failed syntax validation: {error}")
return None, llm_cost, optimize_model.name
# Format the response
if isinstance(code, dict):
# Multi-file response
formatted_code = group_code(code, language="java")
# Single file response - try to get file name from original
elif is_multi_file and original_file_to_code:
file_name = next(iter(original_file_to_code.keys()))
formatted_code = group_code({file_name: code}, language="java")
else:
# Default file name
formatted_code = group_code({"Source.java": code}, language="java")
optimization_id = str(uuid.uuid4())
result = OptimizeResponseItemSchema(
explanation=explanation, optimization_id=optimization_id, source_code=formatted_code
)
return result, llm_cost, optimize_model.name
async def optimize_java_code(
user_id: str,
source_code: str,
trace_id: str,
dependency_code: str | None = None,
language_version: str = "17",
n_candidates: int = 5,
) -> tuple[list[OptimizeResponseItemSchema], float, dict[str, dict[str, str]], dict[str, str]]:
"""Run parallel optimizations with multiple models based on the distribution config.
Returns:
tuple containing:
- list of optimization results
- total LLM cost
- dict of raw code/explanations keyed by optimization_id
- dict mapping optimization_id to model name
"""
tasks: list[tuple[asyncio.Task[tuple[OptimizeResponseItemSchema | None, float | None, str]], None]] = []
call_sequence = 1
if n_candidates == 0:
return [], 0.0, {}, {}
async with asyncio.TaskGroup() as tg:
for model, num_calls in get_model_distribution(n_candidates, MAX_OPTIMIZER_CALLS):
for _ in range(num_calls):
task = tg.create_task(
optimize_java_code_single(
user_id=user_id,
source_code=source_code,
trace_id=trace_id,
dependency_code=dependency_code,
optimize_model=model,
language_version=language_version,
call_sequence=call_sequence,
)
)
tasks.append((task, None))
call_sequence += 1
# Collect results
optimization_results: list[OptimizeResponseItemSchema] = []
total_cost = 0.0
code_and_explanations: dict[str, dict[str, str]] = {}
optimization_models: dict[str, str] = {}
for task, _ in tasks:
result, cost, model_name = task.result()
if cost:
total_cost += cost
if result is not None:
optimization_results.append(result)
code_and_explanations[result.optimization_id] = {
"code": result.source_code,
"explanation": result.explanation,
}
optimization_models[result.optimization_id] = model_name
return optimization_results, total_cost, code_and_explanations, optimization_models
async def optimize_java(
request: AuthenticatedRequest, data: OptimizeSchema
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
"""Optimize Java code for performance using LLMs."""
# Validate trace_id
if not validate_trace_id(data.trace_id):
return 400, OptimizeErrorResponseSchema(error="Invalid trace_id")
user_id = request.user
user = await get_user_by_id(user_id)
if user is None:
raise HttpError(401, "User not found")
# Log the event
optimization_event, _created = await get_or_create_optimization_event(
trace_id=data.trace_id, event_type="no-pr", user_id=user_id
)
if optimization_event is not None:
await asyncio.to_thread(
log_features,
data.source_code[:1000],
optimization_event,
"optimize_request",
{
"source_code_length": len(data.source_code),
"dependency_code_length": len(data.dependency_code) if data.dependency_code else 0,
"n_candidates": data.n_candidates,
"language": "java",
"language_version": data.language_version or "17",
},
)
# Determine Java version
language_version = data.language_version or "17"
# Check for demo mode
if should_hack_for_demo_java(data.source_code):
response = await hack_for_demo_java(data.source_code)
for item in response.optimizations:
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
return 200, response
# Run optimization
optimization_results, total_cost, code_and_explanations, optimization_models = await optimize_java_code(
user_id=user_id,
source_code=data.source_code,
trace_id=data.trace_id,
dependency_code=data.dependency_code,
language_version=language_version,
n_candidates=data.n_candidates,
)
# Track analytics
await asyncio.to_thread(
ph,
user_id,
"aiservice-optimize-java",
properties={
"trace_id": data.trace_id,
"n_candidates_requested": data.n_candidates,
"n_candidates_returned": len(optimization_results),
"total_cost": total_cost,
"language_version": language_version,
},
)
# Log the response
if optimization_event is not None:
await asyncio.to_thread(
log_features,
str(code_and_explanations)[:1000],
optimization_event,
"optimize_response",
{
"n_candidates": len(optimization_results),
"total_cost": total_cost,
"models": list(optimization_models.values()),
},
)
return 200, OptimizeResponseSchema(optimizations=optimization_results)