471 lines
21 KiB
Python
471 lines
21 KiB
Python
# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
from typing import SupportsIndex
|
|
|
|
from aiservice.common_utils import parse_python_version, safe_isort
|
|
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
|
|
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
|
|
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from ninja import NinjaAPI, Schema
|
|
from pydantic import model_validator
|
|
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
|
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
|
|
|
testgen_api = NinjaAPI(urls_namespace="testgen")
|
|
|
|
openai_client = create_openai_client()
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
EXPLAIN_SYSTEM_PROMPT = (current_dir / "sqlalchemy_explain_system_prompt.md").read_text()
|
|
EXPLAIN_USER_PROMPT = (current_dir / "sqlalchemy_explain_user_prompt.md").read_text()
|
|
|
|
EXECUTE_SYSTEM_PROMPT = (current_dir / "sqlalchemy_execute_system_prompt.md").read_text()
|
|
EXECUTE_USER_PROMPT = (current_dir / "sqlalchemy_execute_user_prompt.md").read_text()
|
|
|
|
FETCH_DATA_SYSTEM_PROMPT = (current_dir / "sqlalchemy_fetch_data_system_prompt.md").read_text()
|
|
FETCH_DATA_USER_PROMPT = (current_dir / "sqlalchemy_fetch_data_user_prompt.md").read_text()
|
|
|
|
|
|
class TestGenerationFailedException(Exception):
|
|
pass
|
|
|
|
|
|
color_prefix_by_role = {
|
|
"system": "\033[0m", # gray
|
|
"user": "\033[0m", # gray
|
|
"assistant": "\033[92m", # green
|
|
}
|
|
|
|
|
|
def ellipsis_in_ast_not_types(module: ast.AST) -> bool:
|
|
# Add parent attribute to nodes for easier traversal
|
|
for node in ast.walk(module):
|
|
for child in ast.iter_child_nodes(node):
|
|
child.parent = node
|
|
|
|
for node in ast.walk(module):
|
|
if isinstance(node, ast.Constant) and node.value is Ellipsis:
|
|
# Check if the ellipsis is part of a type annotation
|
|
if isinstance(node.parent, (ast.Subscript, ast.Index, ast.Tuple)):
|
|
continue
|
|
return True
|
|
return False
|
|
|
|
|
|
def any_ellipsis_in_ast(module: ast.AST) -> bool:
|
|
return any(isinstance(node, ast.Constant) and node.value == ... for node in ast.walk(module))
|
|
|
|
|
|
def print_messages(
|
|
messages: dict[SupportsIndex | slice, str],
|
|
color_prefix_by_role: dict[SupportsIndex | slice, str] = color_prefix_by_role,
|
|
) -> None:
|
|
"""Prints messages sent to or from GPT."""
|
|
message: str
|
|
for message in messages:
|
|
role: str = message["role"]
|
|
color_prefix: str = color_prefix_by_role[role]
|
|
content: str = message["content"]
|
|
print(f"{color_prefix}\n[{role}]\n{content}")
|
|
|
|
|
|
def print_message_delta(delta, color_prefix_by_role=color_prefix_by_role) -> None:
|
|
"""Prints a chunk of messages streamed back from GPT."""
|
|
if "role" in delta:
|
|
role = delta["role"]
|
|
color_prefix = color_prefix_by_role[role]
|
|
print(f"{color_prefix}\n[{role}]\n", end="")
|
|
elif "content" in delta:
|
|
content = delta["content"]
|
|
print(content, end="")
|
|
else:
|
|
pass
|
|
|
|
|
|
def write_fetch_data_function(fetch_data_function, connection_string, function_code):
|
|
fetch_data_code = f"""from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text
|
|
from sqlalchemy.engine import Engine, create_engine
|
|
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker
|
|
from sqlalchemy.orm.relationships import Relationship
|
|
|
|
|
|
POSTGRES_CONNECTION_STRING = "{connection_string}"
|
|
catalog_engine: Engine = create_engine(POSTGRES_CONNECTION_STRING, echo=True)
|
|
session: Session = sessionmaker(bind=catalog_engine)()
|
|
|
|
class Base(DeclarativeBase):
|
|
pass
|
|
|
|
|
|
{function_code}
|
|
|
|
{fetch_data_function}
|
|
|
|
data = fetch_data(session)
|
|
"""
|
|
return fetch_data_code
|
|
|
|
|
|
async def generate_regression_tests_from_function(
|
|
user_id: str,
|
|
function_code: str, # Python function to test, as a string
|
|
function_name: str, # the function to test
|
|
python_version: tuple[int, int, int], # Python version to use for ast parsing
|
|
unit_test_package: str = "pytest", # unit testing package; use the name as it appears in the import statement
|
|
approx_min_cases_to_cover: int = 7, # minimum number of test case categories to cover (approximate)
|
|
print_text: bool = os.environ.get("ENVIRONMENT") != "PRODUCTION",
|
|
# optionally prints text; helpful for understanding the function & debugging
|
|
explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1
|
|
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
|
|
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
|
|
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
|
|
trace_id: str = "",
|
|
) -> str:
|
|
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
|
|
# Step 1: Generate an explanation of the function
|
|
|
|
# create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list
|
|
explain_system_message = {"role": "system", "content": EXPLAIN_SYSTEM_PROMPT}
|
|
explain_user_message = {
|
|
"role": "user",
|
|
"content": EXPLAIN_USER_PROMPT.format(function_name=function_name, function_code=function_code),
|
|
}
|
|
explain_messages = [explain_system_message, explain_user_message]
|
|
total_llm_cost = 0.0
|
|
if print_text:
|
|
print_messages(explain_messages)
|
|
try:
|
|
explanation_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=explain_model.name, messages=explain_messages, temperature=temperature
|
|
)
|
|
total_llm_cost += calculate_llm_cost(explanation_response, explain_model) or 0.0
|
|
except Exception as e:
|
|
logging.exception("OpenAI client error in explain step")
|
|
sentry_sdk.capture_exception(e)
|
|
raise TestGenerationFailedException(e) from e
|
|
debug_log_sensitive_data(f"OpenAIClient explanation response:\n{explanation_response.model_dump_json(indent=2)}")
|
|
if explanation_response.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-testgen-explain-openai-usage",
|
|
properties={"model": explain_model.name, "usage": explanation_response.usage.json()},
|
|
)
|
|
explanation = explanation_response.choices[0].message.content
|
|
explain_assistant_message = {"role": "assistant", "content": explanation}
|
|
|
|
# Step 1b: Fetch relevant data from the database to use as inputs based on function explanation
|
|
fetch_data_user_message = {"role": "user", "content": FETCH_DATA_USER_PROMPT.format(orm_code=function_code)}
|
|
|
|
fetch_data_system_message = {"role": "system", "content": FETCH_DATA_SYSTEM_PROMPT.format(orm_code=function_code)}
|
|
fetch_data_messages = [fetch_data_system_message, fetch_data_user_message]
|
|
if print_text:
|
|
print_messages(explain_messages)
|
|
try:
|
|
fetch_data_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=execute_model.name, messages=fetch_data_messages, temperature=temperature
|
|
)
|
|
total_llm_cost += calculate_llm_cost(fetch_data_response, execute_model) or 0.0
|
|
except Exception as e:
|
|
logging.exception("OpenAI client error in explain step")
|
|
sentry_sdk.capture_exception(e)
|
|
raise TestGenerationFailedException(e) from e
|
|
|
|
fetch_data_function = fetch_data_response.choices[0].message.content
|
|
fetch_data_function = fetch_data_function.split("```python")[1].split("```")[0].strip()
|
|
|
|
# Step 1c: Run the function to get the data
|
|
# Put function in a file and run it
|
|
connection_string = os.environ.get("POSTGRES_CONNECTION_STRING")
|
|
if connection_string is None:
|
|
raise ValueError("POSTGRES_CONNECTION_STRING environment variable is not set")
|
|
|
|
# Step 2: Generate a plan to write a unit test
|
|
|
|
# Asks GPT to plan out cases the units tests should cover, formatted as a bullet list
|
|
plan_user_message = {
|
|
"role": "user",
|
|
"content": f"""You want to write tests to test the following function: {function_code}.
|
|
|
|
Imagine you have the data to use given as the following:
|
|
```python
|
|
data_to_use = fetch_data(session)
|
|
```
|
|
where fetch_data(session) returns the data to use as inputs to the function.
|
|
|
|
You should come up with a few different combinations of this data to use as inputs to the function in order to test it. These combinations should cover a wide range of inputs to constitute a good test suite. A good unit test suite should aim to:
|
|
|
|
- Test the function's behavior for a wide range of possible inputs
|
|
- Test edge cases that the author may not have foreseen
|
|
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
|
|
- Be easy to read and understand, with clean code and descriptive names
|
|
- Be deterministic, so that the tests always pass or fail in the same way
|
|
- Have tests sorted by difficulty, from easiest to hardest
|
|
|
|
To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets). Based on our function input, you should use different combinations of the data given to come up with these different scenarios.""",
|
|
}
|
|
plan_messages = [explain_system_message, explain_user_message, explain_assistant_message, plan_user_message]
|
|
if print_text:
|
|
print_messages([plan_user_message])
|
|
try:
|
|
plan_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=plan_model.name, messages=plan_messages, temperature=temperature
|
|
)
|
|
total_llm_cost += calculate_llm_cost(plan_response, plan_model) or 0.0
|
|
except Exception as e:
|
|
logging.exception("OpenAI client error in plan step")
|
|
sentry_sdk.capture_exception(e)
|
|
raise TestGenerationFailedException(e) from e
|
|
debug_log_sensitive_data(f"OpenAIClient plan response:\n{plan_response.model_dump_json(indent=2)}")
|
|
if plan_response.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-testgen-plan-openai-usage",
|
|
properties={"model": plan_model.name, "usage": plan_response.usage.json()},
|
|
)
|
|
|
|
plan = plan_response.choices[0].message.content
|
|
plan_assistant_message = {"role": "assistant", "content": plan}
|
|
|
|
# Step 2b: If the plan is short, ask GPT to elaborate further
|
|
# this counts top-level bullets (e.g., categories), but not sub-bullets (e.g., test cases)
|
|
num_bullets = max(plan.count("\n-"), plan.count("\n*"))
|
|
elaboration_needed = num_bullets < approx_min_cases_to_cover
|
|
if elaboration_needed:
|
|
elaboration_user_message = {
|
|
"role": "user",
|
|
"content": """In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).""",
|
|
}
|
|
elaboration_messages = [
|
|
explain_system_message,
|
|
explain_user_message,
|
|
explain_assistant_message,
|
|
plan_user_message,
|
|
plan_assistant_message,
|
|
elaboration_user_message,
|
|
]
|
|
if print_text:
|
|
print_messages([elaboration_user_message])
|
|
try:
|
|
elaboration_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=plan_model.name, messages=elaboration_messages, temperature=temperature
|
|
)
|
|
total_llm_cost += calculate_llm_cost(elaboration_response, plan_model) or 0.0
|
|
except Exception as e:
|
|
logging.exception("OpenAI client error in elaboration step")
|
|
sentry_sdk.capture_exception(e)
|
|
raise TestGenerationFailedException(e) from e
|
|
|
|
debug_log_sensitive_data(
|
|
f"OpenAIClient elaboration response:\n{elaboration_response.model_dump_json(indent=2)}"
|
|
)
|
|
|
|
elaboration = elaboration_response.choices[0].message.content
|
|
elaboration_assistant_message = {"role": "assistant", "content": elaboration}
|
|
|
|
# Step 3: Generate the unit test
|
|
|
|
# create a markdown-formatted prompt that asks GPT to complete a unit test
|
|
package_comment = ""
|
|
# if unit_test_package == "pytest":
|
|
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
|
|
execute_system_message = {"role": "system", "content": EXECUTE_SYSTEM_PROMPT.format(function_name=function_name)}
|
|
|
|
execute_messages = [
|
|
execute_system_message,
|
|
explain_user_message,
|
|
explain_assistant_message,
|
|
plan_user_message,
|
|
plan_assistant_message,
|
|
]
|
|
if elaboration_needed:
|
|
execute_messages += [elaboration_user_message, elaboration_assistant_message]
|
|
execute_user_message = {
|
|
"role": "user",
|
|
"content": EXECUTE_USER_PROMPT.format(
|
|
unit_test_package=unit_test_package,
|
|
function_name=function_name,
|
|
function_code=function_code,
|
|
fetch_data_function_code=fetch_data_function,
|
|
package_comment=package_comment,
|
|
),
|
|
}
|
|
execute_messages += [execute_user_message]
|
|
if print_text:
|
|
print_messages([execute_system_message, execute_user_message])
|
|
# TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach
|
|
tries = 2
|
|
while tries > 0:
|
|
try:
|
|
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
|
|
model=execute_model.name, messages=execute_messages, temperature=temperature
|
|
)
|
|
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
|
|
except Exception as e:
|
|
logging.exception("OpenAI client error in execute step")
|
|
sentry_sdk.capture_exception(e)
|
|
raise TestGenerationFailedException(e) from e
|
|
debug_log_sensitive_data(f"OpenAIClient execute response:\n{execute_response.model_dump_json(indent=2)}")
|
|
if execute_response.usage is not None:
|
|
ph(
|
|
user_id,
|
|
"aiservice-testgen-execute-openai-usage",
|
|
properties={"model": execute_model.name, "usage": execute_response.usage.json()},
|
|
)
|
|
execution_output = execute_response.choices[0].message.content
|
|
|
|
# check the output for errors
|
|
code = execution_output.split("```python")[1].split("```")[0].strip()
|
|
try:
|
|
module = ast.parse(code, feature_version=python_version[:2])
|
|
original_function = ast.parse(function_code, feature_version=python_version[:2])
|
|
if not any_ellipsis_in_ast(original_function) and ellipsis_in_ast_not_types(module):
|
|
# If the test generator is generating ellipsis, it is punting on generating
|
|
# the concrete test cases and we should re-generate
|
|
raise SyntaxError("Ellipsis in generated test code, regenerating...")
|
|
break
|
|
except SyntaxError as e:
|
|
tries -= 1
|
|
logging.warning("Syntax error in generated code. Trying again.")
|
|
logging.warning(f"Error: {e}")
|
|
logging.warning(f"Generated code: {code}")
|
|
continue
|
|
if tries == 0:
|
|
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
|
|
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
|
# return the unit test as a string
|
|
return code
|
|
|
|
|
|
class TestGenSchema(Schema):
|
|
source_code_being_tested: str
|
|
function_to_optimize: FunctionToOptimize
|
|
helper_function_names: list[str] = None # This is the only one we should use
|
|
dependent_function_names: list[str] = None # Only for backwards compatibility
|
|
module_path: str
|
|
test_module_path: str
|
|
test_framework: str
|
|
test_timeout: int
|
|
trace_id: str
|
|
python_version: str
|
|
|
|
@model_validator(mode="after")
|
|
def helper_function_names_validator(self):
|
|
# To maintain backwards compatibility
|
|
if self.dependent_function_names is None and self.helper_function_names is None:
|
|
raise ValueError("either field 'helper_function_names' or 'dependent_function_names' is required")
|
|
if self.helper_function_names is not None:
|
|
return self
|
|
print("self.dependent_function_names", self.dependent_function_names)
|
|
self.helper_function_names = self.dependent_function_names
|
|
self.dependent_function_names = None
|
|
return self
|
|
|
|
|
|
class TestGenResponseSchema(Schema):
|
|
generated_tests: str
|
|
instrumented_tests: str
|
|
|
|
|
|
class TestGenErrorResponseSchema(Schema):
|
|
error: str
|
|
|
|
|
|
import sentry_sdk
|
|
from aiservice.analytics.posthog import ph
|
|
|
|
|
|
@testgen_api.post(
|
|
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
|
|
)
|
|
async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
|
ph(request.user, "aiservice-testgen-called")
|
|
if data.test_framework not in ["unittest", "pytest"]:
|
|
return 400, TestGenErrorResponseSchema(error="Invalid test framework. We only support unittest and pytest.")
|
|
if not data.function_to_optimize:
|
|
# TODO: Add a validation check here to see if the function_name is actually present in
|
|
# the source_code_being_tested. Parse ast
|
|
return 400, TestGenErrorResponseSchema(error="Invalid function to optimize. It is empty.")
|
|
if data.source_code_being_tested == "":
|
|
return 400, TestGenErrorResponseSchema(error="Invalid source code. It is empty.")
|
|
try:
|
|
python_version: tuple[int, int, int] = parse_python_version(data.python_version)
|
|
except:
|
|
return 400, TestGenErrorResponseSchema(
|
|
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
|
|
)
|
|
print("data.helper_function_names", data.helper_function_names)
|
|
try:
|
|
ast.parse(data.source_code_being_tested, feature_version=python_version[:2])
|
|
compile(data.source_code_being_tested, "data.source_code_being_tested", "exec")
|
|
except SyntaxError:
|
|
return 400, TestGenErrorResponseSchema(
|
|
error="Invalid source code. It is not valid Python code. Please check syntax of your code."
|
|
)
|
|
try:
|
|
print("/testgen: Generating tests...")
|
|
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
|
|
debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}")
|
|
generated_test_source = await generate_regression_tests_from_function(
|
|
user_id=request.user,
|
|
function_code=data.source_code_being_tested,
|
|
function_name=data.function_to_optimize.function_name,
|
|
unit_test_package=data.test_framework,
|
|
approx_min_cases_to_cover=10,
|
|
python_version=python_version,
|
|
trace_id=data.trace_id,
|
|
)
|
|
print("/testgen: Instrumenting tests...")
|
|
instrumented_test_source = safe_isort(
|
|
instrument_test_source(
|
|
test_source=generated_test_source,
|
|
function_to_optimize=data.function_to_optimize,
|
|
helper_function_names=data.helper_function_names,
|
|
module_path=data.module_path,
|
|
test_module_path=data.test_module_path,
|
|
test_framework=data.test_framework,
|
|
test_timeout=data.test_timeout,
|
|
python_version=python_version,
|
|
),
|
|
float_to_top=True,
|
|
)
|
|
generated_test_source = replace_definition_with_import(
|
|
generated_test_source, data.function_to_optimize, data.module_path
|
|
)
|
|
# Use isort to sort and deduplicate the imports in the generated test code
|
|
sorted_imports_test_source = safe_isort(generated_test_source)
|
|
try:
|
|
parse_module_to_cst(sorted_imports_test_source)
|
|
except Exception as e:
|
|
logging.exception("Failed to parse generated test code")
|
|
sentry_sdk.capture_exception(e)
|
|
else:
|
|
generated_test_source = sorted_imports_test_source
|
|
|
|
ph(request.user, "aiservice-testgen-tests-generated")
|
|
except TestGenerationFailedException as e:
|
|
logging.exception("Test generation failed. Skipping test generation.")
|
|
sentry_sdk.capture_exception(e)
|
|
return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.")
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
generated_tests=[generated_test_source],
|
|
test_framework=data.test_framework,
|
|
metadata={
|
|
"test_timeout": data.test_timeout,
|
|
"function_to_optimize": data.function_to_optimize.function_name,
|
|
},
|
|
)
|
|
return 200, TestGenResponseSchema(
|
|
generated_tests=generated_test_source, instrumented_tests=instrumented_test_source
|
|
)
|