2024-06-27 22:37:14 +00:00
# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
from __future__ import annotations
import ast
import logging
import os
from pathlib import Path
from typing import SupportsIndex
import isort
2024-10-25 22:45:44 +00:00
from ninja import NinjaAPI , Schema
from pydantic import model_validator
2024-06-27 22:37:14 +00:00
from aiservice . common_utils import parse_python_version
from aiservice . env_specific import create_openai_client , debug_log_sensitive_data
from aiservice . models . aimodels import EXECUTE_MODEL , EXPLAIN_MODEL , LLM , PLAN_MODEL
from aiservice . models . functions_to_optimize import FunctionToOptimize
from authapp . auth import AuthBearer
from log_features . log_features import log_features
2024-10-25 22:45:44 +00:00
from testgen . instrumentation . edit_generated_test import parse_module_to_cst , replace_definition_with_import
2024-08-23 04:18:05 +00:00
from testgen . instrumentation . instrument_new_tests import instrument_test_source
2024-06-27 22:37:14 +00:00
testgen_api = NinjaAPI ( auth = AuthBearer ( ) , urls_namespace = " testgen " )
openai_client = create_openai_client ( )
# Get the directory of the current file
current_dir = Path ( __file__ ) . parent
EXPLAIN_SYSTEM_PROMPT = ( current_dir / " sqlalchemy_explain_system_prompt.md " ) . read_text ( )
EXPLAIN_USER_PROMPT = ( current_dir / " sqlalchemy_explain_user_prompt.md " ) . read_text ( )
EXECUTE_SYSTEM_PROMPT = ( current_dir / " sqlalchemy_execute_system_prompt.md " ) . read_text ( )
EXECUTE_USER_PROMPT = ( current_dir / " sqlalchemy_execute_user_prompt.md " ) . read_text ( )
FETCH_DATA_SYSTEM_PROMPT = ( current_dir / " sqlalchemy_fetch_data_system_prompt.md " ) . read_text ( )
FETCH_DATA_USER_PROMPT = ( current_dir / " sqlalchemy_fetch_data_user_prompt.md " ) . read_text ( )
class TestGenerationFailedException ( Exception ) :
pass
color_prefix_by_role = {
" system " : " \033 [0m " , # gray
" user " : " \033 [0m " , # gray
" assistant " : " \033 [92m " , # green
}
def ellipsis_in_ast_not_types ( module : ast . AST ) - > bool :
# Add parent attribute to nodes for easier traversal
for node in ast . walk ( module ) :
for child in ast . iter_child_nodes ( node ) :
child . parent = node
for node in ast . walk ( module ) :
if isinstance ( node , ast . Constant ) and node . value is Ellipsis :
# Check if the ellipsis is part of a type annotation
if isinstance ( node . parent , ( ast . Subscript , ast . Index , ast . Tuple ) ) :
continue
return True
return False
def any_ellipsis_in_ast ( module : ast . AST ) - > bool :
return any ( isinstance ( node , ast . Constant ) and node . value == . . . for node in ast . walk ( module ) )
def print_messages (
messages : dict [ SupportsIndex | slice , str ] ,
color_prefix_by_role : dict [ SupportsIndex | slice , str ] = color_prefix_by_role ,
) - > None :
""" Prints messages sent to or from GPT. """
message : str
for message in messages :
role : str = message [ " role " ]
color_prefix : str = color_prefix_by_role [ role ]
content : str = message [ " content " ]
print ( f " { color_prefix } \n [ { role } ] \n { content } " )
def print_message_delta ( delta , color_prefix_by_role = color_prefix_by_role ) - > None :
""" Prints a chunk of messages streamed back from GPT. """
if " role " in delta :
role = delta [ " role " ]
color_prefix = color_prefix_by_role [ role ]
print ( f " { color_prefix } \n [ { role } ] \n " , end = " " )
elif " content " in delta :
content = delta [ " content " ]
print ( content , end = " " )
else :
pass
def write_fetch_data_function ( fetch_data_function , connection_string , function_code ) :
fetch_data_code = f """ from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text
from sqlalchemy . engine import Engine , create_engine
from sqlalchemy . orm import DeclarativeBase , Session , relationship , sessionmaker
from sqlalchemy . orm . relationships import Relationship
POSTGRES_CONNECTION_STRING = " {connection_string} "
catalog_engine : Engine = create_engine ( POSTGRES_CONNECTION_STRING , echo = True )
session : Session = sessionmaker ( bind = catalog_engine ) ( )
class Base ( DeclarativeBase ) :
pass
{ function_code }
{ fetch_data_function }
data = fetch_data ( session )
"""
return fetch_data_code
async def generate_regression_tests_from_function (
user_id : str ,
function_code : str , # Python function to test, as a string
function_name : str , # the function to test
python_version : tuple [ int , int , int ] , # Python version to use for ast parsing
unit_test_package : str = " pytest " , # unit testing package; use the name as it appears in the import statement
approx_min_cases_to_cover : int = 7 , # minimum number of test case categories to cover (approximate)
print_text : bool = os . environ . get ( " ENVIRONMENT " ) != " PRODUCTION " ,
# optionally prints text; helpful for understanding the function & debugging
explain_model : LLM = EXPLAIN_MODEL , # model used to generate text plans in step 1
plan_model : LLM = PLAN_MODEL , # model used to generate text plans in steps 2 and 2b
execute_model : LLM = EXECUTE_MODEL , # model used to generate code in step 3
temperature : float = 0.4 , # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
) - > str :
""" Returns a unit test for a given Python function, using a 3-step GPT prompt. """
# Step 1: Generate an explanation of the function
# create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list
explain_system_message = { " role " : " system " , " content " : EXPLAIN_SYSTEM_PROMPT }
explain_user_message = {
" role " : " user " ,
2024-10-25 22:45:44 +00:00
" content " : EXPLAIN_USER_PROMPT . format ( function_name = function_name , function_code = function_code ) ,
2024-06-27 22:37:14 +00:00
}
explain_messages = [ explain_system_message , explain_user_message ]
if print_text :
print_messages ( explain_messages )
try :
2024-10-25 22:45:44 +00:00
explanation_response = await openai_client . with_options ( max_retries = 2 ) . chat . completions . create (
model = explain_model . name , messages = explain_messages , temperature = temperature
2024-06-27 22:37:14 +00:00
)
except Exception as e :
logging . exception ( " OpenAI client error in explain step " )
raise TestGenerationFailedException ( e ) from e
2024-10-25 22:45:44 +00:00
debug_log_sensitive_data ( f " OpenAIClient explanation response: \n { explanation_response . model_dump_json ( indent = 2 ) } " )
2024-06-27 22:37:14 +00:00
if explanation_response . usage is not None :
ph (
user_id ,
" aiservice-testgen-explain-openai-usage " ,
2024-10-25 22:45:44 +00:00
properties = { " model " : explain_model . name , " usage " : explanation_response . usage . json ( ) } ,
2024-06-27 22:37:14 +00:00
)
explanation = explanation_response . choices [ 0 ] . message . content
explain_assistant_message = { " role " : " assistant " , " content " : explanation }
# Step 1b: Fetch relevant data from the database to use as inputs based on function explanation
2024-10-25 22:45:44 +00:00
fetch_data_user_message = { " role " : " user " , " content " : FETCH_DATA_USER_PROMPT . format ( orm_code = function_code ) }
2024-06-27 22:37:14 +00:00
2024-10-25 22:45:44 +00:00
fetch_data_system_message = { " role " : " system " , " content " : FETCH_DATA_SYSTEM_PROMPT . format ( orm_code = function_code ) }
2024-06-27 22:37:14 +00:00
fetch_data_messages = [ fetch_data_system_message , fetch_data_user_message ]
if print_text :
print_messages ( explain_messages )
try :
2024-10-25 22:45:44 +00:00
fetch_data_response = await openai_client . with_options ( max_retries = 2 ) . chat . completions . create (
model = execute_model . name , messages = fetch_data_messages , temperature = temperature
2024-06-27 22:37:14 +00:00
)
except Exception as e :
logging . exception ( " OpenAI client error in explain step " )
raise TestGenerationFailedException ( e ) from e
fetch_data_function = fetch_data_response . choices [ 0 ] . message . content
fetch_data_function = fetch_data_function . split ( " ```python " ) [ 1 ] . split ( " ``` " ) [ 0 ] . strip ( )
# Step 1c: Run the function to get the data
# Put function in a file and run it
connection_string = os . environ . get ( " POSTGRES_CONNECTION_STRING " )
if connection_string is None :
raise ValueError ( " POSTGRES_CONNECTION_STRING environment variable is not set " )
# Step 2: Generate a plan to write a unit test
# Asks GPT to plan out cases the units tests should cover, formatted as a bullet list
plan_user_message = {
" role " : " user " ,
" content " : f """ You want to write tests to test the following function: { function_code } .
Imagine you have the data to use given as the following :
` ` ` python
data_to_use = fetch_data ( session )
` ` `
where fetch_data ( session ) returns the data to use as inputs to the function .
You should come up with a few different combinations of this data to use as inputs to the function in order to test it . These combinations should cover a wide range of inputs to constitute a good test suite . A good unit test suite should aim to :
- Test the function ' s behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of ` { unit_test_package } ` to make the tests easy to write and maintain
- Be easy to read and understand , with clean code and descriptive names
- Be deterministic , so that the tests always pass or fail in the same way
- Have tests sorted by difficulty , from easiest to hardest
To help unit test the function above , list diverse scenarios that the function should be able to handle ( and under each scenario , include a few examples as sub - bullets ) . Based on our function input , you should use different combinations of the data given to come up with these different scenarios . """ ,
}
2024-10-25 22:45:44 +00:00
plan_messages = [ explain_system_message , explain_user_message , explain_assistant_message , plan_user_message ]
2024-06-27 22:37:14 +00:00
if print_text :
print_messages ( [ plan_user_message ] )
try :
plan_response = await openai_client . with_options ( max_retries = 2 ) . chat . completions . create (
2024-10-25 22:45:44 +00:00
model = plan_model . name , messages = plan_messages , temperature = temperature
2024-06-27 22:37:14 +00:00
)
except Exception as e :
logging . exception ( " OpenAI client error in plan step " )
raise TestGenerationFailedException ( e ) from e
2024-10-25 22:45:44 +00:00
debug_log_sensitive_data ( f " OpenAIClient plan response: \n { plan_response . model_dump_json ( indent = 2 ) } " )
2024-06-27 22:37:14 +00:00
if plan_response . usage is not None :
ph (
user_id ,
" aiservice-testgen-plan-openai-usage " ,
2024-10-25 22:45:44 +00:00
properties = { " model " : plan_model . name , " usage " : plan_response . usage . json ( ) } ,
2024-06-27 22:37:14 +00:00
)
plan = plan_response . choices [ 0 ] . message . content
plan_assistant_message = { " role " : " assistant " , " content " : plan }
# Step 2b: If the plan is short, ask GPT to elaborate further
# this counts top-level bullets (e.g., categories), but not sub-bullets (e.g., test cases)
num_bullets = max ( plan . count ( " \n - " ) , plan . count ( " \n * " ) )
elaboration_needed = num_bullets < approx_min_cases_to_cover
if elaboration_needed :
elaboration_user_message = {
" role " : " user " ,
" content " : """ In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets). """ ,
}
elaboration_messages = [
explain_system_message ,
explain_user_message ,
explain_assistant_message ,
plan_user_message ,
plan_assistant_message ,
elaboration_user_message ,
]
if print_text :
print_messages ( [ elaboration_user_message ] )
try :
2024-10-25 22:45:44 +00:00
elaboration_response = await openai_client . with_options ( max_retries = 2 ) . chat . completions . create (
model = plan_model . name , messages = elaboration_messages , temperature = temperature
2024-06-27 22:37:14 +00:00
)
except Exception as e :
logging . exception ( " OpenAI client error in elaboration step " )
raise TestGenerationFailedException ( e ) from e
debug_log_sensitive_data (
2024-10-25 22:45:44 +00:00
f " OpenAIClient elaboration response: \n { elaboration_response . model_dump_json ( indent = 2 ) } "
2024-06-27 22:37:14 +00:00
)
elaboration = elaboration_response . choices [ 0 ] . message . content
elaboration_assistant_message = { " role " : " assistant " , " content " : elaboration }
# Step 3: Generate the unit test
# create a markdown-formatted prompt that asks GPT to complete a unit test
package_comment = " "
# if unit_test_package == "pytest":
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
2024-10-25 22:45:44 +00:00
execute_system_message = { " role " : " system " , " content " : EXECUTE_SYSTEM_PROMPT . format ( function_name = function_name ) }
2024-06-27 22:37:14 +00:00
execute_messages = [
execute_system_message ,
explain_user_message ,
explain_assistant_message ,
plan_user_message ,
plan_assistant_message ,
]
if elaboration_needed :
execute_messages + = [ elaboration_user_message , elaboration_assistant_message ]
execute_user_message = {
" role " : " user " ,
" content " : EXECUTE_USER_PROMPT . format (
unit_test_package = unit_test_package ,
function_name = function_name ,
function_code = function_code ,
fetch_data_function_code = fetch_data_function ,
package_comment = package_comment ,
) ,
}
execute_messages + = [ execute_user_message ]
if print_text :
print_messages ( [ execute_system_message , execute_user_message ] )
# TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach
tries = 2
while tries > 0 :
try :
2024-10-25 22:45:44 +00:00
execute_response = await openai_client . with_options ( max_retries = 2 ) . chat . completions . create (
model = execute_model . name , messages = execute_messages , temperature = temperature
2024-06-27 22:37:14 +00:00
)
except Exception as e :
logging . exception ( " OpenAI client error in execute step " )
raise TestGenerationFailedException ( e ) from e
2024-10-25 22:45:44 +00:00
debug_log_sensitive_data ( f " OpenAIClient execute response: \n { execute_response . model_dump_json ( indent = 2 ) } " )
2024-06-27 22:37:14 +00:00
if execute_response . usage is not None :
ph (
user_id ,
" aiservice-testgen-execute-openai-usage " ,
2024-10-25 22:45:44 +00:00
properties = { " model " : execute_model . name , " usage " : execute_response . usage . json ( ) } ,
2024-06-27 22:37:14 +00:00
)
execution_output = execute_response . choices [ 0 ] . message . content
# check the output for errors
code = execution_output . split ( " ```python " ) [ 1 ] . split ( " ``` " ) [ 0 ] . strip ( )
try :
module = ast . parse ( code , feature_version = python_version [ : 2 ] )
original_function = ast . parse ( function_code , feature_version = python_version [ : 2 ] )
if not any_ellipsis_in_ast ( original_function ) and ellipsis_in_ast_not_types ( module ) :
# If the test generator is generating ellipsis, it is punting on generating
# the concrete test cases and we should re-generate
raise SyntaxError ( " Ellipsis in generated test code, regenerating... " )
break
except SyntaxError as e :
tries - = 1
logging . warning ( " Syntax error in generated code. Trying again. " )
logging . warning ( f " Error: { e } " )
logging . warning ( f " Generated code: { code } " )
continue
if tries == 0 :
raise TestGenerationFailedException ( " Failed to generate test code after 2 tries. " )
# return the unit test as a string
return code
class TestGenSchema ( Schema ) :
source_code_being_tested : str
function_to_optimize : FunctionToOptimize
helper_function_names : list [ str ] = None # This is the only one we should use
dependent_function_names : list [ str ] = None # Only for backwards compatibility
module_path : str
test_module_path : str
test_framework : str
test_timeout : int
trace_id : str
python_version : str
@model_validator ( mode = " after " )
def helper_function_names_validator ( self ) :
# To maintain backwards compatibility
if self . dependent_function_names is None and self . helper_function_names is None :
raise ValueError ( " either field ' helper_function_names ' or ' dependent_function_names ' is required " )
if self . helper_function_names is not None :
return self
print ( " self.dependent_function_names " , self . dependent_function_names )
self . helper_function_names = self . dependent_function_names
self . dependent_function_names = None
return self
class TestGenResponseSchema ( Schema ) :
generated_tests : str
instrumented_tests : str
class TestGenErrorResponseSchema ( Schema ) :
error : str
from aiservice . analytics . posthog import ph
@testgen_api.post (
2024-10-25 22:45:44 +00:00
" / " , response = { 200 : TestGenResponseSchema , 400 : TestGenErrorResponseSchema , 500 : TestGenErrorResponseSchema }
2024-06-27 22:37:14 +00:00
)
async def testgen (
2024-10-25 22:45:44 +00:00
request : AuthBearer , data : TestGenSchema
2024-06-27 22:37:14 +00:00
) - > tuple [ int , TestGenResponseSchema | TestGenErrorResponseSchema ] :
ph ( request . user , " aiservice-testgen-called " )
if data . test_framework not in [ " unittest " , " pytest " ] :
2024-10-25 22:45:44 +00:00
return 400 , TestGenErrorResponseSchema ( error = " Invalid test framework. We only support unittest and pytest. " )
2024-06-27 22:37:14 +00:00
if not data . function_to_optimize :
# TODO: Add a validation check here to see if the function_name is actually present in
# the source_code_being_tested. Parse ast
return 400 , TestGenErrorResponseSchema ( error = " Invalid function to optimize. It is empty. " )
if data . source_code_being_tested == " " :
return 400 , TestGenErrorResponseSchema ( error = " Invalid source code. It is empty. " )
try :
python_version : tuple [ int , int , int ] = parse_python_version ( data . python_version )
except :
return 400 , TestGenErrorResponseSchema (
2024-10-25 22:45:44 +00:00
error = " Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above. "
2024-06-27 22:37:14 +00:00
)
print ( " data.helper_function_names " , data . helper_function_names )
try :
2024-10-25 22:45:44 +00:00
ast . parse ( data . source_code_being_tested , feature_version = python_version [ : 2 ] )
2024-06-27 22:37:14 +00:00
compile ( data . source_code_being_tested , " data.source_code_being_tested " , " exec " )
except SyntaxError :
return 400 , TestGenErrorResponseSchema (
2024-10-25 22:45:44 +00:00
error = " Invalid source code. It is not valid Python code. Please check syntax of your code. "
2024-06-27 22:37:14 +00:00
)
try :
print ( " /testgen: Generating tests... " )
2024-10-25 22:45:44 +00:00
debug_log_sensitive_data ( f " Generating tests for function { data . function_to_optimize . function_name } " )
2024-06-27 22:37:14 +00:00
debug_log_sensitive_data ( f " Source code being tested: { data . source_code_being_tested } " )
generated_test_source = await generate_regression_tests_from_function (
user_id = request . user ,
function_code = data . source_code_being_tested ,
function_name = data . function_to_optimize . function_name ,
unit_test_package = data . test_framework ,
approx_min_cases_to_cover = 10 ,
python_version = python_version ,
)
print ( " /testgen: Instrumenting tests... " )
instrumented_test_source = isort . code (
instrument_test_source (
test_source = generated_test_source ,
function_to_optimize = data . function_to_optimize ,
helper_function_names = data . helper_function_names ,
module_path = data . module_path ,
test_module_path = data . test_module_path ,
test_framework = data . test_framework ,
test_timeout = data . test_timeout ,
python_version = python_version ,
) ,
float_to_top = True ,
)
generated_test_source = replace_definition_with_import (
2024-10-25 22:45:44 +00:00
generated_test_source , data . function_to_optimize , data . module_path
2024-06-27 22:37:14 +00:00
)
# Use isort to sort and deduplicate the imports in the generated test code
sorted_imports_test_source = isort . code ( generated_test_source )
try :
parse_module_to_cst ( sorted_imports_test_source )
except Exception as e :
logging . exception ( f " Failed to parse generated test code: { e } " )
ph (
request . user ,
" aiservice-testgen-invalid-isort-code " ,
2024-10-25 22:45:44 +00:00
properties = { " error " : str ( e ) , " sorted_imports_test_source " : sorted_imports_test_source } ,
2024-06-27 22:37:14 +00:00
)
else :
generated_test_source = sorted_imports_test_source
ph ( request . user , " aiservice-testgen-tests-generated " )
except TestGenerationFailedException as e :
logging . exception ( " Test generation failed. Skipping test generation. " )
logging . exception ( e )
ph ( request . user , " aiservice-testgen-test-generation-failed " , properties = { " error " : str ( e ) } )
2024-10-25 22:45:44 +00:00
return 500 , TestGenErrorResponseSchema ( error = " Error generating tests. Internal server error. " )
2024-06-27 22:37:14 +00:00
if request . tier is None :
await log_features (
trace_id = data . trace_id ,
user_id = request . user ,
generated_tests = [ generated_test_source ] ,
test_framework = data . test_framework ,
metadata = {
" test_timeout " : data . test_timeout ,
" function_to_optimize " : data . function_to_optimize . function_name ,
} ,
)
return 200 , TestGenResponseSchema (
2024-10-25 22:45:44 +00:00
generated_tests = generated_test_source , instrumented_tests = instrumented_test_source
2024-06-27 22:37:14 +00:00
)