mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
253 lines
9.7 KiB
Python
253 lines
9.7 KiB
Python
"""LLM client setup and API call functions."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import sentry_sdk
|
|
import stamina
|
|
from anthropic import APIConnectionError as AnthropicConnectionError
|
|
from anthropic import APITimeoutError as AnthropicTimeoutError
|
|
from anthropic import AsyncAnthropicBedrock
|
|
from anthropic import InternalServerError as AnthropicServerError
|
|
from anthropic import RateLimitError as AnthropicRateLimitError
|
|
from openai import APIConnectionError as OpenAIConnectionError
|
|
from openai import APITimeoutError as OpenAITimeoutError
|
|
from openai import AsyncAzureOpenAI
|
|
from openai import InternalServerError as OpenAIServerError
|
|
from openai import RateLimitError as OpenAIRateLimitError
|
|
|
|
from aiservice.llm_models import has_anthropic, has_openai
|
|
|
|
if TYPE_CHECKING:
|
|
from anthropic.types import Message as AnthropicMessage
|
|
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
|
|
|
|
from aiservice.llm_models import LLM
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ANTHROPIC_MAX_INPUT_TOKENS = 195_000
|
|
_CHARS_PER_TOKEN_ESTIMATE = 4
|
|
|
|
_TRANSIENT_LLM_ERRORS = (
|
|
AnthropicConnectionError,
|
|
AnthropicTimeoutError,
|
|
AnthropicServerError,
|
|
AnthropicRateLimitError,
|
|
OpenAIConnectionError,
|
|
OpenAITimeoutError,
|
|
OpenAIServerError,
|
|
OpenAIRateLimitError,
|
|
)
|
|
|
|
|
|
class LLMOutputUnparseable(Exception):
|
|
"""Raised when the LLM responds but its output cannot be parsed into the expected format."""
|
|
|
|
def __init__(self, message: str = "LLM output could not be parsed", *, cost: float = 0.0) -> None:
|
|
super().__init__(message)
|
|
self.cost = cost
|
|
|
|
|
|
@dataclass
|
|
class LLMUsage:
|
|
input_tokens: int
|
|
output_tokens: int
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
content: str
|
|
usage: LLMUsage
|
|
raw_response: ChatCompletion | AnthropicMessage
|
|
cost: float = 0.0
|
|
|
|
|
|
class LLMClient:
|
|
def __init__(self) -> None:
|
|
self.openai_client: AsyncAzureOpenAI | None = None
|
|
self.anthropic_client: AsyncAnthropicBedrock | None = None
|
|
self.client_loop: asyncio.AbstractEventLoop | None = None
|
|
self.background_tasks: set[asyncio.Task[Any]] = set()
|
|
|
|
async def call(
|
|
self,
|
|
llm: LLM,
|
|
messages: list[ChatCompletionMessageParam],
|
|
call_type: str = "",
|
|
trace_id: str = "",
|
|
max_tokens: int = 16384,
|
|
user_id: str | None = None,
|
|
python_version: str | None = None,
|
|
context: dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
from aiservice.observability.database import record_llm_call # noqa: PLC0415
|
|
|
|
# Recreate provider clients when the event loop changes (stale connections)
|
|
loop = asyncio.get_running_loop()
|
|
if loop is not self.client_loop:
|
|
# Close old clients to prevent connection leaks and event loop closure errors
|
|
# Ignore errors if the client is already closed or the transport is in a bad state
|
|
if self.openai_client is not None:
|
|
try:
|
|
await self.openai_client.close()
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Failed to close OpenAI client (already closed or transport error): %s", type(e).__name__
|
|
)
|
|
if self.anthropic_client is not None:
|
|
try:
|
|
await self.anthropic_client.close()
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Failed to close Anthropic client (already closed or transport error): %s", type(e).__name__
|
|
)
|
|
|
|
self.client_loop = loop
|
|
self.background_tasks = set()
|
|
self.openai_client = AsyncAzureOpenAI() if has_openai else None
|
|
self.anthropic_client = (
|
|
AsyncAnthropicBedrock(
|
|
aws_access_key=os.environ.get("AWS_ACCESS_KEY_ID", ""),
|
|
aws_secret_key=os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
|
|
aws_region=os.environ.get("AWS_REGION", "us-east-1"),
|
|
)
|
|
if has_anthropic
|
|
else None
|
|
)
|
|
start_time = time.time()
|
|
error: Exception | None = None
|
|
result: LLMResponse | None = None
|
|
|
|
try:
|
|
if llm.model_type == "anthropic":
|
|
if self.anthropic_client is None:
|
|
raise ValueError("Anthropic client is not available")
|
|
result = await self.call_anthropic(llm, messages, max_tokens)
|
|
elif llm.model_type == "openai":
|
|
if self.openai_client is None:
|
|
raise ValueError("OpenAI client is not available")
|
|
result = await self.call_openai(llm, messages, max_tokens)
|
|
else:
|
|
msg = f"Unsupported model type: {llm.model_type}"
|
|
raise ValueError(msg)
|
|
result.cost = calculate_llm_cost(result.raw_response, llm)
|
|
return result
|
|
|
|
except Exception as e:
|
|
error = e
|
|
logger.exception(
|
|
"LLM call failed: model=%s, provider=%s, call_type=%s, trace_id=%s, error=%s",
|
|
llm.name,
|
|
llm.model_type,
|
|
call_type,
|
|
trace_id,
|
|
type(e).__name__,
|
|
)
|
|
sentry_sdk.capture_exception(e)
|
|
raise
|
|
|
|
finally:
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
task = asyncio.create_task(
|
|
record_llm_call(
|
|
trace_id=trace_id,
|
|
call_type=call_type,
|
|
model_name=llm.name,
|
|
messages=messages,
|
|
user_id=user_id,
|
|
python_version=python_version,
|
|
context=context,
|
|
result=result,
|
|
error=error,
|
|
llm_cost=result.cost if result else None,
|
|
latency_ms=latency_ms,
|
|
)
|
|
)
|
|
self.background_tasks.add(task)
|
|
|
|
def on_done(t: asyncio.Task[str]) -> None:
|
|
self.background_tasks.discard(t)
|
|
if exc := t.exception():
|
|
logger.warning("Tracing: Failed to record LLM call: %s", exc)
|
|
|
|
task.add_done_callback(on_done)
|
|
|
|
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
|
|
async def call_anthropic(
|
|
self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int
|
|
) -> LLMResponse:
|
|
estimated_tokens = sum(len(str(m["content"])) for m in messages) // _CHARS_PER_TOKEN_ESTIMATE
|
|
if estimated_tokens > _ANTHROPIC_MAX_INPUT_TOKENS:
|
|
msg = f"Prompt too large (~{estimated_tokens} tokens estimated, limit {_ANTHROPIC_MAX_INPUT_TOKENS})"
|
|
raise ValueError(msg)
|
|
|
|
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
|
|
non_system = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
|
|
|
kwargs: dict[str, Any] = {"model": llm.name, "messages": non_system, "max_tokens": max_tokens}
|
|
if system_prompt:
|
|
kwargs["system"] = system_prompt
|
|
|
|
response = await self.anthropic_client.messages.create(**kwargs) # type: ignore[union-attr]
|
|
content = "".join(block.text for block in response.content if hasattr(block, "text"))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
|
|
raw_response=response,
|
|
)
|
|
|
|
@stamina.retry(on=_TRANSIENT_LLM_ERRORS, attempts=2)
|
|
async def call_openai(self, llm: LLM, messages: list[ChatCompletionMessageParam], max_tokens: int) -> LLMResponse:
|
|
# gpt-5-mini only accepts max_completion_tokens, older models use max_tokens
|
|
if llm.name == "gpt-5-mini":
|
|
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
|
|
model=llm.name, messages=messages, max_completion_tokens=max_tokens
|
|
)
|
|
else:
|
|
response = await self.openai_client.chat.completions.create( # type: ignore[union-attr]
|
|
model=llm.name, messages=messages, max_tokens=max_tokens
|
|
)
|
|
|
|
return LLMResponse(
|
|
content=response.choices[0].message.content or "",
|
|
usage=LLMUsage(
|
|
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
|
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
),
|
|
raw_response=response,
|
|
)
|
|
|
|
|
|
def calculate_llm_cost(response: ChatCompletion | AnthropicMessage, llm: LLM) -> float:
|
|
if response.usage is None:
|
|
return 0.0
|
|
|
|
usage = response.usage
|
|
|
|
# OpenAI: prompt_tokens is total (cached is subset), Anthropic: counts are additive
|
|
if llm.model_type == "anthropic":
|
|
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
|
non_cached = usage.input_tokens + (getattr(usage, "cache_creation_input_tokens", 0) or 0) # type: ignore[union-attr]
|
|
output = usage.output_tokens # type: ignore[union-attr]
|
|
else:
|
|
details = getattr(usage, "prompt_tokens_details", None)
|
|
cache_read = (getattr(details, "cached_tokens", 0) or 0) if details else 0
|
|
non_cached = usage.prompt_tokens - cache_read # type: ignore[union-attr]
|
|
output = usage.completion_tokens # type: ignore[union-attr]
|
|
|
|
input_rate = llm.input_cost or 0.0
|
|
cached_rate = llm.cached_input_cost if llm.cached_input_cost is not None else input_rate
|
|
output_rate = llm.output_cost or 0.0
|
|
|
|
return (non_cached * input_rate + cache_read * cached_rate + output * output_rate) / 1_000_000
|
|
|
|
|
|
llm_client = LLMClient()
|