[Chore ] Save LLM Cost in DB (#1768)

Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com>
This commit is contained in:
HeshamHM28 2025-09-01 22:04:34 +03:00 committed by GitHub
parent 588ba02640
commit ee55e78add
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 118 additions and 15 deletions

View file

@ -1,13 +1,23 @@
import os
from pydantic.dataclasses import dataclass
from typing import Optional, Any
# The following pricing information is based on public OpenAI and Claude documentation
# as of August 2025. Prices can change, so always check the official:
# https://docs.anthropic.com/en/docs/about-claude/pricing
# https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
# The pricing is in USD per 1M tokens.
# Some of the pricing are placeholder from Open AI https://platform.openai.com/docs/pricing?latest-pricing=flex.
@dataclass
class LLM:
name: str # On Azure OpenAI Service, this is the deployment name
max_tokens: int
api_version: str = ""
# Add new pricing attributes in USD per 1M tokens
input_cost: Optional[float] = None
output_cost: Optional[float] = None
# name of the model deployment on Azure OpenAI Service
@ -15,42 +25,55 @@ class LLM:
class GPT_4_OMNI(LLM):
name: str = "gpt-4o-2" if os.environ.get("OPENAI_API_TYPE") == "azure" else "gpt-4o"
max_tokens: int = 128000
input_cost: float = 2.50
output_cost: float = 10.00
@dataclass
class GPT_4_128k(LLM):
name: str = "gpt-4-1106-preview"
max_tokens: int = 128000
input_cost: float = 10.00
output_cost: float = 30.00
@dataclass
class GPT_4_32k(LLM):
name: str = "gpt4-32k"
max_tokens: int = 32768
input_cost: float = 60.00
output_cost: float = 120.00
@dataclass
class GPT_4(LLM):
name: str = "gpt-4-0613"
max_tokens: int = 8192
input_cost: float = 30.00
output_cost: float = 60.00
@dataclass
class GPT_3_5_Turbo_16k(LLM):
name: str = "gpt-3.5-turbo-16k"
max_tokens: int = 16384
input_cost: float = 3.00
output_cost: float = 4.00
@dataclass
class GPT_3_5_Turbo(LLM):
name: str = "gpt-3.5-turbo"
max_tokens: int = 4096
input_cost: float = 0.50
output_cost: float = 1.50
@dataclass
class Antropic_Claude_3_7(LLM):
name: str = "claude-3-7-sonnet-20250219"
max_tokens: int = 100000
input_cost: float = 3.00
output_cost: float = 15.00
@dataclass
@ -58,7 +81,8 @@ class Antropic_Claude_4(LLM):
name: str = "claude-sonnet-4-20250514"
max_tokens: int = 100000
api_version: str = ""
input_cost: float = 3.00
output_cost: float = 15.00
@dataclass
class OpenAI_GPT_4_1(LLM):
@ -66,6 +90,8 @@ class OpenAI_GPT_4_1(LLM):
name: str = "gpt-4.1"
max_tokens: int = 100000
api_version: str = "2024-12-01-preview"
input_cost: float = 2.00
output_cost: float = 8.00
@dataclass
@ -79,6 +105,8 @@ class OpenAI_GPT_O_3(LLM):
name: str = "azure/o3"
max_tokens: int = 100000
api_version = "2025-01-01-preview"
input_cost: float = 2.00
output_cost: float = 8.00
@dataclass
@ -86,12 +114,41 @@ class OpenAI_GPT_O_4_MINI(LLM):
name: str = "azure/o4-mini"
max_tokens: int = 100000
api_version = "2024-12-01-preview"
input_cost: float = 1.10
output_cost: float = 4.40
@dataclass
class GPT_5(LLM):
name: str = "gpt-5"
max_tokens: int = 100000
input_cost: float = 1.25
output_cost: float = 10.00
def calculate_llm_cost(response: Any, llm: LLM) -> float | None:
"""
Calculates the cost of an OpenAI API chat completion call.
Args:
response (dict): The JSON response from the OpenAI API call.
Returns:
float: The total cost in USD, or None if the cost cannot be calculated.
"""
try:
usage = response.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
prompt_cost = (prompt_tokens / 1_000_000) * llm.input_cost
completion_cost = (completion_tokens / 1_000_000) * llm.output_cost
total_cost: float = prompt_cost + completion_cost
return total_cost
except Exception as e:
print(f"An error occurred: {e}")
return None
EXPLAIN_MODEL: LLM = OpenAI_GPT_4_1()
PLAN_MODEL: LLM = OpenAI_GPT_4_1()

View file

@ -10,8 +10,9 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import create_claude_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXPLAINATIONS_MODEL
from aiservice.models.aimodels import EXPLAINATIONS_MODEL, calculate_llm_cost
from authapp.auth import AuthBearer
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
if TYPE_CHECKING:
@ -135,6 +136,7 @@ async def explain_optimizations(
original_explanation: str,
read_only_dependency_code: str | None = None,
explanations_model: LLM = EXPLAINATIONS_MODEL,
trace_id: str = "",
) -> ExplanationsResponseSchema | ExplanationsErrorResponseSchema:
"""Optimize the given python code for performance using the Claude 4 model.
@ -178,6 +180,7 @@ async def explain_optimizations(
output = await claude_client.with_options(max_retries=2).chat.completions.create(
model=explanations_model.name, messages=messages, n=1
)
await update_optimization_cost(trace_id=trace_id,cost=calculate_llm_cost(output, explanations_model))
except OpenAIError as e:
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
return ExplanationsErrorResponseSchema(error=str(e))
@ -240,6 +243,7 @@ async def explain(
data.annotated_tests,
data.original_explanation,
data.dependency_code,
trace_id=data.trace_id,
)
if isinstance(explanation_response, ExplanationsErrorResponseSchema):
ph(request.user, "Explanation not generated, revert to old explanation")

View file

@ -1,5 +1,7 @@
from datetime import UTC, datetime
from django.db.models import F
from log_features.models import OptimizationEvents, Repositories
@ -20,6 +22,7 @@ def log_optimization_event(
api_key_id=None,
metadata=None,
current_username=None,
llm_cost=None,
):
return OptimizationEvents.objects.acreate(
event_type=event_type,
@ -30,4 +33,9 @@ def log_optimization_event(
metadata=metadata,
created_at=datetime.now(UTC),
current_username=current_username,
llm_cost=llm_cost,
)
async def update_optimization_cost(trace_id: str, cost: float) -> float:
"""Atomically increment llm_cost for the given trace_id and return new total"""
await OptimizationEvents.objects.filter(trace_id=trace_id).aupdate(
llm_cost=F("llm_cost") + float(cost))

View file

@ -51,6 +51,7 @@ class OptimizationEvents(models.Model):
trace_id = models.CharField(max_length=36, null=True, blank=True)
pr_id = models.CharField(max_length=255, null=True, blank=True, unique=True)
api_key_id = models.IntegerField(null=True, blank=True)
llm_cost = models.FloatField(null=True, blank=True)
metadata = models.JSONField(null=True, blank=True)
created_at = models.DateTimeField(auto_now_add=True)
current_username = models.CharField(

View file

@ -16,7 +16,7 @@ from aiservice.env_specific import (
debug_log_sensitive_data,
debug_log_sensitive_data_from_callable,
)
from aiservice.models.aimodels import OPTIMIZE_MODEL
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
from authapp.user import get_user_by_id
from log_features.log_event import get_repository, log_optimization_event
from log_features.log_features import log_features
@ -51,7 +51,7 @@ async def optimize_python_code(
n: int = 1,
optimize_model: LLM = OPTIMIZE_MODEL,
python_version: tuple[int, int, int] = (3, 12, 9),
) -> list[OptimizeResponseItemSchema]:
) -> tuple[list[OptimizeResponseItemSchema], float | None]:
"""Optimize the given python code for performance using OpenAI's GPT-4 model.
Parameters
@ -83,7 +83,7 @@ async def optimize_python_code(
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
] = [system_message, user_message]
async with create_openai_client() as openai_client:
# TODO: Verify if the context window length is within the model capability
try:
@ -95,7 +95,7 @@ async def optimize_python_code(
print(e)
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
return []
llm_cost = calculate_llm_cost(output, optimize_model)
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
if output.usage is not None:
@ -119,8 +119,7 @@ async def optimize_python_code(
debug_log_sensitive_data(f"error for source:\n{ctx.source_code}")
debug_log_sensitive_data(f"Traceback: {e}")
continue
return optimization_response_items
return optimization_response_items, llm_cost
class OptimizeSchema(Schema):
@ -155,7 +154,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
)
if not validate_trace_id(data.trace_id):
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
optimization_response_items = await optimize_python_code(
optimization_response_items, llm_cost = await optimize_python_code(
request.user, ctx, data.dependency_code, n=5, python_version=python_version
)
if len(optimization_response_items) == 0:
@ -187,6 +186,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
"num_optimizations": len(optimization_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost
)
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(

View file

@ -14,7 +14,8 @@ from aiservice.env_specific import (
debug_log_sensitive_data,
debug_log_sensitive_data_from_callable,
)
from aiservice.models.aimodels import OPTIMIZE_MODEL
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from optimizer.context_utils.optimizer_context import (
BaseOptimizerContext,
@ -44,6 +45,7 @@ USER_PROMPT = (current_dir / "user_prompt.md").read_text()
async def optimize_python_code_line_profiler(
user_id: str,
trace_id: str,
line_profiler_results: str,
ctx: BaseOptimizerContext,
dependency_code: str | None = None,
@ -94,6 +96,7 @@ async def optimize_python_code_line_profiler(
output = await openai_client.with_options(max_retries=3).chat.completions.create(
model=optimize_model.name, messages=messages, n=n
)
await update_optimization_cost(trace_id=trace_id, cost=calculate_llm_cost(output, optimize_model))
except OpenAIError as e:
print("OpenAI Code Generation error ...")
print(e)
@ -155,6 +158,7 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
optimization_response_items = await optimize_python_code_line_profiler(
user_id=request.user,
trace_id=data.trace_id,
ctx=ctx,
dependency_code=data.dependency_code,
line_profiler_results=data.line_profiler_results,

View file

@ -14,7 +14,8 @@ from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import create_claude_client, debug_log_sensitive_data
from aiservice.models.aimodels import REFINEMENT_MODEL
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
@ -218,6 +219,7 @@ async def refinement(
output = await claude_client.with_options(max_retries=2).chat.completions.create(
model=optimize_model.name, messages=messages, n=1
)
llm_cost = calculate_llm_cost(output, optimize_model)
except OpenAIError as e:
print("refinement api: Claude Code Generation error ...")
print(e)
@ -249,6 +251,7 @@ async def refinement(
source_code=refined_optimization,
explanation=refined_explanation,
original_explanation=ctx.data.optimized_explanation,
llm_cost=llm_cost,
)
@ -276,6 +279,7 @@ class RefinementIntermediateResponseItemschema(Schema):
optimization_id: str
source_code: str
original_explanation: str
llm_cost: float
class RefinementResponseItemschema(Schema):
@ -326,9 +330,11 @@ async def refine(
# simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces, validate with libcst
filtered_refined_optimizations = []
source_code_set = set()
total_llm_cost = 0.0
for elem in refinement_data:
if isinstance(elem, OptimizeErrorResponseSchema):
continue
total_llm_cost += elem.llm_cost
try:
ctx.validate_python_module(elem.source_code)
except cst.ParserSyntaxError as e:
@ -367,6 +373,7 @@ async def refine(
cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations
},
)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return 200, Refinementschema(
refinements=[
RefinementResponseItemschema(

View file

@ -10,8 +10,9 @@ from typing import SupportsIndex
import isort
from aiservice.common_utils import parse_python_version
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
from aiservice.models.functions_to_optimize import FunctionToOptimize
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from pydantic import model_validator
@ -127,6 +128,7 @@ async def generate_regression_tests_from_function(
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
trace_id: str = "",
) -> str:
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
# Step 1: Generate an explanation of the function
@ -138,12 +140,14 @@ async def generate_regression_tests_from_function(
"content": EXPLAIN_USER_PROMPT.format(function_name=function_name, function_code=function_code),
}
explain_messages = [explain_system_message, explain_user_message]
total_llm_cost = 0.0
if print_text:
print_messages(explain_messages)
try:
explanation_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=explain_model.name, messages=explain_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(explanation_response, explain_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in explain step")
raise TestGenerationFailedException(e) from e
@ -168,6 +172,7 @@ async def generate_regression_tests_from_function(
fetch_data_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=execute_model.name, messages=fetch_data_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(fetch_data_response, execute_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in explain step")
raise TestGenerationFailedException(e) from e
@ -212,6 +217,7 @@ To help unit test the function above, list diverse scenarios that the function s
plan_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=plan_model.name, messages=plan_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(plan_response, plan_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in plan step")
raise TestGenerationFailedException(e) from e
@ -249,6 +255,7 @@ To help unit test the function above, list diverse scenarios that the function s
elaboration_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=plan_model.name, messages=elaboration_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(elaboration_response, plan_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in elaboration step")
raise TestGenerationFailedException(e) from e
@ -297,6 +304,7 @@ To help unit test the function above, list diverse scenarios that the function s
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=execute_model.name, messages=execute_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in execute step")
raise TestGenerationFailedException(e) from e
@ -328,6 +336,7 @@ To help unit test the function above, list diverse scenarios that the function s
if tries == 0:
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
# return the unit test as a string
return code
@ -407,6 +416,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
unit_test_package=data.test_framework,
approx_min_cases_to_cover=10,
python_version=python_version,
trace_id=data.trace_id,
)
print("/testgen: Instrumenting tests...")
instrumented_test_source = isort.code(

View file

@ -15,9 +15,10 @@ from pydantic import model_validator
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
from aiservice.models.functions_to_optimize import FunctionToOptimize
from authapp.auth import AuthBearer
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
@ -106,7 +107,8 @@ async def generate_regression_tests_from_function(
explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
temperature: float = 0.4,# temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
trace_id: str = ""
) -> str:
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
openai_client = create_openai_client()
@ -151,12 +153,15 @@ To help unit test the function above, list diverse scenarios that the function s
print_messages([execute_system_message, execute_user_message])
print_messages([note_message])
# TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach
total_llm_cost = 0.0
tries = 2
while tries > 0:
try:
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=execute_model.name, messages=execute_messages, temperature=temperature
)
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
except Exception as e:
logging.exception("OpenAI client error in execute step")
raise TestGenerationFailedError(e) from e
@ -186,9 +191,12 @@ To help unit test the function above, list diverse scenarios that the function s
logging.warning(f"Error: {e}")
logging.warning(f"Generated code: {code}")
continue
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
if tries == 0:
raise TestGenerationFailedError("Failed to generate test code after 2 tries.")
# return the unit test as a string
return code
@ -281,6 +289,7 @@ async def testgen(
unit_test_package=data.test_framework,
approx_min_cases_to_cover=10,
python_version=python_version,
trace_id=data.trace_id,
)
generated_test_source = validate_testgen_code(
og_generated_test_source, python_version[:2], max_lines_to_remove=40

View file

@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "optimization_events" ADD COLUMN "llm_cost" DOUBLE PRECISION;

View file

@ -188,6 +188,7 @@ model optimization_events {
repository repositories? @relation(fields: [repository_id], references: [id], onDelete: SetNull)
api_key cf_api_keys? @relation(fields: [api_key_id], references: [id], onDelete: SetNull)
comments comments[]
llm_cost Float?
@@index([event_type, created_at])
@@index([repository_id, event_type])