# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb from __future__ import annotations import ast import logging import os from pathlib import Path from typing import SupportsIndex from aiservice.common_utils import parse_python_version, safe_isort from aiservice.env_specific import debug_log_sensitive_data from aiservice.llm import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost, call_llm from aiservice.models.functions_to_optimize import FunctionToOptimize from log_features.log_event import update_optimization_cost from log_features.log_features import log_features from ninja import NinjaAPI, Schema from pydantic import model_validator from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import from testgen.instrumentation.instrument_new_tests import instrument_test_source testgen_api = NinjaAPI(urls_namespace="testgen") # Get the directory of the current file current_dir = Path(__file__).parent EXPLAIN_SYSTEM_PROMPT = (current_dir / "sqlalchemy_explain_system_prompt.md").read_text() EXPLAIN_USER_PROMPT = (current_dir / "sqlalchemy_explain_user_prompt.md").read_text() EXECUTE_SYSTEM_PROMPT = (current_dir / "sqlalchemy_execute_system_prompt.md").read_text() EXECUTE_USER_PROMPT = (current_dir / "sqlalchemy_execute_user_prompt.md").read_text() FETCH_DATA_SYSTEM_PROMPT = (current_dir / "sqlalchemy_fetch_data_system_prompt.md").read_text() FETCH_DATA_USER_PROMPT = (current_dir / "sqlalchemy_fetch_data_user_prompt.md").read_text() class TestGenerationFailedException(Exception): pass color_prefix_by_role = { "system": "\033[0m", # gray "user": "\033[0m", # gray "assistant": "\033[92m", # green } def ellipsis_in_ast_not_types(module: ast.AST) -> bool: # Add parent attribute to nodes for easier traversal for node in ast.walk(module): for child in ast.iter_child_nodes(node): child.parent = node for node in ast.walk(module): if isinstance(node, ast.Constant) and node.value is Ellipsis: # Check if the ellipsis is part of a type annotation if isinstance(node.parent, (ast.Subscript, ast.Index, ast.Tuple)): continue return True return False def any_ellipsis_in_ast(module: ast.AST) -> bool: return any(isinstance(node, ast.Constant) and node.value == ... for node in ast.walk(module)) def print_messages( messages: dict[SupportsIndex | slice, str], color_prefix_by_role: dict[SupportsIndex | slice, str] = color_prefix_by_role, ) -> None: """Prints messages sent to or from GPT.""" message: str for message in messages: role: str = message["role"] color_prefix: str = color_prefix_by_role[role] content: str = message["content"] print(f"{color_prefix}\n[{role}]\n{content}") def print_message_delta(delta, color_prefix_by_role=color_prefix_by_role) -> None: """Prints a chunk of messages streamed back from GPT.""" if "role" in delta: role = delta["role"] color_prefix = color_prefix_by_role[role] print(f"{color_prefix}\n[{role}]\n", end="") elif "content" in delta: content = delta["content"] print(content, end="") else: pass def write_fetch_data_function(fetch_data_function, connection_string, function_code): fetch_data_code = f"""from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text from sqlalchemy.engine import Engine, create_engine from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker from sqlalchemy.orm.relationships import Relationship POSTGRES_CONNECTION_STRING = "{connection_string}" catalog_engine: Engine = create_engine(POSTGRES_CONNECTION_STRING, echo=True) session: Session = sessionmaker(bind=catalog_engine)() class Base(DeclarativeBase): pass {function_code} {fetch_data_function} data = fetch_data(session) """ return fetch_data_code async def generate_regression_tests_from_function( user_id: str, function_code: str, # Python function to test, as a string function_name: str, # the function to test python_version: tuple[int, int, int], # Python version to use for ast parsing unit_test_package: str = "pytest", # unit testing package; use the name as it appears in the import statement approx_min_cases_to_cover: int = 7, # minimum number of test case categories to cover (approximate) print_text: bool = os.environ.get("ENVIRONMENT") != "PRODUCTION", # optionally prints text; helpful for understanding the function & debugging explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1 plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3 temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4 trace_id: str = "", ) -> str: """Returns a unit test for a given Python function, using a 3-step GPT prompt.""" # Step 1: Generate an explanation of the function # create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list explain_system_message = {"role": "system", "content": EXPLAIN_SYSTEM_PROMPT} explain_user_message = { "role": "user", "content": EXPLAIN_USER_PROMPT.format(function_name=function_name, function_code=function_code), } explain_messages = [explain_system_message, explain_user_message] total_llm_cost = 0.0 if print_text: print_messages(explain_messages) try: explanation_response = await call_llm( model_name=explain_model.name, model_type=explain_model.model_type, messages=explain_messages, temperature=temperature, ) total_llm_cost += calculate_llm_cost(explanation_response.raw_response, explain_model) except Exception as e: logging.exception("OpenAI client error in explain step") sentry_sdk.capture_exception(e) raise TestGenerationFailedException(e) from e debug_log_sensitive_data(f"OpenAIClient explanation response:\n{explanation_response.raw_response.model_dump_json(indent=2)}") if explanation_response.raw_response.usage is not None: ph( user_id, "aiservice-testgen-explain-openai-usage", properties={"model": explain_model.name, "usage": explanation_response.raw_response.usage.json()}, ) explanation = explanation_response.content explain_assistant_message = {"role": "assistant", "content": explanation} # Step 1b: Fetch relevant data from the database to use as inputs based on function explanation fetch_data_user_message = {"role": "user", "content": FETCH_DATA_USER_PROMPT.format(orm_code=function_code)} fetch_data_system_message = {"role": "system", "content": FETCH_DATA_SYSTEM_PROMPT.format(orm_code=function_code)} fetch_data_messages = [fetch_data_system_message, fetch_data_user_message] if print_text: print_messages(explain_messages) try: fetch_data_response = await call_llm( model_name=execute_model.name, model_type=execute_model.model_type, messages=fetch_data_messages, temperature=temperature, ) total_llm_cost += calculate_llm_cost(fetch_data_response.raw_response, execute_model) except Exception as e: logging.exception("OpenAI client error in explain step") sentry_sdk.capture_exception(e) raise TestGenerationFailedException(e) from e fetch_data_function = fetch_data_response.content fetch_data_function = fetch_data_function.split("```python")[1].split("```")[0].strip() # Step 1c: Run the function to get the data # Put function in a file and run it connection_string = os.environ.get("POSTGRES_CONNECTION_STRING") if connection_string is None: raise ValueError("POSTGRES_CONNECTION_STRING environment variable is not set") # Step 2: Generate a plan to write a unit test # Asks GPT to plan out cases the units tests should cover, formatted as a bullet list plan_user_message = { "role": "user", "content": f"""You want to write tests to test the following function: {function_code}. Imagine you have the data to use given as the following: ```python data_to_use = fetch_data(session) ``` where fetch_data(session) returns the data to use as inputs to the function. You should come up with a few different combinations of this data to use as inputs to the function in order to test it. These combinations should cover a wide range of inputs to constitute a good test suite. A good unit test suite should aim to: - Test the function's behavior for a wide range of possible inputs - Test edge cases that the author may not have foreseen - Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain - Be easy to read and understand, with clean code and descriptive names - Be deterministic, so that the tests always pass or fail in the same way - Have tests sorted by difficulty, from easiest to hardest To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets). Based on our function input, you should use different combinations of the data given to come up with these different scenarios.""", } plan_messages = [explain_system_message, explain_user_message, explain_assistant_message, plan_user_message] if print_text: print_messages([plan_user_message]) try: plan_response = await call_llm( model_name=plan_model.name, model_type=plan_model.model_type, messages=plan_messages, temperature=temperature, ) total_llm_cost += calculate_llm_cost(plan_response.raw_response, plan_model) except Exception as e: logging.exception("OpenAI client error in plan step") sentry_sdk.capture_exception(e) raise TestGenerationFailedException(e) from e debug_log_sensitive_data(f"OpenAIClient plan response:\n{plan_response.raw_response.model_dump_json(indent=2)}") if plan_response.raw_response.usage is not None: ph( user_id, "aiservice-testgen-plan-openai-usage", properties={"model": plan_model.name, "usage": plan_response.raw_response.usage.json()}, ) plan = plan_response.content plan_assistant_message = {"role": "assistant", "content": plan} # Step 2b: If the plan is short, ask GPT to elaborate further # this counts top-level bullets (e.g., categories), but not sub-bullets (e.g., test cases) num_bullets = max(plan.count("\n-"), plan.count("\n*")) elaboration_needed = num_bullets < approx_min_cases_to_cover if elaboration_needed: elaboration_user_message = { "role": "user", "content": """In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).""", } elaboration_messages = [ explain_system_message, explain_user_message, explain_assistant_message, plan_user_message, plan_assistant_message, elaboration_user_message, ] if print_text: print_messages([elaboration_user_message]) try: elaboration_response = await call_llm( model_name=plan_model.name, model_type=plan_model.model_type, messages=elaboration_messages, temperature=temperature, ) total_llm_cost += calculate_llm_cost(elaboration_response.raw_response, plan_model) except Exception as e: logging.exception("OpenAI client error in elaboration step") sentry_sdk.capture_exception(e) raise TestGenerationFailedException(e) from e debug_log_sensitive_data( f"OpenAIClient elaboration response:\n{elaboration_response.raw_response.model_dump_json(indent=2)}" ) elaboration = elaboration_response.content elaboration_assistant_message = {"role": "assistant", "content": elaboration} # Step 3: Generate the unit test # create a markdown-formatted prompt that asks GPT to complete a unit test package_comment = "" # if unit_test_package == "pytest": # package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator" execute_system_message = {"role": "system", "content": EXECUTE_SYSTEM_PROMPT.format(function_name=function_name)} execute_messages = [ execute_system_message, explain_user_message, explain_assistant_message, plan_user_message, plan_assistant_message, ] if elaboration_needed: execute_messages += [elaboration_user_message, elaboration_assistant_message] execute_user_message = { "role": "user", "content": EXECUTE_USER_PROMPT.format( unit_test_package=unit_test_package, function_name=function_name, function_code=function_code, fetch_data_function_code=fetch_data_function, package_comment=package_comment, ), } execute_messages += [execute_user_message] if print_text: print_messages([execute_system_message, execute_user_message]) # TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach tries = 2 while tries > 0: try: execute_response = await call_llm( model_name=execute_model.name, model_type=execute_model.model_type, messages=execute_messages, temperature=temperature, ) total_llm_cost += calculate_llm_cost(execute_response.raw_response, execute_model) except Exception as e: logging.exception("OpenAI client error in execute step") sentry_sdk.capture_exception(e) raise TestGenerationFailedException(e) from e debug_log_sensitive_data(f"OpenAIClient execute response:\n{execute_response.raw_response.model_dump_json(indent=2)}") if execute_response.raw_response.usage is not None: ph( user_id, "aiservice-testgen-execute-openai-usage", properties={"model": execute_model.name, "usage": execute_response.raw_response.usage.json()}, ) execution_output = execute_response.content # check the output for errors code = execution_output.split("```python")[1].split("```")[0].strip() try: module = ast.parse(code, feature_version=python_version[:2]) original_function = ast.parse(function_code, feature_version=python_version[:2]) if not any_ellipsis_in_ast(original_function) and ellipsis_in_ast_not_types(module): # If the test generator is generating ellipsis, it is punting on generating # the concrete test cases and we should re-generate raise SyntaxError("Ellipsis in generated test code, regenerating...") break except SyntaxError as e: tries -= 1 logging.warning("Syntax error in generated code. Trying again.") logging.warning(f"Error: {e}") logging.warning(f"Generated code: {code}") continue if tries == 0: raise TestGenerationFailedException("Failed to generate test code after 2 tries.") await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id) # return the unit test as a string return code class TestGenSchema(Schema): source_code_being_tested: str function_to_optimize: FunctionToOptimize helper_function_names: list[str] = None # This is the only one we should use dependent_function_names: list[str] = None # Only for backwards compatibility module_path: str test_module_path: str test_framework: str test_timeout: int trace_id: str python_version: str @model_validator(mode="after") def helper_function_names_validator(self): # To maintain backwards compatibility if self.dependent_function_names is None and self.helper_function_names is None: raise ValueError("either field 'helper_function_names' or 'dependent_function_names' is required") if self.helper_function_names is not None: return self print("self.dependent_function_names", self.dependent_function_names) self.helper_function_names = self.dependent_function_names self.dependent_function_names = None return self class TestGenResponseSchema(Schema): generated_tests: str instrumented_tests: str class TestGenErrorResponseSchema(Schema): error: str import sentry_sdk from aiservice.analytics.posthog import ph @testgen_api.post( "/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema} ) async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]: ph(request.user, "aiservice-testgen-called") if data.test_framework not in ["unittest", "pytest"]: return 400, TestGenErrorResponseSchema(error="Invalid test framework. We only support unittest and pytest.") if not data.function_to_optimize: # TODO: Add a validation check here to see if the function_name is actually present in # the source_code_being_tested. Parse ast return 400, TestGenErrorResponseSchema(error="Invalid function to optimize. It is empty.") if data.source_code_being_tested == "": return 400, TestGenErrorResponseSchema(error="Invalid source code. It is empty.") try: python_version: tuple[int, int, int] = parse_python_version(data.python_version) except: return 400, TestGenErrorResponseSchema( error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above." ) print("data.helper_function_names", data.helper_function_names) try: ast.parse(data.source_code_being_tested, feature_version=python_version[:2]) compile(data.source_code_being_tested, "data.source_code_being_tested", "exec") except SyntaxError: return 400, TestGenErrorResponseSchema( error="Invalid source code. It is not valid Python code. Please check syntax of your code." ) try: print("/testgen: Generating tests...") debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}") debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}") generated_test_source = await generate_regression_tests_from_function( user_id=request.user, function_code=data.source_code_being_tested, function_name=data.function_to_optimize.function_name, unit_test_package=data.test_framework, approx_min_cases_to_cover=10, python_version=python_version, trace_id=data.trace_id, ) print("/testgen: Instrumenting tests...") instrumented_test_source = safe_isort( instrument_test_source( test_source=generated_test_source, function_to_optimize=data.function_to_optimize, helper_function_names=data.helper_function_names, module_path=data.module_path, test_module_path=data.test_module_path, test_framework=data.test_framework, test_timeout=data.test_timeout, python_version=python_version, ), float_to_top=True, ) generated_test_source = replace_definition_with_import( generated_test_source, data.function_to_optimize, data.module_path ) # Use isort to sort and deduplicate the imports in the generated test code sorted_imports_test_source = safe_isort(generated_test_source) try: parse_module_to_cst(sorted_imports_test_source) except Exception as e: logging.exception("Failed to parse generated test code") sentry_sdk.capture_exception(e) else: generated_test_source = sorted_imports_test_source ph(request.user, "aiservice-testgen-tests-generated") except TestGenerationFailedException as e: logging.exception("Test generation failed. Skipping test generation.") sentry_sdk.capture_exception(e) return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.") if hasattr(request, "should_log_features") and request.should_log_features: await log_features( trace_id=data.trace_id, user_id=request.user, generated_tests=[generated_test_source], test_framework=data.test_framework, metadata={ "test_timeout": data.test_timeout, "function_to_optimize": data.function_to_optimize.function_name, }, ) return 200, TestGenResponseSchema( generated_tests=generated_test_source, instrumented_tests=instrumented_test_source )