codeflash-internal/django/aiservice/testgen/sqlalchemy/sqlalchemy_testgen.py

473 lines
21 KiB
Python
Raw Normal View History

2024-06-27 22:37:14 +00:00
# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
from __future__ import annotations
import ast
import logging
import os
from pathlib import Path
from typing import SupportsIndex
import isort
from aiservice.common_utils import parse_python_version
from aiservice.env_specific import create_openai_client, debug_log_sensitive_data
from aiservice.models.aimodels import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL, calculate_llm_cost
2024-06-27 22:37:14 +00:00
from aiservice.models.functions_to_optimize import FunctionToOptimize
from log_features.log_event import update_optimization_cost
2024-06-27 22:37:14 +00:00
from log_features.log_features import log_features
from ninja import NinjaAPI, Schema
from pydantic import model_validator
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
2024-06-27 22:37:14 +00:00
testgen_api = NinjaAPI(urls_namespace="testgen")
2024-06-27 22:37:14 +00:00
openai_client = create_openai_client()
# Get the directory of the current file
current_dir = Path(__file__).parent
EXPLAIN_SYSTEM_PROMPT = (current_dir / "sqlalchemy_explain_system_prompt.md").read_text()
EXPLAIN_USER_PROMPT = (current_dir / "sqlalchemy_explain_user_prompt.md").read_text()
EXECUTE_SYSTEM_PROMPT = (current_dir / "sqlalchemy_execute_system_prompt.md").read_text()
EXECUTE_USER_PROMPT = (current_dir / "sqlalchemy_execute_user_prompt.md").read_text()
FETCH_DATA_SYSTEM_PROMPT = (current_dir / "sqlalchemy_fetch_data_system_prompt.md").read_text()
FETCH_DATA_USER_PROMPT = (current_dir / "sqlalchemy_fetch_data_user_prompt.md").read_text()
class TestGenerationFailedException(Exception):
pass
color_prefix_by_role = {
"system": "\033[0m", # gray
"user": "\033[0m", # gray
"assistant": "\033[92m", # green
}
def ellipsis_in_ast_not_types(module: ast.AST) -> bool:
# Add parent attribute to nodes for easier traversal
for node in ast.walk(module):
for child in ast.iter_child_nodes(node):
child.parent = node
for node in ast.walk(module):
if isinstance(node, ast.Constant) and node.value is Ellipsis:
# Check if the ellipsis is part of a type annotation
if isinstance(node.parent, (ast.Subscript, ast.Index, ast.Tuple)):
continue
return True
return False
def any_ellipsis_in_ast(module: ast.AST) -> bool:
return any(isinstance(node, ast.Constant) and node.value == ... for node in ast.walk(module))
def print_messages(
messages: dict[SupportsIndex | slice, str],
color_prefix_by_role: dict[SupportsIndex | slice, str] = color_prefix_by_role,
) -> None:
"""Prints messages sent to or from GPT."""
message: str
for message in messages:
role: str = message["role"]
color_prefix: str = color_prefix_by_role[role]
content: str = message["content"]
print(f"{color_prefix}\n[{role}]\n{content}")
def print_message_delta(delta, color_prefix_by_role=color_prefix_by_role) -> None:
"""Prints a chunk of messages streamed back from GPT."""
if "role" in delta:
role = delta["role"]
color_prefix = color_prefix_by_role[role]
print(f"{color_prefix}\n[{role}]\n", end="")
elif "content" in delta:
content = delta["content"]
print(content, end="")
else:
pass
def write_fetch_data_function(fetch_data_function, connection_string, function_code):
fetch_data_code = f"""from sqlalchemy import Boolean, Column, ForeignKey, Integer, Text
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker
from sqlalchemy.orm.relationships import Relationship
POSTGRES_CONNECTION_STRING = "{connection_string}"
catalog_engine: Engine = create_engine(POSTGRES_CONNECTION_STRING, echo=True)
session: Session = sessionmaker(bind=catalog_engine)()
class Base(DeclarativeBase):
pass
{function_code}
{fetch_data_function}
data = fetch_data(session)
"""
return fetch_data_code
async def generate_regression_tests_from_function(
user_id: str,
function_code: str, # Python function to test, as a string
function_name: str, # the function to test
python_version: tuple[int, int, int], # Python version to use for ast parsing
unit_test_package: str = "pytest", # unit testing package; use the name as it appears in the import statement
approx_min_cases_to_cover: int = 7, # minimum number of test case categories to cover (approximate)
print_text: bool = os.environ.get("ENVIRONMENT") != "PRODUCTION",
# optionally prints text; helpful for understanding the function & debugging
explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
trace_id: str = "",
2024-06-27 22:37:14 +00:00
) -> str:
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
# Step 1: Generate an explanation of the function
# create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list
explain_system_message = {"role": "system", "content": EXPLAIN_SYSTEM_PROMPT}
explain_user_message = {
"role": "user",
"content": EXPLAIN_USER_PROMPT.format(function_name=function_name, function_code=function_code),
2024-06-27 22:37:14 +00:00
}
explain_messages = [explain_system_message, explain_user_message]
total_llm_cost = 0.0
2024-06-27 22:37:14 +00:00
if print_text:
print_messages(explain_messages)
try:
explanation_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=explain_model.name, messages=explain_messages, temperature=temperature
2024-06-27 22:37:14 +00:00
)
total_llm_cost += calculate_llm_cost(explanation_response, explain_model) or 0.0
2024-06-27 22:37:14 +00:00
except Exception as e:
logging.exception("OpenAI client error in explain step")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
raise TestGenerationFailedException(e) from e
debug_log_sensitive_data(f"OpenAIClient explanation response:\n{explanation_response.model_dump_json(indent=2)}")
2024-06-27 22:37:14 +00:00
if explanation_response.usage is not None:
ph(
user_id,
"aiservice-testgen-explain-openai-usage",
properties={"model": explain_model.name, "usage": explanation_response.usage.json()},
2024-06-27 22:37:14 +00:00
)
explanation = explanation_response.choices[0].message.content
explain_assistant_message = {"role": "assistant", "content": explanation}
# Step 1b: Fetch relevant data from the database to use as inputs based on function explanation
fetch_data_user_message = {"role": "user", "content": FETCH_DATA_USER_PROMPT.format(orm_code=function_code)}
2024-06-27 22:37:14 +00:00
fetch_data_system_message = {"role": "system", "content": FETCH_DATA_SYSTEM_PROMPT.format(orm_code=function_code)}
2024-06-27 22:37:14 +00:00
fetch_data_messages = [fetch_data_system_message, fetch_data_user_message]
if print_text:
print_messages(explain_messages)
try:
fetch_data_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=execute_model.name, messages=fetch_data_messages, temperature=temperature
2024-06-27 22:37:14 +00:00
)
total_llm_cost += calculate_llm_cost(fetch_data_response, execute_model) or 0.0
2024-06-27 22:37:14 +00:00
except Exception as e:
logging.exception("OpenAI client error in explain step")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
raise TestGenerationFailedException(e) from e
fetch_data_function = fetch_data_response.choices[0].message.content
fetch_data_function = fetch_data_function.split("```python")[1].split("```")[0].strip()
# Step 1c: Run the function to get the data
# Put function in a file and run it
connection_string = os.environ.get("POSTGRES_CONNECTION_STRING")
if connection_string is None:
raise ValueError("POSTGRES_CONNECTION_STRING environment variable is not set")
# Step 2: Generate a plan to write a unit test
# Asks GPT to plan out cases the units tests should cover, formatted as a bullet list
plan_user_message = {
"role": "user",
"content": f"""You want to write tests to test the following function: {function_code}.
Imagine you have the data to use given as the following:
```python
data_to_use = fetch_data(session)
```
where fetch_data(session) returns the data to use as inputs to the function.
You should come up with a few different combinations of this data to use as inputs to the function in order to test it. These combinations should cover a wide range of inputs to constitute a good test suite. A good unit test suite should aim to:
- Test the function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
- Be easy to read and understand, with clean code and descriptive names
- Be deterministic, so that the tests always pass or fail in the same way
- Have tests sorted by difficulty, from easiest to hardest
To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets). Based on our function input, you should use different combinations of the data given to come up with these different scenarios.""",
}
plan_messages = [explain_system_message, explain_user_message, explain_assistant_message, plan_user_message]
2024-06-27 22:37:14 +00:00
if print_text:
print_messages([plan_user_message])
try:
plan_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=plan_model.name, messages=plan_messages, temperature=temperature
2024-06-27 22:37:14 +00:00
)
total_llm_cost += calculate_llm_cost(plan_response, plan_model) or 0.0
2024-06-27 22:37:14 +00:00
except Exception as e:
logging.exception("OpenAI client error in plan step")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
raise TestGenerationFailedException(e) from e
debug_log_sensitive_data(f"OpenAIClient plan response:\n{plan_response.model_dump_json(indent=2)}")
2024-06-27 22:37:14 +00:00
if plan_response.usage is not None:
ph(
user_id,
"aiservice-testgen-plan-openai-usage",
properties={"model": plan_model.name, "usage": plan_response.usage.json()},
2024-06-27 22:37:14 +00:00
)
plan = plan_response.choices[0].message.content
plan_assistant_message = {"role": "assistant", "content": plan}
# Step 2b: If the plan is short, ask GPT to elaborate further
# this counts top-level bullets (e.g., categories), but not sub-bullets (e.g., test cases)
num_bullets = max(plan.count("\n-"), plan.count("\n*"))
elaboration_needed = num_bullets < approx_min_cases_to_cover
if elaboration_needed:
elaboration_user_message = {
"role": "user",
"content": """In addition to those scenarios above, list a few rare or unexpected edge cases (and as before, under each edge case, include a few examples as sub-bullets).""",
}
elaboration_messages = [
explain_system_message,
explain_user_message,
explain_assistant_message,
plan_user_message,
plan_assistant_message,
elaboration_user_message,
]
if print_text:
print_messages([elaboration_user_message])
try:
elaboration_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=plan_model.name, messages=elaboration_messages, temperature=temperature
2024-06-27 22:37:14 +00:00
)
total_llm_cost += calculate_llm_cost(elaboration_response, plan_model) or 0.0
2024-06-27 22:37:14 +00:00
except Exception as e:
logging.exception("OpenAI client error in elaboration step")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
raise TestGenerationFailedException(e) from e
debug_log_sensitive_data(
f"OpenAIClient elaboration response:\n{elaboration_response.model_dump_json(indent=2)}"
2024-06-27 22:37:14 +00:00
)
elaboration = elaboration_response.choices[0].message.content
elaboration_assistant_message = {"role": "assistant", "content": elaboration}
# Step 3: Generate the unit test
# create a markdown-formatted prompt that asks GPT to complete a unit test
package_comment = ""
# if unit_test_package == "pytest":
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
execute_system_message = {"role": "system", "content": EXECUTE_SYSTEM_PROMPT.format(function_name=function_name)}
2024-06-27 22:37:14 +00:00
execute_messages = [
execute_system_message,
explain_user_message,
explain_assistant_message,
plan_user_message,
plan_assistant_message,
]
if elaboration_needed:
execute_messages += [elaboration_user_message, elaboration_assistant_message]
execute_user_message = {
"role": "user",
"content": EXECUTE_USER_PROMPT.format(
unit_test_package=unit_test_package,
function_name=function_name,
function_code=function_code,
fetch_data_function_code=fetch_data_function,
package_comment=package_comment,
),
}
execute_messages += [execute_user_message]
if print_text:
print_messages([execute_system_message, execute_user_message])
# TODO: Implement a fallback if the code is too long, implement a straightforward way to write the tests rather than the iterative approach
tries = 2
while tries > 0:
try:
execute_response = await openai_client.with_options(max_retries=2).chat.completions.create(
model=execute_model.name, messages=execute_messages, temperature=temperature
2024-06-27 22:37:14 +00:00
)
total_llm_cost += calculate_llm_cost(execute_response, execute_model) or 0.0
2024-06-27 22:37:14 +00:00
except Exception as e:
logging.exception("OpenAI client error in execute step")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
raise TestGenerationFailedException(e) from e
debug_log_sensitive_data(f"OpenAIClient execute response:\n{execute_response.model_dump_json(indent=2)}")
2024-06-27 22:37:14 +00:00
if execute_response.usage is not None:
ph(
user_id,
"aiservice-testgen-execute-openai-usage",
properties={"model": execute_model.name, "usage": execute_response.usage.json()},
2024-06-27 22:37:14 +00:00
)
execution_output = execute_response.choices[0].message.content
# check the output for errors
code = execution_output.split("```python")[1].split("```")[0].strip()
try:
module = ast.parse(code, feature_version=python_version[:2])
original_function = ast.parse(function_code, feature_version=python_version[:2])
if not any_ellipsis_in_ast(original_function) and ellipsis_in_ast_not_types(module):
# If the test generator is generating ellipsis, it is punting on generating
# the concrete test cases and we should re-generate
raise SyntaxError("Ellipsis in generated test code, regenerating...")
break
except SyntaxError as e:
tries -= 1
logging.warning("Syntax error in generated code. Trying again.")
logging.warning(f"Error: {e}")
logging.warning(f"Generated code: {code}")
continue
if tries == 0:
raise TestGenerationFailedException("Failed to generate test code after 2 tries.")
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
2024-06-27 22:37:14 +00:00
# return the unit test as a string
return code
class TestGenSchema(Schema):
source_code_being_tested: str
function_to_optimize: FunctionToOptimize
helper_function_names: list[str] = None # This is the only one we should use
dependent_function_names: list[str] = None # Only for backwards compatibility
module_path: str
test_module_path: str
test_framework: str
test_timeout: int
trace_id: str
python_version: str
@model_validator(mode="after")
def helper_function_names_validator(self):
# To maintain backwards compatibility
if self.dependent_function_names is None and self.helper_function_names is None:
raise ValueError("either field 'helper_function_names' or 'dependent_function_names' is required")
if self.helper_function_names is not None:
return self
print("self.dependent_function_names", self.dependent_function_names)
self.helper_function_names = self.dependent_function_names
self.dependent_function_names = None
return self
class TestGenResponseSchema(Schema):
generated_tests: str
instrumented_tests: str
class TestGenErrorResponseSchema(Schema):
error: str
from aiservice.analytics.posthog import ph
import sentry_sdk
2024-06-27 22:37:14 +00:00
@testgen_api.post(
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
2024-06-27 22:37:14 +00:00
)
async def testgen(request, data: TestGenSchema) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
2024-06-27 22:37:14 +00:00
ph(request.user, "aiservice-testgen-called")
if data.test_framework not in ["unittest", "pytest"]:
return 400, TestGenErrorResponseSchema(error="Invalid test framework. We only support unittest and pytest.")
2024-06-27 22:37:14 +00:00
if not data.function_to_optimize:
# TODO: Add a validation check here to see if the function_name is actually present in
# the source_code_being_tested. Parse ast
return 400, TestGenErrorResponseSchema(error="Invalid function to optimize. It is empty.")
if data.source_code_being_tested == "":
return 400, TestGenErrorResponseSchema(error="Invalid source code. It is empty.")
try:
python_version: tuple[int, int, int] = parse_python_version(data.python_version)
except:
return 400, TestGenErrorResponseSchema(
error="Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
2024-06-27 22:37:14 +00:00
)
print("data.helper_function_names", data.helper_function_names)
try:
ast.parse(data.source_code_being_tested, feature_version=python_version[:2])
2024-06-27 22:37:14 +00:00
compile(data.source_code_being_tested, "data.source_code_being_tested", "exec")
except SyntaxError:
return 400, TestGenErrorResponseSchema(
error="Invalid source code. It is not valid Python code. Please check syntax of your code."
2024-06-27 22:37:14 +00:00
)
try:
print("/testgen: Generating tests...")
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
2024-06-27 22:37:14 +00:00
debug_log_sensitive_data(f"Source code being tested: {data.source_code_being_tested}")
generated_test_source = await generate_regression_tests_from_function(
user_id=request.user,
function_code=data.source_code_being_tested,
function_name=data.function_to_optimize.function_name,
unit_test_package=data.test_framework,
approx_min_cases_to_cover=10,
python_version=python_version,
trace_id=data.trace_id,
2024-06-27 22:37:14 +00:00
)
print("/testgen: Instrumenting tests...")
instrumented_test_source = isort.code(
instrument_test_source(
test_source=generated_test_source,
function_to_optimize=data.function_to_optimize,
helper_function_names=data.helper_function_names,
module_path=data.module_path,
test_module_path=data.test_module_path,
test_framework=data.test_framework,
test_timeout=data.test_timeout,
python_version=python_version,
),
float_to_top=True,
)
generated_test_source = replace_definition_with_import(
generated_test_source, data.function_to_optimize, data.module_path
2024-06-27 22:37:14 +00:00
)
# Use isort to sort and deduplicate the imports in the generated test code
sorted_imports_test_source = isort.code(generated_test_source)
try:
parse_module_to_cst(sorted_imports_test_source)
except Exception as e:
logging.exception("Failed to parse generated test code")
sentry_sdk.capture_exception(e)
2024-06-27 22:37:14 +00:00
else:
generated_test_source = sorted_imports_test_source
ph(request.user, "aiservice-testgen-tests-generated")
except TestGenerationFailedException as e:
logging.exception("Test generation failed. Skipping test generation.")
sentry_sdk.capture_exception(e)
return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.")
if hasattr(request, "should_log_features") and request.should_log_features:
2024-06-27 22:37:14 +00:00
await log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[generated_test_source],
test_framework=data.test_framework,
metadata={
"test_timeout": data.test_timeout,
"function_to_optimize": data.function_to_optimize.function_name,
},
)
return 200, TestGenResponseSchema(
generated_tests=generated_test_source, instrumented_tests=instrumented_test_source
2024-06-27 22:37:14 +00:00
)