mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
refactor: aiservice deep dive — LLM client, dedup, async, cleanup (#2482)
## Summary Comprehensive refactoring of the aiservice Django backend focusing on code quality, deduplication, and correctness: - **LLM client extraction**: Extract `LLMClient` class with lazy client init, centralized error handling, and event loop detection - **Centralize retry logic**: `@stamina.retry` on `call_anthropic`/`call_openai` for transient errors (rate limits, timeouts, 500s), removing scattered retry decorators from testgen files - **Deduplicate helpers**: Consolidate `extract_code_and_explanation` into shared `context_helpers.py`, unify `normalize_*_code` into `normalize_c_style_code` - **Eliminate double DB queries**: Auth middleware `afirst()` then `aupdate()` by PK, middleware caches org/subscription - **Parallelize Java optimizer**: Use `asyncio.TaskGroup` for independent LLM calls - **Lazy logging**: Convert all f-string logging to lazy `%s` formatting across 11 files - **Cleanup**: Remove unused `PipelineError`/`ValidationError`, fix `seach_and_replace.py` typo, replace `print()` with `logging.debug()` in middleware - **Sentry**: Reduce sampling 1.0 → 0.1/0.01, fix auth `settings.DEBUG` check, sanitize ranker errors ## Test plan - [x] All existing pytest tests pass (`uv run pytest`) - [x] Ruff lint/format clean - [x] No behavioral changes — pure refactoring --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c5e8b56c6f
commit
28c9acc877
38 changed files with 1127 additions and 1296 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -260,3 +260,6 @@ fabric.properties
|
|||
|
||||
/cli/experiments/js-serialization-experiment/node_modules/*
|
||||
/cli/packages/codeflash/.npmrc
|
||||
|
||||
# Tessl auto-generated skills
|
||||
**/skills/tessl__*
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Unified LLM module for all model definitions, clients, and API calls."""
|
||||
"""LLM client setup and API call functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -7,278 +7,202 @@ import logging
|
|||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from anthropic import AsyncAnthropicBedrock
|
||||
from openai import AsyncAzureOpenAI
|
||||
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||
import stamina
|
||||
import sentry_sdk
|
||||
from anthropic import (
|
||||
APIConnectionError as AnthropicConnectionError,
|
||||
APITimeoutError as AnthropicTimeoutError,
|
||||
AsyncAnthropicBedrock,
|
||||
InternalServerError as AnthropicServerError,
|
||||
RateLimitError as AnthropicRateLimitError,
|
||||
)
|
||||
from openai import (
|
||||
APIConnectionError as OpenAIConnectionError,
|
||||
APITimeoutError as OpenAITimeoutError,
|
||||
AsyncAzureOpenAI,
|
||||
InternalServerError as OpenAIServerError,
|
||||
RateLimitError as OpenAIRateLimitError,
|
||||
)
|
||||
|
||||
from aiservice.llm_models import has_anthropic, has_openai
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Definitions
|
||||
# =============================================================================
|
||||
|
||||
# Pricing is in USD per 1M tokens. See:
|
||||
# https://docs.anthropic.com/en/docs/about-claude/pricing
|
||||
# https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class LLM:
|
||||
"""Base LLM configuration with pricing info."""
|
||||
|
||||
name: str # On Azure OpenAI Service, this is the deployment name
|
||||
max_tokens: int
|
||||
model_type: Literal["openai", "anthropic", "google"]
|
||||
input_cost: float | None = None # USD per 1M tokens
|
||||
cached_input_cost: float | None = None # USD per 1M tokens (cached input)
|
||||
output_cost: float | None = None # USD per 1M tokens
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class OpenAI_GPT_4_1(LLM):
|
||||
"""OpenAI GPT-4.1 model."""
|
||||
|
||||
name: str = "gpt-4.1"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "openai"
|
||||
max_tokens: int = 100000
|
||||
input_cost: float = 2.00
|
||||
cached_input_cost: float = 0.50
|
||||
output_cost: float = 8.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class OpenAI_GPT_5_Mini(LLM):
|
||||
"""OpenAI GPT-5-mini model."""
|
||||
|
||||
name: str = "gpt-5-mini"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "openai"
|
||||
max_tokens: int = 200000
|
||||
input_cost: float = 0.25
|
||||
cached_input_cost: float = 0.03
|
||||
output_cost: float = 2.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class Anthropic_Claude_Sonnet_4_5(LLM):
|
||||
"""Anthropic Claude 4.5 Sonnet via AWS Bedrock."""
|
||||
|
||||
name: str = "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "anthropic"
|
||||
max_tokens: int = 200000
|
||||
input_cost: float = 3.00
|
||||
output_cost: float = 15.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class Anthropic_Claude_Haiku_4_5(LLM):
|
||||
"""Anthropic Claude 4.5 Haiku via AWS Bedrock."""
|
||||
|
||||
name: str = "us.anthropic.claude-haiku-4-5-20251001-v1:0"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "anthropic"
|
||||
max_tokens: int = 200000
|
||||
input_cost: float = 1.00
|
||||
output_cost: float = 5.00
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM Client Setup
|
||||
# =============================================================================
|
||||
|
||||
# Read environment variables once at module load
|
||||
_AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
_AWS_ACCESS_KEY = os.environ.get("AWS_ACCESS_KEY_ID")
|
||||
_AWS_SECRET_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
|
||||
_AWS_REGION = os.environ.get("AWS_REGION", "us-east-1")
|
||||
|
||||
|
||||
def _create_openai_client() -> AsyncAzureOpenAI | None:
|
||||
if _AZURE_OPENAI_API_KEY:
|
||||
return AsyncAzureOpenAI() # SDK auto-reads AZURE_OPENAI_* and OPENAI_API_VERSION
|
||||
return None
|
||||
|
||||
|
||||
def _create_anthropic_client() -> AsyncAnthropicBedrock | None:
|
||||
if _AWS_ACCESS_KEY and _AWS_SECRET_KEY:
|
||||
return AsyncAnthropicBedrock(
|
||||
aws_access_key=_AWS_ACCESS_KEY, aws_secret_key=_AWS_SECRET_KEY, aws_region=_AWS_REGION
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_client(model_type: str) -> AsyncAzureOpenAI | AsyncAnthropicBedrock | None:
|
||||
"""Get a fresh LLM client for the request.
|
||||
|
||||
Creates a new client for each request to avoid event loop issues
|
||||
with Django dev server.
|
||||
"""
|
||||
if model_type == "openai":
|
||||
return _create_openai_client()
|
||||
if model_type == "anthropic":
|
||||
return _create_anthropic_client()
|
||||
return None
|
||||
|
||||
|
||||
# Keep module-level clients for backwards compatibility
|
||||
_openai_client = _create_openai_client()
|
||||
_anthropic_client = _create_anthropic_client()
|
||||
|
||||
llm_clients: dict[str, AsyncAzureOpenAI | AsyncAnthropicBedrock | None] = {
|
||||
"openai": _openai_client,
|
||||
"anthropic": _anthropic_client,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Response Types
|
||||
# =============================================================================
|
||||
_TRANSIENT_LLM_ERRORS = (
|
||||
AnthropicConnectionError,
|
||||
AnthropicTimeoutError,
|
||||
AnthropicServerError,
|
||||
AnthropicRateLimitError,
|
||||
OpenAIConnectionError,
|
||||
OpenAITimeoutError,
|
||||
OpenAIServerError,
|
||||
OpenAIRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUsage:
|
||||
"""Unified usage stats for both OpenAI and Anthropic responses."""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
"""Unified response wrapper for both OpenAI and Anthropic API responses."""
|
||||
|
||||
content: str
|
||||
usage: LLMUsage
|
||||
raw_response: ChatCompletion | AnthropicMessage
|
||||
cost: float = 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM API Call
|
||||
# =============================================================================
|
||||
class LLMClient:
|
||||
def __init__(self) -> None:
|
||||
self.openai_client: AsyncAzureOpenAI | None = None
|
||||
self.anthropic_client: AsyncAnthropicBedrock | None = None
|
||||
self.client_loop: asyncio.AbstractEventLoop | None = None
|
||||
self.background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def call(
|
||||
self,
|
||||
llm: LLM,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
call_type: str = "",
|
||||
trace_id: str = "",
|
||||
max_tokens: int = 16384,
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
from aiservice.observability.database import record_llm_call # noqa: PLC0415
|
||||
|
||||
async def call_llm(
|
||||
llm: LLM,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
call_type: str = "",
|
||||
trace_id: str = "",
|
||||
max_tokens: int = 16384,
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM with OpenAI or Anthropic client."""
|
||||
from aiservice.observability.database import record_llm_call # noqa: PLC0415
|
||||
# Recreate provider clients when the event loop changes (stale connections)
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop is not self.client_loop:
|
||||
self.client_loop = loop
|
||||
self.background_tasks = set()
|
||||
self.openai_client = AsyncAzureOpenAI() if has_openai else None
|
||||
self.anthropic_client = (
|
||||
AsyncAnthropicBedrock(
|
||||
aws_access_key=os.environ.get("AWS_ACCESS_KEY_ID", ""),
|
||||
aws_secret_key=os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
|
||||
aws_region=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
if has_anthropic
|
||||
else None
|
||||
)
|
||||
start_time = time.time()
|
||||
error: Exception | None = None
|
||||
result: LLMResponse | None = None
|
||||
|
||||
# Create a fresh client for each request to avoid event loop issues with Django dev server
|
||||
client = get_llm_client(llm.model_type)
|
||||
if client is None:
|
||||
msg = f"LLM client for model type '{llm.model_type}' is not available"
|
||||
raise ValueError(msg)
|
||||
try:
|
||||
if llm.model_type == "anthropic":
|
||||
if self.anthropic_client is None:
|
||||
raise ValueError("Anthropic client is not available")
|
||||
result = await self.call_anthropic(llm, messages, max_tokens)
|
||||
elif llm.model_type == "openai":
|
||||
if self.openai_client is None:
|
||||
raise ValueError("OpenAI client is not available")
|
||||
result = await self.call_openai(llm, messages, max_tokens)
|
||||
else:
|
||||
msg = f"Unsupported model type: {llm.model_type}"
|
||||
raise ValueError(msg)
|
||||
result.cost = calculate_llm_cost(result.raw_response, llm)
|
||||
return result
|
||||
|
||||
start_time = time.time()
|
||||
error: Exception | None = None
|
||||
result: LLMResponse | None = None
|
||||
except Exception as e:
|
||||
error = e
|
||||
logger.exception(
|
||||
"LLM call failed: model=%s, provider=%s, call_type=%s, trace_id=%s, error=%s",
|
||||
llm.name,
|
||||
llm.model_type,
|
||||
call_type,
|
||||
trace_id,
|
||||
type(e).__name__,
|
||||
)
|
||||
sentry_sdk.capture_exception(e)
|
||||
raise
|
||||
|
||||
try:
|
||||
if llm.model_type == "anthropic":
|
||||
assert isinstance(client, AsyncAnthropicBedrock)
|
||||
system_prompt_content = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||
anthropic_messages = [
|
||||
{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"
|
||||
]
|
||||
finally:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
task = asyncio.create_task(
|
||||
record_llm_call(
|
||||
trace_id=trace_id,
|
||||
call_type=call_type,
|
||||
model_name=llm.name,
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
python_version=python_version,
|
||||
context=context,
|
||||
result=result,
|
||||
error=error,
|
||||
llm_cost=result.cost if result else None,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
)
|
||||
self.background_tasks.add(task)
|
||||
|
||||
kwargs: dict[str, Any] = {"model": llm.name, "messages": anthropic_messages, "max_tokens": max_tokens}
|
||||
if system_prompt_content:
|
||||
kwargs["system"] = system_prompt_content
|
||||
def on_done(t: asyncio.Task[str]) -> None:
|
||||
self.background_tasks.discard(t)
|
||||
if exc := t.exception():
|
||||
logger.warning("Tracing: Failed to record LLM call: %s", exc)
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
content = "".join(block.text for block in response.content if hasattr(block, "text"))
|
||||
task.add_done_callback(on_done)
|
||||
|
||||
result = LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
|
||||
raw_response=response,
|
||||
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
|
||||
async def call_anthropic(
|
||||
self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int
|
||||
) -> LLMResponse:
|
||||
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||
non_system = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
||||
|
||||
kwargs: dict[str, Any] = {"model": llm.name, "messages": non_system, "max_tokens": max_tokens}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
response = await self.anthropic_client.messages.create(**kwargs) # type: ignore[union-attr]
|
||||
content = "".join(block.text for block in response.content if hasattr(block, "text"))
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
|
||||
async def call_openai(self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int) -> LLMResponse:
|
||||
# gpt-5-mini only accepts max_completion_tokens, older models use max_tokens
|
||||
if llm.name == "gpt-5-mini":
|
||||
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
|
||||
model=llm.name, messages=messages, max_completion_tokens=max_tokens
|
||||
)
|
||||
else:
|
||||
# Azure OpenAI
|
||||
assert isinstance(client, AsyncAzureOpenAI)
|
||||
# gpt-5-mini only accepts max_completion_tokens, older models use max_tokens
|
||||
if llm.name == "gpt-5-mini":
|
||||
response = await client.chat.completions.create(
|
||||
model=llm.name, messages=messages, max_completion_tokens=max_tokens
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=llm.name, messages=messages, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
result = LLMResponse(
|
||||
content=response.choices[0].message.content or "",
|
||||
usage=LLMUsage(
|
||||
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
),
|
||||
raw_response=response,
|
||||
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
|
||||
model=llm.name, messages=messages, max_tokens=max_tokens
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error = e
|
||||
logger.error(
|
||||
"LLM call failed: model=%s, provider=%s, call_type=%s, trace_id=%s, error=%s: %s",
|
||||
llm.name,
|
||||
llm.model_type,
|
||||
call_type,
|
||||
trace_id,
|
||||
type(e).__name__,
|
||||
e,
|
||||
return LLMResponse(
|
||||
content=response.choices[0].message.content or "",
|
||||
usage=LLMUsage(
|
||||
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
),
|
||||
raw_response=response,
|
||||
)
|
||||
raise
|
||||
|
||||
finally:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
task = asyncio.create_task(
|
||||
record_llm_call(
|
||||
trace_id=trace_id,
|
||||
call_type=call_type,
|
||||
model_name=llm.name,
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
python_version=python_version,
|
||||
context=context,
|
||||
result=result,
|
||||
error=error,
|
||||
llm_cost=calculate_llm_cost(result.raw_response, llm) if result else None,
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
)
|
||||
|
||||
def _log_record_failure(t: asyncio.Task[str]) -> None:
|
||||
if exc := t.exception():
|
||||
logger.warning(f"Tracing: Failed to record LLM call: {exc}")
|
||||
|
||||
task.add_done_callback(_log_record_failure)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cost Calculation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def calculate_llm_cost(response: ChatCompletion | AnthropicMessage, llm: LLM) -> float:
|
||||
"""Calculate the cost of an LLM API call based on token usage."""
|
||||
if response.usage is None:
|
||||
return 0.0
|
||||
|
||||
usage = response.usage
|
||||
|
||||
# Extract token counts per provider
|
||||
# OpenAI: prompt_tokens is total (cached is subset), Anthropic: counts are additive
|
||||
if llm.model_type == "anthropic":
|
||||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
|
|
@ -297,28 +221,4 @@ def calculate_llm_cost(response: ChatCompletion | AnthropicMessage, llm: LLM) ->
|
|||
return (non_cached * input_rate + cache_read * cached_rate + output * output_rate) / 1_000_000
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Selection (based on available clients)
|
||||
# =============================================================================
|
||||
|
||||
# Prefer OpenAI: use OpenAI if available, fall back to Anthropic
|
||||
OPENAI_MODEL: LLM = OpenAI_GPT_5_Mini() if _openai_client else Anthropic_Claude_Sonnet_4_5()
|
||||
|
||||
# Prefer Anthropic: use Anthropic (AWS Bedrock) if available, fall back to OpenAI
|
||||
ANTHROPIC_MODEL: LLM = Anthropic_Claude_Sonnet_4_5() if _anthropic_client else OpenAI_GPT_5_Mini()
|
||||
|
||||
# Haiku model for cost-effective tasks (testgen diversity)
|
||||
HAIKU_MODEL: LLM = Anthropic_Claude_Haiku_4_5() if _anthropic_client else OpenAI_GPT_5_Mini()
|
||||
|
||||
# Model assignments
|
||||
EXPLAIN_MODEL: LLM = OPENAI_MODEL
|
||||
PLAN_MODEL: LLM = OPENAI_MODEL
|
||||
EXECUTE_MODEL: LLM = OPENAI_MODEL
|
||||
OPTIMIZE_MODEL: LLM = OPENAI_MODEL
|
||||
RANKING_MODEL: LLM = OPENAI_MODEL
|
||||
|
||||
REFINEMENT_MODEL: LLM = ANTHROPIC_MODEL
|
||||
EXPLANATIONS_MODEL: LLM = ANTHROPIC_MODEL
|
||||
OPTIMIZATION_REVIEW_MODEL: LLM = ANTHROPIC_MODEL
|
||||
CODE_REPAIR_MODEL: LLM = ANTHROPIC_MODEL
|
||||
ADAPTIVE_OPTIMIZE_MODEL: LLM = ANTHROPIC_MODEL
|
||||
llm_client = LLMClient()
|
||||
|
|
|
|||
97
django/aiservice/aiservice/llm_models.py
Normal file
97
django/aiservice/aiservice/llm_models.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""LLM model definitions and provider configurations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic.dataclasses import dataclass as pydantic_dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class LLM:
|
||||
name: str # On Azure OpenAI Service, this is the deployment name
|
||||
model_type: Literal["openai", "anthropic", "google"]
|
||||
input_cost: float | None = None # USD per 1M tokens
|
||||
cached_input_cost: float | None = None
|
||||
output_cost: float | None = None
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class OpenAI_GPT_4_1(LLM):
|
||||
name: str = "gpt-4.1"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "openai"
|
||||
input_cost: float = 2.00
|
||||
cached_input_cost: float = 0.50
|
||||
output_cost: float = 8.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class OpenAI_GPT_5_Mini(LLM):
|
||||
name: str = "gpt-5-mini"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "openai"
|
||||
input_cost: float = 0.25
|
||||
cached_input_cost: float = 0.03
|
||||
output_cost: float = 2.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class Anthropic_Claude_Sonnet_4_5(LLM):
|
||||
name: str = "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "anthropic"
|
||||
input_cost: float = 3.00
|
||||
output_cost: float = 15.00
|
||||
|
||||
|
||||
@pydantic_dataclass
|
||||
class Anthropic_Claude_Haiku_4_5(LLM):
|
||||
name: str = "us.anthropic.claude-haiku-4-5-20251001-v1:0"
|
||||
model_type: Literal["openai", "anthropic", "google"] = "anthropic"
|
||||
input_cost: float = 1.00
|
||||
output_cost: float = 5.00
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Selection (based on available providers)
|
||||
# =============================================================================
|
||||
|
||||
has_openai = bool(os.environ.get("AZURE_OPENAI_API_KEY"))
|
||||
has_anthropic = bool(os.environ.get("AWS_ACCESS_KEY_ID") and os.environ.get("AWS_SECRET_ACCESS_KEY"))
|
||||
|
||||
if has_openai:
|
||||
OPENAI_MODEL: LLM = OpenAI_GPT_5_Mini()
|
||||
else:
|
||||
logger.warning("AZURE_OPENAI_API_KEY not set, falling back to Anthropic for OpenAI-preferred tasks")
|
||||
sentry_sdk.capture_message("AZURE_OPENAI_API_KEY not set, falling back to Anthropic", level="warning")
|
||||
OPENAI_MODEL = Anthropic_Claude_Sonnet_4_5()
|
||||
|
||||
if has_anthropic:
|
||||
ANTHROPIC_MODEL: LLM = Anthropic_Claude_Sonnet_4_5()
|
||||
HAIKU_MODEL: LLM = Anthropic_Claude_Haiku_4_5()
|
||||
else:
|
||||
logger.warning("AWS credentials not set, falling back to OpenAI for Anthropic-preferred tasks")
|
||||
sentry_sdk.capture_message("AWS credentials not set, falling back to OpenAI", level="warning")
|
||||
ANTHROPIC_MODEL = OpenAI_GPT_5_Mini()
|
||||
HAIKU_MODEL = OpenAI_GPT_5_Mini()
|
||||
|
||||
TASK_MODELS: dict[str, LLM] = {
|
||||
"EXECUTE_MODEL": OPENAI_MODEL,
|
||||
"OPTIMIZE_MODEL": OPENAI_MODEL,
|
||||
"RANKING_MODEL": OPENAI_MODEL,
|
||||
"REFINEMENT_MODEL": ANTHROPIC_MODEL,
|
||||
"EXPLANATIONS_MODEL": ANTHROPIC_MODEL,
|
||||
"OPTIMIZATION_REVIEW_MODEL": ANTHROPIC_MODEL,
|
||||
"CODE_REPAIR_MODEL": ANTHROPIC_MODEL,
|
||||
"ADAPTIVE_OPTIMIZE_MODEL": ANTHROPIC_MODEL,
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> LLM:
|
||||
if name in TASK_MODELS:
|
||||
return TASK_MODELS[name]
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
|
|
@ -1,4 +1,7 @@
|
|||
import logging
|
||||
|
||||
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
|
||||
from django.conf import settings
|
||||
from django.http import JsonResponse
|
||||
from django.utils.decorators import async_only_middleware
|
||||
from ninja.errors import HttpError
|
||||
|
|
@ -14,7 +17,7 @@ class AuthMiddleware:
|
|||
|
||||
# Check if the `get_response` function is a coroutine (async)
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("AuthMiddleware is async.")
|
||||
logging.debug("AuthMiddleware is async.")
|
||||
# Mark this middleware as async if get_response is async
|
||||
markcoroutinefunction(self) # type: ignore[arg-type]
|
||||
|
||||
|
|
@ -33,8 +36,7 @@ class AuthMiddleware:
|
|||
# HttpError from auth layer - invalid API key
|
||||
return JsonResponse({"error": e.message}, status=e.status_code)
|
||||
except Exception as e:
|
||||
# Database errors, network issues during authentication
|
||||
response_error = str(e) if request.META.get("DEBUG") else "Authentication failed"
|
||||
response_error = str(e) if settings.DEBUG else "Authentication failed"
|
||||
return JsonResponse({"error": response_error}, status=500)
|
||||
|
||||
# Process the request - let downstream exceptions propagate naturally
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import sentry_sdk
|
||||
|
|
@ -29,7 +30,7 @@ class RateLimitMiddleware:
|
|||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("RateLimitMiddleware is async.")
|
||||
logging.debug("RateLimitMiddleware is async.")
|
||||
markcoroutinefunction(self) # type: ignore[arg-type]
|
||||
self.restricted_paths = [
|
||||
"/" + str(pattern.pattern)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class TrackUsageMiddleware:
|
|||
def __init__(self, get_response) -> None:
|
||||
self.get_response = get_response
|
||||
if iscoroutinefunction(self.get_response):
|
||||
print("TrackUsageMiddleware is async.")
|
||||
logging.debug("TrackUsageMiddleware is async.")
|
||||
markcoroutinefunction(self) # ty:ignore[invalid-argument-type]
|
||||
|
||||
async def __call__(self, request):
|
||||
|
|
@ -117,7 +117,7 @@ class TrackUsageMiddleware:
|
|||
if not subscription:
|
||||
# Subscription is now created during login in cf-webapp
|
||||
# As a backup, creating one here (this should be rare only now and for old users)
|
||||
logging.warning(f"No subscription found for user {user_id}. Creating backup subscription.")
|
||||
logging.warning("No subscription found for user %s. Creating backup subscription.", user_id)
|
||||
try:
|
||||
subscription = await Subscriptions.objects.acreate(
|
||||
user_id=user_id,
|
||||
|
|
@ -126,7 +126,7 @@ class TrackUsageMiddleware:
|
|||
subscription_status="active",
|
||||
optimizations_used=cost, # Charge for this request
|
||||
)
|
||||
logging.info(f"Created backup subscription for user {user_id}")
|
||||
logging.info("Created backup subscription for user %s", user_id)
|
||||
request.subscription_info = {
|
||||
"userId": user_id,
|
||||
"tier": subscription.plan_type,
|
||||
|
|
@ -137,7 +137,7 @@ class TrackUsageMiddleware:
|
|||
except IntegrityError:
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
if not subscription:
|
||||
logging.exception(f"Failed to create or fetch subscription for user {user_id}")
|
||||
logging.exception("Failed to create or fetch subscription for user %s", user_id)
|
||||
return JsonResponse({"error": "Failed to initialize user subscription"}, status=500)
|
||||
|
||||
if subscription.subscription_status != "active":
|
||||
|
|
@ -172,8 +172,11 @@ class TrackUsageMiddleware:
|
|||
new_used = current_used + cost
|
||||
|
||||
logging.debug(
|
||||
f"track_usage_middleware.py|__call__|Atomic update completed: "
|
||||
f"user_id={user_id}, endpoint={endpoint}, cost={cost}, new_used={new_used}"
|
||||
"track_usage_middleware.py|__call__|Atomic update completed: user_id=%s, endpoint=%s, cost=%s, new_used=%s",
|
||||
user_id,
|
||||
endpoint,
|
||||
cost,
|
||||
new_used,
|
||||
)
|
||||
|
||||
# Attach subscription info to request
|
||||
|
|
@ -190,15 +193,15 @@ class TrackUsageMiddleware:
|
|||
# Handling database constraint violations more specifically
|
||||
if "duplicate key value violates unique constraint" in str(e):
|
||||
if "subscriptions_pkey" in str(e):
|
||||
logging.warning(f"Subscription creation race condition for user {user_id}: {e}")
|
||||
logging.warning("Subscription creation race condition for user %s: %s", user_id, e)
|
||||
return JsonResponse({"error": "Subscription initialization conflict. Please retry."}, status=429)
|
||||
if "subscriptions_user_id_key" in str(e):
|
||||
logging.warning(f"User already has subscription for user {user_id}: {e}")
|
||||
logging.warning("User already has subscription for user %s: %s", user_id, e)
|
||||
return JsonResponse({"error": "User subscription already exists. Please refresh."}, status=409)
|
||||
sentry_sdk.capture_exception(e)
|
||||
logging.exception(f"Database integrity error for user {user_id}")
|
||||
logging.exception("Database integrity error for user %s", user_id)
|
||||
return JsonResponse({"error": "Database constraint error. Please contact support."}, status=500)
|
||||
except Exception as e:
|
||||
sentry_sdk.capture_exception(e)
|
||||
logging.exception(f"Unexpected error tracking usage for user {user_id}")
|
||||
logging.exception("Unexpected error tracking usage for user %s", user_id)
|
||||
return JsonResponse({"error": "Service temporarily unavailable. Please try again."}, status=500)
|
||||
|
|
|
|||
|
|
@ -119,12 +119,7 @@ DEFAULT_AUTO_FIELD: str = "django.db.models.BigAutoField"
|
|||
if os.environ.get("ENVIRONMENT", default="") == "PRODUCTION":
|
||||
sentry_sdk.init(
|
||||
dsn="https://8a857cbf974ca889a46c1b39173db44b@o4506833230561280.ingest.sentry.io/4506833234493440",
|
||||
# Set traces_sample_rate to 1.0 to capture 100%
|
||||
# of transactions for performance monitoring.
|
||||
traces_sample_rate=1.0,
|
||||
# Set profiles_sample_rate to 1.0 to profile 100%
|
||||
# of sampled transactions.
|
||||
# We recommend adjusting this value in production.
|
||||
profiles_sample_rate=1.0,
|
||||
traces_sample_rate=0.1,
|
||||
profiles_sample_rate=0.01,
|
||||
enable_logs=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from django.db.models.functions import Now
|
|||
from ninja.errors import HttpError
|
||||
from ninja.security import HttpBearer
|
||||
|
||||
from authapp.auth_utils import hash_api_key, instance_for_api_key
|
||||
from authapp.auth_utils import hash_api_key
|
||||
from authapp.models import CFAPIKeys, Organizations, Subscriptions
|
||||
|
||||
|
||||
|
|
@ -27,24 +27,26 @@ async def check_subscription_status(request, user_id, tier, organization_id=None
|
|||
Attaches fetched organization and subscription objects to the request
|
||||
so downstream middleware can reuse them without re-querying.
|
||||
"""
|
||||
if tier is not None:
|
||||
return False
|
||||
|
||||
try:
|
||||
org = None
|
||||
if organization_id:
|
||||
org = await Organizations.objects.filter(id=organization_id).afirst()
|
||||
request.organization = org
|
||||
if org and org.name == "codeflash-ai":
|
||||
return True
|
||||
if org and org.subscription:
|
||||
return False
|
||||
|
||||
subscription = await Subscriptions.objects.filter(user_id=user_id).afirst()
|
||||
request.subscription = subscription
|
||||
|
||||
if tier is not None:
|
||||
return False
|
||||
|
||||
if org and org.name == "codeflash-ai":
|
||||
return True
|
||||
if org and org.subscription:
|
||||
return False
|
||||
|
||||
if subscription and subscription.plan_type.lower() in ["pro", "enterprise"]:
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error checking subscription: {e!s}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return False
|
||||
return True
|
||||
|
|
@ -53,25 +55,15 @@ async def check_subscription_status(request, user_id, tier, organization_id=None
|
|||
class AuthBearer(HttpBearer):
|
||||
async def authenticate(self, request, token):
|
||||
hashed_token = hash_api_key(token)
|
||||
try:
|
||||
num_users = await CFAPIKeys.objects.filter(key=hashed_token).aupdate(last_used=Now())
|
||||
if num_users == 0:
|
||||
raise HttpError(403, "Invalid API key")
|
||||
if num_users == 1:
|
||||
api_key_instance = await instance_for_api_key(hashed_token)
|
||||
if not api_key_instance:
|
||||
print(f"Instance not found for api key {token}. Returning 403")
|
||||
raise HttpError(403, "Invalid API key")
|
||||
request.user = api_key_instance.user_id
|
||||
request.tier = api_key_instance.tier
|
||||
request.api_key_id = api_key_instance.id
|
||||
request.organization_id = api_key_instance.organization_id
|
||||
request.should_log_features = await check_subscription_status(
|
||||
request, user_id=request.user, tier=request.tier, organization_id=request.organization_id
|
||||
)
|
||||
return token
|
||||
|
||||
print("THIS SHOULD NOT HAPPEN! More than one users found in the db with the same api key!")
|
||||
raise HttpError(403, "Invalid API key")
|
||||
except CFAPIKeys.DoesNotExist:
|
||||
api_key_instance = await CFAPIKeys.objects.filter(key=hashed_token).afirst()
|
||||
if api_key_instance is None:
|
||||
raise HttpError(403, "Invalid API key")
|
||||
await CFAPIKeys.objects.filter(id=api_key_instance.id).aupdate(last_used=Now())
|
||||
request.user = api_key_instance.user_id
|
||||
request.tier = api_key_instance.tier
|
||||
request.api_key_id = api_key_instance.id
|
||||
request.organization_id = api_key_instance.organization_id
|
||||
request.should_log_features = await check_subscription_status(
|
||||
request, user_id=request.user, tier=request.tier, organization_id=request.organization_id
|
||||
)
|
||||
return token
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Core infrastructure for multi-language support."""
|
||||
|
||||
from .errors import HandlerError, HandlerNotImplementedError, LanguageNotFoundError, PipelineError, ValidationError
|
||||
from .errors import HandlerError, HandlerNotImplementedError, LanguageNotFoundError
|
||||
from .pipeline import PipelineContext
|
||||
from .protocols import CodeRepairProtocol, LanguageHandler, OptimizerProtocol, TestGenProtocol
|
||||
from .registry import register_handler, registry
|
||||
|
|
@ -13,9 +13,7 @@ __all__ = [
|
|||
"LanguageNotFoundError",
|
||||
"OptimizerProtocol",
|
||||
"PipelineContext",
|
||||
"PipelineError",
|
||||
"TestGenProtocol",
|
||||
"ValidationError",
|
||||
"register_handler",
|
||||
"registry",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -29,22 +29,3 @@ class HandlerNotImplementedError(HandlerError):
|
|||
f"Check handler.capabilities['{feature}']."
|
||||
)
|
||||
super().__init__(message, {"language_id": language_id, "feature": feature, "capability": capability})
|
||||
|
||||
|
||||
class PipelineError(HandlerError):
|
||||
"""Raised when a pipeline stage fails."""
|
||||
|
||||
def __init__(self, stage: str, message: str, cause: Exception | None = None) -> None:
|
||||
full_message = f"Pipeline error in {stage}: {message}"
|
||||
if cause:
|
||||
full_message += f" (caused by {type(cause).__name__}: {cause!s})"
|
||||
super().__init__(full_message, {"stage": stage, "cause": cause})
|
||||
self.__cause__ = cause
|
||||
|
||||
|
||||
class ValidationError(HandlerError):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
def __init__(self, field: str, message: str, value: object = None) -> None:
|
||||
full_message = f"Validation failed for '{field}': {message}"
|
||||
super().__init__(full_message, {"field": field, "value": value})
|
||||
|
|
|
|||
|
|
@ -11,20 +11,20 @@ import re
|
|||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import sentry_sdk
|
||||
from ninja.errors import HttpError
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from authapp.user import get_user_by_id
|
||||
from core.languages.java.prompts.optimizer import get_system_prompt, get_user_prompt
|
||||
from core.log_features.log_event import get_or_create_optimization_event
|
||||
from core.log_features.log_features import log_features
|
||||
from core.shared.context_helpers import group_code, split_markdown_code
|
||||
from core.shared.context_helpers import extract_code_and_explanation, group_code, split_markdown_code
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_models import OptimizeSchema
|
||||
from core.shared.optimizer_schemas import (
|
||||
|
|
@ -50,47 +50,6 @@ def is_multi_context_java(source_code: str) -> bool:
|
|||
return source_code.count("```java:") >= 1
|
||||
|
||||
|
||||
def extract_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
|
||||
"""Extract code and explanation from LLM response.
|
||||
|
||||
Args:
|
||||
content: The raw LLM response content
|
||||
is_multi_file: Whether to expect multi-file format
|
||||
|
||||
Returns:
|
||||
Tuple of (code, explanation) where code is a string for single file
|
||||
or dict[str, str] for multi-file
|
||||
|
||||
"""
|
||||
if is_multi_file:
|
||||
# Extract all code blocks with file paths
|
||||
matches = JAVA_CODE_WITH_PATH_PATTERN.findall(content)
|
||||
if matches:
|
||||
file_to_code: dict[str, str] = {}
|
||||
first_match_pos = content.find("```")
|
||||
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
||||
|
||||
for file_path, code in matches:
|
||||
file_to_code[file_path.strip()] = code.strip()
|
||||
|
||||
return file_to_code, explanation
|
||||
|
||||
# Fall back to single file extraction
|
||||
return extract_code_and_explanation(content, is_multi_file=False)
|
||||
|
||||
# Single file extraction
|
||||
match = JAVA_CODE_PATTERN.search(content)
|
||||
if match:
|
||||
code = match.group(1).strip()
|
||||
# Explanation is everything before the code block
|
||||
explanation_end = match.start()
|
||||
explanation = content[:explanation_end].strip()
|
||||
return code, explanation
|
||||
|
||||
# No code block found, return empty code
|
||||
return "", content
|
||||
|
||||
|
||||
def _extract_demo_context(source_code: str) -> tuple[str, str, str, str]:
|
||||
"""Extract package, class name, exception type, and extra imports from the demo source code.
|
||||
|
||||
|
|
@ -379,7 +338,7 @@ async def optimize_java_code_single(
|
|||
]
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="optimization",
|
||||
|
|
@ -388,26 +347,25 @@ async def optimize_java_code_single(
|
|||
python_version="N/A", # Not applicable for Java
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("LLM Code Generation error in Java optimizer")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for Java source:\n{source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
if output.raw_response.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-openai-usage",
|
||||
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json(), "language": "java"},
|
||||
)
|
||||
|
||||
# Extract code and explanation from response
|
||||
code, explanation = extract_code_and_explanation(output.content, is_multi_file)
|
||||
code, explanation = extract_code_and_explanation(
|
||||
output.content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file
|
||||
)
|
||||
|
||||
if not code:
|
||||
logging.warning("No valid Java code extracted from LLM response")
|
||||
|
|
@ -417,7 +375,7 @@ async def optimize_java_code_single(
|
|||
code_to_validate = code if isinstance(code, str) else "\n".join(code.values())
|
||||
is_valid, error = validate_java_syntax(code_to_validate)
|
||||
if not is_valid:
|
||||
logging.warning(f"Java code failed syntax validation: {error}")
|
||||
logging.warning("Java code failed syntax validation: %s", error)
|
||||
return None, llm_cost, optimize_model.name
|
||||
|
||||
# Format the response
|
||||
|
|
@ -511,14 +469,15 @@ async def optimize_java(
|
|||
return 400, OptimizeErrorResponseSchema(error="Invalid trace_id")
|
||||
|
||||
user_id = request.user
|
||||
user = await get_user_by_id(user_id)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
user_task = tg.create_task(get_user_by_id(user_id))
|
||||
event_task = tg.create_task(
|
||||
get_or_create_optimization_event(trace_id=data.trace_id, event_type="no-pr", user_id=user_id)
|
||||
)
|
||||
user = user_task.result()
|
||||
if user is None:
|
||||
raise HttpError(401, "User not found")
|
||||
|
||||
# Log the event
|
||||
optimization_event, _created = await get_or_create_optimization_event(
|
||||
trace_id=data.trace_id, event_type="no-pr", user_id=user_id
|
||||
)
|
||||
optimization_event, _created = event_task.result()
|
||||
# Determine Java version
|
||||
language_version = data.language_version or "17"
|
||||
|
||||
|
|
@ -540,8 +499,7 @@ async def optimize_java(
|
|||
)
|
||||
|
||||
# Track analytics
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-java",
|
||||
properties={
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZE_MODEL
|
||||
from aiservice.validators.java_validator import validate_java_syntax
|
||||
from core.languages.java.optimizer import (
|
||||
_build_demo_optimizations,
|
||||
|
|
@ -26,14 +27,19 @@ from core.languages.java.optimizer import (
|
|||
_extract_demo_context,
|
||||
is_multi_context_java,
|
||||
)
|
||||
from core.shared.context_helpers import group_code, split_markdown_code
|
||||
from core.shared.context_helpers import (
|
||||
extract_code_and_explanation,
|
||||
group_code,
|
||||
normalize_c_style_code,
|
||||
split_markdown_code,
|
||||
)
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_schemas import OptimizeResponseItemSchema, OptimizeResponseSchema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
|
||||
# Get the prompts directory
|
||||
|
|
@ -64,53 +70,7 @@ Use this data to identify performance bottlenecks and focus your optimization on
|
|||
|
||||
|
||||
def extract_java_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
|
||||
"""Extract Java code and explanation from LLM response.
|
||||
|
||||
Args:
|
||||
content: The raw LLM response content
|
||||
is_multi_file: Whether to expect multi-file format
|
||||
|
||||
Returns:
|
||||
Tuple of (code, explanation) where code is a string for single file
|
||||
or dict[str, str] for multi-file
|
||||
|
||||
"""
|
||||
if is_multi_file:
|
||||
# Extract all code blocks with file paths
|
||||
matches = JAVA_CODE_WITH_PATH_PATTERN.findall(content)
|
||||
if matches:
|
||||
file_to_code: dict[str, str] = {}
|
||||
first_match_pos = content.find("```")
|
||||
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
||||
|
||||
for file_path, code in matches:
|
||||
file_to_code[file_path.strip()] = code.strip()
|
||||
|
||||
return file_to_code, explanation
|
||||
|
||||
# Fall back to single file extraction
|
||||
return extract_java_code_and_explanation(content, is_multi_file=False)
|
||||
|
||||
# Single file extraction
|
||||
match = JAVA_CODE_PATTERN.search(content)
|
||||
if match:
|
||||
code = match.group(1).strip()
|
||||
explanation_end = match.start()
|
||||
explanation = content[:explanation_end].strip()
|
||||
return code, explanation
|
||||
|
||||
return "", content
|
||||
|
||||
|
||||
def normalize_java_code(code: str) -> str:
|
||||
"""Normalize Java code for comparison."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
code = " ".join(code.split())
|
||||
return code
|
||||
return extract_code_and_explanation(content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file)
|
||||
|
||||
|
||||
async def hack_for_demo_java_lp(source_code: str) -> OptimizeResponseSchema:
|
||||
|
|
@ -157,7 +117,9 @@ async def optimize_java_code_line_profiler_single(
|
|||
if is_multi_file:
|
||||
original_file_to_code = split_markdown_code(source_code, "java")
|
||||
logging.info(
|
||||
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
||||
"Multi-file context detected with %d files: %s",
|
||||
len(original_file_to_code),
|
||||
list(original_file_to_code.keys()),
|
||||
)
|
||||
|
||||
# Format system prompt with language version
|
||||
|
|
@ -224,7 +186,7 @@ Here is the code to optimize:
|
|||
]
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="line_profiler",
|
||||
|
|
@ -233,13 +195,11 @@ Here is the code to optimize:
|
|||
python_version=f"Java {language_version}",
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("LLM Code Generation error in Java line profiler optimizer")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
|
|
@ -279,7 +239,7 @@ Here is the code to optimize:
|
|||
merged_file_to_code[file_path] = original_code
|
||||
else:
|
||||
merged_file_to_code[file_path] = new_code
|
||||
if normalize_java_code(new_code) != normalize_java_code(original_code):
|
||||
if normalize_c_style_code(new_code) != normalize_c_style_code(original_code):
|
||||
has_changes = True
|
||||
else:
|
||||
# File not in response, keep original
|
||||
|
|
@ -312,7 +272,7 @@ Here is the code to optimize:
|
|||
return None, llm_cost, optimize_model.name
|
||||
|
||||
# Check that the code is actually different from the original
|
||||
if normalize_java_code(optimized_code) == normalize_java_code(source_code):
|
||||
if normalize_c_style_code(optimized_code) == normalize_c_style_code(source_code):
|
||||
debug_log_sensitive_data("Generated code identical to original")
|
||||
return None, llm_cost, optimize_model.name
|
||||
|
||||
|
|
|
|||
|
|
@ -14,14 +14,13 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
import stamina
|
||||
from openai import OpenAIError
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.log_features.log_event import update_optimization_cost
|
||||
from core.shared.testgen_models import (
|
||||
|
|
@ -32,7 +31,7 @@ from core.shared.testgen_models import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
from aiservice.validators.java_validator import validate_java_syntax
|
||||
|
||||
|
|
@ -314,7 +313,6 @@ def _has_test_functions(code: str) -> bool:
|
|||
return _TEST_FUNC_RE.search(code) is not None
|
||||
|
||||
|
||||
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
||||
async def generate_and_validate_java_test_code(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: LLM,
|
||||
|
|
@ -348,22 +346,17 @@ async def generate_and_validate_java_test_code(
|
|||
obs_context: dict | None = (
|
||||
{"call_sequence": call_sequence, "test_index": test_index} if call_sequence is not None else None
|
||||
)
|
||||
try:
|
||||
output = await call_llm(
|
||||
llm=model,
|
||||
messages=messages,
|
||||
call_type="testgen",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
python_version="N/A", # Not applicable for Java
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("LLM Code Generation error")
|
||||
sentry_sdk.capture_exception(e)
|
||||
raise
|
||||
output = await llm_client.call(
|
||||
llm=model,
|
||||
messages=messages,
|
||||
call_type="testgen",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
python_version="N/A", # Not applicable for Java
|
||||
context=obs_context,
|
||||
)
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, model)
|
||||
llm_cost = output.cost
|
||||
cost_tracker.append(llm_cost)
|
||||
|
||||
debug_log_sensitive_data(f"LLM testgen response:\n{output.content}")
|
||||
|
|
@ -401,7 +394,7 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
package_pattern = re.compile(r"^\s*package\s+([\w.]+)\s*;", re.MULTILINE)
|
||||
match = package_pattern.search(source_code)
|
||||
if match:
|
||||
logging.debug(f"Extracted package from declaration: {match.group(1)}")
|
||||
logging.debug("Extracted package from declaration: %s", match.group(1))
|
||||
return match.group(1)
|
||||
|
||||
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
|
||||
|
|
@ -411,7 +404,7 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
file_path = markdown_match.group(1).strip()
|
||||
package = _extract_package_from_path(file_path)
|
||||
if package:
|
||||
logging.debug(f"Extracted package from markdown header: {package}")
|
||||
logging.debug("Extracted package from markdown header: %s", package)
|
||||
return package
|
||||
|
||||
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
|
||||
|
|
@ -420,10 +413,10 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
file_match = file_comment_pattern.search(source_code)
|
||||
if file_match:
|
||||
file_path = file_match.group(1).strip()
|
||||
logging.debug(f"Found file comment: {file_path}")
|
||||
logging.debug("Found file comment: %s", file_path)
|
||||
package = _extract_package_from_path(file_path)
|
||||
if package:
|
||||
logging.debug(f"Extracted package from file comment: {package}")
|
||||
logging.debug("Extracted package from file comment: %s", package)
|
||||
return package
|
||||
|
||||
# Fourth try: infer package from import statements (last resort)
|
||||
|
|
@ -439,7 +432,7 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
if internal_imports:
|
||||
# Use the shortest import path as a hint
|
||||
internal_imports.sort(key=len)
|
||||
logging.debug(f"Inferred package from imports: {internal_imports[0]}")
|
||||
logging.debug("Inferred package from imports: %s", internal_imports[0])
|
||||
return internal_imports[0]
|
||||
|
||||
logging.warning("Could not extract package name from source code")
|
||||
|
|
@ -1360,7 +1353,7 @@ async def testgen_java(
|
|||
request: AuthenticatedRequest, data: TestGenSchema
|
||||
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
||||
"""Generate Java tests using LLMs."""
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-java-called")
|
||||
ph(request.user, "aiservice-testgen-java-called")
|
||||
|
||||
# Validate request
|
||||
if not data.function_to_optimize:
|
||||
|
|
@ -1376,7 +1369,7 @@ async def testgen_java(
|
|||
|
||||
try:
|
||||
debug_log_sensitive_data(f"Generating Java tests for function {data.function_to_optimize.function_name}")
|
||||
logging.info(f"Generating Java tests for function {data.function_to_optimize.function_name}")
|
||||
logging.info("Generating Java tests for function %s", data.function_to_optimize.function_name)
|
||||
|
||||
# Extract class and package info from source code (more reliable than qualified_name)
|
||||
source_code = data.source_code_being_tested
|
||||
|
|
@ -1395,7 +1388,11 @@ async def testgen_java(
|
|||
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
|
||||
|
||||
logging.info(
|
||||
f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
|
||||
"Java testgen: package=%s, class=%s, module_path=%s, framework=%s",
|
||||
package_name,
|
||||
class_name,
|
||||
module_path,
|
||||
test_framework,
|
||||
)
|
||||
debug_log_sensitive_data(
|
||||
f"Extracted: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
|
||||
|
|
@ -1430,8 +1427,7 @@ async def testgen_java(
|
|||
|
||||
# Track analytics
|
||||
total_cost = sum(cost_tracker)
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
request.user,
|
||||
f"aiservice-testgen-{posthog_event_suffix}success",
|
||||
properties={
|
||||
|
|
@ -1453,11 +1449,11 @@ async def testgen_java(
|
|||
)
|
||||
|
||||
except TestGenerationFailedError as e:
|
||||
logging.warning(f"Java test generation failed: {e}")
|
||||
logging.warning("Java test generation failed: %s", e)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 400, TestGenErrorResponseSchema(error=str(e))
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logging.warning(f"Java test generation error: {e}")
|
||||
logging.warning("Java test generation error: %s", e)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 400, TestGenErrorResponseSchema(error=f"Failed to generate valid tests: {e}")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ from aiservice.analytics.posthog import ph
|
|||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from authapp.user import get_user_by_id
|
||||
|
|
@ -27,7 +28,7 @@ from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_c
|
|||
from core.languages.js_ts.prompts.optimizer import get_system_prompt, get_user_prompt
|
||||
from core.log_features.log_event import get_or_create_optimization_event
|
||||
from core.log_features.log_features import log_features
|
||||
from core.shared.context_helpers import group_code
|
||||
from core.shared.context_helpers import extract_code_and_explanation, group_code, normalize_c_style_code
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_models import OptimizedCandidateSource, OptimizeSchema
|
||||
from core.shared.optimizer_schemas import (
|
||||
|
|
@ -49,50 +50,6 @@ JS_CODE_WITH_PATH_PATTERN = re.compile(
|
|||
)
|
||||
|
||||
|
||||
def extract_code_and_explanation(
|
||||
content: str, is_multi_file: bool = False, language: str = "javascript"
|
||||
) -> tuple[str | dict[str, str], str]:
|
||||
"""Extract code and explanation from LLM response.
|
||||
|
||||
Args:
|
||||
content: The raw LLM response content
|
||||
is_multi_file: Whether to expect multi-file format
|
||||
language: The language (javascript or typescript)
|
||||
|
||||
Returns:
|
||||
Tuple of (code, explanation) where code is a string for single file
|
||||
or dict[str, str] for multi-file
|
||||
|
||||
"""
|
||||
if is_multi_file:
|
||||
# Extract all code blocks with file paths
|
||||
matches = JS_CODE_WITH_PATH_PATTERN.findall(content)
|
||||
if matches:
|
||||
file_to_code: dict[str, str] = {}
|
||||
first_match_pos = content.find("```")
|
||||
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
||||
|
||||
for file_path, code in matches:
|
||||
file_to_code[file_path.strip()] = code.strip()
|
||||
|
||||
return file_to_code, explanation
|
||||
|
||||
# Fall back to single file extraction
|
||||
return extract_code_and_explanation(content, is_multi_file=False, language=language)
|
||||
|
||||
# Single file extraction
|
||||
match = JS_CODE_PATTERN.search(content)
|
||||
if match:
|
||||
code = match.group(1).strip()
|
||||
# Explanation is everything before the code block
|
||||
explanation_end = match.start()
|
||||
explanation = content[:explanation_end].strip()
|
||||
return code, explanation
|
||||
|
||||
# No code block found, return empty code
|
||||
return "", content
|
||||
|
||||
|
||||
async def optimize_javascript_code_single(
|
||||
user_id: str,
|
||||
source_code: str,
|
||||
|
|
@ -123,7 +80,7 @@ async def optimize_javascript_code_single(
|
|||
"""
|
||||
lang_name = "TypeScript" if language == "typescript" else "JavaScript"
|
||||
code_block_tag = "typescript" if language == "typescript" else "javascript"
|
||||
logging.info(f"/optimize: Optimizing {lang_name} code.")
|
||||
logging.info("/optimize: Optimizing %s code.", lang_name)
|
||||
debug_log_sensitive_data(f"Optimizing {lang_name} code for user {user_id}:\n{source_code}")
|
||||
|
||||
# Check if source code is multi-file format
|
||||
|
|
@ -133,7 +90,9 @@ async def optimize_javascript_code_single(
|
|||
if is_multi_file:
|
||||
original_file_to_code = split_markdown_code(source_code, language)
|
||||
logging.info(
|
||||
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
||||
"Multi-file context detected with %d files: %s",
|
||||
len(original_file_to_code),
|
||||
list(original_file_to_code.keys()),
|
||||
)
|
||||
|
||||
# Get language-appropriate prompts (TypeScript uses same prompts as JavaScript)
|
||||
|
|
@ -187,7 +146,7 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
]
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="optimization",
|
||||
|
|
@ -196,13 +155,11 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
python_version=language_version, # Reusing python_version field for language version
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("LLM Code Generation error in JavaScript optimizer")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
|
|
@ -215,7 +172,7 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
|
||||
# Extract code and explanation from response
|
||||
extracted_code, explanation = extract_code_and_explanation(
|
||||
output.content, is_multi_file=is_multi_file, language=language
|
||||
output.content, JS_CODE_PATTERN, JS_CODE_WITH_PATH_PATTERN, is_multi_file
|
||||
)
|
||||
|
||||
if not extracted_code:
|
||||
|
|
@ -248,7 +205,7 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
merged_file_to_code[file_path] = original_code
|
||||
else:
|
||||
merged_file_to_code[file_path] = new_code
|
||||
if _normalize_code(new_code) != _normalize_code(original_code):
|
||||
if normalize_c_style_code(new_code) != normalize_c_style_code(original_code):
|
||||
has_changes = True
|
||||
else:
|
||||
# File not in response, keep original
|
||||
|
|
@ -284,7 +241,7 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
return None, llm_cost, optimize_model.name
|
||||
|
||||
# Check that the code is actually different from the original
|
||||
if _normalize_code(optimized_code) == _normalize_code(source_code):
|
||||
if normalize_c_style_code(optimized_code) == normalize_c_style_code(source_code):
|
||||
debug_log_sensitive_data("Generated code identical to original")
|
||||
return None, llm_cost, optimize_model.name
|
||||
|
||||
|
|
@ -302,17 +259,6 @@ You MUST output the target file. You may also output helper files if you optimiz
|
|||
return result, llm_cost, optimize_model.name
|
||||
|
||||
|
||||
def _normalize_code(code: str) -> str:
|
||||
"""Normalize code for comparison (remove comments and whitespace)."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
code = " ".join(code.split())
|
||||
return code
|
||||
|
||||
|
||||
async def optimize_javascript_code(
|
||||
user_id: str,
|
||||
source_code: str,
|
||||
|
|
@ -376,7 +322,7 @@ async def optimize_javascript_code(
|
|||
total_cost += cost
|
||||
if result is not None:
|
||||
# Deduplicate by normalized code
|
||||
normalized = _normalize_code(result.source_code)
|
||||
normalized = normalize_c_style_code(result.source_code)
|
||||
if normalized not in seen_code:
|
||||
seen_code.add(normalized)
|
||||
optimization_results.append(result)
|
||||
|
|
@ -454,7 +400,7 @@ async def optimize_javascript(
|
|||
validate_javascript_request_data(data)
|
||||
except HttpError as e:
|
||||
e.add_note(f"JavaScript optimizer request validation error: {e.status_code} {e.message}")
|
||||
logging.exception(f"JavaScript optimizer request validation error: {e.message}. trace_id={data.trace_id}")
|
||||
logging.exception("JavaScript optimizer request validation error: %s. trace_id=%s", e.message, data.trace_id)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
|
||||
|
||||
|
|
@ -476,7 +422,7 @@ async def optimize_javascript(
|
|||
if data.current_username is None:
|
||||
user_task = tg.create_task(get_user_by_id(request.user))
|
||||
except Exception as e:
|
||||
logging.exception(f"Error during JavaScript optimization task. trace_id={data.trace_id}")
|
||||
logging.exception("Error during JavaScript optimization task. trace_id=%s", data.trace_id)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
||||
|
||||
|
|
@ -489,7 +435,7 @@ async def optimize_javascript(
|
|||
if len(optimization_response_items) == 0:
|
||||
ph(request.user, "aiservice-optimize-no-optimizations-found", properties={"language": language})
|
||||
debug_log_sensitive_data(f"No JavaScript optimizations found for source:\n{data.source_code}")
|
||||
logging.error(f"Could not generate any JavaScript optimizations. trace_id={data.trace_id}")
|
||||
logging.error("Could not generate any JavaScript optimizations. trace_id=%s", data.trace_id)
|
||||
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
||||
|
||||
ph(
|
||||
|
|
@ -525,7 +471,7 @@ async def optimize_javascript(
|
|||
dependency_code=data.dependency_code,
|
||||
optimizations_post={opt.optimization_id: opt.source_code for opt in optimization_response_items},
|
||||
explanations_post={opt.optimization_id: opt.explanation for opt in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
experiment_metadata=data.experiment_metadata or None,
|
||||
optimizations_origin={
|
||||
opt.optimization_id: {
|
||||
"source": OptimizedCandidateSource.OPTIMIZE,
|
||||
|
|
|
|||
|
|
@ -18,17 +18,18 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZE_MODEL
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_context_ts
|
||||
from core.shared.context_helpers import group_code
|
||||
from core.shared.context_helpers import extract_code_and_explanation, group_code, normalize_c_style_code
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_schemas import OptimizeResponseItemSchema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
|
||||
# Get the prompts directory
|
||||
|
|
@ -61,53 +62,7 @@ Use this data to identify performance bottlenecks and focus your optimization on
|
|||
|
||||
|
||||
def extract_js_code_and_explanation(content: str, is_multi_file: bool = False) -> tuple[str | dict[str, str], str]:
|
||||
"""Extract JavaScript code and explanation from LLM response.
|
||||
|
||||
Args:
|
||||
content: The raw LLM response content
|
||||
is_multi_file: Whether to expect multi-file format
|
||||
|
||||
Returns:
|
||||
Tuple of (code, explanation) where code is a string for single file
|
||||
or dict[str, str] for multi-file
|
||||
|
||||
"""
|
||||
if is_multi_file:
|
||||
# Extract all code blocks with file paths
|
||||
matches = JS_CODE_WITH_PATH_PATTERN.findall(content)
|
||||
if matches:
|
||||
file_to_code: dict[str, str] = {}
|
||||
first_match_pos = content.find("```")
|
||||
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
||||
|
||||
for file_path, code in matches:
|
||||
file_to_code[file_path.strip()] = code.strip()
|
||||
|
||||
return file_to_code, explanation
|
||||
|
||||
# Fall back to single file extraction
|
||||
return extract_js_code_and_explanation(content, is_multi_file=False)
|
||||
|
||||
# Single file extraction
|
||||
match = JS_CODE_PATTERN.search(content)
|
||||
if match:
|
||||
code = match.group(1).strip()
|
||||
explanation_end = match.start()
|
||||
explanation = content[:explanation_end].strip()
|
||||
return code, explanation
|
||||
|
||||
return "", content
|
||||
|
||||
|
||||
def normalize_js_code(code: str) -> str:
|
||||
"""Normalize JavaScript code for comparison."""
|
||||
# Remove single-line comments
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
# Remove multi-line comments
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
# Normalize whitespace
|
||||
code = " ".join(code.split())
|
||||
return code
|
||||
return extract_code_and_explanation(content, JS_CODE_PATTERN, JS_CODE_WITH_PATH_PATTERN, is_multi_file)
|
||||
|
||||
|
||||
async def optimize_javascript_code_line_profiler_single(
|
||||
|
|
@ -124,7 +79,7 @@ async def optimize_javascript_code_line_profiler_single(
|
|||
"""Optimize JavaScript/TypeScript code using LLMs with line profiler guidance."""
|
||||
lang_name = "TypeScript" if language == "typescript" else "JavaScript"
|
||||
code_block_tag = "typescript" if language == "typescript" else "javascript"
|
||||
logging.info(f"/optimize-line-profiler: Optimizing {lang_name} code.")
|
||||
logging.info("/optimize-line-profiler: Optimizing %s code.", lang_name)
|
||||
debug_log_sensitive_data(f"Optimizing {lang_name} code for user {user_id}:\n{source_code}")
|
||||
|
||||
# Check if source code is multi-file format
|
||||
|
|
@ -134,7 +89,9 @@ async def optimize_javascript_code_line_profiler_single(
|
|||
if is_multi_file:
|
||||
original_file_to_code = split_markdown_code(source_code, language)
|
||||
logging.info(
|
||||
f"Multi-file context detected with {len(original_file_to_code)} files: {list(original_file_to_code.keys())}"
|
||||
"Multi-file context detected with %d files: %s",
|
||||
len(original_file_to_code),
|
||||
list(original_file_to_code.keys()),
|
||||
)
|
||||
|
||||
# Format system prompt with language version
|
||||
|
|
@ -201,7 +158,7 @@ Here is the code to optimize:
|
|||
]
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="line_profiler",
|
||||
|
|
@ -210,13 +167,11 @@ Here is the code to optimize:
|
|||
python_version=language_version, # Reusing python_version field for language version
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception(f"LLM Code Generation error in {lang_name} line profiler optimizer")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
|
|
@ -260,7 +215,7 @@ Here is the code to optimize:
|
|||
merged_file_to_code[file_path] = original_code
|
||||
else:
|
||||
merged_file_to_code[file_path] = new_code
|
||||
if normalize_js_code(new_code) != normalize_js_code(original_code):
|
||||
if normalize_c_style_code(new_code) != normalize_c_style_code(original_code):
|
||||
has_changes = True
|
||||
else:
|
||||
# File not in response, keep original
|
||||
|
|
@ -296,7 +251,7 @@ Here is the code to optimize:
|
|||
return None, llm_cost, optimize_model.name
|
||||
|
||||
# Check that the code is actually different from the original
|
||||
if normalize_js_code(optimized_code) == normalize_js_code(source_code):
|
||||
if normalize_c_style_code(optimized_code) == normalize_c_style_code(source_code):
|
||||
debug_log_sensitive_data("Generated code identical to original")
|
||||
return None, llm_cost, optimize_model.name
|
||||
|
||||
|
|
|
|||
|
|
@ -12,15 +12,14 @@ from pathlib import Path
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
import stamina
|
||||
from ninja.errors import HttpError
|
||||
from openai import OpenAIError
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.log_features.log_event import update_optimization_cost
|
||||
|
|
@ -34,7 +33,7 @@ from core.shared.testgen_models import (
|
|||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
_TEST_FUNC_RE = re.compile(r"(?:test|it)\s*\(\s*['\"]")
|
||||
|
||||
|
|
@ -292,7 +291,6 @@ def _has_test_functions(code: str) -> bool:
|
|||
return _TEST_FUNC_RE.search(code) is not None
|
||||
|
||||
|
||||
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
||||
async def generate_and_validate_js_test_code(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: LLM,
|
||||
|
|
@ -329,7 +327,7 @@ async def generate_and_validate_js_test_code(
|
|||
{"call_sequence": call_sequence, "test_index": test_index} if call_sequence is not None else None
|
||||
)
|
||||
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=model,
|
||||
messages=messages,
|
||||
call_type="test_generation",
|
||||
|
|
@ -339,7 +337,7 @@ async def generate_and_validate_js_test_code(
|
|||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, model)
|
||||
cost = response.cost
|
||||
cost_tracker.append(cost)
|
||||
|
||||
debug_log_sensitive_data(
|
||||
|
|
@ -358,7 +356,6 @@ async def generate_and_validate_js_test_code(
|
|||
return validated_code
|
||||
|
||||
|
||||
@stamina.retry(on=TestGenerationFailedError, attempts=2)
|
||||
async def generate_javascript_tests_from_function(
|
||||
user_id: str,
|
||||
function_name: str,
|
||||
|
|
@ -504,7 +501,7 @@ async def testgen_javascript(
|
|||
execute_model = HAIKU_MODEL
|
||||
model_source = "Anthropic"
|
||||
|
||||
logging.info(f"Using {model_source} model ({execute_model.name}) for JavaScript test_index {test_index}")
|
||||
logging.info("Using %s model (%s) for JavaScript test_index %s", model_source, execute_model.name, test_index)
|
||||
|
||||
(
|
||||
generated_test_source,
|
||||
|
|
@ -554,6 +551,6 @@ async def testgen_javascript(
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(f"JavaScript test generation failed. trace_id={data.trace_id}")
|
||||
logging.exception("JavaScript test generation failed. trace_id=%s", data.trace_id)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, TestGenErrorResponseSchema(error="Error generating JavaScript tests. Internal server error.")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -14,7 +13,8 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import ADAPTIVE_OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import ADAPTIVE_OPTIMIZE_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.log_features.log_event import update_optimization_cost
|
||||
from core.log_features.log_features import log_features
|
||||
|
|
@ -26,7 +26,7 @@ from .adaptive_optimizer_context import AdaptiveOptContext, AdaptiveOptContextDa
|
|||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
|
||||
class AdaptiveOptErrorResponseSchema(Schema):
|
||||
|
|
@ -60,19 +60,16 @@ async def perform_adaptive_optimize(
|
|||
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
||||
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model, messages=messages, call_type="adaptive_optimize", trace_id=trace_id, user_id=user_id
|
||||
)
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
except Exception as e:
|
||||
logging.exception("Claude Code Generation error in adaptive_optimize")
|
||||
sentry_sdk.capture_exception(e)
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
||||
return None, None, AdaptiveOptErrorResponseSchema(error=str(e))
|
||||
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
|
||||
if output.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"adaptive_optimize-usage",
|
||||
properties={
|
||||
|
|
@ -113,7 +110,7 @@ async def perform_adaptive_optimize(
|
|||
async def adaptive_optimize(
|
||||
request: AuthenticatedRequest, data: AdaptiveOptRequestSchema
|
||||
) -> tuple[int, OptimizeResponseItemSchema | AdaptiveOptErrorResponseSchema]:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-adaptive_optimize-called")
|
||||
ph(request.user, "aiservice-adaptive_optimize-called")
|
||||
ctx_data = AdaptiveOptContextData(
|
||||
original_source_code=data.original_source_code, attempts=data.candidates, python_version_str="3.12"
|
||||
)
|
||||
|
|
@ -133,22 +130,25 @@ async def adaptive_optimize(
|
|||
if llm_cost is not None:
|
||||
total_llm_cost += llm_cost
|
||||
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={
|
||||
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.source_code
|
||||
},
|
||||
explanations_raw={
|
||||
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.explanation
|
||||
},
|
||||
optimizations_origin={
|
||||
adaptive_optimization_candidate.optimization_id: {
|
||||
"source": OptimizedCandidateSource.ADAPTIVE,
|
||||
"parent": adaptive_optimization_candidate.parent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user))
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
tg.create_task(
|
||||
log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={
|
||||
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.source_code
|
||||
},
|
||||
explanations_raw={
|
||||
adaptive_optimization_candidate.optimization_id: adaptive_optimization_candidate.explanation
|
||||
},
|
||||
optimizations_origin={
|
||||
adaptive_optimization_candidate.optimization_id: {
|
||||
"source": OptimizedCandidateSource.ADAPTIVE,
|
||||
"parent": adaptive_optimization_candidate.parent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
return 200, adaptive_optimization_candidate
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -15,7 +14,8 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import CODE_REPAIR_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import CODE_REPAIR_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.languages.python.code_repair.code_repair_context import (
|
||||
CodeRepairContext,
|
||||
|
|
@ -30,7 +30,7 @@ from core.shared.optimizer_schemas import OptimizeResponseItemSchema
|
|||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
code_repair_api = NinjaAPI(urls_namespace="code_repair")
|
||||
|
||||
|
|
@ -70,17 +70,14 @@ async def code_repair( # noqa: D417
|
|||
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
||||
debug_log_sensitive_data(f"This was the user prompt\n {user_prompt}\n")
|
||||
try:
|
||||
output = await call_llm(llm=optimize_model, messages=messages)
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
output = await llm_client.call(llm=optimize_model, messages=messages)
|
||||
llm_cost = output.cost
|
||||
except Exception as e:
|
||||
logging.exception("Claude Code Generation error in code_repair")
|
||||
sentry_sdk.capture_exception(e)
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
||||
return CodeRepairErrorResponseSchema(error=str(e))
|
||||
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
|
||||
if output.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"code_repair-usage",
|
||||
properties={
|
||||
|
|
@ -133,7 +130,7 @@ class CodeRepairIntermediateResponseItemschema(Schema):
|
|||
async def repair(
|
||||
request: AuthenticatedRequest, data: CodeRepairRequestSchema
|
||||
) -> tuple[int, OptimizeResponseItemSchema | CodeRepairErrorResponseSchema]:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-code_repair-called")
|
||||
ph(request.user, "aiservice-code_repair-called")
|
||||
ctx_data = CodeRepairContextData(
|
||||
original_source_code=data.original_source_code,
|
||||
modified_source_code=data.modified_source_code,
|
||||
|
|
@ -165,22 +162,23 @@ async def repair(
|
|||
debug_log_sensitive_data(f"Traceback: {exc}")
|
||||
return 500, CodeRepairErrorResponseSchema(error=str(exc))
|
||||
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
|
||||
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
|
||||
# explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
# optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||
optimizations_origin={
|
||||
code_repair_data.optimization_id: {
|
||||
"source": OptimizedCandidateSource.REPAIR,
|
||||
"parent": code_repair_data.parent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=request.user))
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
tg.create_task(
|
||||
log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
optimizations_raw={code_repair_data.optimization_id: code_repair_data.source_code},
|
||||
explanations_raw={code_repair_data.optimization_id: code_repair_data.explanation},
|
||||
optimizations_origin={
|
||||
code_repair_data.optimization_id: {
|
||||
"source": OptimizedCandidateSource.REPAIR,
|
||||
"parent": code_repair_data.parent_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
return 200, OptimizeResponseItemSchema(
|
||||
source_code=code_repair_data.source_code,
|
||||
optimization_id=code_repair_data.optimization_id,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pydantic import ValidationError
|
|||
|
||||
from core.languages.python.cst_utils import parse_module_to_cst
|
||||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
from core.languages.python.optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.diff_patches_utils.search_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.models import CodeAndExplanation
|
||||
from core.shared.context_helpers import group_code, is_markdown_structure_changed
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
|
|
@ -18,7 +15,8 @@ from aiservice.common.markdown_utils import wrap_code_in_markdown
|
|||
from aiservice.common.xml_utils import extract_xml_tag
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXPLANATIONS_MODEL, LLM, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXPLANATIONS_MODEL, LLM
|
||||
from core.languages.python.explanations.models import (
|
||||
ExplanationsErrorResponseSchema,
|
||||
ExplanationsResponseSchema,
|
||||
|
|
@ -91,7 +89,7 @@ async def explain_optimizations(
|
|||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=explanations_model,
|
||||
messages=messages,
|
||||
call_type="explanation",
|
||||
|
|
@ -99,17 +97,12 @@ async def explain_optimizations(
|
|||
user_id=user_id,
|
||||
context=obs_context,
|
||||
)
|
||||
await update_optimization_cost(
|
||||
trace_id=data.trace_id, cost=calculate_llm_cost(output.raw_response, explanations_model), user_id=user_id
|
||||
)
|
||||
await update_optimization_cost(trace_id=data.trace_id, cost=output.cost, user_id=user_id)
|
||||
except Exception as e:
|
||||
logging.exception("Failed to generate explanation")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return ExplanationsErrorResponseSchema(error=str(e))
|
||||
debug_log_sensitive_data(f"AIClient optimization response:\n{output.content}")
|
||||
if output.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-openai-usage",
|
||||
properties={
|
||||
|
|
@ -133,15 +126,15 @@ async def explain(
|
|||
request, # noqa: ANN001
|
||||
data: ExplanationsSchema,
|
||||
) -> tuple[int, ExplanationsResponseSchema | ExplanationsErrorResponseSchema]:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-explain-called")
|
||||
ph(request.user, "aiservice-explain-called")
|
||||
if not validate_trace_id(data.trace_id):
|
||||
return 400, ExplanationsErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
explanation_response = await explain_optimizations(request.user, data)
|
||||
if isinstance(explanation_response, ExplanationsErrorResponseSchema):
|
||||
await asyncio.to_thread(ph, request.user, "Explanation not generated, revert to old explanation")
|
||||
ph(request.user, "Explanation not generated, revert to old explanation")
|
||||
debug_log_sensitive_data("No explanation was generated")
|
||||
return 500, ExplanationsErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
||||
await asyncio.to_thread(ph, request.user, "explanation generated", properties={"explanation": explanation_response})
|
||||
ph(request.user, "explanation generated", properties={"explanation": explanation_response})
|
||||
# parse xml tag for explanation
|
||||
explanation = extract_xml_tag(explanation_response.explanation, "explain")
|
||||
if not explanation:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from core.languages.python.optimizer.optimizer import optimize_python
|
|||
from core.languages.python.testgen.generate import testgen_python
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.languages.python.code_repair.code_repair import (
|
||||
CodeRepairErrorResponseSchema,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from authapp.user import get_user_by_id
|
||||
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
|
||||
|
|
@ -63,7 +64,7 @@ async def jit_rewrite_python_code_single(
|
|||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=jit_rewrite_model,
|
||||
messages=messages,
|
||||
call_type="optimization",
|
||||
|
|
@ -72,16 +73,13 @@ async def jit_rewrite_python_code_single(
|
|||
python_version=python_version_str,
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI Code Generation error during jit rewrite.")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
||||
return None, None, jit_rewrite_model.name
|
||||
llm_cost = calculate_llm_cost(output.raw_response, jit_rewrite_model)
|
||||
llm_cost = output.cost
|
||||
debug_log_sensitive_data(f"OpenAIClient jit rewrite response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
if output.raw_response.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-jit_rewrite-openai-usage",
|
||||
properties={"model": jit_rewrite_model.name, "usage": output.raw_response.usage.json()},
|
||||
|
|
@ -199,7 +197,7 @@ async def jit_rewrite(
|
|||
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
|
||||
system_prompt, user_prompt, data.source_code, DiffMethod.NO_DIFF
|
||||
)
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-jit-rewrite-called")
|
||||
ph(request.user, "aiservice-jit-rewrite-called")
|
||||
try:
|
||||
python_version = validate_request_data(data, ctx)
|
||||
except HttpError as e:
|
||||
|
|
@ -234,7 +232,7 @@ async def jit_rewrite(
|
|||
data.current_username = str(user.github_username)
|
||||
|
||||
if len(jit_rewrite_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-jit-rewrite-no-optimizations-found")
|
||||
ph(request.user, "aiservice-jit-rewrite-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations (jit_rewrite). trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
|
||||
|
|
@ -245,8 +243,7 @@ async def jit_rewrite(
|
|||
len(data.source_code) if data.source_code else 0,
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
request.user,
|
||||
"aiservice-jit-rewrite-optimizations-found",
|
||||
properties={"num_optimizations": len(jit_rewrite_response_items)},
|
||||
|
|
@ -279,7 +276,7 @@ async def jit_rewrite(
|
|||
optimizations_post={cei.optimization_id: cei.source_code for cei in jit_rewrite_response_items},
|
||||
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in jit_rewrite_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
experiment_metadata=data.experiment_metadata or None,
|
||||
optimizations_origin={
|
||||
cei.optimization_id: {
|
||||
"source": OptimizedCandidateSource.JIT_REWRITE,
|
||||
|
|
@ -302,5 +299,5 @@ async def jit_rewrite(
|
|||
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
|
||||
|
||||
debug_log_sensitive_data_from_callable(log_response)
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-jit-rewrite-successful")
|
||||
ph(request.user, "aiservice-jit-rewrite-successful")
|
||||
return 200, response
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
|
@ -15,14 +14,15 @@ from packaging import version
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context, wrap_code_in_markdown
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import OPTIMIZATION_REVIEW_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZATION_REVIEW_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.log_features.log_event import update_optimization_cost, update_optimization_features_review
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
optimization_review_api = NinjaAPI(urls_namespace="optimization_review")
|
||||
|
||||
|
|
@ -178,7 +178,7 @@ async def get_optimization_review(
|
|||
optimization_review_model: LLM = OPTIMIZATION_REVIEW_MODEL,
|
||||
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
|
||||
"""Compute optimization review via Claude."""
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimization-review-called")
|
||||
ph(request.user, "aiservice-optimization-review-called")
|
||||
|
||||
try:
|
||||
messages = _build_optimization_review_messages(data)
|
||||
|
|
@ -189,7 +189,7 @@ async def get_optimization_review(
|
|||
if data.call_sequence is not None:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=optimization_review_model,
|
||||
messages=messages,
|
||||
call_type="optimization_review",
|
||||
|
|
@ -198,7 +198,7 @@ async def get_optimization_review(
|
|||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, optimization_review_model)
|
||||
cost = response.cost
|
||||
await update_optimization_cost(data.trace_id, cost, user_id=request.user)
|
||||
|
||||
review_text = response.content.strip()
|
||||
|
|
@ -249,7 +249,7 @@ async def get_optimization_review(
|
|||
debug_log_sensitive_data(f"Invalid response : {e}")
|
||||
return 500, OptimizationReviewErrorSchema(error="Invalid response")
|
||||
else:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimization-review-successful")
|
||||
ph(request.user, "aiservice-optimization-review-successful")
|
||||
return 200, review
|
||||
else:
|
||||
return 500, OptimizationReviewErrorSchema(error="Invalid response")
|
||||
|
|
@ -273,10 +273,10 @@ async def optimization_review(
|
|||
response_code, output = await get_optimization_review(request, data)
|
||||
try:
|
||||
if response_code == 200:
|
||||
review_event = output.review.value # ty:ignore[possibly-missing-attribute]
|
||||
review_explanation = output.review_explanation # ty:ignore[possibly-missing-attribute]
|
||||
review_event = output.review.value # ty:ignore[unresolved-attribute]
|
||||
review_explanation = output.review_explanation # ty:ignore[unresolved-attribute]
|
||||
else:
|
||||
review_event = output.error # ty:ignore[possibly-missing-attribute]
|
||||
review_event = output.error # ty:ignore[unresolved-attribute]
|
||||
review_explanation = ""
|
||||
await update_optimization_features_review(
|
||||
trace_id=data.trace_id,
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from core.languages.python.optimizer.diff_patches_utils.diff import (
|
|||
V4A_DIFF_FORMAT_PROMPT,
|
||||
DiffMethod,
|
||||
)
|
||||
from core.languages.python.optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.diff_patches_utils.search_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.diff_patches_utils.v4a_diff import V4ADiff
|
||||
from core.languages.python.optimizer.models import CodeExplanationAndID
|
||||
from core.languages.python.optimizer.postprocess import optimizations_postprocessing_pipeline
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from core.languages.python.cst_utils import parse_module_to_cst
|
|||
from aiservice.common.markdown_utils import split_markdown_code, wrap_code_in_markdown
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from core.languages.python.optimizer.context_utils.context_helpers import is_multi_context
|
||||
from core.languages.python.optimizer.diff_patches_utils.seach_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.diff_patches_utils.search_and_replace import SearchAndReplaceDiff
|
||||
from core.languages.python.optimizer.models import CodeAndExplanation
|
||||
from core.shared.context_helpers import group_code, is_markdown_structure_changed
|
||||
|
||||
|
|
|
|||
|
|
@ -15,7 +15,8 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
from authapp.user import get_user_by_id
|
||||
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
|
||||
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
|
||||
|
|
@ -143,7 +144,7 @@ async def generate_optimization_candidate(
|
|||
ChatCompletionUserMessageParam(role="user", content=user_prompt),
|
||||
]
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="optimization",
|
||||
|
|
@ -152,19 +153,16 @@ async def generate_optimization_candidate(
|
|||
python_version=python_version_str,
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("LLM code generation error in optimizer (model=%s)", optimize_model.name)
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"LLM optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
if output.raw_response.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-openai-usage",
|
||||
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json()},
|
||||
|
|
@ -289,13 +287,13 @@ async def optimize_python(
|
|||
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
|
||||
system_prompt, user_prompt, data.source_code, DiffMethod.NO_DIFF
|
||||
)
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-called")
|
||||
ph(request.user, "aiservice-optimize-called")
|
||||
|
||||
try:
|
||||
python_version = validate_request_data(data, ctx)
|
||||
except HttpError as e:
|
||||
e.add_note(f"Optimizer request validation error: {e.status_code} {e.message}")
|
||||
logging.exception(f"Optimizer request validation error: {e.message}. trace_id={data.trace_id}")
|
||||
logging.exception("Optimizer request validation error: %s. trace_id=%s", e.message, data.trace_id)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
|
||||
|
||||
|
|
@ -345,7 +343,7 @@ async def optimize_python(
|
|||
if data.current_username is None:
|
||||
user_task = tg.create_task(get_user_by_id(request.user))
|
||||
except Exception as e:
|
||||
logging.exception(f"Error during optimization task or user retrieval. trace_id={data.trace_id}")
|
||||
logging.exception("Error during optimization task or user retrieval. trace_id=%s", data.trace_id)
|
||||
sentry_sdk.capture_exception(e)
|
||||
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
||||
|
||||
|
|
@ -356,7 +354,7 @@ async def optimize_python(
|
|||
data.current_username = str(user.github_username)
|
||||
|
||||
if len(optimization_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
|
||||
ph(request.user, "aiservice-optimize-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations. trace_id=%s, repo=%s/%s, n_candidates=%d, source_len=%d",
|
||||
|
|
@ -367,8 +365,7 @@ async def optimize_python(
|
|||
len(data.source_code) if data.source_code else 0,
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
request.user,
|
||||
"aiservice-optimize-optimizations-found",
|
||||
properties={"num_optimizations": len(optimization_response_items)},
|
||||
|
|
@ -402,7 +399,7 @@ async def optimize_python(
|
|||
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
experiment_metadata=data.experiment_metadata or None,
|
||||
optimizations_origin={
|
||||
cei.optimization_id: {
|
||||
"source": OptimizedCandidateSource.OPTIMIZE,
|
||||
|
|
@ -428,5 +425,5 @@ async def optimize_python(
|
|||
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
|
||||
|
||||
debug_log_sensitive_data_from_callable(log_response)
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-successful")
|
||||
ph(request.user, "aiservice-optimize-successful")
|
||||
return 200, response
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import logging
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
|
|
@ -13,7 +12,8 @@ from aiservice.analytics.posthog import ph
|
|||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZE_MODEL
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from core.languages.java.optimizer_lp import hack_for_demo_java_lp, optimize_java_code_line_profiler
|
||||
from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_context_ts
|
||||
|
|
@ -33,7 +33,7 @@ from core.shared.optimizer_schemas import (
|
|||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
|
||||
optimize_line_profiler_api = NinjaAPI(urls_namespace="optimize-line-profiler")
|
||||
|
|
@ -79,7 +79,7 @@ async def optimize_python_code_line_profiler_single(
|
|||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="line_profiler",
|
||||
|
|
@ -88,19 +88,16 @@ async def optimize_python_code_line_profiler_single(
|
|||
python_version=python_version_str,
|
||||
context=obs_context,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI Code Generation error in optimizer-line-profiler")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
||||
return None, None, optimize_model.name
|
||||
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
|
||||
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
if output.raw_response.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-line-profiler-openai-usage",
|
||||
properties={"model": optimize_model.name, "usage": output.raw_response.usage.json()},
|
||||
|
|
@ -191,7 +188,7 @@ async def optimize_python_code_line_profiler(
|
|||
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
|
||||
)
|
||||
async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]: # noqa: ANN001
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-called")
|
||||
ph(request.user, "aiservice-optimize-called")
|
||||
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(
|
||||
SYSTEM_PROMPT, USER_PROMPT, data.source_code, DiffMethod.NO_DIFF
|
||||
)
|
||||
|
|
@ -355,11 +352,9 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
python_version=python_version,
|
||||
)
|
||||
|
||||
# Update total cost
|
||||
await update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user)
|
||||
|
||||
if len(optimization_response_items) == 0:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-no-optimizations-found")
|
||||
await update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user)
|
||||
ph(request.user, "aiservice-optimize-no-optimizations-found")
|
||||
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
|
||||
logging.error(
|
||||
"Could not generate any optimizations (line profiler). trace_id=%s, n_candidates=%d, source_len=%d, has_line_profiler=%s",
|
||||
|
|
@ -369,34 +364,37 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
bool(data.line_profiler_results),
|
||||
)
|
||||
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
request.user,
|
||||
"aiservice-optimize-optimizations-found",
|
||||
properties={"num_optimizations": len(optimization_response_items), "language": language},
|
||||
)
|
||||
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
original_code=data.source_code,
|
||||
dependency_code=data.dependency_code,
|
||||
line_profiler_results=data.line_profiler_results,
|
||||
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
|
||||
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
|
||||
optimizations_origin={
|
||||
cei.optimization_id: {
|
||||
"source": OptimizedCandidateSource.OPTIMIZE_LP,
|
||||
"parent": None,
|
||||
"model": optimization_models.get(cei.optimization_id, "unknown"),
|
||||
}
|
||||
for cei in optimization_response_items
|
||||
},
|
||||
)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(update_optimization_cost(trace_id=data.trace_id, cost=llm_cost, user_id=request.user))
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
tg.create_task(
|
||||
log_features(
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
original_code=data.source_code,
|
||||
dependency_code=data.dependency_code,
|
||||
line_profiler_results=data.line_profiler_results,
|
||||
optimizations_raw={op_id: cei["code"] for op_id, cei in code_and_explanations.items()},
|
||||
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
|
||||
explanations_raw={op_id: cei["explanation"] for op_id, cei in code_and_explanations.items()},
|
||||
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
|
||||
experiment_metadata=data.experiment_metadata or None,
|
||||
optimizations_origin={
|
||||
cei.optimization_id: {
|
||||
"source": OptimizedCandidateSource.OPTIMIZE_LP,
|
||||
"parent": None,
|
||||
"model": optimization_models.get(cei.optimization_id, "unknown"),
|
||||
}
|
||||
for cei in optimization_response_items
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
response = OptimizeResponseSchema(optimizations=optimization_response_items)
|
||||
|
||||
|
|
@ -407,5 +405,5 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
|
||||
|
||||
debug_log_sensitive_data_from_callable(log_response)
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-optimize-successful")
|
||||
ph(request.user, "aiservice-optimize-successful")
|
||||
return 200, response
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
|
@ -16,7 +15,8 @@ from aiservice.analytics.posthog import ph
|
|||
from aiservice.common.xml_utils import extract_xml_tag
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import REFINEMENT_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import REFINEMENT_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.languages.python.optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
||||
from core.log_features.log_event import update_optimization_cost
|
||||
|
|
@ -27,7 +27,7 @@ from core.shared.optimizer_schemas import OptimizeResponseItemSchema
|
|||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
|
||||
|
||||
refinement_api = NinjaAPI(urls_namespace="refinement")
|
||||
|
|
@ -86,7 +86,7 @@ async def refinement( # noqa: D417
|
|||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=optimize_model,
|
||||
messages=messages,
|
||||
call_type="refinement",
|
||||
|
|
@ -94,16 +94,13 @@ async def refinement( # noqa: D417
|
|||
user_id=user_id,
|
||||
context=obs_context,
|
||||
)
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
llm_cost = output.cost
|
||||
except Exception as e:
|
||||
logging.exception("Claude Code Generation error in refinement")
|
||||
sentry_sdk.capture_exception(e)
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.data.original_source_code}")
|
||||
return OptimizeErrorResponseSchema(error=str(e))
|
||||
debug_log_sensitive_data(f"ClaudeClient optimization response:\n{output.content}")
|
||||
if output.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"refinement-usage",
|
||||
properties={
|
||||
|
|
@ -177,38 +174,32 @@ class Refinementschema(Schema):
|
|||
async def refine(
|
||||
request: AuthenticatedRequest, data: list[RefinementRequestSchema]
|
||||
) -> tuple[int, Refinementschema | OptimizeErrorResponseSchema]:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-refinement-called")
|
||||
ctx_data_list = [
|
||||
RefinementContextData(
|
||||
original_source_code=opt.original_source_code,
|
||||
original_line_profiler_results=opt.original_line_profiler_results,
|
||||
original_code_runtime=opt.original_code_runtime,
|
||||
optimized_source_code=opt.optimized_source_code,
|
||||
read_only_dependency_code=opt.read_only_dependency_code,
|
||||
optimized_line_profiler_results=opt.optimized_line_profiler_results,
|
||||
optimized_code_runtime=opt.optimized_code_runtime,
|
||||
speedup=opt.speedup,
|
||||
optimized_explanation=opt.optimized_explanation,
|
||||
python_version=opt.python_version,
|
||||
function_references=opt.function_references,
|
||||
language=opt.language,
|
||||
language_version=opt.language_version,
|
||||
)
|
||||
for opt in data
|
||||
]
|
||||
ctx = BaseRefinerContext.get_dynamic_context(
|
||||
ctx_data=ctx_data_list[0], base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT
|
||||
)
|
||||
ph(request.user, "aiservice-refinement-called")
|
||||
trace_id = data[0].trace_id
|
||||
if not validate_trace_id(trace_id):
|
||||
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
optimized_source_code_and_explanations_futures = []
|
||||
refinement_coros = []
|
||||
for i, item in enumerate(data):
|
||||
if i != 0:
|
||||
ctx.data = ctx_data_list[i]
|
||||
# Auto-assign call_sequence if not provided (1-indexed)
|
||||
ctx_data = RefinementContextData(
|
||||
original_source_code=item.original_source_code,
|
||||
original_line_profiler_results=item.original_line_profiler_results,
|
||||
original_code_runtime=item.original_code_runtime,
|
||||
optimized_source_code=item.optimized_source_code,
|
||||
read_only_dependency_code=item.read_only_dependency_code,
|
||||
optimized_line_profiler_results=item.optimized_line_profiler_results,
|
||||
optimized_code_runtime=item.optimized_code_runtime,
|
||||
speedup=item.speedup,
|
||||
optimized_explanation=item.optimized_explanation,
|
||||
python_version=item.python_version,
|
||||
function_references=item.function_references,
|
||||
language=item.language,
|
||||
language_version=item.language_version,
|
||||
)
|
||||
ctx = BaseRefinerContext.get_dynamic_context(
|
||||
ctx_data=ctx_data, base_system_prompt=SYSTEM_PROMPT, base_user_prompt=USER_PROMPT
|
||||
)
|
||||
sequence = item.call_sequence if item.call_sequence is not None else i + 1
|
||||
optimized_source_code_and_explanations_futures.append(
|
||||
refinement_coros.append(
|
||||
refinement(
|
||||
user_id=request.user,
|
||||
optimization_id=item.optimization_id,
|
||||
|
|
@ -217,7 +208,7 @@ async def refine(
|
|||
call_sequence=sequence,
|
||||
)
|
||||
)
|
||||
refinement_data = await asyncio.gather(*optimized_source_code_and_explanations_futures)
|
||||
refinement_data = await asyncio.gather(*refinement_coros)
|
||||
# simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces, validate with libcst
|
||||
filtered_refined_optimizations = []
|
||||
source_code_set = set()
|
||||
|
|
|
|||
|
|
@ -2,22 +2,20 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
import stamina
|
||||
from ninja.errors import HttpError
|
||||
from openai import OpenAIError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block, split_markdown_code
|
||||
from aiservice.common_utils import safe_isort, should_hack_for_demo
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from core.languages.python.cst_utils import any_ellipsis_in_cst, ellipsis_in_cst_not_types, parse_module_to_cst
|
||||
from core.languages.python.optimizer.context_utils.context_helpers import is_multi_context
|
||||
|
|
@ -44,7 +42,7 @@ from core.shared.testgen_models import (
|
|||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.llm import LLM
|
||||
from aiservice.llm_models import LLM
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_models import TestGenSchema
|
||||
|
||||
|
|
@ -147,7 +145,6 @@ def parse_and_validate_llm_output(
|
|||
raise
|
||||
|
||||
|
||||
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError, CodeValidationError, LLMOutputParseError), attempts=2)
|
||||
async def generate_and_validate_test_code(
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: LLM,
|
||||
|
|
@ -187,7 +184,7 @@ async def generate_and_validate_test_code(
|
|||
if call_sequence is not None
|
||||
else None
|
||||
)
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=model,
|
||||
messages=messages,
|
||||
call_type="test_generation",
|
||||
|
|
@ -197,14 +194,13 @@ async def generate_and_validate_test_code(
|
|||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, model)
|
||||
cost = response.cost
|
||||
cost_tracker.add(cost)
|
||||
|
||||
debug_log_sensitive_data(f"LLM {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}")
|
||||
|
||||
if response.raw_response.usage:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
|
||||
properties={"model": model.name, "usage": response.raw_response.usage.model_dump_json()},
|
||||
|
|
@ -225,7 +221,6 @@ async def generate_and_validate_test_code(
|
|||
return validated_code, raw_llm_content
|
||||
|
||||
|
||||
@stamina.retry(on=TestGenerationFailedError, attempts=2)
|
||||
async def generate_regression_tests_from_function(
|
||||
source_code: str,
|
||||
qualified_name: str,
|
||||
|
|
@ -360,7 +355,7 @@ async def testgen_python(
|
|||
request: AuthenticatedRequest, data: TestGenSchema
|
||||
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
||||
"""Generate Python tests using LLMs."""
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-called")
|
||||
ph(request.user, "aiservice-testgen-called")
|
||||
|
||||
try:
|
||||
python_version = validate_request_data(data)
|
||||
|
|
@ -408,7 +403,7 @@ async def testgen_python(
|
|||
model_type=execute_model.model_type,
|
||||
)
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-tests-generated")
|
||||
ph(request.user, "aiservice-testgen-tests-generated")
|
||||
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,53 @@
|
|||
"""Shared context helpers used across language handlers."""
|
||||
|
||||
import re
|
||||
|
||||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
|
||||
__all__ = ["group_code", "is_markdown_structure_changed", "split_markdown_code"]
|
||||
__all__ = [
|
||||
"extract_code_and_explanation",
|
||||
"group_code",
|
||||
"is_markdown_structure_changed",
|
||||
"normalize_c_style_code",
|
||||
"split_markdown_code",
|
||||
]
|
||||
|
||||
|
||||
def normalize_c_style_code(code: str) -> str:
|
||||
"""Normalize C-style code for comparison by stripping comments and whitespace.
|
||||
|
||||
Works for JavaScript, TypeScript, and Java.
|
||||
"""
|
||||
code = re.sub(r"//.*$", "", code, flags=re.MULTILINE)
|
||||
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
|
||||
return " ".join(code.split())
|
||||
|
||||
|
||||
def extract_code_and_explanation(
|
||||
content: str, code_pattern: re.Pattern[str], code_with_path_pattern: re.Pattern[str], is_multi_file: bool = False
|
||||
) -> tuple[str | dict[str, str], str]:
|
||||
"""Extract code and explanation from LLM response.
|
||||
|
||||
Works for any language by accepting compiled regex patterns.
|
||||
"""
|
||||
if is_multi_file:
|
||||
matches = code_with_path_pattern.findall(content)
|
||||
if matches:
|
||||
file_to_code: dict[str, str] = {}
|
||||
first_match_pos = content.find("```")
|
||||
explanation = content[:first_match_pos].strip() if first_match_pos > 0 else ""
|
||||
for file_path, code in matches:
|
||||
file_to_code[file_path.strip()] = code.strip()
|
||||
return file_to_code, explanation
|
||||
return extract_code_and_explanation(content, code_pattern, code_with_path_pattern, is_multi_file=False)
|
||||
|
||||
match = code_pattern.search(content)
|
||||
if match:
|
||||
code = match.group(1).strip()
|
||||
explanation = content[: match.start()].strip()
|
||||
return code, explanation
|
||||
|
||||
return "", content
|
||||
|
||||
|
||||
def is_markdown_structure_changed(old: str, new: str, language: str = "python") -> bool:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Shared optimizer configuration for model distributions."""
|
||||
|
||||
from aiservice.llm import ANTHROPIC_MODEL, LLM, OPENAI_MODEL
|
||||
from aiservice.llm_models import ANTHROPIC_MODEL, LLM, OPENAI_MODEL
|
||||
|
||||
MAX_OPTIMIZER_CALLS = 6
|
||||
MAX_OPTIMIZER_LP_CALLS = 7
|
||||
|
|
|
|||
|
|
@ -1,20 +1,19 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI, Schema
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import LLM, RANKING_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, RANKING_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.log_features.log_event import update_optimization_cost
|
||||
from core.log_features.log_features import log_features
|
||||
|
|
@ -214,7 +213,7 @@ def _parse_json_response(content: str, num_candidates: int) -> ParsedRankingResp
|
|||
|
||||
# Validate ranking
|
||||
if sorted(ranking) != list(range(1, num_candidates + 1)):
|
||||
logging.warning(f"Invalid ranking in JSON response: {ranking}")
|
||||
logging.warning("Invalid ranking in JSON response: %s", ranking)
|
||||
return None
|
||||
|
||||
# Parse scores
|
||||
|
|
@ -247,7 +246,7 @@ def _parse_json_response(content: str, num_candidates: int) -> ParsedRankingResp
|
|||
return ParsedRankingResponse(ranking=ranking, scores=scores, explanation=explanation)
|
||||
|
||||
except (KeyError, TypeError, ValueError) as e:
|
||||
logging.warning(f"Failed to parse JSON response: {e}")
|
||||
logging.warning("Failed to parse JSON response: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -349,7 +348,7 @@ async def rank_optimizations( # noqa: D417
|
|||
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
||||
|
||||
try:
|
||||
output = await call_llm(
|
||||
output = await llm_client.call(
|
||||
llm=rank_model,
|
||||
messages=messages,
|
||||
call_type="ranking",
|
||||
|
|
@ -361,18 +360,14 @@ async def rank_optimizations( # noqa: D417
|
|||
"python_version": data.python_version,
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(
|
||||
trace_id=data.trace_id, cost=calculate_llm_cost(output.raw_response, rank_model), user_id=user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("Failed to generate ranking")
|
||||
sentry_sdk.capture_exception(e)
|
||||
return RankErrorResponseSchema(error=str(e))
|
||||
await update_optimization_cost(trace_id=data.trace_id, cost=output.cost, user_id=user_id)
|
||||
except Exception:
|
||||
logging.exception("Ranking failed for trace_id=%s", data.trace_id)
|
||||
return RankErrorResponseSchema(error="Failed to rank optimizations. Please try again.")
|
||||
|
||||
debug_log_sensitive_data(f"AIClient optimization response:\n{output}")
|
||||
if output.raw_response.usage is not None:
|
||||
await asyncio.to_thread(
|
||||
ph,
|
||||
ph(
|
||||
user_id,
|
||||
"aiservice-optimize-openai-usage",
|
||||
properties={"model": rank_model.name, "n": 1, "usage": output.raw_response.usage.model_dump_json()},
|
||||
|
|
@ -457,15 +452,15 @@ class RankErrorResponseSchema(Schema):
|
|||
async def rank(
|
||||
request: AuthenticatedRequest, data: RankInputSchema
|
||||
) -> tuple[int, RankResponseSchema | RankErrorResponseSchema]:
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-rank-called")
|
||||
ph(request.user, "aiservice-rank-called")
|
||||
if not validate_trace_id(data.trace_id):
|
||||
return 400, RankErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
ranking_response = await rank_optimizations(request.user, data)
|
||||
if isinstance(ranking_response, RankErrorResponseSchema):
|
||||
await asyncio.to_thread(ph, request.user, "Invalid Ranking, fallback to default")
|
||||
ph(request.user, "Invalid Ranking, fallback to default")
|
||||
debug_log_sensitive_data("No valid ranking was generated")
|
||||
return 500, RankErrorResponseSchema(error="Error generating ranking. Internal server error.")
|
||||
await asyncio.to_thread(ph, request.user, "ranking generated", properties={"ranking": ranking_response})
|
||||
ph(request.user, "ranking generated", properties={"ranking": ranking_response})
|
||||
ranking_0_idx = [x - 1 for x in ranking_response.ranking]
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
ranked_opt_ids = [data.optimization_ids[i] for i in ranking_0_idx]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ Currently enabled for Python only.
|
|||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -27,7 +26,8 @@ from openai.types.chat import (
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block
|
||||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import HAIKU_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import TestRepairErrorSchema, TestRepairResponseSchema, TestRepairSchema
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ async def testgen_repair(
|
|||
if data.language != "python":
|
||||
return 400, TestRepairErrorSchema(error="Test repair is only supported for Python")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-called")
|
||||
ph(request.user, "aiservice-testgen-repair-called")
|
||||
|
||||
try:
|
||||
from core.shared.testgen_review.review import _build_coverage_context # noqa: PLC0415
|
||||
|
|
@ -111,7 +111,7 @@ async def testgen_repair(
|
|||
if data.call_sequence is not None:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_repair",
|
||||
|
|
@ -120,7 +120,7 @@ async def testgen_repair(
|
|||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
|
||||
cost = response.cost
|
||||
logging.debug("testgen_repair LLM cost: %s", cost)
|
||||
|
||||
from core.languages.python.testgen.instrumentation.edit_generated_test import ( # noqa: PLC0415
|
||||
|
|
@ -145,7 +145,7 @@ async def testgen_repair(
|
|||
"Please return the complete corrected file in a single ```python code block.",
|
||||
)
|
||||
)
|
||||
retry_response = await call_llm(
|
||||
retry_response = await llm_client.call(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_repair_retry",
|
||||
|
|
@ -153,7 +153,7 @@ async def testgen_repair(
|
|||
user_id=request.user,
|
||||
context=obs_context,
|
||||
)
|
||||
cost += calculate_llm_cost(retry_response.raw_response, HAIKU_MODEL)
|
||||
cost += retry_response.cost
|
||||
repaired_code, repaired_cst = _extract_and_validate(retry_response.content.strip())
|
||||
|
||||
if repaired_cst is None:
|
||||
|
|
@ -209,7 +209,7 @@ async def testgen_repair(
|
|||
if instrumented_behavior is None or instrumented_perf is None:
|
||||
return 500, TestRepairErrorSchema(error="Failed to instrument repaired tests")
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-repair-completed")
|
||||
ph(request.user, "aiservice-testgen-repair-completed")
|
||||
return 200, TestRepairResponseSchema(
|
||||
generated_tests=display_code,
|
||||
instrumented_behavior_tests=instrumented_behavior,
|
||||
|
|
|
|||
|
|
@ -26,7 +26,8 @@ from openai.types.chat import (
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block_with_context
|
||||
from aiservice.llm import HAIKU_MODEL, calculate_llm_cost, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import HAIKU_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.testgen_review.models import (
|
||||
CoverageDetails,
|
||||
|
|
@ -53,7 +54,7 @@ async def testgen_review(
|
|||
if data.language != "python":
|
||||
return 200, TestgenReviewResponseSchema(reviews=[])
|
||||
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-testgen-review-called")
|
||||
ph(request.user, "aiservice-testgen-review-called")
|
||||
|
||||
try:
|
||||
coverage_context = (
|
||||
|
|
@ -81,9 +82,7 @@ async def testgen_review(
|
|||
]
|
||||
reviews = [task.result() for task in tasks]
|
||||
|
||||
await asyncio.to_thread(
|
||||
ph, request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id}
|
||||
)
|
||||
ph(request.user, "aiservice-testgen-review-completed", properties={"trace_id": data.trace_id})
|
||||
return 200, TestgenReviewResponseSchema(reviews=reviews)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -149,7 +148,7 @@ async def _review_single_test(
|
|||
if call_sequence is not None:
|
||||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=HAIKU_MODEL,
|
||||
messages=messages,
|
||||
call_type="testgen_review",
|
||||
|
|
@ -158,8 +157,8 @@ async def _review_single_test(
|
|||
context=obs_context,
|
||||
)
|
||||
|
||||
cost = calculate_llm_cost(response.raw_response, HAIKU_MODEL)
|
||||
logging.debug(f"testgen_review LLM cost: {cost}")
|
||||
cost = response.cost
|
||||
logging.debug("testgen_review LLM cost: %s", cost)
|
||||
|
||||
ai_verdicts = _parse_review_response(response.content.strip())
|
||||
return TestReview(test_index=test_index, functions=failed_verdicts + ai_verdicts)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -13,7 +12,8 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, call_llm
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
|
||||
_steps_section_re = re.compile(r"steps:\s*\n((?:[^\n]+\n)*)", re.MULTILINE)
|
||||
|
|
@ -93,7 +93,7 @@ async def generate_workflow_steps_llm(
|
|||
debug_log_sensitive_data(f"Generating workflow steps with prompt length: {len(user_prompt)}")
|
||||
|
||||
try:
|
||||
response = await call_llm(
|
||||
response = await llm_client.call(
|
||||
llm=EXECUTE_MODEL,
|
||||
messages=[system_message, user_message],
|
||||
call_type="workflow_generation",
|
||||
|
|
@ -111,15 +111,13 @@ async def generate_workflow_steps_llm(
|
|||
# Extract YAML steps
|
||||
steps_yaml = _extract_yaml_steps(response_text)
|
||||
if steps_yaml:
|
||||
logger.info(f"Successfully generated workflow steps ({len(steps_yaml)} chars)")
|
||||
logger.info("Successfully generated workflow steps (%d chars)", len(steps_yaml))
|
||||
return steps_yaml
|
||||
|
||||
logger.warning(f"Could not extract valid YAML steps from LLM response: {response_text[:200]}")
|
||||
logger.warning("Could not extract valid YAML steps from LLM response: %s", response_text[:200])
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating workflow steps: {e}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -146,7 +144,7 @@ async def generate_workflow_steps(
|
|||
request: AuthenticatedRequest, data: WorkflowGenInputSchema
|
||||
) -> tuple[int, WorkflowGenResponseSchema | WorkflowGenErrorResponseSchema]:
|
||||
"""Generate GitHub Actions workflow steps based on repository analysis."""
|
||||
await asyncio.to_thread(ph, request.user, "aiservice-workflow-gen-called")
|
||||
ph(request.user, "aiservice-workflow-gen-called")
|
||||
|
||||
try:
|
||||
# Validate input
|
||||
|
|
@ -170,9 +168,7 @@ async def generate_workflow_steps(
|
|||
error="Failed to generate workflow steps. Please try again or use the static template."
|
||||
)
|
||||
|
||||
await asyncio.to_thread(
|
||||
ph, request.user, "aiservice-workflow-gen-success", properties={"steps_length": len(workflow_steps)}
|
||||
)
|
||||
ph(request.user, "aiservice-workflow-gen-success", properties={"steps_length": len(workflow_steps)})
|
||||
return 200, WorkflowGenResponseSchema(workflow_steps=workflow_steps)
|
||||
|
||||
except Exception:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from core.languages.python.optimizer.diff_patches_utils.seach_and_replace import apply_patches
|
||||
from core.languages.python.optimizer.diff_patches_utils.search_and_replace import apply_patches
|
||||
|
||||
|
||||
def test_patches() -> None:
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue