codeflash-internal/experiments/gen_inspired_tests.py
2025-12-30 15:22:39 -05:00

215 lines
9.9 KiB
Python

# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
# imports needed to run the code in this notebook
# TODO: This is only here as a temporary reference implementaion of how an early version of LLM inspired tests was written.
# It didn't work very well. This should be improved significantly.
import ast # used for detecting whether generated Python code is valid
import platform
import openai # used for calling the OpenAI API
from codeflash.code_utils.code_extractor import get_code
from codeflash.code_utils.code_utils import ellipsis_in_ast, get_imports_from_file
from codeflash.models.models import TestsInFile
from codeflash.verification.gen_regression_tests import print_message_delta, print_messages
from aiservice.llm import EXECUTE_MODEL, EXPLAIN_MODEL, LLM, PLAN_MODEL
def regression_tests_from_function_with_inspiration(
function_to_test: str, # Python function to test, as a string
function_name: str, # name of the function to test
existing_unit_test_path_and_function: list[
tuple[str, str]
], # path to existing unit test file and function name that is testing the function
unit_test_package: str = "pytest", # unit testing package; use the name as it appears in the import statement
approx_min_cases_to_cover: int = 7, # minimum number of test case categories to cover (approximate)
print_text: bool = False, # optionally prints text; helpful for understanding the function & debugging
explain_model: LLM = EXPLAIN_MODEL, # model used to generate text plans in step 1
plan_model: LLM = PLAN_MODEL, # model used to generate text plans in steps 2 and 2b
execute_model: LLM = EXECUTE_MODEL, # model used to generate code in step 3
temperature: float = 0.4, # temperature = 0 can sometimes get stuck in repetitive loops, so we use 0.4
reruns_if_fail: int = 1, # if the output code cannot be parsed, this will re-run the function up to N times
python_version: tuple[int, int, int] = tuple([int(ver) for ver in platform.python_version_tuple()]),
) -> str:
"""Returns a unit test for a given Python function, using a 3-step GPT prompt."""
# TODO: This step is exactly the same as the non-inspired test generator. Merge them into one to save on API calls
# Step 1: Generate an explanation of the function
# create a markdown-formatted message that asks GPT to explain the function, formatted as a bullet list
explain_system_message = {
"role": "system",
"content": "You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.",
}
explain_user_message = {
"role": "user",
"content": f"""Please explain the following Python function '{function_name}'. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list.
```python
{function_to_test}
```""",
}
explain_messages = [explain_system_message, explain_user_message]
if print_text:
print_messages(explain_messages)
try:
explanation_response = openai.ChatCompletion.create(
model=explain_model.name, messages=explain_messages, temperature=temperature, stream=True
)
except Exception as e:
print(str(e))
if reruns_if_fail > 0:
print("Rerunning...")
return regression_tests_from_function_with_inspiration(
function_to_test=function_to_test,
function_name=function_name,
existing_unit_test_path_and_function=existing_unit_test_path_and_function,
unit_test_package=unit_test_package,
approx_min_cases_to_cover=approx_min_cases_to_cover,
print_text=print_text,
explain_model=EXPLAIN_MODEL,
plan_model=PLAN_MODEL,
execute_model=EXECUTE_MODEL,
temperature=temperature,
reruns_if_fail=reruns_if_fail - 1, # decrement rerun counter when calling again
)
explanation = ""
for chunk in explanation_response:
delta = chunk["choices"][0]["delta"]
if print_text:
print_message_delta(delta)
if "content" in delta:
explanation += delta["content"]
explain_assistant_message = {"role": "assistant", "content": explanation}
# Step 2: Generate a plan to write a unit test
# Asks GPT to plan out cases the units tests should cover, formatted as a bullet list
existing_unit_tests = []
max_inspiration_tests = 3
tests_in_file: TestsInFile
inspired_test_imports = []
test_files = set()
for i, tests_in_file in enumerate(existing_unit_test_path_and_function):
path = tests_in_file.test_file
test_files.add(path)
function = tests_in_file.test_function
if i >= max_inspiration_tests:
break
tester_code = get_code(path, function)
code = f"""```python
{tester_code}
```"""
existing_unit_tests.append(code)
for path in test_files:
imports_ast = get_imports_from_file(file_path=path)
inspired_test_imports.append(imports_ast)
package_comment = ""
execute_user_message = {
"role": "user",
"content": f"""A good unit test suite should aim to:
- Test the function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
- Be easy to read and understand, with clean code and descriptive names
- Be deterministic, so that the tests always pass or fail in the same way
To test the function above, here {"are a few unit tests" if len(existing_unit_tests) > 1 else "is a unit test"} that already exist:
{f"{chr(10)}".join(existing_unit_tests)}
Using Python and the `{unit_test_package}` package, and taking inspiration from the existing test cases above, write a new suite of unit tests for the function '{function_name}'. Include helpful comments to explain each line. Generate concrete runnable test without using ellipsis. Reply only with code, formatted as follows:
```python
# imports
import {unit_test_package} # used for our unit tests
{{insert other imports as needed}}
# function to test
{function_to_test}
# unit tests
{package_comment}
{{insert unit test code here}}
```
""",
}
plan_messages = [explain_system_message, explain_user_message, explain_assistant_message, execute_user_message]
if print_text:
print_messages([execute_user_message])
try:
execute_response = openai.ChatCompletion.create(
model=plan_model.name, messages=plan_messages, temperature=temperature, stream=True
)
except Exception as e:
print(str(e))
if reruns_if_fail > 0:
print("Rerunning...")
return regression_tests_from_function_with_inspiration(
function_to_test=function_to_test,
function_name=function_name,
existing_unit_test_path_and_function=existing_unit_test_path_and_function,
unit_test_package=unit_test_package,
approx_min_cases_to_cover=approx_min_cases_to_cover,
print_text=print_text,
explain_model=EXPLAIN_MODEL,
plan_model=PLAN_MODEL,
execute_model=EXECUTE_MODEL,
temperature=temperature,
reruns_if_fail=reruns_if_fail - 1, # decrement rerun counter when calling again
)
execution = ""
for chunk in execute_response:
delta = chunk["choices"][0]["delta"]
if print_text:
print_message_delta(delta)
if "content" in delta:
execution += delta["content"]
# check the output for errors
code = execution.split("```python")[1].split("```")[0].strip()
# TODO: This adds a bunch of redundant imports, clean them up
tests_list = [imp for sublist in inspired_test_imports for imp in sublist]
code = ast.unparse(tests_list) + "\n" + code
try:
module = ast.parse(code, feature_version=python_version[:2])
if ellipsis_in_ast(module):
# If the test generator is generating ellipsis, it is punting on generating
# the concrete test cases and we should re-generate
raise SyntaxError("Ellipsis in generated test code, regenerating...")
except SyntaxError:
print("Syntax error in generated code.")
if reruns_if_fail > 0:
print("Rerunning...")
return regression_tests_from_function_with_inspiration(
function_to_test=function_to_test,
function_name=function_name,
existing_unit_test_path_and_function=existing_unit_test_path_and_function,
unit_test_package=unit_test_package,
approx_min_cases_to_cover=approx_min_cases_to_cover,
print_text=print_text,
explain_model=EXPLAIN_MODEL,
plan_model=PLAN_MODEL,
execute_model=EXECUTE_MODEL,
temperature=temperature,
reruns_if_fail=reruns_if_fail - 1, # decrement rerun counter when calling again
)
# return the unit test as a string
return code
class ModifyInspiredTests(ast.NodeTransformer):
def __init__(self, import_list, test_framework):
self.import_list = import_list
self.test_framework = test_framework
def visit_Import(self, node: ast.Import):
self.import_list.append(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
self.import_list.append(node)
def visit_ClassDef(self, node: ast.ClassDef):
# No unittest-specific transformations needed since we only support pytest
return node