mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
Signed-off-by: Saurabh Misra <misra.saurabh1@gmail.com> Co-authored-by: saga4 <saga4@codeflashs-MacBook-Air.local> Co-authored-by: Sarthak Agarwal <sarthak.saga@gmail.com> Co-authored-by: Mohamed Ashraf <mohamedashrraf222@gmail.com> Co-authored-by: Aseem Saxena <aseem.bits@gmail.com>
402 lines
26 KiB
Python
402 lines
26 KiB
Python
# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, SupportsIndex
|
|
|
|
import sentry_sdk
|
|
import stamina
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import parse_python_version, safe_isort, should_hack_for_demo, validate_trace_id
|
|
from aiservice.env_specific import IS_PRODUCTION, create_llm_client, debug_log_sensitive_data, llm_clients
|
|
from aiservice.models.aimodels import EXECUTE_MODEL, calculate_llm_cost
|
|
from log_features.log_event import update_optimization_cost
|
|
from log_features.log_features import log_features
|
|
from ninja import NinjaAPI
|
|
from ninja.errors import HttpError
|
|
from openai import OpenAIError
|
|
|
|
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
|
|
from testgen.instrumentation.instrument_new_tests import instrument_test_source
|
|
from testgen.models import (
|
|
TestGenerationFailedError,
|
|
TestGenErrorResponseSchema,
|
|
TestGenResponseSchema,
|
|
TestGenSchema, # noqa: TC001
|
|
TestingMode,
|
|
)
|
|
from testgen.postprocessing.code_validator import has_test_functions, validate_testgen_code
|
|
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
|
|
from testgen.testgen_context import BaseTestGenContext, TestGenContextData
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.models.aimodels import LLM
|
|
from authapp.auth import AuthBearer
|
|
|
|
testgen_api = NinjaAPI(urls_namespace="testgen")
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
EXPLAIN_SYSTEM_PROMPT = (current_dir / "explain_system_prompt.md").read_text()
|
|
EXPLAIN_USER_PROMPT = (current_dir / "explain_user_prompt.md").read_text()
|
|
|
|
EXECUTE_SYSTEM_PROMPT = (current_dir / "execute_system_prompt.md").read_text()
|
|
EXECUTE_USER_PROMPT = (current_dir / "execute_user_prompt.md").read_text()
|
|
|
|
EXECUTE_ASYNC_SYSTEM_PROMPT = (current_dir / "execute_async_system_prompt.md").read_text()
|
|
EXECUTE_ASYNC_USER_PROMPT = (current_dir / "execute_async_user_prompt.md").read_text()
|
|
pattern = re.compile(r"^```python\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
|
|
|
|
|
|
color_prefix_by_role = {
|
|
"system": "\033[0m", # gray
|
|
"user": "\033[0m", # gray
|
|
"assistant": "\033[92m", # green
|
|
}
|
|
|
|
|
|
def print_messages(
|
|
messages: dict[SupportsIndex | slice, str],
|
|
color_prefix_by_role: dict[SupportsIndex | slice, str] = color_prefix_by_role,
|
|
) -> None:
|
|
"""Print messages sent to or from GPT."""
|
|
message: str
|
|
for message in messages:
|
|
role: str = message["role"]
|
|
color_prefix: str = color_prefix_by_role[role]
|
|
content: str = message["content"]
|
|
print(f"{color_prefix}\n[{role}]\n{content}")
|
|
|
|
|
|
def build_prompt(
|
|
ctx: BaseTestGenContext, function_name: str, unit_test_package: str, *, is_async: bool
|
|
) -> tuple[list[dict[str, str]], str, str]:
|
|
if is_async:
|
|
execute_system_prompt = EXECUTE_ASYNC_SYSTEM_PROMPT
|
|
execute_user_prompt = EXECUTE_ASYNC_USER_PROMPT
|
|
plan_content = f"""A good unit test suite for an ASYNC function should aim to:
|
|
- Test the async function's behavior for a wide range of possible inputs
|
|
- Test edge cases that the author may not have foreseen, including async-specific edge cases
|
|
- Take advantage of the features of `{unit_test_package}` to make async tests easy to write and maintain
|
|
- Be easy to read and understand, with clean async code and descriptive names
|
|
- Be deterministic, so that the async tests always pass or fail in the same way
|
|
- Have tests sorted by difficulty, from easiest to hardest
|
|
- Should try not to mock or stub any dependencies, so that the async testing environment is as close to production
|
|
- Include concurrent execution test cases to assess the function's async performance and behavior
|
|
- Include throughput test cases to measure the function's performance under load and high-volume scenarios
|
|
- Test proper async/await patterns and coroutine handling
|
|
|
|
To help unit test the ASYNC function above, list diverse scenarios that the async function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
|
|
posthog_event_suffix = "async-"
|
|
error_context = "async "
|
|
else:
|
|
execute_system_prompt = EXECUTE_SYSTEM_PROMPT
|
|
execute_user_prompt = EXECUTE_USER_PROMPT
|
|
plan_content = f"""A good unit test suite should aim to:
|
|
- Test the function's behavior for a wide range of possible inputs
|
|
- Test edge cases that the author may not have foreseen
|
|
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
|
|
- Be easy to read and understand, with clean code and descriptive names
|
|
- Be deterministic, so that the tests always pass or fail in the same way
|
|
- Have tests sorted by difficulty, from easiest to hardest
|
|
- Should try not to mock or stub any dependencies by using `{unit_test_package}`.mock or any other similar mocking or stubbing module, so that the testing environment is as close to the production environment as possible
|
|
- Include Large Scale Test Cases to assess the function's performance and scalability with large data samples.
|
|
|
|
To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
|
|
posthog_event_suffix = ""
|
|
error_context = ""
|
|
|
|
plan_user_message = {"role": "user", "content": plan_content}
|
|
package_comment = ""
|
|
# if unit_test_package == "pytest":
|
|
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
|
|
execute_system_message = {
|
|
"role": "system",
|
|
"content": execute_system_prompt.format(function_name=ctx.data.qualified_name),
|
|
}
|
|
|
|
execute_messages = [execute_system_message, plan_user_message]
|
|
|
|
all_notes = ctx.generate_notes_markdown()
|
|
note_message = {"role": "user", "content": all_notes}
|
|
|
|
execute_messages += [note_message]
|
|
|
|
execute_user_message = {
|
|
"role": "user",
|
|
"content": execute_user_prompt.format(
|
|
unit_test_package=unit_test_package,
|
|
function_name=function_name,
|
|
function_code=ctx.data.source_code_being_tested,
|
|
package_comment=package_comment,
|
|
),
|
|
}
|
|
|
|
execute_messages += [execute_user_message]
|
|
if IS_PRODUCTION is False:
|
|
print_messages(execute_messages)
|
|
|
|
return execute_messages, posthog_event_suffix, error_context
|
|
|
|
|
|
def instrument_tests(
|
|
generated_test_source: str, data: TestGenSchema, python_version: tuple[int, int, int]
|
|
) -> tuple[str | None, str | None]:
|
|
common_args = {
|
|
"test_source": generated_test_source,
|
|
"function_to_optimize": data.function_to_optimize,
|
|
"helper_function_names": data.helper_function_names,
|
|
"module_path": data.module_path,
|
|
"test_module_path": data.test_module_path,
|
|
"test_framework": data.test_framework,
|
|
"test_timeout": data.test_timeout,
|
|
"python_version": python_version,
|
|
}
|
|
|
|
# instrument_test_source() already applies isort via format_and_float_to_top()
|
|
# No need to apply isort again here (was causing double formatting overhead)
|
|
behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR)
|
|
perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE)
|
|
|
|
return behavior_result, perf_result
|
|
|
|
|
|
def parse_and_validate_llm_output(
|
|
response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str
|
|
) -> str:
|
|
try:
|
|
if "```python" not in response_content:
|
|
sentry_sdk.capture_message("LLM response did not contain a Python code block:\n" + response_content)
|
|
raise ValueError("LLM response did not contain a Python code block.")
|
|
pattern_res = pattern.search(response_content)
|
|
if not pattern_res:
|
|
raise ValueError("No Python code block found in the LLM response.")
|
|
|
|
code = pattern_res.group(1)
|
|
cleaned_code = validate_testgen_code(code, python_version[:2], max_lines_to_remove=120)
|
|
|
|
if ctx.did_generate_ellipsis(cleaned_code, python_version):
|
|
msg = f"Ellipsis in generated {error_context}test code, regenerating..."
|
|
raise SyntaxError(msg)
|
|
|
|
return cleaned_code # noqa: TRY300
|
|
except Exception as e:
|
|
sentry_sdk.capture_exception(e)
|
|
raise
|
|
|
|
|
|
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
|
async def generate_and_validate_test_code(
|
|
messages: list[dict[str, str]],
|
|
model: LLM,
|
|
temperature: float,
|
|
ctx: BaseTestGenContext,
|
|
python_version: tuple[int, int, int],
|
|
error_context: str,
|
|
execute_model: LLM,
|
|
cost_tracker: list[float],
|
|
user_id: str,
|
|
posthog_event_suffix: str,
|
|
) -> str:
|
|
llm_client = llm_clients[execute_model.model_type]
|
|
response = await llm_client.with_options(max_retries=2).chat.completions.create(
|
|
model=model.name, messages=messages, temperature=temperature
|
|
)
|
|
cost = calculate_llm_cost(response, execute_model) or 0.0
|
|
cost_tracker.append(cost)
|
|
|
|
debug_log_sensitive_data(f"OpenAIClient {error_context}execute response:\n{response.model_dump_json(indent=2)}")
|
|
|
|
if response.usage:
|
|
ph(
|
|
user_id,
|
|
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
|
|
properties={"model": execute_model.name, "usage": response.usage.model_dump_json()},
|
|
)
|
|
|
|
return parse_and_validate_llm_output(
|
|
response_content=response.choices[0].message.content,
|
|
ctx=ctx,
|
|
python_version=python_version,
|
|
error_context=error_context,
|
|
)
|
|
|
|
|
|
@stamina.retry(on=TestGenerationFailedError, attempts=2)
|
|
async def generate_regression_tests_from_function(
|
|
ctx: BaseTestGenContext,
|
|
user_id: str,
|
|
function_name: str,
|
|
python_version: tuple[int, int, int],
|
|
data: TestGenSchema,
|
|
unit_test_package: str = "pytest",
|
|
execute_model: LLM = EXECUTE_MODEL,
|
|
temperature: float = 0.4,
|
|
is_async: bool = False, # noqa: FBT001, FBT002
|
|
trace_id: str = "",
|
|
) -> tuple[str, str | None, str | None]:
|
|
execute_messages, posthog_event_suffix, error_context = build_prompt(
|
|
ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async
|
|
)
|
|
|
|
cost_tracker = []
|
|
try:
|
|
validated_code = await generate_and_validate_test_code(
|
|
messages=execute_messages,
|
|
model=execute_model,
|
|
temperature=temperature,
|
|
ctx=ctx,
|
|
python_version=python_version,
|
|
error_context=error_context,
|
|
execute_model=execute_model,
|
|
cost_tracker=cost_tracker,
|
|
user_id=user_id,
|
|
posthog_event_suffix=posthog_event_suffix,
|
|
)
|
|
total_llm_cost = sum(cost_tracker)
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
|
|
|
processed_cst = postprocessing_testgen_pipeline(
|
|
parse_module_to_cst(validated_code), data.helper_function_names, data.function_to_optimize, data.module_path
|
|
)
|
|
|
|
generated_test_source = replace_definition_with_import(
|
|
processed_cst.code, data.function_to_optimize, data.module_path
|
|
)
|
|
|
|
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
|
|
generated_test_source, data, python_version
|
|
)
|
|
|
|
if instrumented_behavior_tests is None or instrumented_perf_tests is None:
|
|
msg = "There was an error detected in the function to optimize, is it valid Python code?"
|
|
raise TestGenerationFailedError(msg)
|
|
|
|
sorted_generated_tests = safe_isort(generated_test_source)
|
|
|
|
try:
|
|
parse_module_to_cst(sorted_generated_tests)
|
|
generated_test_source = sorted_generated_tests
|
|
except Exception: # noqa: BLE001
|
|
sentry_sdk.capture_message("isort caused a syntax error in testgen; returning un-sorted code.")
|
|
tree = ast.parse(generated_test_source, feature_version=python_version[:2])
|
|
if not has_test_functions(tree): # sanity check, after all the processing somehow no test functions
|
|
msg = "No test functions were found in the generated test code."
|
|
raise TestGenerationFailedError(msg)
|
|
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests # noqa: TRY300
|
|
except (SyntaxError, ValueError) as e:
|
|
total_llm_cost = sum(cost_tracker)
|
|
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
|
|
msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries."
|
|
raise TestGenerationFailedError(msg) from e
|
|
|
|
|
|
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
|
|
if data.test_index == 0:
|
|
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_single_article():\n # Single article should return its tags\n articles = [{"tags": ["python", "coding", "tutorial"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_with_common_tags():\n # Multiple articles with common tags should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python", "data"]},\n {"tags": ["python", "machine learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_empty_list_of_articles():\n # Empty list of articles should return an empty set\n articles = []\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_no_common_tags():\n # Articles with no common tags should return an empty set\n articles = [\n {"tags": ["python"]},\n {"tags": ["java"]},\n {"tags": ["c++"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_empty_tag_lists():\n # Articles with some empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": ["python"]},\n {"tags": ["python", "java"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_all_articles_with_empty_tag_lists():\n # All articles with empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": []},\n {"tags": []}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_special_characters():\n # Tags with special characters should be handled correctly\n articles = [\n {"tags": ["python!", "coding"]},\n {"tags": ["python!", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_case_sensitivity():\n # Tags with different cases should not be considered the same\n articles = [\n {"tags": ["Python", "coding"]},\n {"tags": ["python", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_articles():\n # Large number of articles with a common tag should return that tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(1000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_tags():\n # Large number of tags with some common tags should return the common tags\n articles = [\n {"tags": [f"tag{i}" for i in range(1000)]},\n {"tags": [f"tag{i}" for i in range(500, 1500)]}\n ]\n expected = {f"tag{i}" for i in range(500, 1000)}\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_mixed_length_of_tag_lists():\n # Articles with mixed length of tag lists should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python"]},\n {"tags": ["python", "coding", "tutorial"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_different_data_types():\n # Tags with different data types should only consider strings\n articles = [\n {"tags": ["python", 123]},\n {"tags": ["python", "123"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_performance_with_large_data():\n # Performance with large data should return the common tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(10000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_scalability_with_increasing_tags():\n # Scalability with increasing tags should return the common tag\n articles = [{"tags": ["common_tag"] + [f"tag{i}" for i in range(j)]} for j in range(1, 1001)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation'
|
|
else:
|
|
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_empty_input_list():\n # Test with an empty list\n codeflash_output = find_common_tags([])\n # Outputs were verified to be equal to the original implementation\n\ndef test_single_article():\n # Test with a single article with tags\n codeflash_output = find_common_tags([{"tags": ["python", "coding", "development"]}])\n # Test with a single article with no tags\n codeflash_output = find_common_tags([{"tags": []}])\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_some_common_tags():\n # Test with multiple articles having some common tags\n articles = [\n {"tags": ["python", "coding", "development"]},\n {"tags": ["python", "development", "tutorial"]},\n {"tags": ["python", "development", "guide"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "news"]},\n {"tags": ["tech", "gadgets"]},\n {"tags": ["tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_no_common_tags():\n # Test with multiple articles having no common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["development", "tutorial"]},\n {"tags": ["guide", "learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["apple", "banana"]},\n {"tags": ["orange", "grape"]},\n {"tags": ["melon", "kiwi"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_duplicate_tags():\n # Test with articles having duplicate tags\n articles = [\n {"tags": ["python", "python", "coding"]},\n {"tags": ["python", "development", "python"]},\n {"tags": ["python", "guide", "python"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "tech", "news"]},\n {"tags": ["tech", "tech", "gadgets"]},\n {"tags": ["tech", "tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_mixed_case_tags():\n # Test with articles having mixed case tags\n articles = [\n {"tags": ["Python", "Coding"]},\n {"tags": ["python", "Development"]},\n {"tags": ["PYTHON", "Guide"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n\n articles = [\n {"tags": ["Tech", "News"]},\n {"tags": ["tech", "Gadgets"]},\n {"tags": ["TECH", "Reviews"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_non_string_tags():\n # Test with articles having non-string tags\n articles = [\n {"tags": ["python", 123, "coding"]},\n {"tags": ["python", "development", 123]},\n {"tags": ["python", "guide", 123]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": [None, "news"]},\n {"tags": ["tech", None]},\n {"tags": [None, "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_scale_test_cases():\n # Test with large scale input where all tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(100)\n ]\n expected_output = {"tag" + str(i) for i in range(1000)}\n codeflash_output = find_common_tags(articles)\n\n # Test with large scale input where no tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(50)\n ] + [{"tags": ["unique_tag"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation'
|
|
|
|
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version)
|
|
await asyncio.sleep(9)
|
|
return TestGenResponseSchema(
|
|
generated_tests=generated_test_source,
|
|
instrumented_behavior_tests=instrumented_behavior_tests,
|
|
instrumented_perf_tests=instrumented_perf_tests,
|
|
)
|
|
|
|
|
|
def validate_request_data(data: TestGenSchema) -> tuple[tuple[int, int, int], BaseTestGenContext]:
|
|
if data.test_framework not in ["unittest", "pytest"]:
|
|
raise HttpError(400, "Invalid test framework. We only support unittest and pytest.")
|
|
if not data.function_to_optimize:
|
|
raise HttpError(400, "Invalid function to optimize. It is empty.")
|
|
if not validate_trace_id(data.trace_id):
|
|
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
|
|
|
|
try:
|
|
python_version = parse_python_version(data.python_version)
|
|
except ValueError:
|
|
raise HttpError(400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above.") # noqa: B904
|
|
|
|
try:
|
|
ctx = BaseTestGenContext.get_dynamic_context(
|
|
TestGenContextData(
|
|
source_code_being_tested=data.source_code_being_tested,
|
|
qualified_name=data.function_to_optimize.qualified_name,
|
|
)
|
|
)
|
|
ctx.validate_python_module(feature_version=python_version[:2])
|
|
except SyntaxError:
|
|
raise HttpError(400, "Invalid source code. It is not valid Python code.") # noqa: B904
|
|
|
|
return python_version, ctx
|
|
|
|
|
|
@testgen_api.post(
|
|
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
|
|
)
|
|
async def testgen(
|
|
request: AuthBearer, data: TestGenSchema
|
|
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
|
ph(request.user, "aiservice-testgen-called")
|
|
|
|
try:
|
|
python_version, ctx = validate_request_data(data)
|
|
except HttpError as e:
|
|
e.add_note(f"Testgen request validation error: {e.status_code} {e.message}")
|
|
sentry_sdk.capture_exception(e)
|
|
return e.status_code, TestGenErrorResponseSchema(error=e.message)
|
|
|
|
if should_hack_for_demo(data.source_code_being_tested):
|
|
demo_hack_response = await hack_for_demo(data, python_version)
|
|
return 200, demo_hack_response
|
|
|
|
logging.info("/testgen: Generating tests...")
|
|
try:
|
|
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
|
|
|
|
(
|
|
generated_test_source,
|
|
instrumented_behavior_tests,
|
|
instrumented_perf_tests,
|
|
) = await generate_regression_tests_from_function(
|
|
ctx=ctx,
|
|
user_id=request.user,
|
|
function_name=data.function_to_optimize.qualified_name,
|
|
python_version=python_version,
|
|
data=data,
|
|
unit_test_package=data.test_framework,
|
|
is_async=data.is_async,
|
|
trace_id=data.trace_id,
|
|
)
|
|
|
|
ph(request.user, "aiservice-testgen-tests-generated")
|
|
|
|
if hasattr(request, "should_log_features") and request.should_log_features:
|
|
await log_features(
|
|
trace_id=data.trace_id,
|
|
user_id=request.user,
|
|
generated_tests=[generated_test_source],
|
|
instrumented_generated_tests=[instrumented_behavior_tests],
|
|
test_framework=data.test_framework,
|
|
metadata={
|
|
"test_timeout": data.test_timeout,
|
|
"function_to_optimize": data.function_to_optimize.function_name,
|
|
},
|
|
)
|
|
|
|
return 200, TestGenResponseSchema(
|
|
generated_tests=generated_test_source,
|
|
instrumented_behavior_tests=instrumented_behavior_tests,
|
|
instrumented_perf_tests=instrumented_perf_tests,
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.exception("Test generation failed")
|
|
sentry_sdk.capture_exception(e)
|
|
return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.")
|