2023-12-20 03:26:34 +00:00
# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
2024-04-10 05:46:37 +00:00
from __future__ import annotations
2023-12-20 03:26:34 +00:00
2025-10-18 07:55:48 +00:00
import ast
2024-10-28 05:59:24 +00:00
import asyncio
2023-12-20 03:26:34 +00:00
import logging
2025-10-18 05:11:20 +00:00
import re
2024-03-27 02:46:49 +00:00
from pathlib import Path
2025-12-30 07:44:28 +00:00
from typing import TYPE_CHECKING , TypedDict
2024-03-06 00:33:50 +00:00
2025-10-14 23:31:06 +00:00
import sentry_sdk
2025-10-18 05:11:20 +00:00
import stamina
2025-12-22 07:56:59 +00:00
from ninja import NinjaAPI
from ninja . errors import HttpError
from openai import OpenAIError
2024-10-28 05:59:24 +00:00
from aiservice . analytics . posthog import ph
2025-10-23 04:20:20 +00:00
from aiservice . common_utils import parse_python_version , safe_isort , should_hack_for_demo , validate_trace_id
2026-01-01 22:42:52 +00:00
from aiservice . env_specific import debug_log_sensitive_data
2026-01-07 04:54:20 +00:00
from aiservice . llm import EXECUTE_MODEL , HAIKU_MODEL , OPENAI_MODEL , calculate_llm_cost , call_llm
2025-12-30 07:44:28 +00:00
from aiservice . models . functions_to_optimize import FunctionToOptimize
2025-12-30 19:50:29 +00:00
from authapp . auth import AuthenticatedRequest
2025-09-01 19:04:34 +00:00
from log_features . log_event import update_optimization_cost
2025-10-31 10:09:03 +00:00
from log_features . log_features import log_features
2026-01-09 04:29:09 +00:00
from optimizer . context_utils . context_helpers import split_markdown_code
2024-11-07 05:24:01 +00:00
from testgen . instrumentation . edit_generated_test import parse_module_to_cst , replace_definition_with_import
2024-08-23 04:18:05 +00:00
from testgen . instrumentation . instrument_new_tests import instrument_test_source
2025-10-15 08:45:47 +00:00
from testgen . models import (
TestGenerationFailedError ,
TestGenErrorResponseSchema ,
TestGenResponseSchema ,
2025-12-30 06:45:23 +00:00
TestGenSchema ,
2025-10-15 08:45:47 +00:00
TestingMode ,
)
2026-01-07 19:12:04 +00:00
from testgen . postprocessing . add_missing_imports import add_missing_imports_from_source
2025-10-18 07:55:48 +00:00
from testgen . postprocessing . code_validator import has_test_functions , validate_testgen_code
2024-11-07 05:24:01 +00:00
from testgen . postprocessing . postprocess_pipeline import postprocessing_testgen_pipeline
2025-09-25 00:22:55 +00:00
from testgen . testgen_context import BaseTestGenContext , TestGenContextData
2026-01-15 06:15:27 +00:00
from testgen . testgen_javascript import testgen_javascript
2023-12-20 03:26:34 +00:00
2025-11-17 20:35:09 +00:00
if TYPE_CHECKING :
2025-12-30 07:05:02 +00:00
from openai . types . chat import ChatCompletionMessageParam
2025-12-23 04:51:05 +00:00
from aiservice . llm import LLM
2025-11-17 20:35:09 +00:00
2025-12-30 07:44:28 +00:00
class InstrumentTestSourceArgs ( TypedDict ) :
test_source : str
function_to_optimize : FunctionToOptimize
helper_function_names : list [ str ]
module_path : str
test_module_path : str
test_framework : str
test_timeout : int
python_version : tuple [ int , int , int ]
2025-05-09 01:51:03 +00:00
testgen_api = NinjaAPI ( urls_namespace = " testgen " )
2024-01-26 00:40:56 +00:00
2024-05-14 00:17:47 +00:00
# Get the directory of the current file
current_dir = Path ( __file__ ) . parent
EXPLAIN_SYSTEM_PROMPT = ( current_dir / " explain_system_prompt.md " ) . read_text ( )
EXPLAIN_USER_PROMPT = ( current_dir / " explain_user_prompt.md " ) . read_text ( )
2023-12-20 03:26:34 +00:00
2024-05-14 00:17:47 +00:00
EXECUTE_SYSTEM_PROMPT = ( current_dir / " execute_system_prompt.md " ) . read_text ( )
EXECUTE_USER_PROMPT = ( current_dir / " execute_user_prompt.md " ) . read_text ( )
2024-03-29 01:40:33 +00:00
2025-08-28 22:51:33 +00:00
EXECUTE_ASYNC_SYSTEM_PROMPT = ( current_dir / " execute_async_system_prompt.md " ) . read_text ( )
EXECUTE_ASYNC_USER_PROMPT = ( current_dir / " execute_async_user_prompt.md " ) . read_text ( )
2025-10-18 05:11:20 +00:00
pattern = re . compile ( r " ^```python \ s* \ n(.*?) \ n``` " , re . MULTILINE | re . DOTALL )
2025-08-28 22:51:33 +00:00
2024-03-27 02:46:49 +00:00
2025-10-14 22:20:58 +00:00
def build_prompt (
ctx : BaseTestGenContext , function_name : str , unit_test_package : str , * , is_async : bool
) - > tuple [ list [ dict [ str , str ] ] , str , str ] :
2025-08-28 22:51:33 +00:00
if is_async :
execute_system_prompt = EXECUTE_ASYNC_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_ASYNC_USER_PROMPT
plan_content = f """ A good unit test suite for an ASYNC function should aim to:
- Test the async function ' s behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen , including async - specific edge cases
- Take advantage of the features of ` { unit_test_package } ` to make async tests easy to write and maintain
- Be easy to read and understand , with clean async code and descriptive names
- Be deterministic , so that the async tests always pass or fail in the same way
- Have tests sorted by difficulty , from easiest to hardest
- Should try not to mock or stub any dependencies , so that the async testing environment is as close to production
- Include concurrent execution test cases to assess the function ' s async performance and behavior
2025-09-12 21:34:02 +00:00
- Include throughput test cases to measure the function ' s performance under load and high-volume scenarios
2025-08-28 22:51:33 +00:00
- Test proper async / await patterns and coroutine handling
To help unit test the ASYNC function above , list diverse scenarios that the async function should be able to handle ( and under each scenario , include a few examples as sub - bullets ) . """
posthog_event_suffix = " async- "
error_context = " async "
else :
execute_system_prompt = EXECUTE_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_USER_PROMPT
plan_content = f """ A good unit test suite should aim to:
2023-12-20 03:26:34 +00:00
- Test the function ' s behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of ` { unit_test_package } ` to make the tests easy to write and maintain
- Be easy to read and understand , with clean code and descriptive names
- Be deterministic , so that the tests always pass or fail in the same way
2024-04-13 02:18:30 +00:00
- Have tests sorted by difficulty , from easiest to hardest
- Should try not to mock or stub any dependencies by using ` { unit_test_package } ` . mock or any other similar mocking or stubbing module , so that the testing environment is as close to the production environment as possible
2025-08-28 22:51:33 +00:00
- Include Large Scale Test Cases to assess the function ' s performance and scalability with large data samples.
2023-12-20 03:26:34 +00:00
2025-08-28 22:51:33 +00:00
To help unit test the function above , list diverse scenarios that the function should be able to handle ( and under each scenario , include a few examples as sub - bullets ) . """
posthog_event_suffix = " "
error_context = " "
plan_user_message = { " role " : " user " , " content " : plan_content }
2023-12-20 03:26:34 +00:00
package_comment = " "
# if unit_test_package == "pytest":
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
2025-11-17 20:35:09 +00:00
execute_system_message = {
" role " : " system " ,
" content " : execute_system_prompt . format ( function_name = ctx . data . qualified_name ) ,
}
2023-12-20 03:26:34 +00:00
2025-08-07 20:32:55 +00:00
execute_messages = [ execute_system_message , plan_user_message ]
2025-03-13 00:40:37 +00:00
2025-09-25 00:22:55 +00:00
all_notes = ctx . generate_notes_markdown ( )
2025-04-21 01:38:16 +00:00
note_message = { " role " : " user " , " content " : all_notes }
2025-03-13 00:40:37 +00:00
execute_messages + = [ note_message ]
2024-03-29 01:40:33 +00:00
execute_user_message = {
" role " : " user " ,
2025-08-28 22:51:33 +00:00
" content " : execute_user_prompt . format (
2024-03-29 01:40:33 +00:00
unit_test_package = unit_test_package ,
function_name = function_name ,
2025-09-25 00:48:35 +00:00
function_code = ctx . data . source_code_being_tested ,
2024-03-29 01:40:33 +00:00
package_comment = package_comment ,
) ,
}
2025-03-13 00:40:37 +00:00
2023-12-20 03:26:34 +00:00
execute_messages + = [ execute_user_message ]
2025-10-14 22:20:58 +00:00
return execute_messages , posthog_event_suffix , error_context
2025-10-15 08:40:39 +00:00
def instrument_tests (
generated_test_source : str , data : TestGenSchema , python_version : tuple [ int , int , int ]
2025-10-17 08:31:14 +00:00
) - > tuple [ str | None , str | None ] :
2025-12-30 07:44:28 +00:00
common_args : InstrumentTestSourceArgs = {
2025-10-15 08:40:39 +00:00
" test_source " : generated_test_source ,
" function_to_optimize " : data . function_to_optimize ,
2025-12-30 07:44:28 +00:00
" helper_function_names " : data . helper_function_names or [ ] ,
2025-10-15 08:40:39 +00:00
" module_path " : data . module_path ,
" test_module_path " : data . test_module_path ,
" test_framework " : data . test_framework ,
" test_timeout " : data . test_timeout ,
" python_version " : python_version ,
}
2025-10-31 06:31:49 +00:00
# instrument_test_source() already applies isort via format_and_float_to_top()
# No need to apply isort again here (was causing double formatting overhead)
2025-10-17 08:31:14 +00:00
behavior_result = instrument_test_source ( * * common_args , mode = TestingMode . BEHAVIOR )
perf_result = instrument_test_source ( * * common_args , mode = TestingMode . PERFORMANCE )
2025-10-31 06:31:49 +00:00
return behavior_result , perf_result
2025-10-15 08:40:39 +00:00
def parse_and_validate_llm_output (
response_content : str , ctx : BaseTestGenContext , python_version : tuple [ int , int , int ] , error_context : str
) - > str :
try :
if " ```python " not in response_content :
sentry_sdk . capture_message ( " LLM response did not contain a Python code block: \n " + response_content )
raise ValueError ( " LLM response did not contain a Python code block. " )
2025-10-18 05:11:20 +00:00
pattern_res = pattern . search ( response_content )
if not pattern_res :
raise ValueError ( " No Python code block found in the LLM response. " )
2025-10-15 08:40:39 +00:00
2025-10-18 05:11:20 +00:00
code = pattern_res . group ( 1 )
2025-10-18 05:40:32 +00:00
cleaned_code = validate_testgen_code ( code , python_version [ : 2 ] , max_lines_to_remove = 120 )
2025-10-15 08:40:39 +00:00
2025-10-18 05:40:32 +00:00
if ctx . did_generate_ellipsis ( cleaned_code , python_version ) :
2025-10-15 08:40:39 +00:00
msg = f " Ellipsis in generated { error_context } test code, regenerating... "
raise SyntaxError ( msg )
2025-10-18 05:40:32 +00:00
return cleaned_code # noqa: TRY300
2025-10-15 08:40:39 +00:00
except Exception as e :
sentry_sdk . capture_exception ( e )
raise
2025-10-15 09:24:34 +00:00
2025-10-20 22:07:49 +00:00
@stamina.retry ( on = ( SyntaxError , ValueError , OpenAIError ) , attempts = 2 )
2025-10-18 05:19:11 +00:00
async def generate_and_validate_test_code (
2025-12-30 07:05:02 +00:00
messages : list [ ChatCompletionMessageParam ] ,
2025-10-18 05:19:11 +00:00
model : LLM ,
ctx : BaseTestGenContext ,
python_version : tuple [ int , int , int ] ,
error_context : str ,
execute_model : LLM ,
cost_tracker : list [ float ] ,
user_id : str ,
posthog_event_suffix : str ,
2025-12-19 22:24:09 +00:00
trace_id : str = " " ,
2025-12-23 19:31:59 +00:00
call_sequence : int | None = None ,
2025-10-18 05:19:11 +00:00
) - > str :
2025-12-26 21:06:29 +00:00
obs_context : dict | None = { " call_sequence " : call_sequence } if call_sequence is not None else None
2025-12-26 19:15:13 +00:00
response = await call_llm (
2025-12-26 20:02:30 +00:00
llm = model ,
2025-12-19 22:24:09 +00:00
messages = messages ,
2025-12-26 19:41:21 +00:00
call_type = " test_generation " ,
2025-12-26 19:15:13 +00:00
trace_id = trace_id ,
2025-12-19 22:24:09 +00:00
user_id = user_id ,
python_version = " . " . join ( str ( v ) for v in python_version ) ,
2025-12-23 19:31:59 +00:00
context = obs_context ,
2025-10-31 10:07:19 +00:00
)
2025-12-19 22:24:09 +00:00
2025-12-23 19:31:59 +00:00
cost = calculate_llm_cost ( response . raw_response , execute_model )
2025-10-31 10:07:19 +00:00
cost_tracker . append ( cost )
2025-10-29 02:21:21 +00:00
2025-12-23 04:28:29 +00:00
debug_log_sensitive_data (
f " OpenAIClient { error_context } execute response: \n { response . raw_response . model_dump_json ( indent = 2 ) } "
)
2025-10-29 02:21:21 +00:00
2025-12-23 04:28:29 +00:00
if response . raw_response . usage :
2025-10-31 10:07:19 +00:00
ph (
user_id ,
f " aiservice-testgen- { posthog_event_suffix } execute-openai-usage " ,
2025-12-23 04:28:29 +00:00
properties = { " model " : execute_model . name , " usage " : response . raw_response . usage . model_dump_json ( ) } ,
2025-10-18 05:19:11 +00:00
)
2025-12-19 22:24:09 +00:00
# Parse and validate
validated_code = parse_and_validate_llm_output (
2025-12-23 04:28:29 +00:00
response_content = response . content , ctx = ctx , python_version = python_version , error_context = error_context
2025-10-31 10:07:19 +00:00
)
2025-12-19 22:24:09 +00:00
return validated_code
2025-10-18 05:19:11 +00:00
2025-10-18 06:49:09 +00:00
@stamina.retry ( on = TestGenerationFailedError , attempts = 2 )
2025-10-14 22:20:58 +00:00
async def generate_regression_tests_from_function (
ctx : BaseTestGenContext ,
user_id : str ,
2025-10-15 08:40:39 +00:00
function_name : str ,
python_version : tuple [ int , int , int ] ,
2025-10-18 07:35:20 +00:00
data : TestGenSchema ,
2025-10-15 08:40:39 +00:00
unit_test_package : str = " pytest " ,
execute_model : LLM = EXECUTE_MODEL ,
2025-10-15 08:45:47 +00:00
is_async : bool = False , # noqa: FBT001, FBT002
2025-10-14 22:20:58 +00:00
trace_id : str = " " ,
2025-12-23 19:31:59 +00:00
call_sequence : int | None = None ,
2025-10-18 07:35:20 +00:00
) - > tuple [ str , str | None , str | None ] :
2025-10-14 22:20:58 +00:00
execute_messages , posthog_event_suffix , error_context = build_prompt (
ctx = ctx , function_name = function_name , unit_test_package = unit_test_package , is_async = is_async
)
2025-10-18 05:19:11 +00:00
cost_tracker = [ ]
try :
validated_code = await generate_and_validate_test_code (
messages = execute_messages ,
model = execute_model ,
ctx = ctx ,
python_version = python_version ,
error_context = error_context ,
execute_model = execute_model ,
cost_tracker = cost_tracker ,
user_id = user_id ,
posthog_event_suffix = posthog_event_suffix ,
2025-12-19 22:24:09 +00:00
trace_id = trace_id ,
2025-12-23 19:31:59 +00:00
call_sequence = call_sequence ,
2025-08-28 22:51:33 +00:00
)
2025-10-18 05:19:11 +00:00
total_llm_cost = sum ( cost_tracker )
2025-12-23 19:31:59 +00:00
await update_optimization_cost ( trace_id = trace_id , cost = total_llm_cost , user_id = user_id )
2025-10-18 07:35:20 +00:00
processed_cst = postprocessing_testgen_pipeline (
2025-12-30 19:55:42 +00:00
parse_module_to_cst ( validated_code ) ,
data . helper_function_names or [ ] ,
data . function_to_optimize ,
data . module_path ,
2025-10-18 07:35:20 +00:00
)
2026-01-07 19:12:04 +00:00
# Add missing imports for symbols defined in source module but not imported in test.
# This handles cases where the LLM redefines some classes locally but forgets others.
source_code_blocks = split_markdown_code ( data . source_code_being_tested )
# Combine all source code blocks to find all available symbols
combined_source = " \n " . join ( source_code_blocks . values ( ) )
processed_cst = add_missing_imports_from_source ( processed_cst , combined_source , data . module_path )
2025-10-18 07:35:20 +00:00
generated_test_source = replace_definition_with_import (
processed_cst . code , data . function_to_optimize , data . module_path
)
instrumented_behavior_tests , instrumented_perf_tests = instrument_tests (
generated_test_source , data , python_version
)
if instrumented_behavior_tests is None or instrumented_perf_tests is None :
2026-01-13 01:34:36 +00:00
msg = (
f " There was an error detected in the function to optimize, is it valid Python code? trace_id= { trace_id } "
)
logging . error ( msg )
2025-10-18 07:35:20 +00:00
raise TestGenerationFailedError ( msg )
2025-10-23 04:08:06 +00:00
sorted_generated_tests = safe_isort ( generated_test_source )
2025-10-18 07:35:20 +00:00
try :
parse_module_to_cst ( sorted_generated_tests )
generated_test_source = sorted_generated_tests
except Exception : # noqa: BLE001
sentry_sdk . capture_message ( " isort caused a syntax error in testgen; returning un-sorted code. " )
2025-10-18 07:55:48 +00:00
tree = ast . parse ( generated_test_source , feature_version = python_version [ : 2 ] )
2025-10-20 19:54:49 +00:00
if not has_test_functions ( tree ) : # sanity check, after all the processing somehow no test functions
2026-01-13 01:34:36 +00:00
msg = f " No test functions were found in the generated test code. trace_id= { trace_id } "
logging . error ( msg )
2025-10-18 07:55:48 +00:00
raise TestGenerationFailedError ( msg )
2025-10-18 07:35:20 +00:00
return generated_test_source , instrumented_behavior_tests , instrumented_perf_tests # noqa: TRY300
2025-10-18 05:19:11 +00:00
except ( SyntaxError , ValueError ) as e :
total_llm_cost = sum ( cost_tracker )
2025-12-23 19:31:59 +00:00
await update_optimization_cost ( trace_id = trace_id , cost = total_llm_cost , user_id = user_id )
2026-01-13 01:34:36 +00:00
msg = f " Failed to generate valid { error_context } test code after { len ( cost_tracker ) } tries. trace_id= { trace_id } "
logging . error ( msg )
2025-10-18 05:19:11 +00:00
raise TestGenerationFailedError ( msg ) from e
2023-12-20 03:26:34 +00:00
2025-10-15 08:28:00 +00:00
async def hack_for_demo ( data : TestGenSchema , python_version : tuple [ int , int , int ] ) - > TestGenResponseSchema :
2025-09-23 05:34:07 +00:00
if data . test_index == 0 :
generated_test_source = ' # imports \n # function to test \n from __future__ import annotations \n \n import pytest # used for our unit tests \n from codeflash.result.common_tags import find_common_tags \n \n # unit tests \n \n def test_single_article(): \n # Single article should return its tags \n articles = [ { " tags " : [ " python " , " coding " , " tutorial " ]}] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_multiple_articles_with_common_tags(): \n # Multiple articles with common tags should return the common tags \n articles = [ \n { " tags " : [ " python " , " coding " ]}, \n { " tags " : [ " python " , " data " ]}, \n { " tags " : [ " python " , " machine learning " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_empty_list_of_articles(): \n # Empty list of articles should return an empty set \n articles = [] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_articles_with_no_common_tags(): \n # Articles with no common tags should return an empty set \n articles = [ \n { " tags " : [ " python " ]}, \n { " tags " : [ " java " ]}, \n { " tags " : [ " c++ " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_articles_with_empty_tag_lists(): \n # Articles with some empty tag lists should return an empty set \n articles = [ \n { " tags " : []}, \n { " tags " : [ " python " ]}, \n { " tags " : [ " python " , " java " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_all_articles_with_empty_tag_lists(): \n # All articles with empty tag lists should return an empty set \n articles = [ \n { " tags " : []}, \n { " tags " : []}, \n { " tags " : []} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_tags_with_special_characters(): \n # Tags with special characters should be handled correctly \n articles = [ \n { " tags " : [ " python! " , " coding " ]}, \n { " tags " : [ " python! " , " data " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_case_sensitivity(): \n # Tags with different cases should not be considered the same \n articles = [ \n { " tags " : [ " Python " , " coding " ]}, \n { " tags " : [ " python " , " data " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_large_number_of_articles(): \n # Large number of articles with a common tag should return that tag \n articles = [ { " tags " : [ " common_tag " , f " tag {i} " ]} for i in range(1000)] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_large_number_of_tags(): \n # Large number of tags with some common tags should return the common tags \n articles = [ \n { " tags " : [f " tag {i} " for i in range(1000)]}, \n { " tags " : [f " tag {i} " for i in range(500, 1500)]} \n ] \n expected = { f " tag {i} " for i in range(500, 1000)} \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_mixed_length_of_tag_lists(): \n # Articles with mixed length of tag lists should return the common tags \n articles = [ \n { " tags " : [ " python " , " coding " ]}, \n { " tags " : [ " python " ]}, \n { " tags " : [ " python " , " coding " , " tutorial " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_tags_with_different_data_types(): \n # Tags with different data types should only consider strings \n articles = [ \n { " tags " : [ " python " ,
else :
generated_test_source = ' # imports \n # function to test \n from __future__ import annotations \n \n import pytest # used for our unit tests \n from codeflash.result.common_tags import find_common_tags \n \n # unit tests \n \n def test_empty_input_list(): \n # Test with an empty list \n codeflash_output = find_common_tags([]) \n # Outputs were verified to be equal to the original implementation \n \n def test_single_article(): \n # Test with a single article with tags \n codeflash_output = find_common_tags([ { " tags " : [ " python " , " coding " , " development " ]}]) \n # Test with a single article with no tags \n codeflash_output = find_common_tags([ { " tags " : []}]) \n # Outputs were verified to be equal to the original implementation \n \n def test_multiple_articles_some_common_tags(): \n # Test with multiple articles having some common tags \n articles = [ \n { " tags " : [ " python " , " coding " , " development " ]}, \n { " tags " : [ " python " , " development " , " tutorial " ]}, \n { " tags " : [ " python " , " development " , " guide " ]} \n ] \n codeflash_output = find_common_tags(articles) \n \n articles = [ \n { " tags " : [ " tech " , " news " ]}, \n { " tags " : [ " tech " , " gadgets " ]}, \n { " tags " : [ " tech " , " reviews " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_multiple_articles_no_common_tags(): \n # Test with multiple articles having no common tags \n articles = [ \n { " tags " : [ " python " , " coding " ]}, \n { " tags " : [ " development " , " tutorial " ]}, \n { " tags " : [ " guide " , " learning " ]} \n ] \n codeflash_output = find_common_tags(articles) \n \n articles = [ \n { " tags " : [ " apple " , " banana " ]}, \n { " tags " : [ " orange " , " grape " ]}, \n { " tags " : [ " melon " , " kiwi " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_articles_with_duplicate_tags(): \n # Test with articles having duplicate tags \n articles = [ \n { " tags " : [ " python " , " python " , " coding " ]}, \n { " tags " : [ " python " , " development " , " python " ]}, \n { " tags " : [ " python " , " guide " , " python " ]} \n ] \n codeflash_output = find_common_tags(articles) \n \n articles = [ \n { " tags " : [ " tech " , " tech " , " news " ]}, \n { " tags " : [ " tech " , " tech " , " gadgets " ]}, \n { " tags " : [ " tech " , " tech " , " reviews " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_articles_with_mixed_case_tags(): \n # Test with articles having mixed case tags \n articles = [ \n { " tags " : [ " Python " , " Coding " ]}, \n { " tags " : [ " python " , " Development " ]}, \n { " tags " : [ " PYTHON " , " Guide " ]} \n ] \n codeflash_output = find_common_tags(articles) # Assuming case sensitivity \n \n articles = [ \n { " tags " : [ " Tech " , " News " ]}, \n { " tags " : [ " tech " , " Gadgets " ]}, \n { " tags " : [ " TECH " , " Reviews " ]} \n ] \n codeflash_output = find_common_tags(articles) # Assuming case sensitivity \n # Outputs were verified to be equal to the original implementation \n \n def test_articles_with_non_string_tags(): \n # Test with articles having non-string tags \n articles = [ \n { " tags " : [ " python " , 123, " coding " ]}, \n { " tags " : [ " python " , " development " , 123]}, \n { " tags " : [ " python " , " guide " , 123]} \n ] \n codeflash_output = find_common_tags(articles) \n \n articles = [ \n { " tags " : [None, " news " ]}, \n { " tags " : [ " tech " , None]}, \n { " tags " : [None, " reviews " ]} \n ] \n codeflash_output = find_common_tags(articles) \n # Outputs were verified to be equal to the original implementation \n \n def test_large_scale_test_cases(): \n # Test with large scale input where all tags should be common \n articles = [ \n { " tags " : [ " tag " + str(i) for i in range(1000)]} for _ in range(100) \n ] \n expected_output = { " tag " + str(i) for i in range(1000)} \n codeflash_output = find_common_tags(articles) \n \n # Test with large scale input wh
2025-10-15 08:28:00 +00:00
instrumented_behavior_tests , instrumented_perf_tests = instrument_tests ( generated_test_source , data , python_version )
2025-11-18 18:17:09 +00:00
await asyncio . sleep ( 5 )
2025-09-23 05:34:07 +00:00
return TestGenResponseSchema (
generated_tests = generated_test_source ,
2025-12-30 07:44:28 +00:00
instrumented_behavior_tests = instrumented_behavior_tests or " " ,
instrumented_perf_tests = instrumented_perf_tests or " " ,
2025-09-23 05:34:07 +00:00
)
2025-11-18 16:30:54 +00:00
async def hack_for_demo_gsq ( data : TestGenSchema , python_version : tuple [ int , int , int ] ) - > TestGenResponseSchema :
if data . test_index == 0 :
generated_test_source = " from datetime import datetime, timedelta \n # function to test \n from functools import reduce \n \n import numpy as np \n import pandas as pd \n # imports \n import pytest # used for our unit tests \n from gs_quant.timeseries.algebra import weighted_sum \n \n \n class MqTypeError(TypeError): pass \n class MqValueError(ValueError): pass \n from gs_quant.timeseries.algebra import weighted_sum \n \n # unit tests \n \n # --- Basic Test Cases --- \n \n def test_weighted_sum_simple_two_series(): \n # Simple case: two series, same index, weights sum to 1 \n idx = pd.date_range( ' 2024-01-01 ' , periods=3) \n s1 = pd.Series([1, 2, 3], index=idx) \n s2 = pd.Series([4, 5, 6], index=idx) \n weights = [0.7, 0.3] \n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output \n expected = (s1 * 0.7 + s2 * 0.3) / (0.7 + 0.3) \n \n def test_weighted_sum_three_series_equal_weights(): \n # Three series, equal weights \n idx = pd.date_range( ' 2024-01-01 ' , periods=3) \n s1 = pd.Series([1, 2, 3], index=idx) \n s2 = pd.Series([4, 5, 6], index=idx) \n s3 = pd.Series([7, 8, 9], index=idx) \n weights = [1, 1, 1] \n codeflash_output = weighted_sum([s1, s2, s3], weights); result = codeflash_output \n expected = (s1 + s2 + s3) / 3 \n \n def test_weighted_sum_weights_not_normalized(): \n # Weights do not sum to 1 \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([10, 20], index=idx) \n s2 = pd.Series([30, 40], index=idx) \n weights = [2, 3] \n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output \n expected = (s1 * 2 + s2 * 3) / (2 + 3) \n \n # --- Edge Test Cases --- \n \n def test_weighted_sum_empty_series_list(): \n # Empty input series list \n with pytest.raises(MqValueError): \n weighted_sum([], [1, 2]) \n \n def test_weighted_sum_empty_weights_list(): \n # Empty weights list \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx) \n with pytest.raises(MqValueError): \n weighted_sum([s1], []) \n \n def test_weighted_sum_weights_length_mismatch(): \n # Weights and series length mismatch \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx) \n s2 = pd.Series([3, 4], index=idx) \n with pytest.raises(MqValueError): \n weighted_sum([s1, s2], [1]) \n \n def test_weighted_sum_non_series_input(): \n # Non-pandas Series input \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx) \n with pytest.raises(MqTypeError): \n weighted_sum([s1, [3, 4]], [0.5, 0.5]) \n \n def test_weighted_sum_non_numeric_weights(): \n # Non-numeric weights \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx) \n s2 = pd.Series([3, 4], index=idx) \n with pytest.raises(MqTypeError): \n weighted_sum([s1, s2], [0.5, \" a \" ]) \n \n def test_weighted_sum_disjoint_calendars(): \n # Series with disjoint calendars (no overlap) \n idx1 = pd.date_range( ' 2024-01-01 ' , periods=2) \n idx2 = pd.date_range( ' 2024-02-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx1) \n s2 = pd.Series([3, 4], index=idx2) \n weights = [0.5, 0.5] \n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output \n \n def test_weighted_sum_partial_overlap_calendars(): \n # Series with partial overlap in calendar \n idx1 = pd.date_range( ' 2024-01-01 ' , periods=3) \n idx2 = pd.date_range( ' 2024-01-02 ' , periods=3) \n s1 = pd.Series([1, 2, 3], index=idx1) \n s2 = pd.Series([4, 5, 6], index=idx2) \n weights = [1, 1] \n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output \n # Only dates present in both should be included \n expected_idx = idx1.intersection(idx2) \n expected = (s1.reindex(expected_idx) + s2.reindex(expected_idx)) / 2 \n \n def test_weighted_sum_zero_weights(): \n # Zero weights (should return NaN for all dates) \n idx = pd.date_range( ' 2024-01-01 ' , periods=2) \n s1 = pd.Series([1, 2], index=idx) \n s2 = pd.Series([3, 4], index=idx) \n weights = [0, 0] \n codeflash_out
else :
generated_test_source = " from datetime import datetime, timedelta \n \n import pandas as pd \n # imports \n import pytest # used for our unit tests \n from gs_quant.timeseries.algebra import weighted_sum \n \n # ---------- Unit Tests ---------- \n \n # Helper to create pd.Series with consecutive dates \n def make_series(start, count, value_func=lambda i: i): \n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(count)] \n values = [value_func(i) for i in range(count)] \n return pd.Series(values, index=dates) \n \n # 1. Basic Test Cases \n \n def test_weighted_sum_simple_two_series_equal_weights(): \n # Two series, equal weights \n s1 = make_series(0, 5, lambda i: 2) \n s2 = make_series(0, 5, lambda i: 4) \n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output \n expected = pd.Series([3.0]*5, index=s1.index) \n \n def test_weighted_sum_simple_two_series_unequal_weights(): \n # Two series, weights 2:1 \n s1 = make_series(0, 5, lambda i: 2) \n s2 = make_series(0, 5, lambda i: 8) \n codeflash_output = weighted_sum([s1, s2], [2, 1]); result = codeflash_output \n expected = pd.Series([(2*2+1*8)/3]*5, index=s1.index) \n \n def test_weighted_sum_three_series_varied_weights(): \n # Three series, varied weights \n s1 = make_series(0, 3, lambda i: 1) \n s2 = make_series(0, 3, lambda i: 3) \n s3 = make_series(0, 3, lambda i: 5) \n codeflash_output = weighted_sum([s1, s2, s3], [0.2, 0.3, 0.5]); result = codeflash_output \n expected = pd.Series([1*0.2+3*0.3+5*0.5], index=s1.index) \n expected = pd.Series([1*0.2+3*0.3+5*0.5]*3, index=s1.index) \n \n def test_weighted_sum_one_series(): \n # Single series, should return the series itself \n s1 = make_series(0, 4, lambda i: i+5) \n codeflash_output = weighted_sum([s1], [1]); result = codeflash_output \n \n # 2. Edge Test Cases \n \n def test_weighted_sum_empty_series_list(): \n # Empty series list should raise ValueError \n with pytest.raises(ValueError): \n weighted_sum([], []) \n \n def test_weighted_sum_weights_length_mismatch(): \n # Series and weights length mismatch \n s1 = make_series(0, 3) \n s2 = make_series(0, 3) \n with pytest.raises(ValueError): \n weighted_sum([s1, s2], [1]) \n \n def test_weighted_sum_non_series_input(): \n # Non-Series in series list \n s1 = make_series(0, 3) \n with pytest.raises(TypeError): \n weighted_sum([s1, [1, 2, 3]], [1, 1]) \n \n def test_weighted_sum_non_numeric_weights(): \n # Non-numeric in weights \n s1 = make_series(0, 3) \n s2 = make_series(0, 3) \n with pytest.raises(TypeError): \n weighted_sum([s1, s2], [1, ' a ' ]) \n \n def test_weighted_sum_zero_weights(): \n # All weights zero should return NaN series \n s1 = make_series(0, 3, lambda i: 1) \n s2 = make_series(0, 3, lambda i: 2) \n codeflash_output = weighted_sum([s1, s2], [0, 0]); result = codeflash_output \n \n def test_weighted_sum_negative_weights(): \n # Negative weights \n s1 = make_series(0, 3, lambda i: 2) \n s2 = make_series(0, 3, lambda i: 4) \n codeflash_output = weighted_sum([s1, s2], [-1, 2]); result = codeflash_output \n expected = pd.Series([(2*-1+4*2)/(-1+2)]*3, index=s1.index) \n \n def test_weighted_sum_partial_overlap_indices(): \n # Series with partially overlapping indices \n s1 = make_series(0, 5, lambda i: i) \n s2 = make_series(2, 5, lambda i: i+10) \n # Overlap is on days 2,3,4 \n overlap_dates = s1.index.intersection(s2.index) \n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output \n expected = pd.Series([(s1[d]+s2[d])/2 for d in overlap_dates], index=overlap_dates) \n \n def test_weighted_sum_series_with_nans(): \n # Series with NaNs, intersection should drop NaN dates \n s1 = make_series(0, 4, lambda i: float( ' nan ' ) if i==2 else i) \n s2 = make_series(0, 4, lambda i: i) \n # Only dates where both have non-NaN should be used \n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output \n \n def test_weighted_sum_series_with_different_freq(): \n # Series with different frequencies \n s1 = make_se
instrumented_behavior_tests , instrumented_perf_tests = instrument_tests ( generated_test_source , data , python_version )
2025-11-18 18:17:09 +00:00
await asyncio . sleep ( 2 )
2025-11-18 16:30:54 +00:00
return TestGenResponseSchema (
generated_tests = generated_test_source ,
2025-12-30 07:44:28 +00:00
instrumented_behavior_tests = instrumented_behavior_tests or " " ,
instrumented_perf_tests = instrumented_perf_tests or " " ,
2025-11-18 16:30:54 +00:00
)
2025-10-18 07:17:16 +00:00
def validate_request_data ( data : TestGenSchema ) - > tuple [ tuple [ int , int , int ] , BaseTestGenContext ] :
2025-10-15 08:28:00 +00:00
if data . test_framework not in [ " unittest " , " pytest " ] :
raise HttpError ( 400 , " Invalid test framework. We only support unittest and pytest. " )
if not data . function_to_optimize :
raise HttpError ( 400 , " Invalid function to optimize. It is empty. " )
if not validate_trace_id ( data . trace_id ) :
raise HttpError ( 400 , " Invalid trace ID. Please provide a valid UUIDv4. " )
try :
python_version = parse_python_version ( data . python_version )
except ValueError :
raise HttpError ( 400 , " Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above. " ) # noqa: B904
try :
ctx = BaseTestGenContext . get_dynamic_context (
TestGenContextData (
source_code_being_tested = data . source_code_being_tested ,
2025-11-06 01:38:44 +00:00
qualified_name = data . function_to_optimize . qualified_name ,
2025-10-15 08:28:00 +00:00
)
)
ctx . validate_python_module ( feature_version = python_version [ : 2 ] )
except SyntaxError :
raise HttpError ( 400 , " Invalid source code. It is not valid Python code. " ) # noqa: B904
2025-10-18 07:17:16 +00:00
return python_version , ctx
2025-10-15 08:28:00 +00:00
2025-05-09 01:51:03 +00:00
@testgen_api.post (
2024-10-28 00:03:28 +00:00
" / " , response = { 200 : TestGenResponseSchema , 400 : TestGenErrorResponseSchema , 500 : TestGenErrorResponseSchema }
2023-12-21 01:07:24 +00:00
)
2025-05-09 01:51:03 +00:00
async def testgen (
2025-12-30 07:51:38 +00:00
request : AuthenticatedRequest , data : TestGenSchema
2024-04-11 05:24:50 +00:00
) - > tuple [ int , TestGenResponseSchema | TestGenErrorResponseSchema ] :
2026-01-15 06:15:27 +00:00
# Route based on language
if data . language in ( " javascript " , " typescript " ) :
return await testgen_javascript ( request , data )
# Default: Python test generation
return await testgen_python ( request , data )
async def testgen_python (
request : AuthenticatedRequest , data : TestGenSchema
) - > tuple [ int , TestGenResponseSchema | TestGenErrorResponseSchema ] :
""" Generate Python tests using LLMs. """
ph ( request . user , " aiservice-testgen-called " , properties = { " language " : " python " } )
2025-09-25 00:22:55 +00:00
2023-12-20 06:12:37 +00:00
try :
2025-10-18 07:17:16 +00:00
python_version , ctx = validate_request_data ( data )
2025-10-15 08:28:00 +00:00
except HttpError as e :
2025-10-15 08:57:46 +00:00
e . add_note ( f " Testgen request validation error: { e . status_code } { e . message } " )
sentry_sdk . capture_exception ( e )
2025-10-15 08:28:00 +00:00
return e . status_code , TestGenErrorResponseSchema ( error = e . message )
2025-09-23 05:34:07 +00:00
2025-10-30 17:03:48 +00:00
if should_hack_for_demo ( data . source_code_being_tested ) :
2025-11-18 18:17:09 +00:00
if " find_common_tags " in data . source_code_being_tested :
2025-11-18 16:30:54 +00:00
demo_hack_response = await hack_for_demo ( data , python_version )
2025-11-18 18:17:09 +00:00
elif " weighted_sum " in data . source_code_being_tested :
2025-11-18 16:30:54 +00:00
demo_hack_response = await hack_for_demo_gsq ( data , python_version )
2025-10-15 08:28:00 +00:00
return 200 , demo_hack_response
2025-10-17 07:41:00 +00:00
logging . info ( " /testgen: Generating tests... " )
2024-04-10 05:46:37 +00:00
try :
2024-10-28 00:03:28 +00:00
debug_log_sensitive_data ( f " Generating tests for function { data . function_to_optimize . function_name } " )
2024-06-27 21:02:19 +00:00
2025-12-29 16:39:44 +00:00
# Using different LLMs for different test_index values to get more diverse tests
2025-12-29 16:35:09 +00:00
test_index = data . test_index if data . test_index is not None else 0
if test_index % 2 == 0 :
2026-01-07 04:54:20 +00:00
execute_model = OPENAI_MODEL
2025-12-29 16:35:09 +00:00
model_source = " OpenAI "
else :
2026-01-07 04:54:20 +00:00
execute_model = HAIKU_MODEL
2025-12-29 16:35:09 +00:00
model_source = " Anthropic "
2025-12-30 06:45:23 +00:00
2025-12-30 04:51:23 +00:00
logging . info (
f " Using { model_source } model ( { execute_model . name } ) for test_index { test_index } to generate diverse tests "
)
2025-12-29 16:35:09 +00:00
2025-10-18 07:35:20 +00:00
(
generated_test_source ,
instrumented_behavior_tests ,
instrumented_perf_tests ,
) = await generate_regression_tests_from_function (
2025-10-17 07:41:00 +00:00
ctx = ctx ,
user_id = request . user ,
2025-11-06 01:38:44 +00:00
function_name = data . function_to_optimize . qualified_name ,
2025-10-17 07:41:00 +00:00
python_version = python_version ,
2025-10-18 07:35:20 +00:00
data = data ,
unit_test_package = data . test_framework ,
2025-10-17 07:41:00 +00:00
is_async = data . is_async ,
trace_id = data . trace_id ,
2025-12-23 19:31:59 +00:00
call_sequence = data . call_sequence ,
2025-12-29 16:35:09 +00:00
execute_model = execute_model ,
2025-10-17 07:41:00 +00:00
)
2024-03-01 01:29:17 +00:00
ph ( request . user , " aiservice-testgen-tests-generated " )
2025-10-15 08:28:00 +00:00
2025-10-31 10:56:25 +00:00
if hasattr ( request , " should_log_features " ) and request . should_log_features :
await log_features (
2025-10-15 08:28:00 +00:00
trace_id = data . trace_id ,
user_id = request . user ,
generated_tests = [ generated_test_source ] ,
instrumented_generated_tests = [ instrumented_behavior_tests ] ,
test_framework = data . test_framework ,
metadata = {
" test_timeout " : data . test_timeout ,
" function_to_optimize " : data . function_to_optimize . function_name ,
} ,
2025-10-20 19:54:49 +00:00
)
2025-10-15 08:28:00 +00:00
return 200 , TestGenResponseSchema (
generated_tests = generated_test_source ,
instrumented_behavior_tests = instrumented_behavior_tests ,
instrumented_perf_tests = instrumented_perf_tests ,
)
2025-10-10 01:39:18 +00:00
except Exception as e :
2026-01-13 01:34:36 +00:00
logging . exception ( f " Test generation failed. trace_id= { data . trace_id } " )
2025-10-14 23:31:06 +00:00
sentry_sdk . capture_exception ( e )
2024-10-28 00:03:28 +00:00
return 500 , TestGenErrorResponseSchema ( error = " Error generating tests. Internal server error. " )