Add LLM layer: client abstraction, cost calculation, retry policy
Dual-provider client (Azure OpenAI + Anthropic Bedrock) behind a common async interface with cache-aware cost calculation and event-loop-safe client lifecycle.
This commit is contained in:
parent
d20b82762a
commit
fcaac3a9f2
5 changed files with 866 additions and 0 deletions
254
packages/codeflash-api/src/codeflash_api/llm/_client.py
Normal file
254
packages/codeflash-api/src/codeflash_api/llm/_client.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import attrs
|
||||
|
||||
from codeflash_api.llm._cost import calculate_llm_cost
|
||||
from codeflash_api.llm._retry import (
|
||||
ANTHROPIC_MAX_INPUT_TOKENS,
|
||||
CHARS_PER_TOKEN_ESTIMATE,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic import AsyncAnthropicBedrock
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
from codeflash_api.llm._models import LLM
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMOutputUnparseableError(Exception):
|
||||
"""
|
||||
Raised when LLM output cannot be parsed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "LLM output could not be parsed",
|
||||
*,
|
||||
cost: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.cost = cost
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class LLMUsage:
|
||||
"""
|
||||
Token counts from an LLM call.
|
||||
"""
|
||||
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class LLMResponse:
|
||||
"""
|
||||
Response from an LLM call.
|
||||
"""
|
||||
|
||||
content: str
|
||||
usage: LLMUsage
|
||||
raw_response: Any
|
||||
cost: float = 0.0
|
||||
|
||||
|
||||
@attrs.define
|
||||
class LLMClient:
|
||||
"""
|
||||
Async LLM client supporting OpenAI and Anthropic.
|
||||
"""
|
||||
|
||||
_openai_client: AsyncAzureOpenAI | None = None
|
||||
_anthropic_client: AsyncAnthropicBedrock | None = None
|
||||
_client_loop: asyncio.AbstractEventLoop | None = None
|
||||
_background_tasks: set[asyncio.Task[Any]] = attrs.Factory(set)
|
||||
|
||||
def _ensure_clients(self) -> None:
|
||||
"""
|
||||
Recreate clients if the event loop changed.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop is self._client_loop:
|
||||
return
|
||||
|
||||
self._client_loop = loop
|
||||
self._background_tasks = set()
|
||||
|
||||
if os.environ.get("AZURE_OPENAI_API_KEY"):
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
self._openai_client = AsyncAzureOpenAI()
|
||||
else:
|
||||
self._openai_client = None
|
||||
|
||||
if os.environ.get("AWS_ACCESS_KEY_ID") and os.environ.get(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
):
|
||||
from anthropic import AsyncAnthropicBedrock
|
||||
|
||||
self._anthropic_client = AsyncAnthropicBedrock(
|
||||
aws_access_key=os.environ.get("AWS_ACCESS_KEY_ID", ""),
|
||||
aws_secret_key=os.environ.get("AWS_SECRET_ACCESS_KEY", ""),
|
||||
aws_region=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
else:
|
||||
self._anthropic_client = None
|
||||
|
||||
async def call(
|
||||
self,
|
||||
llm: LLM,
|
||||
messages: list[dict[str, Any]],
|
||||
*,
|
||||
max_tokens: int = 16384,
|
||||
call_type: str = "",
|
||||
trace_id: str = "",
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Route to the correct provider and return an LLMResponse.
|
||||
"""
|
||||
self._ensure_clients()
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
if llm.model_type == "anthropic":
|
||||
result = await self._call_anthropic(llm, messages, max_tokens)
|
||||
elif llm.model_type == "openai":
|
||||
result = await self._call_openai(llm, messages, max_tokens)
|
||||
else:
|
||||
msg = f"Unsupported model type: {llm.model_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
result = LLMResponse(
|
||||
content=result.content,
|
||||
usage=result.usage,
|
||||
raw_response=result.raw_response,
|
||||
cost=calculate_llm_cost(result.raw_response, llm),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
latency_ms = int((time.monotonic() - start) * 1000)
|
||||
log.exception(
|
||||
"LLM call failed: type=%s trace=%s latency=%dms",
|
||||
call_type,
|
||||
trace_id,
|
||||
latency_ms,
|
||||
)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
async def _call_anthropic(
|
||||
self,
|
||||
llm: LLM,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Call Anthropic via Bedrock.
|
||||
"""
|
||||
if self._anthropic_client is None:
|
||||
msg = "Anthropic client not configured"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
estimated_tokens = (
|
||||
sum(len(str(m.get("content", ""))) for m in messages)
|
||||
// CHARS_PER_TOKEN_ESTIMATE
|
||||
)
|
||||
|
||||
if estimated_tokens > ANTHROPIC_MAX_INPUT_TOKENS:
|
||||
msg = (
|
||||
f"Prompt too large"
|
||||
f" (~{estimated_tokens} tokens"
|
||||
f" estimated, limit"
|
||||
f" {ANTHROPIC_MAX_INPUT_TOKENS})"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
system_prompt = next(
|
||||
(m["content"] for m in messages if m["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
non_system = [
|
||||
{"role": m["role"], "content": m["content"]}
|
||||
for m in messages
|
||||
if m["role"] != "system"
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": llm.name,
|
||||
"messages": non_system,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
response = await self._anthropic_client.messages.create(**kwargs)
|
||||
content = "".join(
|
||||
block.text for block in response.content if hasattr(block, "text")
|
||||
)
|
||||
if not content:
|
||||
cost = calculate_llm_cost(response, llm)
|
||||
raise LLMOutputUnparseableError(cost=cost)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
),
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
async def _call_openai(
|
||||
self,
|
||||
llm: LLM,
|
||||
messages: list[dict[str, Any]],
|
||||
max_tokens: int,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Call Azure OpenAI.
|
||||
"""
|
||||
if self._openai_client is None:
|
||||
msg = "OpenAI client not configured"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if llm.name == "gpt-5-mini":
|
||||
response = await self._openai_client.chat.completions.create(
|
||||
model=llm.name,
|
||||
messages=messages,
|
||||
max_completion_tokens=max_tokens,
|
||||
)
|
||||
else:
|
||||
response = await self._openai_client.chat.completions.create(
|
||||
model=llm.name,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
content = (
|
||||
response.choices[0].message.content if response.choices else None
|
||||
)
|
||||
if not content:
|
||||
cost = calculate_llm_cost(response, llm)
|
||||
raise LLMOutputUnparseableError(cost=cost)
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(
|
||||
input_tokens=(
|
||||
response.usage.prompt_tokens if response.usage else 0
|
||||
),
|
||||
output_tokens=(
|
||||
response.usage.completion_tokens if response.usage else 0
|
||||
),
|
||||
),
|
||||
raw_response=response,
|
||||
)
|
||||
45
packages/codeflash-api/src/codeflash_api/llm/_cost.py
Normal file
45
packages/codeflash-api/src/codeflash_api/llm/_cost.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash_api.llm._models import LLM
|
||||
|
||||
|
||||
def calculate_llm_cost(
|
||||
response: Any,
|
||||
llm: LLM,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost in USD from a provider response.
|
||||
"""
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is None:
|
||||
return 0.0
|
||||
|
||||
input_rate = llm.input_cost or 0.0
|
||||
cached_rate = (
|
||||
llm.cached_input_cost
|
||||
if llm.cached_input_cost is not None
|
||||
else input_rate
|
||||
)
|
||||
output_rate = llm.output_cost or 0.0
|
||||
|
||||
if llm.model_type == "anthropic":
|
||||
cache_read = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
cache_creation = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
non_cached = usage.input_tokens + cache_creation
|
||||
output = usage.output_tokens
|
||||
else:
|
||||
details = getattr(usage, "prompt_tokens_details", None)
|
||||
cache_read = (
|
||||
(getattr(details, "cached_tokens", 0) or 0) if details else 0
|
||||
)
|
||||
non_cached = usage.prompt_tokens - cache_read
|
||||
output = usage.completion_tokens
|
||||
|
||||
return (
|
||||
non_cached * input_rate
|
||||
+ cache_read * cached_rate
|
||||
+ output * output_rate
|
||||
) / 1_000_000
|
||||
52
packages/codeflash-api/src/codeflash_api/llm/_models.py
Normal file
52
packages/codeflash-api/src/codeflash_api/llm/_models.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import attrs
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class LLM:
|
||||
"""
|
||||
A model definition with cost rates.
|
||||
"""
|
||||
|
||||
name: str
|
||||
model_type: Literal["openai", "anthropic"]
|
||||
input_cost: float | None = None
|
||||
cached_input_cost: float | None = None
|
||||
output_cost: float | None = None
|
||||
|
||||
|
||||
OPENAI_GPT_4_1 = LLM(
|
||||
name="gpt-4.1",
|
||||
model_type="openai",
|
||||
input_cost=2.00,
|
||||
cached_input_cost=0.50,
|
||||
output_cost=8.00,
|
||||
)
|
||||
|
||||
OPENAI_GPT_5_MINI = LLM(
|
||||
name="gpt-5-mini",
|
||||
model_type="openai",
|
||||
input_cost=0.25,
|
||||
cached_input_cost=0.03,
|
||||
output_cost=2.00,
|
||||
)
|
||||
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5 = LLM(
|
||||
name="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
model_type="anthropic",
|
||||
input_cost=3.00,
|
||||
output_cost=15.00,
|
||||
)
|
||||
|
||||
ANTHROPIC_CLAUDE_HAIKU_4_5 = LLM(
|
||||
name="us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
model_type="anthropic",
|
||||
input_cost=1.00,
|
||||
output_cost=5.00,
|
||||
)
|
||||
40
packages/codeflash-api/src/codeflash_api/llm/_retry.py
Normal file
40
packages/codeflash-api/src/codeflash_api/llm/_retry.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from anthropic import (
|
||||
APIConnectionError as AnthropicConnectionError,
|
||||
)
|
||||
from anthropic import (
|
||||
APITimeoutError as AnthropicTimeoutError,
|
||||
)
|
||||
from anthropic import (
|
||||
InternalServerError as AnthropicServerError,
|
||||
)
|
||||
from anthropic import (
|
||||
RateLimitError as AnthropicRateLimitError,
|
||||
)
|
||||
from openai import (
|
||||
APIConnectionError as OpenAIConnectionError,
|
||||
)
|
||||
from openai import (
|
||||
APITimeoutError as OpenAITimeoutError,
|
||||
)
|
||||
from openai import (
|
||||
InternalServerError as OpenAIServerError,
|
||||
)
|
||||
from openai import (
|
||||
RateLimitError as OpenAIRateLimitError,
|
||||
)
|
||||
|
||||
TRANSIENT_LLM_ERRORS: tuple[type[Exception], ...] = (
|
||||
AnthropicConnectionError,
|
||||
AnthropicTimeoutError,
|
||||
AnthropicServerError,
|
||||
AnthropicRateLimitError,
|
||||
OpenAIConnectionError,
|
||||
OpenAITimeoutError,
|
||||
OpenAIServerError,
|
||||
OpenAIRateLimitError,
|
||||
)
|
||||
|
||||
ANTHROPIC_MAX_INPUT_TOKENS = 195_000
|
||||
CHARS_PER_TOKEN_ESTIMATE = 4
|
||||
475
packages/codeflash-api/tests/test_llm.py
Normal file
475
packages/codeflash-api/tests/test_llm.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash_api.llm._client import (
|
||||
LLMClient,
|
||||
LLMOutputUnparseableError,
|
||||
LLMResponse,
|
||||
LLMUsage,
|
||||
)
|
||||
from codeflash_api.llm._cost import calculate_llm_cost
|
||||
from codeflash_api.llm._models import (
|
||||
ANTHROPIC_CLAUDE_HAIKU_4_5,
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
LLM,
|
||||
OPENAI_GPT_4_1,
|
||||
OPENAI_GPT_5_MINI,
|
||||
)
|
||||
from codeflash_api.llm._retry import (
|
||||
ANTHROPIC_MAX_INPUT_TOKENS,
|
||||
CHARS_PER_TOKEN_ESTIMATE,
|
||||
TRANSIENT_LLM_ERRORS,
|
||||
)
|
||||
|
||||
|
||||
class TestLLMModel:
|
||||
"""Tests for LLM model definitions."""
|
||||
|
||||
def test_frozen(self) -> None:
|
||||
"""
|
||||
LLM instances are immutable.
|
||||
"""
|
||||
model = LLM(name="test", model_type="openai")
|
||||
with pytest.raises(AttributeError):
|
||||
model.name = "changed"
|
||||
|
||||
def test_cost_defaults(self) -> None:
|
||||
"""
|
||||
Cost rates default to None.
|
||||
"""
|
||||
model = LLM(name="test", model_type="openai")
|
||||
assert model.input_cost is None
|
||||
assert model.cached_input_cost is None
|
||||
assert model.output_cost is None
|
||||
|
||||
def test_openai_gpt_4_1_rates(self) -> None:
|
||||
"""
|
||||
GPT-4.1 has correct cost rates.
|
||||
"""
|
||||
assert 2.00 == OPENAI_GPT_4_1.input_cost
|
||||
assert 0.50 == OPENAI_GPT_4_1.cached_input_cost
|
||||
assert 8.00 == OPENAI_GPT_4_1.output_cost
|
||||
|
||||
def test_openai_gpt_5_mini_rates(self) -> None:
|
||||
"""
|
||||
GPT-5-mini has correct cost rates.
|
||||
"""
|
||||
assert 0.25 == OPENAI_GPT_5_MINI.input_cost
|
||||
assert 0.03 == OPENAI_GPT_5_MINI.cached_input_cost
|
||||
assert 2.00 == OPENAI_GPT_5_MINI.output_cost
|
||||
|
||||
def test_anthropic_sonnet_rates(self) -> None:
|
||||
"""
|
||||
Claude Sonnet 4.5 has correct cost rates.
|
||||
"""
|
||||
assert 3.00 == ANTHROPIC_CLAUDE_SONNET_4_5.input_cost
|
||||
assert ANTHROPIC_CLAUDE_SONNET_4_5.cached_input_cost is None
|
||||
assert 15.00 == ANTHROPIC_CLAUDE_SONNET_4_5.output_cost
|
||||
|
||||
def test_anthropic_haiku_rates(self) -> None:
|
||||
"""
|
||||
Claude Haiku 4.5 has correct cost rates.
|
||||
"""
|
||||
assert 1.00 == ANTHROPIC_CLAUDE_HAIKU_4_5.input_cost
|
||||
assert 5.00 == ANTHROPIC_CLAUDE_HAIKU_4_5.output_cost
|
||||
|
||||
def test_all_models_have_names(self) -> None:
|
||||
"""
|
||||
Every predefined model has a non-empty name.
|
||||
"""
|
||||
for model in (
|
||||
OPENAI_GPT_4_1,
|
||||
OPENAI_GPT_5_MINI,
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
ANTHROPIC_CLAUDE_HAIKU_4_5,
|
||||
):
|
||||
assert model.name
|
||||
|
||||
|
||||
class TestCostCalculation:
|
||||
"""Tests for calculate_llm_cost."""
|
||||
|
||||
def test_no_usage_returns_zero(self) -> None:
|
||||
"""
|
||||
Response without usage attribute returns 0.0.
|
||||
"""
|
||||
assert 0.0 == calculate_llm_cost(object(), OPENAI_GPT_4_1)
|
||||
|
||||
def test_openai_basic(self) -> None:
|
||||
"""
|
||||
OpenAI cost calculated from prompt/completion tokens.
|
||||
"""
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
prompt_tokens_details=None,
|
||||
)
|
||||
response = SimpleNamespace(usage=usage)
|
||||
cost = calculate_llm_cost(response, OPENAI_GPT_4_1)
|
||||
expected = (1000 * 2.00 + 500 * 8.00) / 1_000_000
|
||||
assert pytest.approx(expected) == cost
|
||||
|
||||
def test_openai_with_cache(self) -> None:
|
||||
"""
|
||||
OpenAI cached tokens use cached_input_cost rate.
|
||||
"""
|
||||
details = SimpleNamespace(cached_tokens=400)
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=200,
|
||||
prompt_tokens_details=details,
|
||||
)
|
||||
response = SimpleNamespace(usage=usage)
|
||||
cost = calculate_llm_cost(response, OPENAI_GPT_4_1)
|
||||
expected = (600 * 2.00 + 400 * 0.50 + 200 * 8.00) / 1_000_000
|
||||
assert pytest.approx(expected) == cost
|
||||
|
||||
def test_anthropic_basic(self) -> None:
|
||||
"""
|
||||
Anthropic cost uses input_tokens + output_tokens.
|
||||
"""
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
cache_read_input_tokens=0,
|
||||
cache_creation_input_tokens=0,
|
||||
)
|
||||
response = SimpleNamespace(usage=usage)
|
||||
cost = calculate_llm_cost(response, ANTHROPIC_CLAUDE_SONNET_4_5)
|
||||
expected = (1000 * 3.00 + 500 * 15.00) / 1_000_000
|
||||
assert pytest.approx(expected) == cost
|
||||
|
||||
def test_anthropic_with_cache_creation(self) -> None:
|
||||
"""
|
||||
Cache creation tokens count toward input cost.
|
||||
"""
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_read_input_tokens=100,
|
||||
cache_creation_input_tokens=300,
|
||||
)
|
||||
response = SimpleNamespace(usage=usage)
|
||||
cost = calculate_llm_cost(response, ANTHROPIC_CLAUDE_SONNET_4_5)
|
||||
non_cached = 800 + 300
|
||||
expected = (non_cached * 3.00 + 100 * 3.00 + 200 * 15.00) / 1_000_000
|
||||
assert pytest.approx(expected) == cost
|
||||
|
||||
def test_model_without_costs(self) -> None:
|
||||
"""
|
||||
Model with None costs treats them as 0.
|
||||
"""
|
||||
model = LLM(name="free", model_type="openai")
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
prompt_tokens_details=None,
|
||||
)
|
||||
response = SimpleNamespace(usage=usage)
|
||||
assert 0.0 == calculate_llm_cost(response, model)
|
||||
|
||||
|
||||
class TestRetryConstants:
|
||||
"""Tests for retry module constants."""
|
||||
|
||||
def test_transient_errors_count(self) -> None:
|
||||
"""
|
||||
Eight transient error types are defined.
|
||||
"""
|
||||
assert 8 == len(TRANSIENT_LLM_ERRORS)
|
||||
|
||||
def test_all_are_exception_subclasses(self) -> None:
|
||||
"""
|
||||
Every transient error is an Exception subclass.
|
||||
"""
|
||||
for exc_type in TRANSIENT_LLM_ERRORS:
|
||||
assert issubclass(exc_type, Exception)
|
||||
|
||||
def test_constants(self) -> None:
|
||||
"""
|
||||
Token limit and char estimate are set.
|
||||
"""
|
||||
assert 195_000 == ANTHROPIC_MAX_INPUT_TOKENS
|
||||
assert 4 == CHARS_PER_TOKEN_ESTIMATE
|
||||
|
||||
|
||||
class TestLLMUsage:
|
||||
"""Tests for LLMUsage."""
|
||||
|
||||
def test_frozen(self) -> None:
|
||||
"""
|
||||
LLMUsage is immutable.
|
||||
"""
|
||||
usage = LLMUsage(input_tokens=10, output_tokens=20)
|
||||
with pytest.raises(AttributeError):
|
||||
usage.input_tokens = 99
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
"""Tests for LLMResponse."""
|
||||
|
||||
def test_cost_defaults_to_zero(self) -> None:
|
||||
"""
|
||||
Cost defaults to 0.0.
|
||||
"""
|
||||
resp = LLMResponse(
|
||||
content="hello",
|
||||
usage=LLMUsage(input_tokens=1, output_tokens=1),
|
||||
raw_response=None,
|
||||
)
|
||||
assert 0.0 == resp.cost
|
||||
|
||||
def test_frozen(self) -> None:
|
||||
"""
|
||||
LLMResponse is immutable.
|
||||
"""
|
||||
resp = LLMResponse(
|
||||
content="hello",
|
||||
usage=LLMUsage(input_tokens=1, output_tokens=1),
|
||||
raw_response=None,
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
resp.content = "changed"
|
||||
|
||||
|
||||
class TestLLMOutputUnparseableError:
|
||||
"""Tests for LLMOutputUnparseableError."""
|
||||
|
||||
def test_default_message(self) -> None:
|
||||
"""
|
||||
Default message is set.
|
||||
"""
|
||||
err = LLMOutputUnparseableError()
|
||||
assert "could not be parsed" in str(err)
|
||||
assert 0.0 == err.cost
|
||||
|
||||
def test_custom_cost(self) -> None:
|
||||
"""
|
||||
Cost is preserved on the exception.
|
||||
"""
|
||||
err = LLMOutputUnparseableError(cost=1.23)
|
||||
assert 1.23 == err.cost
|
||||
|
||||
|
||||
class TestLLMClient:
|
||||
"""Tests for LLMClient."""
|
||||
|
||||
async def test_unsupported_model_type(self) -> None:
|
||||
"""
|
||||
Unknown model_type raises ValueError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
bad_model = LLM.__new__(LLM)
|
||||
object.__setattr__(bad_model, "name", "bad")
|
||||
object.__setattr__(bad_model, "model_type", "unknown")
|
||||
object.__setattr__(bad_model, "input_cost", None)
|
||||
object.__setattr__(bad_model, "cached_input_cost", None)
|
||||
object.__setattr__(bad_model, "output_cost", None)
|
||||
|
||||
with (
|
||||
patch.object(LLMClient, "_ensure_clients"),
|
||||
pytest.raises(ValueError, match="Unsupported"),
|
||||
):
|
||||
await client.call(bad_model, [])
|
||||
|
||||
async def test_anthropic_not_configured(self) -> None:
|
||||
"""
|
||||
Calling Anthropic without configuration raises RuntimeError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
await client._call_anthropic(ANTHROPIC_CLAUDE_SONNET_4_5, [], 1024)
|
||||
|
||||
async def test_openai_not_configured(self) -> None:
|
||||
"""
|
||||
Calling OpenAI without configuration raises RuntimeError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
with pytest.raises(RuntimeError, match="not configured"):
|
||||
await client._call_openai(OPENAI_GPT_4_1, [], 1024)
|
||||
|
||||
async def test_anthropic_prompt_too_large(self) -> None:
|
||||
"""
|
||||
Oversized prompt raises ValueError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
client._anthropic_client = AsyncMock()
|
||||
huge = "x" * (
|
||||
ANTHROPIC_MAX_INPUT_TOKENS * CHARS_PER_TOKEN_ESTIMATE + 100
|
||||
)
|
||||
messages = [{"role": "user", "content": huge}]
|
||||
|
||||
with pytest.raises(ValueError, match="too large"):
|
||||
await client._call_anthropic(
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5, messages, 1024
|
||||
)
|
||||
|
||||
async def test_anthropic_empty_response(self) -> None:
|
||||
"""
|
||||
Empty Anthropic response raises LLMOutputUnparseableError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
response = SimpleNamespace(
|
||||
content=[],
|
||||
usage=SimpleNamespace(
|
||||
input_tokens=10,
|
||||
output_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
cache_creation_input_tokens=0,
|
||||
),
|
||||
)
|
||||
mock_client.messages.create = AsyncMock(return_value=response)
|
||||
client._anthropic_client = mock_client
|
||||
|
||||
with pytest.raises(LLMOutputUnparseableError):
|
||||
await client._call_anthropic(
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
1024,
|
||||
)
|
||||
|
||||
async def test_anthropic_success(self) -> None:
|
||||
"""
|
||||
Successful Anthropic call returns LLMResponse.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
block = SimpleNamespace(text="hello world")
|
||||
response = SimpleNamespace(
|
||||
content=[block],
|
||||
usage=SimpleNamespace(
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
cache_read_input_tokens=0,
|
||||
cache_creation_input_tokens=0,
|
||||
),
|
||||
)
|
||||
mock_client.messages.create = AsyncMock(return_value=response)
|
||||
client._anthropic_client = mock_client
|
||||
|
||||
result = await client._call_anthropic(
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
1024,
|
||||
)
|
||||
assert "hello world" == result.content
|
||||
assert 10 == result.usage.input_tokens
|
||||
assert 5 == result.usage.output_tokens
|
||||
|
||||
async def test_anthropic_separates_system_prompt(self) -> None:
|
||||
"""
|
||||
System messages are extracted into the system kwarg.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
block = SimpleNamespace(text="ok")
|
||||
response = SimpleNamespace(
|
||||
content=[block],
|
||||
usage=SimpleNamespace(
|
||||
input_tokens=10,
|
||||
output_tokens=5,
|
||||
cache_read_input_tokens=0,
|
||||
cache_creation_input_tokens=0,
|
||||
),
|
||||
)
|
||||
mock_client.messages.create = AsyncMock(return_value=response)
|
||||
client._anthropic_client = mock_client
|
||||
|
||||
await client._call_anthropic(
|
||||
ANTHROPIC_CLAUDE_SONNET_4_5,
|
||||
[
|
||||
{"role": "system", "content": "be helpful"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
1024,
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.messages.create.call_args[1]
|
||||
assert "be helpful" == call_kwargs["system"]
|
||||
assert all(m["role"] != "system" for m in call_kwargs["messages"])
|
||||
|
||||
async def test_openai_empty_response(self) -> None:
|
||||
"""
|
||||
Empty OpenAI response raises LLMOutputUnparseableError.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
response = SimpleNamespace(
|
||||
choices=[],
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=0,
|
||||
prompt_tokens_details=None,
|
||||
),
|
||||
)
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=response)
|
||||
client._openai_client = mock_client
|
||||
|
||||
with pytest.raises(LLMOutputUnparseableError):
|
||||
await client._call_openai(
|
||||
OPENAI_GPT_4_1,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
1024,
|
||||
)
|
||||
|
||||
async def test_openai_success(self) -> None:
|
||||
"""
|
||||
Successful OpenAI call returns LLMResponse.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(message=SimpleNamespace(content="hello"))
|
||||
],
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
prompt_tokens_details=None,
|
||||
),
|
||||
)
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=response)
|
||||
client._openai_client = mock_client
|
||||
|
||||
result = await client._call_openai(
|
||||
OPENAI_GPT_4_1,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
1024,
|
||||
)
|
||||
assert "hello" == result.content
|
||||
assert 10 == result.usage.input_tokens
|
||||
assert 5 == result.usage.output_tokens
|
||||
|
||||
async def test_openai_gpt5_mini_uses_max_completion_tokens(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
GPT-5-mini uses max_completion_tokens instead of max_tokens.
|
||||
"""
|
||||
client = LLMClient()
|
||||
mock_client = AsyncMock()
|
||||
response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))],
|
||||
usage=SimpleNamespace(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
prompt_tokens_details=None,
|
||||
),
|
||||
)
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=response)
|
||||
client._openai_client = mock_client
|
||||
|
||||
await client._call_openai(
|
||||
OPENAI_GPT_5_MINI,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
2048,
|
||||
)
|
||||
|
||||
call_kwargs = mock_client.chat.completions.create.call_args[1]
|
||||
assert 2048 == call_kwargs["max_completion_tokens"]
|
||||
assert "max_tokens" not in call_kwargs
|
||||
Loading…
Reference in a new issue