simplify logic
This commit is contained in:
parent
72a356de8c
commit
ac986b1b86
1 changed files with 57 additions and 43 deletions
|
|
@ -177,17 +177,6 @@ def instrument_tests(
|
|||
return instrumented_behavior, instrumented_perf
|
||||
|
||||
|
||||
async def invoke_openai_testgen_model(messages: list[dict[str, str]], model: LLM, temperature: float) -> ChatCompletion:
|
||||
try:
|
||||
return await open_ai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=model.name, messages=messages, temperature=temperature
|
||||
)
|
||||
except Exception as e:
|
||||
logging.exception("OpenAI client error in execute step")
|
||||
sentry_sdk.capture_exception(e)
|
||||
raise TestGenerationFailedError(e) from e
|
||||
|
||||
|
||||
def parse_and_validate_llm_output(
|
||||
response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str
|
||||
) -> str:
|
||||
|
|
@ -212,6 +201,42 @@ def parse_and_validate_llm_output(
|
|||
raise
|
||||
|
||||
|
||||
@stamina.retry(on=(SyntaxError, ValueError), attempts=2)
|
||||
async def generate_and_validate_test_code(
|
||||
messages: list[dict[str, str]],
|
||||
model: LLM,
|
||||
temperature: float,
|
||||
ctx: BaseTestGenContext,
|
||||
python_version: tuple[int, int, int],
|
||||
error_context: str,
|
||||
execute_model: LLM,
|
||||
cost_tracker: list[float],
|
||||
user_id: str,
|
||||
posthog_event_suffix: str,
|
||||
) -> str:
|
||||
response = await open_ai_client.with_options(max_retries=2).chat.completions.create(
|
||||
model=model.name, messages=messages, temperature=temperature
|
||||
)
|
||||
cost = calculate_llm_cost(response, execute_model) or 0.0
|
||||
cost_tracker.append(cost)
|
||||
|
||||
debug_log_sensitive_data(f"OpenAIClient {error_context}execute response:\n{response.model_dump_json(indent=2)}")
|
||||
|
||||
if response.usage:
|
||||
ph(
|
||||
user_id,
|
||||
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
|
||||
properties={"model": execute_model.name, "usage": response.usage.model_dump_json()},
|
||||
)
|
||||
|
||||
return parse_and_validate_llm_output(
|
||||
response_content=response.choices[0].message.content,
|
||||
ctx=ctx,
|
||||
python_version=python_version,
|
||||
error_context=error_context,
|
||||
)
|
||||
|
||||
|
||||
async def generate_regression_tests_from_function(
|
||||
ctx: BaseTestGenContext,
|
||||
user_id: str,
|
||||
|
|
@ -227,39 +252,28 @@ async def generate_regression_tests_from_function(
|
|||
ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async
|
||||
)
|
||||
|
||||
total_llm_cost = 0.0
|
||||
max_tries = 2
|
||||
for attempt in range(max_tries):
|
||||
response = await invoke_openai_testgen_model(
|
||||
messages=execute_messages, model=execute_model, temperature=temperature
|
||||
cost_tracker = []
|
||||
try:
|
||||
validated_code = await generate_and_validate_test_code(
|
||||
messages=execute_messages,
|
||||
model=execute_model,
|
||||
temperature=temperature,
|
||||
ctx=ctx,
|
||||
python_version=python_version,
|
||||
error_context=error_context,
|
||||
execute_model=execute_model,
|
||||
cost_tracker=cost_tracker,
|
||||
user_id=user_id,
|
||||
posthog_event_suffix=posthog_event_suffix,
|
||||
)
|
||||
total_llm_cost += calculate_llm_cost(response, execute_model) or 0.0
|
||||
debug_log_sensitive_data(f"OpenAIClient {error_context}execute response:\n{response.model_dump_json(indent=2)}")
|
||||
|
||||
if response.usage:
|
||||
ph(
|
||||
user_id,
|
||||
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
|
||||
properties={"model": execute_model.name, "usage": response.usage.model_dump_json()},
|
||||
)
|
||||
|
||||
try:
|
||||
validated_code = parse_and_validate_llm_output(
|
||||
response_content=response.choices[0].message.content,
|
||||
ctx=ctx,
|
||||
python_version=python_version,
|
||||
error_context=error_context,
|
||||
)
|
||||
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
return validated_code # noqa: TRY300
|
||||
except (SyntaxError, ValueError) as e:
|
||||
logging.warning("Attempt %d/%d failed: %s", attempt + 1, max_tries, e)
|
||||
if attempt + 1 >= max_tries:
|
||||
msg = f"Failed to generate valid {error_context}test code after {max_tries} tries."
|
||||
raise TestGenerationFailedError(msg) from e
|
||||
|
||||
raise TestGenerationFailedError("Exited generation loop unexpectedly.")
|
||||
total_llm_cost = sum(cost_tracker)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
return validated_code
|
||||
except (SyntaxError, ValueError) as e:
|
||||
total_llm_cost = sum(cost_tracker)
|
||||
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
||||
msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries."
|
||||
raise TestGenerationFailedError(msg) from e
|
||||
|
||||
|
||||
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
|
||||
|
|
|
|||
Loading…
Reference in a new issue