simplify logic

This commit is contained in:
Kevin Turcios 2025-10-17 22:19:11 -07:00
parent 72a356de8c
commit ac986b1b86

View file

@ -177,17 +177,6 @@ def instrument_tests(
return instrumented_behavior, instrumented_perf
async def invoke_openai_testgen_model(messages: list[dict[str, str]], model: LLM, temperature: float) -> ChatCompletion:
try:
return await open_ai_client.with_options(max_retries=2).chat.completions.create(
model=model.name, messages=messages, temperature=temperature
)
except Exception as e:
logging.exception("OpenAI client error in execute step")
sentry_sdk.capture_exception(e)
raise TestGenerationFailedError(e) from e
def parse_and_validate_llm_output(
response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str
) -> str:
@ -212,6 +201,42 @@ def parse_and_validate_llm_output(
raise
@stamina.retry(on=(SyntaxError, ValueError), attempts=2)
async def generate_and_validate_test_code(
messages: list[dict[str, str]],
model: LLM,
temperature: float,
ctx: BaseTestGenContext,
python_version: tuple[int, int, int],
error_context: str,
execute_model: LLM,
cost_tracker: list[float],
user_id: str,
posthog_event_suffix: str,
) -> str:
response = await open_ai_client.with_options(max_retries=2).chat.completions.create(
model=model.name, messages=messages, temperature=temperature
)
cost = calculate_llm_cost(response, execute_model) or 0.0
cost_tracker.append(cost)
debug_log_sensitive_data(f"OpenAIClient {error_context}execute response:\n{response.model_dump_json(indent=2)}")
if response.usage:
ph(
user_id,
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
properties={"model": execute_model.name, "usage": response.usage.model_dump_json()},
)
return parse_and_validate_llm_output(
response_content=response.choices[0].message.content,
ctx=ctx,
python_version=python_version,
error_context=error_context,
)
async def generate_regression_tests_from_function(
ctx: BaseTestGenContext,
user_id: str,
@ -227,39 +252,28 @@ async def generate_regression_tests_from_function(
ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async
)
total_llm_cost = 0.0
max_tries = 2
for attempt in range(max_tries):
response = await invoke_openai_testgen_model(
messages=execute_messages, model=execute_model, temperature=temperature
cost_tracker = []
try:
validated_code = await generate_and_validate_test_code(
messages=execute_messages,
model=execute_model,
temperature=temperature,
ctx=ctx,
python_version=python_version,
error_context=error_context,
execute_model=execute_model,
cost_tracker=cost_tracker,
user_id=user_id,
posthog_event_suffix=posthog_event_suffix,
)
total_llm_cost += calculate_llm_cost(response, execute_model) or 0.0
debug_log_sensitive_data(f"OpenAIClient {error_context}execute response:\n{response.model_dump_json(indent=2)}")
if response.usage:
ph(
user_id,
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
properties={"model": execute_model.name, "usage": response.usage.model_dump_json()},
)
try:
validated_code = parse_and_validate_llm_output(
response_content=response.choices[0].message.content,
ctx=ctx,
python_version=python_version,
error_context=error_context,
)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return validated_code # noqa: TRY300
except (SyntaxError, ValueError) as e:
logging.warning("Attempt %d/%d failed: %s", attempt + 1, max_tries, e)
if attempt + 1 >= max_tries:
msg = f"Failed to generate valid {error_context}test code after {max_tries} tries."
raise TestGenerationFailedError(msg) from e
raise TestGenerationFailedError("Exited generation loop unexpectedly.")
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
return validated_code
except (SyntaxError, ValueError) as e:
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries."
raise TestGenerationFailedError(msg) from e
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema: