codeflash-internal/django/aiservice/aiservice/llm.py
claude[bot] 2135849f27 style: auto-fix linting issues
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-03 19:20:57 +00:00

253 lines
9.7 KiB
Python

"""LLM client setup and API call functions."""
from __future__ import annotations
import asyncio
import logging
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import sentry_sdk
import stamina
from anthropic import APIConnectionError as AnthropicConnectionError
from anthropic import APITimeoutError as AnthropicTimeoutError
from anthropic import AsyncAnthropicBedrock
from anthropic import InternalServerError as AnthropicServerError
from anthropic import RateLimitError as AnthropicRateLimitError
from openai import APIConnectionError as OpenAIConnectionError
from openai import APITimeoutError as OpenAITimeoutError
from openai import AsyncAzureOpenAI
from openai import InternalServerError as OpenAIServerError
from openai import RateLimitError as OpenAIRateLimitError
from aiservice.llm_models import has_anthropic, has_openai
if TYPE_CHECKING:
from anthropic.types import Message as AnthropicMessage
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
from aiservice.llm_models import LLM
logger = logging.getLogger(__name__)
_ANTHROPIC_MAX_INPUT_TOKENS = 195_000
_CHARS_PER_TOKEN_ESTIMATE = 4
_TRANSIENT_LLM_ERRORS = (
AnthropicConnectionError,
AnthropicTimeoutError,
AnthropicServerError,
AnthropicRateLimitError,
OpenAIConnectionError,
OpenAITimeoutError,
OpenAIServerError,
OpenAIRateLimitError,
)
class LLMOutputUnparseable(Exception):
"""Raised when the LLM responds but its output cannot be parsed into the expected format."""
def __init__(self, message: str = "LLM output could not be parsed", *, cost: float = 0.0) -> None:
super().__init__(message)
self.cost = cost
@dataclass
class LLMUsage:
input_tokens: int
output_tokens: int
@dataclass
class LLMResponse:
content: str
usage: LLMUsage
raw_response: ChatCompletion | AnthropicMessage
cost: float = 0.0
class LLMClient:
def __init__(self) -> None:
self.openai_client: AsyncAzureOpenAI | None = None
self.anthropic_client: AsyncAnthropicBedrock | None = None
self.client_loop: asyncio.AbstractEventLoop | None = None
self.background_tasks: set[asyncio.Task[Any]] = set()
async def call(
self,
llm: LLM,
messages: list[ChatCompletionMessageParam],
call_type: str = "",
trace_id: str = "",
max_tokens: int = 16384,
user_id: str | None = None,
python_version: str | None = None,
context: dict[str, Any] | None = None,
) -> LLMResponse:
from aiservice.observability.database import record_llm_call # noqa: PLC0415
# Recreate provider clients when the event loop changes (stale connections)
loop = asyncio.get_running_loop()
if loop is not self.client_loop:
# Close old clients to prevent connection leaks and event loop closure errors
# Ignore errors if the client is already closed or the transport is in a bad state
if self.openai_client is not None:
try:
await self.openai_client.close()
except Exception as e:
logger.debug(
"Failed to close OpenAI client (already closed or transport error): %s", type(e).__name__
)
if self.anthropic_client is not None:
try:
await self.anthropic_client.close()
except Exception as e:
logger.debug(
"Failed to close Anthropic client (already closed or transport error): %s", type(e).__name__
)
self.client_loop = loop
self.background_tasks = set()
self.openai_client = AsyncAzureOpenAI() if has_openai else None
self.anthropic_client = (
AsyncAnthropicBedrock(
aws_access_key=os.environ.get("AWS_ACCESS_KEY_ID", ""),
aws_secret_key=os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
aws_region=os.environ.get("AWS_REGION", "us-east-1"),
)
if has_anthropic
else None
)
start_time = time.time()
error: Exception | None = None
result: LLMResponse | None = None
try:
if llm.model_type == "anthropic":
if self.anthropic_client is None:
raise ValueError("Anthropic client is not available")
result = await self.call_anthropic(llm, messages, max_tokens)
elif llm.model_type == "openai":
if self.openai_client is None:
raise ValueError("OpenAI client is not available")
result = await self.call_openai(llm, messages, max_tokens)
else:
msg = f"Unsupported model type: {llm.model_type}"
raise ValueError(msg)
result.cost = calculate_llm_cost(result.raw_response, llm)
return result
except Exception as e:
error = e
logger.exception(
"LLM call failed: model=%s, provider=%s, call_type=%s, trace_id=%s, error=%s",
llm.name,
llm.model_type,
call_type,
trace_id,
type(e).__name__,
)
sentry_sdk.capture_exception(e)
raise
finally:
latency_ms = int((time.time() - start_time) * 1000)
task = asyncio.create_task(
record_llm_call(
trace_id=trace_id,
call_type=call_type,
model_name=llm.name,
messages=messages,
user_id=user_id,
python_version=python_version,
context=context,
result=result,
error=error,
llm_cost=result.cost if result else None,
latency_ms=latency_ms,
)
)
self.background_tasks.add(task)
def on_done(t: asyncio.Task[str]) -> None:
self.background_tasks.discard(t)
if exc := t.exception():
logger.warning("Tracing: Failed to record LLM call: %s", exc)
task.add_done_callback(on_done)
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
async def call_anthropic(
self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int
) -> LLMResponse:
estimated_tokens = sum(len(str(m["content"])) for m in messages) // _CHARS_PER_TOKEN_ESTIMATE
if estimated_tokens > _ANTHROPIC_MAX_INPUT_TOKENS:
msg = f"Prompt too large (~{estimated_tokens} tokens estimated, limit {_ANTHROPIC_MAX_INPUT_TOKENS})"
raise ValueError(msg)
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
non_system = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
kwargs: dict[str, Any] = {"model": llm.name, "messages": non_system, "max_tokens": max_tokens}
if system_prompt:
kwargs["system"] = system_prompt
response = await self.anthropic_client.messages.create(**kwargs) # type: ignore[union-attr]
content = "".join(block.text for block in response.content if hasattr(block, "text"))
return LLMResponse(
content=content,
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
raw_response=response,
)
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
async def call_openai(self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int) -> LLMResponse:
# gpt-5-mini only accepts max_completion_tokens, older models use max_tokens
if llm.name == "gpt-5-mini":
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
model=llm.name, messages=messages, max_completion_tokens=max_tokens
)
else:
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
model=llm.name, messages=messages, max_tokens=max_tokens
)
return LLMResponse(
content=response.choices[0].message.content or "",
usage=LLMUsage(
input_tokens=response.usage.prompt_tokens if response.usage else 0,
output_tokens=response.usage.completion_tokens if response.usage else 0,
),
raw_response=response,
)
def calculate_llm_cost(response: ChatCompletion | AnthropicMessage, llm: LLM) -> float:
if response.usage is None:
return 0.0
usage = response.usage
# OpenAI: prompt_tokens is total (cached is subset), Anthropic: counts are additive
if llm.model_type == "anthropic":
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
non_cached = usage.input_tokens + (getattr(usage, "cache_creation_input_tokens", 0) or 0) # type: ignore[union-attr]
output = usage.output_tokens # type: ignore[union-attr]
else:
details = getattr(usage, "prompt_tokens_details", None)
cache_read = (getattr(details, "cached_tokens", 0) or 0) if details else 0
non_cached = usage.prompt_tokens - cache_read # type: ignore[union-attr]
output = usage.completion_tokens # type: ignore[union-attr]
input_rate = llm.input_cost or 0.0
cached_rate = llm.cached_input_cost if llm.cached_input_cost is not None else input_rate
output_rate = llm.output_cost or 0.0
return (non_cached * input_rate + cache_read * cached_rate + output * output_rate) / 1_000_000
llm_client = LLMClient()