[Chore ] Save LLM Cost in DB (#1768)
Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com>
This commit is contained in:
parent
588ba02640
commit
ee55e78add
11 changed files with 118 additions and 15 deletions
|
|
@ -1,13 +1,23 @@
|
|||
import os
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
|
||||
# The following pricing information is based on public OpenAI and Claude documentation
|
||||
# as of August 2025. Prices can change, so always check the official:
|
||||
# https://docs.anthropic.com/en/docs/about-claude/pricing
|
||||
# https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
# The pricing is in USD per 1M tokens.
|
||||
# Some of the pricing are placeholder from Open AI https://platform.openai.com/docs/pricing?latest-pricing=flex.
|
||||
@dataclass
|
||||
class LLM:
|
||||
name: str # On Azure OpenAI Service, this is the deployment name
|
||||
max_tokens: int
|
||||
api_version: str = ""
|
||||
# Add new pricing attributes in USD per 1M tokens
|
||||
input_cost: Optional[float] = None
|
||||
output_cost: Optional[float] = None
|
||||
|
||||
|
||||
# name of the model deployment on Azure OpenAI Service
|
||||
|
|
@ -15,42 +25,55 @@ class LLM:
|
|||
class GPT_4_OMNI(LLM):
|
||||
name: str = "gpt-4o-2" if os.environ.get("OPENAI_API_TYPE") == "azure" else "gpt-4o"
|
||||
max_tokens: int = 128000
|
||||
input_cost: float = 2.50
|
||||
output_cost: float = 10.00
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT_4_128k(LLM):
|
||||
name: str = "gpt-4-1106-preview"
|
||||
max_tokens: int = 128000
|
||||
|
||||
input_cost: float = 10.00
|
||||
output_cost: float = 30.00
|
||||
|
||||
@dataclass
|
||||
class GPT_4_32k(LLM):
|
||||
name: str = "gpt4-32k"
|
||||
max_tokens: int = 32768
|
||||
input_cost: float = 60.00
|
||||
output_cost: float = 120.00
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT_4(LLM):
|
||||
name: str = "gpt-4-0613"
|
||||
max_tokens: int = 8192
|
||||
input_cost: float = 30.00
|
||||
output_cost: float = 60.00
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT_3_5_Turbo_16k(LLM):
|
||||
name: str = "gpt-3.5-turbo-16k"
|
||||
max_tokens: int = 16384
|
||||
input_cost: float = 3.00
|
||||
output_cost: float = 4.00
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT_3_5_Turbo(LLM):
|
||||
name: str = "gpt-3.5-turbo"
|
||||
max_tokens: int = 4096
|
||||
input_cost: float = 0.50
|
||||
output_cost: float = 1.50
|
||||
|
||||
|
||||
@dataclass
|
||||
class Antropic_Claude_3_7(LLM):
|
||||
name: str = "claude-3-7-sonnet-20250219"
|
||||
max_tokens: int = 100000
|
||||
input_cost: float = 3.00
|
||||
output_cost: float = 15.00
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -58,7 +81,8 @@ class Antropic_Claude_4(LLM):
|
|||
name: str = "claude-sonnet-4-20250514"
|
||||
max_tokens: int = 100000
|
||||
api_version: str = ""
|
||||
|
||||
input_cost: float = 3.00
|
||||
output_cost: float = 15.00
|
||||
|
||||
@dataclass
|
||||
class OpenAI_GPT_4_1(LLM):
|
||||
|
|
@ -66,6 +90,8 @@ class OpenAI_GPT_4_1(LLM):
|
|||
name: str = "gpt-4.1"
|
||||
max_tokens: int = 100000
|
||||
api_version: str = "2024-12-01-preview"
|
||||
input_cost: float = 2.00
|
||||
output_cost: float = 8.00
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -79,6 +105,8 @@ class OpenAI_GPT_O_3(LLM):
|
|||
name: str = "azure/o3"
|
||||
max_tokens: int = 100000
|
||||
api_version = "2025-01-01-preview"
|
||||
input_cost: float = 2.00
|
||||
output_cost: float = 8.00
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -86,12 +114,41 @@ class OpenAI_GPT_O_4_MINI(LLM):
|
|||
name: str = "azure/o4-mini"
|
||||
max_tokens: int = 100000
|
||||
api_version = "2024-12-01-preview"
|
||||
input_cost: float = 1.10
|
||||
output_cost: float = 4.40
|
||||
|
||||
@dataclass
|
||||
class GPT_5(LLM):
|
||||
name: str = "gpt-5"
|
||||
max_tokens: int = 100000
|
||||
input_cost: float = 1.25
|
||||
output_cost: float = 10.00
|
||||
|
||||
def calculate_llm_cost(response: Any, llm: LLM) -> float | None:
|
||||
"""
|
||||
Calculates the cost of an OpenAI API chat completion call.
|
||||
|
||||
Args:
|
||||
response (dict): The JSON response from the OpenAI API call.
|
||||
|
||||
Returns:
|
||||
float: The total cost in USD, or None if the cost cannot be calculated.
|
||||
"""
|
||||
try:
|
||||
usage = response.usage
|
||||
prompt_tokens = usage.prompt_tokens
|
||||
completion_tokens = usage.completion_tokens
|
||||
|
||||
prompt_cost = (prompt_tokens / 1_000_000) * llm.input_cost
|
||||
completion_cost = (completion_tokens / 1_000_000) * llm.output_cost
|
||||
|
||||
total_cost: float = prompt_cost + completion_cost
|
||||
|
||||
return total_cost
|
||||
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
EXPLAIN_MODEL: LLM = OpenAI_GPT_4_1()
|
||||
PLAN_MODEL: LLM = OpenAI_GPT_4_1()
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import create_claude_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import EXPLAINATIONS_MODEL
|
||||
from aiservice.models.aimodels import EXPLAINATIONS_MODEL, calculate_llm_cost
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -135,6 +136,7 @@ async def explain_optimizations(
|
|||
original_explanation: str,
|
||||
read_only_dependency_code: str | None = None,
|
||||
explanations_model: LLM = EXPLAINATIONS_MODEL,
|
||||
trace_id: str = "",
|
||||
) -> ExplanationsResponseSchema | ExplanationsErrorResponseSchema:
|
||||
"""Optimize the given python code for performance using the Claude 4 model.
|
||||
|
||||
|
|
@ -178,6 +180,7 @@ async def explain_optimizations(
|
|||
output = await claude_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=explanations_model.name, messages=messages, n=1
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id,cost=calculate_llm_cost(output, explanations_model))
|
||||
except OpenAIError as e:
|
||||
debug_log_sensitive_data(f"Failed to generate new explanation, Error message: {e}")
|
||||
return ExplanationsErrorResponseSchema(error=str(e))
|
||||
|
|
@ -240,6 +243,7 @@ async def explain(
|
|||
data.annotated_tests,
|
||||
data.original_explanation,
|
||||
data.dependency_code,
|
||||
trace_id=data.trace_id,
|
||||
)
|
||||
if isinstance(explanation_response, ExplanationsErrorResponseSchema):
|
||||
ph(request.user, "Explanation not generated, revert to old explanation")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from datetime import UTC, datetime
|
||||
|
||||
from django.db.models import F
|
||||
|
||||
from log_features.models import OptimizationEvents, Repositories
|
||||
|
||||
|
||||
|
|
@ -20,6 +22,7 @@ def log_optimization_event(
|
|||
api_key_id=None,
|
||||
metadata=None,
|
||||
current_username=None,
|
||||
llm_cost=None,
|
||||
):
|
||||
return OptimizationEvents.objects.acreate(
|
||||
event_type=event_type,
|
||||
|
|
@ -30,4 +33,9 @@ def log_optimization_event(
|
|||
metadata=metadata,
|
||||
created_at=datetime.now(UTC),
|
||||
current_username=current_username,
|
||||
llm_cost=llm_cost,
|
||||
)
|
||||
async def update_optimization_cost(trace_id: str, cost: float) -> float:
|
||||
"""Atomically increment llm_cost for the given trace_id and return new total"""
|
||||
await OptimizationEvents.objects.filter(trace_id=trace_id).aupdate(
|
||||
llm_cost=F("llm_cost") + float(cost))
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class OptimizationEvents(models.Model):
|
|||
trace_id = models.CharField(max_length=36, null=True, blank=True)
|
||||
pr_id = models.CharField(max_length=255, null=True, blank=True, unique=True)
|
||||
api_key_id = models.IntegerField(null=True, blank=True)
|
||||
llm_cost = models.FloatField(null=True, blank=True)
|
||||
metadata = models.JSONField(null=True, blank=True)
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
current_username = models.CharField(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from aiservice.env_specific import (
|
|||
debug_log_sensitive_data,
|
||||
debug_log_sensitive_data_from_callable,
|
||||
)
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
||||
from authapp.user import get_user_by_id
|
||||
from log_features.log_event import get_repository, log_optimization_event
|
||||
from log_features.log_features import log_features
|
||||
|
|
@ -51,7 +51,7 @@ async def optimize_python_code(
|
|||
n: int = 1,
|
||||
optimize_model: LLM = OPTIMIZE_MODEL,
|
||||
python_version: tuple[int, int, int] = (3, 12, 9),
|
||||
) -> list[OptimizeResponseItemSchema]:
|
||||
) -> tuple[list[OptimizeResponseItemSchema], float | None]:
|
||||
"""Optimize the given python code for performance using OpenAI's GPT-4 model.
|
||||
|
||||
Parameters
|
||||
|
|
@ -83,7 +83,7 @@ async def optimize_python_code(
|
|||
| ChatCompletionAssistantMessageParam
|
||||
| ChatCompletionToolMessageParam
|
||||
| ChatCompletionFunctionMessageParam
|
||||
] = [system_message, user_message]
|
||||
] = [system_message, user_message]
|
||||
async with create_openai_client() as openai_client:
|
||||
# TODO: Verify if the context window length is within the model capability
|
||||
try:
|
||||
|
|
@ -95,7 +95,7 @@ async def optimize_python_code(
|
|||
print(e)
|
||||
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
|
||||
return []
|
||||
|
||||
llm_cost = calculate_llm_cost(output, optimize_model)
|
||||
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
|
||||
|
||||
if output.usage is not None:
|
||||
|
|
@ -119,8 +119,7 @@ async def optimize_python_code(
|
|||
debug_log_sensitive_data(f"error for source:\n{ctx.source_code}")
|
||||
debug_log_sensitive_data(f"Traceback: {e}")
|
||||
continue
|
||||
|
||||
return optimization_response_items
|
||||
return optimization_response_items, llm_cost
|
||||
|
||||
|
||||
class OptimizeSchema(Schema):
|
||||
|
|
@ -155,7 +154,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
|
|||
)
|
||||
if not validate_trace_id(data.trace_id):
|
||||
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
optimization_response_items = await optimize_python_code(
|
||||
optimization_response_items, llm_cost = await optimize_python_code(
|
||||
request.user, ctx, data.dependency_code, n=5, python_version=python_version
|
||||
)
|
||||
if len(optimization_response_items) == 0:
|
||||
|
|
@ -187,6 +186,7 @@ async def optimize(request, data: OptimizeSchema) -> tuple[int, OptimizeResponse
|
|||
"num_optimizations": len(optimization_response_items),
|
||||
"experiment_metadata": data.experiment_metadata,
|
||||
},
|
||||
llm_cost=llm_cost
|
||||
)
|
||||
if hasattr(request, "should_log_features") and request.should_log_features:
|
||||
await log_features(
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from aiservice.env_specific import (
|
|||
debug_log_sensitive_data,
|
||||
debug_log_sensitive_data_from_callable,
|
||||
)
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL
|
||||
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.context_utils.optimizer_context import (
|
||||
BaseOptimizerContext,
|
||||
|
|
@ -44,6 +45,7 @@ USER_PROMPT = (current_dir / "user_prompt.md").read_text()
|
|||
|
||||
async def optimize_python_code_line_profiler(
|
||||
user_id: str,
|
||||
trace_id: str,
|
||||
line_profiler_results: str,
|
||||
ctx: BaseOptimizerContext,
|
||||
dependency_code: str | None = None,
|
||||
|
|
@ -94,6 +96,7 @@ async def optimize_python_code_line_profiler(
|
|||
output = await openai_client.with_options(max_retries=3).chat.completions.create(
|
||||
model=optimize_model.name, messages=messages, n=n
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=calculate_llm_cost(output, optimize_model))
|
||||
except OpenAIError as e:
|
||||
print("OpenAI Code Generation error ...")
|
||||
print(e)
|
||||
|
|
@ -155,6 +158,7 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
return 400, OptimizeErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
||||
optimization_response_items = await optimize_python_code_line_profiler(
|
||||
user_id=request.user,
|
||||
trace_id=data.trace_id,
|
||||
ctx=ctx,
|
||||
dependency_code=data.dependency_code,
|
||||
line_profiler_results=data.line_profiler_results,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from pydantic import ValidationError
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import create_claude_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import REFINEMENT_MODEL
|
||||
from aiservice.models.aimodels import REFINEMENT_MODEL, calculate_llm_cost
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from optimizer.context_utils.refiner_context import BaseRefinerContext, RefinementContextData
|
||||
|
||||
|
|
@ -218,6 +219,7 @@ async def refinement(
|
|||
output = await claude_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=optimize_model.name, messages=messages, n=1
|
||||
)
|
||||
llm_cost = calculate_llm_cost(output, optimize_model)
|
||||
except OpenAIError as e:
|
||||
print("refinement api: Claude Code Generation error ...")
|
||||
print(e)
|
||||
|
|
@ -249,6 +251,7 @@ async def refinement(
|
|||
source_code=refined_optimization,
|
||||
explanation=refined_explanation,
|
||||
original_explanation=ctx.data.optimized_explanation,
|
||||
llm_cost=llm_cost,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -276,6 +279,7 @@ class RefinementIntermediateResponseItemschema(Schema):
|
|||
optimization_id: str
|
||||
source_code: str
|
||||
original_explanation: str
|
||||
llm_cost: float
|
||||
|
||||
|
||||
class RefinementResponseItemschema(Schema):
|
||||
|
|
@ -326,9 +330,11 @@ async def refine(
|
|||
# simple filtering mechanism, remove empty strings and remove duplicates after removing trailing and leading whitespaces, validate with libcst
|
||||
filtered_refined_optimizations = []
|
||||
source_code_set = set()
|
||||
total_llm_cost = 0.0
|
||||
for elem in refinement_data:
|
||||
if isinstance(elem, OptimizeErrorResponseSchema):
|
||||
continue
|
||||
total_llm_cost += elem.llm_cost
|
||||
try:
|
||||
ctx.validate_python_module(elem.source_code)
|
||||
except cst.ParserSyntaxError as e:
|
||||
|
|
@ -367,6 +373,7 @@ async def refine(
|
|||
cei.optimization_id[:-4] + "refi": cei.explanation for cei in filtered_refined_optimizations
|
||||
},
|
||||
)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
return 200, Refinementschema(
|
||||
refinements=[
|
||||
RefinementResponseItemschema(
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@ from typing import SupportsIndex
|
|||
import isort
|
||||
from aiservice.common_utils import parse_python_version
|
||||
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from ninja import NinjaAPI, Schema
|
||||
from pydantic import model_validator
|
||||
|
|
@ -127,6 +128,7 @@ async def generate_regression_tests_from_function(
|
|||
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
|
||||
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
|
||||
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
|
||||
trace_id: str = "",
|
||||
) -> str:
|
||||
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
|
||||
# Step 1: Generate an explanation of the function
|
||||
|
|
@ -138,12 +140,14 @@ async def generate_regression_tests_from_function(
|
|||
"content": EXPLAIN_USER_PROMPT.format(function_name=function_name, function_code=function_code),
|
||||
}
|
||||
explain_messages = [explain_system_message, explain_user_message]
|
||||
total_llm_cost = 0.0
|
||||
if print_text:
|
||||
print_messages(explain_messages)
|
||||
try:
|
||||
explanation_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=explain_model.name, messages=explain_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(explanation_response, explain_model) or 0.0
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in explain step")
|
||||
raise TestGenerationFailedException(e) from e
|
||||
|
|
@ -168,6 +172,7 @@ async def generate_regression_tests_from_function(
|
|||
fetch_data_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=execute_model.name, messages=fetch_data_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(fetch_data_response, execute_model) or 0.0
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in explain step")
|
||||
raise TestGenerationFailedException(e) from e
|
||||
|
|
@ -212,6 +217,7 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
plan_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=plan_model.name, messages=plan_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(plan_response, plan_model) or 0.0
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in plan step")
|
||||
raise TestGenerationFailedException(e) from e
|
||||
|
|
@ -249,6 +255,7 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
elaboration_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=plan_model.name, messages=elaboration_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(elaboration_response, plan_model) or 0.0
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in elaboration step")
|
||||
raise TestGenerationFailedException(e) from e
|
||||
|
|
@ -297,6 +304,7 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=execute_model.name, messages=execute_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in execute step")
|
||||
raise TestGenerationFailedException(e) from e
|
||||
|
|
@ -328,6 +336,7 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
if tries == 0:
|
||||
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
|
||||
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
# return the unit test as a string
|
||||
return code
|
||||
|
||||
|
|
@ -407,6 +416,7 @@ async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSch
|
|||
unit_test_package=data.test_framework,
|
||||
approx_min_cases_to_cover=10,
|
||||
python_version=python_version,
|
||||
trace_id=data.trace_id,
|
||||
)
|
||||
print("/testgen: Instrumenting tests...")
|
||||
instrumented_test_source = isort.code(
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ from pydantic import model_validator
|
|||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
|
||||
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from authapp.auth import AuthBearer
|
||||
from log_features.log_event import update_optimization_cost
|
||||
from log_features.log_features import log_features
|
||||
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
||||
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
||||
|
|
@ -106,7 +107,8 @@ async def generate_regression_tests_from_function(
|
|||
explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1
|
||||
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
|
||||
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
|
||||
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
|
||||
temperature: float = 0.4,# temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
|
||||
trace_id: str = ""
|
||||
) -> str:
|
||||
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
|
||||
openai_client = create_openai_client()
|
||||
|
|
@ -151,12 +153,15 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
print_messages([execute_system_message, execute_user_message])
|
||||
print_messages([note_message])
|
||||
# TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach
|
||||
total_llm_cost = 0.0
|
||||
tries = 2
|
||||
while tries > 0:
|
||||
try:
|
||||
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=execute_model.name, messages=execute_messages, temperature=temperature
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in execute step")
|
||||
raise TestGenerationFailedError(e) from e
|
||||
|
|
@ -186,9 +191,12 @@ To help unit test the function above, list diverse scenarios that the function s
|
|||
logging.warning(f"Error: {e}")
|
||||
logging.warning(f"Generated code: {code}")
|
||||
continue
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
|
||||
if tries == 0:
|
||||
raise TestGenerationFailedError("Failed to generate test code after 2 tries.")
|
||||
|
||||
|
||||
# return the unit test as a string
|
||||
return code
|
||||
|
||||
|
|
@ -281,6 +289,7 @@ async def testgen(
|
|||
unit_test_package=data.test_framework,
|
||||
approx_min_cases_to_cover=10,
|
||||
python_version=python_version,
|
||||
trace_id=data.trace_id,
|
||||
)
|
||||
generated_test_source = validate_testgen_code(
|
||||
og_generated_test_source, python_version[:2], max_lines_to_remove=40
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
-- AlterTable
|
||||
ALTER TABLE "optimization_events" ADD COLUMN "llm_cost" DOUBLE PRECISION;
|
||||
|
|
@ -188,6 +188,7 @@ model optimization_events {
|
|||
repository repositories? @relation(fields: [repository_id], references: [id], onDelete: SetNull)
|
||||
api_key cf_api_keys? @relation(fields: [api_key_id], references: [id], onDelete: SetNull)
|
||||
comments comments[]
|
||||
llm_cost Float?
|
||||
|
||||
@@index([event_type, created_at])
|
||||
@@index([repository_id, event_type])
|
||||
|
|
|
|||
Loading…
Reference in a new issue