reduce abstraction
This commit is contained in:
parent
238ed71576
commit
2c46242165
8 changed files with 237 additions and 202 deletions
|
|
@ -2,7 +2,10 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
|
|
@ -14,6 +17,8 @@ if TYPE_CHECKING:
|
|||
from anthropic.types import Message as AnthropicMessage
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Model Definitions
|
||||
|
|
@ -129,48 +134,203 @@ async def call_llm(
|
|||
messages: list[dict[str, Any]],
|
||||
max_tokens: int = 8192,
|
||||
temperature: float | None = None,
|
||||
*,
|
||||
observe_as: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM with OpenAI or Anthropic client."""
|
||||
"""Call LLM with OpenAI or Anthropic client.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to use.
|
||||
model_type: Type of model ("openai", "anthropic", "google").
|
||||
messages: List of message dicts with "role" and "content".
|
||||
max_tokens: Maximum tokens in response.
|
||||
temperature: Sampling temperature.
|
||||
observe_as: If set, enables observability recording with this call type
|
||||
(e.g., "optimization", "test_generation").
|
||||
trace_id: Trace ID for observability.
|
||||
user_id: User ID for observability.
|
||||
python_version: Python version for observability context.
|
||||
context: Additional context for observability.
|
||||
|
||||
"""
|
||||
# Set up observability if requested
|
||||
llm_call_id = None
|
||||
llm_recorder = None
|
||||
error_recorder = None
|
||||
|
||||
if observe_as and trace_id:
|
||||
from aiservice.observability.database import ErrorRecorder, LLMCallRecorder
|
||||
|
||||
llm_recorder = LLMCallRecorder()
|
||||
error_recorder = ErrorRecorder()
|
||||
|
||||
# Extract prompts from messages for recording
|
||||
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), "")
|
||||
user_prompt = next((m["content"] for m in messages if m["role"] == "user"), "")
|
||||
|
||||
try:
|
||||
llm_call_id = await llm_recorder.record_llm_call_start(
|
||||
trace_id=trace_id,
|
||||
call_type=observe_as,
|
||||
model_name=model_name,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
n_candidates=1,
|
||||
user_id=user_id,
|
||||
python_version=python_version,
|
||||
context=context,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Observability: Failed to record LLM start ({observe_as}): {e}")
|
||||
|
||||
client = llm_clients[model_type]
|
||||
if client is None:
|
||||
msg = f"LLM client for model type '{model_type}' is not available"
|
||||
raise ValueError(msg)
|
||||
|
||||
if model_type == "anthropic":
|
||||
assert isinstance(client, AsyncAnthropicFoundry)
|
||||
system_prompt = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||
anthropic_messages = [{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"]
|
||||
start_time = time.time()
|
||||
try:
|
||||
if model_type == "anthropic":
|
||||
assert isinstance(client, AsyncAnthropicFoundry)
|
||||
system_prompt_content = next((m["content"] for m in messages if m["role"] == "system"), None)
|
||||
anthropic_messages = [
|
||||
{"role": m["role"], "content": m["content"]} for m in messages if m["role"] != "system"
|
||||
]
|
||||
|
||||
kwargs: dict[str, Any] = {"model": model_name, "messages": anthropic_messages, "max_tokens": max_tokens}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
if temperature is not None:
|
||||
kwargs["temperature"] = temperature
|
||||
kwargs: dict[str, Any] = {"model": model_name, "messages": anthropic_messages, "max_tokens": max_tokens}
|
||||
if system_prompt_content:
|
||||
kwargs["system"] = system_prompt_content
|
||||
if temperature is not None:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
response = await client.messages.create(**kwargs)
|
||||
content = "".join(block.text for block in response.content if hasattr(block, "text"))
|
||||
response = await client.messages.create(**kwargs)
|
||||
content = "".join(block.text for block in response.content if hasattr(block, "text"))
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
|
||||
raw_response=response,
|
||||
)
|
||||
result = LLMResponse(
|
||||
content=content,
|
||||
usage=LLMUsage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
|
||||
raw_response=response,
|
||||
)
|
||||
else:
|
||||
# OpenAI / Google (OpenAI-compatible)
|
||||
assert isinstance(client, AsyncOpenAI)
|
||||
openai_kwargs: dict[str, Any] = {"model": model_name, "messages": messages}
|
||||
if temperature is not None:
|
||||
openai_kwargs["temperature"] = temperature
|
||||
response = await client.chat.completions.create(**openai_kwargs)
|
||||
|
||||
# OpenAI / Google (OpenAI-compatible)
|
||||
assert isinstance(client, AsyncOpenAI)
|
||||
openai_kwargs: dict[str, Any] = {"model": model_name, "messages": messages}
|
||||
if temperature is not None:
|
||||
openai_kwargs["temperature"] = temperature
|
||||
response = await client.chat.completions.create(**openai_kwargs)
|
||||
result = LLMResponse(
|
||||
content=response.choices[0].message.content or "",
|
||||
usage=LLMUsage(
|
||||
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
),
|
||||
raw_response=response,
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=response.choices[0].message.content or "",
|
||||
usage=LLMUsage(
|
||||
input_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||
output_tokens=response.usage.completion_tokens if response.usage else 0,
|
||||
),
|
||||
raw_response=response,
|
||||
)
|
||||
# Record success if observability is enabled
|
||||
if llm_call_id and llm_recorder:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
_record_completion_background(llm_recorder, llm_call_id, result, latency_ms, model_name)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Record error if observability is enabled
|
||||
if trace_id and observe_as and error_recorder and llm_recorder:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
_record_error_background(error_recorder, llm_recorder, trace_id, llm_call_id, e, latency_ms, model_name)
|
||||
raise
|
||||
|
||||
|
||||
def _record_completion_background(
|
||||
llm_recorder: Any, llm_call_id: str, result: LLMResponse, latency_ms: int, model_name: str
|
||||
) -> None:
|
||||
"""Record LLM call completion in background (non-blocking)."""
|
||||
|
||||
async def _record() -> None:
|
||||
try:
|
||||
raw_response = None
|
||||
prompt_tokens = None
|
||||
completion_tokens = None
|
||||
total_tokens = None
|
||||
llm_cost = None
|
||||
candidates_generated = None
|
||||
|
||||
if hasattr(result.raw_response, "model_dump_json"):
|
||||
raw_response = result.raw_response.model_dump_json(indent=2)
|
||||
|
||||
if hasattr(result.raw_response, "usage"):
|
||||
usage = result.raw_response.usage
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", None) or getattr(usage, "input_tokens", None)
|
||||
completion_tokens = getattr(usage, "completion_tokens", None) or getattr(usage, "output_tokens", None)
|
||||
total_tokens = getattr(usage, "total_tokens", None)
|
||||
|
||||
if hasattr(result.raw_response, "choices"):
|
||||
candidates_generated = len(result.raw_response.choices)
|
||||
|
||||
# Calculate cost
|
||||
model_map = {"gpt-4.1": OpenAI_GPT_4_1(), "claude-sonnet-4-5": Anthropic_Claude_Sonnet_4_5()}
|
||||
if model_name in model_map:
|
||||
llm_cost = calculate_llm_cost(result.raw_response, model_map[model_name])
|
||||
|
||||
await llm_recorder.record_llm_call_completion(
|
||||
llm_call_id=llm_call_id,
|
||||
status="success",
|
||||
raw_response=raw_response,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
llm_cost=llm_cost,
|
||||
latency_ms=latency_ms,
|
||||
candidates_generated=candidates_generated,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Observability: Failed to record completion: {e}")
|
||||
|
||||
asyncio.create_task(_record())
|
||||
|
||||
|
||||
def _record_error_background(
|
||||
error_recorder: Any,
|
||||
llm_recorder: Any,
|
||||
trace_id: str,
|
||||
llm_call_id: str | None,
|
||||
error: Exception,
|
||||
latency_ms: int,
|
||||
model_name: str,
|
||||
) -> None:
|
||||
"""Record error in background (non-blocking)."""
|
||||
|
||||
async def _record() -> None:
|
||||
try:
|
||||
await error_recorder.record_error(
|
||||
trace_id=trace_id,
|
||||
error_type="llm_api",
|
||||
error_category="llm_error",
|
||||
severity="error",
|
||||
error_message=str(error),
|
||||
error_code=type(error).__name__,
|
||||
context={"model": model_name},
|
||||
)
|
||||
if llm_call_id:
|
||||
await llm_recorder.record_llm_call_completion(
|
||||
llm_call_id=llm_call_id,
|
||||
status="failed",
|
||||
error_type=type(error).__name__,
|
||||
error_message=str(error),
|
||||
latency_ms=latency_ms,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Observability: Failed to record error: {e}")
|
||||
|
||||
asyncio.create_task(_record())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -13,8 +13,7 @@ from packaging import version
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import OPTIMIZATION_REVIEW_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import OPTIMIZATION_REVIEW_MODEL, calculate_llm_cost, call_llm
|
||||
from log_features.log_event import update_optimization_cost, update_optimization_features_review
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -144,13 +143,6 @@ Output as a json markdown block with the key named as 'rating' and value being o
|
|||
return [system_message, user_message]
|
||||
|
||||
|
||||
@observe_llm_call("optimization_review")
|
||||
async def call_optimization_review_llm(
|
||||
trace_id: str, model: LLM, messages: list[dict[str, str]], user_id: str | None = None, context: dict | None = None
|
||||
) -> LLMResponse:
|
||||
return await call_llm(model_name=model.name, model_type=model.model_type, messages=messages)
|
||||
|
||||
|
||||
async def get_optimization_review(
|
||||
request, data: OptimizationReviewSchema, optimization_review_model: LLM = OPTIMIZATION_REVIEW_MODEL
|
||||
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
|
||||
|
|
@ -166,10 +158,12 @@ async def get_optimization_review(
|
|||
if data.call_sequence:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
response = await call_optimization_review_llm(
|
||||
trace_id=data.trace_id,
|
||||
model=optimization_review_model,
|
||||
response = await call_llm(
|
||||
model_name=optimization_review_model.name,
|
||||
model_type=optimization_review_model.model_type,
|
||||
messages=messages,
|
||||
observe_as="optimization_review",
|
||||
trace_id=data.trace_id,
|
||||
user_id=request.user,
|
||||
context=obs_context,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import LLM, OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from authapp.user import get_user_by_id
|
||||
from log_features.log_event import get_or_create_optimization_event
|
||||
from log_features.log_features import log_features
|
||||
|
|
@ -111,34 +110,6 @@ ASYNC_SYSTEM_PROMPT = (current_dir / "async_system_prompt.md").read_text()
|
|||
ASYNC_USER_PROMPT = (current_dir / "async_user_prompt.md").read_text()
|
||||
|
||||
|
||||
@observe_llm_call("optimization")
|
||||
async def call_optimization_llm(
|
||||
trace_id: str,
|
||||
model: LLM,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM for code optimization with automatic observability.
|
||||
|
||||
This function is decorated with @observe_llm_call which automatically:
|
||||
- Records call start (non-blocking)
|
||||
- Captures timing and token usage
|
||||
- Records completion (non-blocking)
|
||||
- Handles errors automatically
|
||||
|
||||
All observability runs in the background without blocking the LLM call.
|
||||
|
||||
Args:
|
||||
context: Additional context for observability (e.g., call_sequence for multi-model tracking)
|
||||
|
||||
"""
|
||||
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
||||
return await call_llm(model_name=model.name, model_type=model.model_type, messages=messages)
|
||||
|
||||
|
||||
async def optimize_python_code(
|
||||
user_id: str,
|
||||
ctx: BaseOptimizerContext,
|
||||
|
|
@ -163,13 +134,14 @@ async def optimize_python_code(
|
|||
# Build context for observability (includes call_sequence for multi-model tracking)
|
||||
obs_context = {"call_sequence": call_sequence} if call_sequence else None
|
||||
|
||||
# Call LLM with automatic observability (decorator handles everything)
|
||||
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
||||
try:
|
||||
output = await call_optimization_llm(
|
||||
output = await call_llm(
|
||||
model_name=optimize_model.name,
|
||||
model_type=optimize_model.model_type,
|
||||
messages=messages,
|
||||
observe_as="optimization",
|
||||
trace_id=trace_id,
|
||||
model=optimize_model,
|
||||
system_prompt=system_prompt,
|
||||
user_prompt=user_prompt,
|
||||
user_id=user_id,
|
||||
python_version=python_version_str,
|
||||
context=obs_context,
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import OPTIMIZE_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import OPTIMIZE_MODEL, calculate_llm_cost, call_llm
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.context_utils.optimizer_context import (
|
||||
|
|
@ -42,28 +41,6 @@ SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
|
|||
USER_PROMPT = (current_dir / "user_prompt.md").read_text()
|
||||
|
||||
|
||||
@observe_llm_call("line_profiler")
|
||||
async def call_line_profiler_llm(
|
||||
trace_id: str,
|
||||
model: LLM,
|
||||
messages: list[dict[str, str]],
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM for line profiler optimization with automatic observability.
|
||||
|
||||
This function is decorated with @observe_llm_call which automatically:
|
||||
- Records call start (non-blocking)
|
||||
- Captures timing and token usage
|
||||
- Records completion (non-blocking)
|
||||
- Handles errors automatically
|
||||
|
||||
All observability runs in the background without blocking the LLM call.
|
||||
"""
|
||||
return await call_llm(model_name=model.name, model_type=model.model_type, messages=messages)
|
||||
|
||||
|
||||
async def optimize_python_code_line_profiler( # noqa: D417
|
||||
user_id: str,
|
||||
trace_id: str,
|
||||
|
|
@ -105,12 +82,13 @@ async def optimize_python_code_line_profiler( # noqa: D417
|
|||
if call_sequence:
|
||||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
# Call LLM with automatic observability (decorator handles everything)
|
||||
try:
|
||||
output = await call_line_profiler_llm(
|
||||
trace_id=trace_id,
|
||||
model=optimize_model,
|
||||
output = await call_llm(
|
||||
model_name=optimize_model.name,
|
||||
model_type=optimize_model.model_type,
|
||||
messages=messages,
|
||||
observe_as="line_profiler",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
python_version=python_version_str,
|
||||
context=obs_context,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import REFINEMENT_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import REFINEMENT_MODEL, calculate_llm_cost, call_llm
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
||||
|
|
@ -188,13 +187,6 @@ Here is the function_references
|
|||
"""
|
||||
|
||||
|
||||
@observe_llm_call("refinement")
|
||||
async def call_refinement_llm(
|
||||
trace_id: str, model: LLM, messages: list[dict[str, str]], user_id: str | None = None, context: dict | None = None
|
||||
) -> LLMResponse:
|
||||
return await call_llm(model_name=model.name, model_type=model.model_type, messages=messages)
|
||||
|
||||
|
||||
async def refinement( # noqa: D417
|
||||
user_id: str,
|
||||
optimization_id: str,
|
||||
|
|
@ -249,8 +241,14 @@ async def refinement( # noqa: D417
|
|||
obs_context["call_sequence"] = call_sequence
|
||||
|
||||
try:
|
||||
output = await call_refinement_llm(
|
||||
trace_id=trace_id, model=optimize_model, messages=messages, user_id=user_id, context=obs_context
|
||||
output = await call_llm(
|
||||
model_name=optimize_model.name,
|
||||
model_type=optimize_model.model_type,
|
||||
messages=messages,
|
||||
observe_as="refinement",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
context=obs_context,
|
||||
)
|
||||
llm_cost = calculate_llm_cost(output.raw_response, optimize_model)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import LLM, RANKING_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import LLM, RANKING_MODEL, calculate_llm_cost, call_llm
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
|
||||
|
|
@ -73,23 +72,6 @@ Here are the function references
|
|||
"""
|
||||
|
||||
|
||||
@observe_llm_call("ranking")
|
||||
async def call_ranker_llm(
|
||||
trace_id: str, model: LLM, messages: list[dict[str, str]], user_id: str | None = None, context: dict | None = None
|
||||
) -> LLMResponse:
|
||||
"""Call LLM for ranking with automatic observability.
|
||||
|
||||
This function is decorated with @observe_llm_call which automatically:
|
||||
- Records call start (non-blocking)
|
||||
- Captures timing and token usage
|
||||
- Records completion (non-blocking)
|
||||
- Handles errors automatically
|
||||
|
||||
All observability runs in the background without blocking the LLM call.
|
||||
"""
|
||||
return await call_llm(model_name=model.name, model_type=model.model_type, messages=messages)
|
||||
|
||||
|
||||
async def rank_optimizations( # noqa: D417
|
||||
user_id: str, data: RankInputSchema, rank_model: LLM = RANKING_MODEL
|
||||
) -> RankResponseSchema | RankErrorResponseSchema:
|
||||
|
|
@ -128,12 +110,13 @@ async def rank_optimizations( # noqa: D417
|
|||
| ChatCompletionFunctionMessageParam
|
||||
] = [system_message, user_message]
|
||||
|
||||
# Call LLM with automatic observability (decorator handles everything)
|
||||
try:
|
||||
output = await call_ranker_llm(
|
||||
trace_id=data.trace_id,
|
||||
model=rank_model,
|
||||
output = await call_llm(
|
||||
model_name=rank_model.name,
|
||||
model_type=rank_model.model_type,
|
||||
messages=messages,
|
||||
observe_as="ranking",
|
||||
trace_id=data.trace_id,
|
||||
user_id=user_id,
|
||||
context={
|
||||
"num_candidates": len(data.diffs),
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ from openai import OpenAIError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, safe_isort, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, LLMResponse, calculate_llm_cost, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
||||
|
|
@ -190,31 +189,6 @@ def parse_and_validate_llm_output(
|
|||
raise
|
||||
|
||||
|
||||
@observe_llm_call("test_generation")
|
||||
async def call_testgen_llm(
|
||||
trace_id: str,
|
||||
model: LLM,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float,
|
||||
user_id: str | None = None,
|
||||
python_version: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM for test generation with automatic observability.
|
||||
|
||||
This function is decorated with @observe_llm_call which automatically:
|
||||
- Records call start (non-blocking)
|
||||
- Captures timing and token usage
|
||||
- Records completion (non-blocking)
|
||||
- Handles errors automatically
|
||||
|
||||
All observability runs in the background without blocking the LLM call.
|
||||
"""
|
||||
return await call_llm(
|
||||
model_name=model.name, model_type=model.model_type, messages=messages, temperature=temperature
|
||||
)
|
||||
|
||||
|
||||
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
||||
async def generate_and_validate_test_code(
|
||||
messages: list[dict[str, str]],
|
||||
|
|
@ -231,11 +205,13 @@ async def generate_and_validate_test_code(
|
|||
call_sequence: int | None = None,
|
||||
) -> str:
|
||||
obs_context: dict | None = {"call_sequence": call_sequence} if call_sequence else None
|
||||
response = await call_testgen_llm(
|
||||
trace_id=trace_id,
|
||||
model=model,
|
||||
response = await call_llm(
|
||||
model_name=model.name,
|
||||
model_type=model.model_type,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
observe_as="test_generation",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
python_version=".".join(str(v) for v in python_version),
|
||||
context=obs_context,
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import EXECUTE_MODEL, LLMResponse, call_llm
|
||||
from aiservice.observability.decorators import observe_llm_call
|
||||
from aiservice.llm import EXECUTE_MODEL, call_llm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.http import HttpRequest
|
||||
|
|
@ -68,31 +67,6 @@ def _extract_yaml_steps(text: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
@observe_llm_call("workflow_generation")
|
||||
async def call_workflow_gen_llm(
|
||||
trace_id: str,
|
||||
model,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: int = 0,
|
||||
n: int = 1,
|
||||
user_id: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call LLM for workflow generation with automatic observability.
|
||||
|
||||
This function is decorated with @observe_llm_call which automatically:
|
||||
- Records call start (non-blocking)
|
||||
- Captures timing and token usage
|
||||
- Records completion (non-blocking)
|
||||
- Handles errors automatically
|
||||
|
||||
All observability runs in the background without blocking the LLM call.
|
||||
"""
|
||||
return await call_llm(
|
||||
model_name=model.name, model_type=model.model_type, messages=messages, n=n, temperature=temperature
|
||||
)
|
||||
|
||||
|
||||
async def generate_workflow_steps_llm(
|
||||
repo_files: dict[str, str],
|
||||
directory_structure: dict[str, Any],
|
||||
|
|
@ -120,13 +94,13 @@ async def generate_workflow_steps_llm(
|
|||
debug_log_sensitive_data(f"Generating workflow steps with prompt length: {len(user_prompt)}")
|
||||
|
||||
try:
|
||||
# Call LLM with automatic observability (decorator handles everything)
|
||||
response = await call_workflow_gen_llm(
|
||||
trace_id=trace_id,
|
||||
model=EXECUTE_MODEL,
|
||||
response = await call_llm(
|
||||
model_name=EXECUTE_MODEL.name,
|
||||
model_type=EXECUTE_MODEL.model_type,
|
||||
messages=[system_message, user_message],
|
||||
temperature=0,
|
||||
n=1,
|
||||
observe_as="workflow_generation",
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
context={"num_files": len(repo_files)},
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue