codeflash-internal/django/aiservice/testgen/testgen.py
Kevin Turcios 273edff3ab unify
2025-12-22 23:51:05 -05:00

454 lines
41 KiB
Python

# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
from __future__ import annotations
import ast
import asyncio
import logging
import re
from pathlib import Path
from typing import TYPE_CHECKING, SupportsIndex
import sentry_sdk
import stamina
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai import OpenAIError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, safe_isort, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import IS_PRODUCTION, debug_log_sensitive_data
from aiservice.llm import EXECUTE_MODEL, LLMResponse, calculate_llm_cost, call_llm
from aiservice.observability.decorators import observe_llm_call
from log_features.log_event import update_optimization_cost
from log_features.log_features import log_features
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
from testgen.models import (
TestGenerationFailedError,
TestGenErrorResponseSchema,
TestGenResponseSchema,
TestGenSchema, # noqa: TC001
TestingMode,
)
from testgen.postprocessing.code_validator import has_test_functions, validate_testgen_code
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from testgen.testgen_context import BaseTestGenContext, TestGenContextData
if TYPE_CHECKING:
from aiservice.llm import LLM
from authapp.auth import AuthBearer
testgen_api = NinjaAPI(urls_namespace="testgen")
# Get the directory of the current file
current_dir = Path(__file__).parent
EXPLAIN_SYSTEM_PROMPT = (current_dir / "explain_system_prompt.md").read_text()
EXPLAIN_USER_PROMPT = (current_dir / "explain_user_prompt.md").read_text()
EXECUTE_SYSTEM_PROMPT = (current_dir / "execute_system_prompt.md").read_text()
EXECUTE_USER_PROMPT = (current_dir / "execute_user_prompt.md").read_text()
EXECUTE_ASYNC_SYSTEM_PROMPT = (current_dir / "execute_async_system_prompt.md").read_text()
EXECUTE_ASYNC_USER_PROMPT = (current_dir / "execute_async_user_prompt.md").read_text()
pattern = re.compile(r"^```python\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
color_prefix_by_role = {
"system": "\033[0m", # gray
"user": "\033[0m", # gray
"assistant": "\033[92m", # green
}
def print_messages(
messages: dict[SupportsIndex | slice, str],
color_prefix_by_role: dict[SupportsIndex | slice, str] = color_prefix_by_role,
) -> None:
"""Print messages sent to or from GPT."""
message: str
for message in messages:
role: str = message["role"]
color_prefix: str = color_prefix_by_role[role]
content: str = message["content"]
print(f"{color_prefix}\n[{role}]\n{content}")
def build_prompt(
ctx: BaseTestGenContext, function_name: str, unit_test_package: str, *, is_async: bool
) -> tuple[list[dict[str, str]], str, str]:
if is_async:
execute_system_prompt = EXECUTE_ASYNC_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_ASYNC_USER_PROMPT
plan_content = f"""A good unit test suite for an ASYNC function should aim to:
- Test the async function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen, including async-specific edge cases
- Take advantage of the features of `{unit_test_package}` to make async tests easy to write and maintain
- Be easy to read and understand, with clean async code and descriptive names
- Be deterministic, so that the async tests always pass or fail in the same way
- Have tests sorted by difficulty, from easiest to hardest
- Should try not to mock or stub any dependencies, so that the async testing environment is as close to production
- Include concurrent execution test cases to assess the function's async performance and behavior
- Include throughput test cases to measure the function's performance under load and high-volume scenarios
- Test proper async/await patterns and coroutine handling
To help unit test the ASYNC function above, list diverse scenarios that the async function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
posthog_event_suffix = "async-"
error_context = "async "
else:
execute_system_prompt = EXECUTE_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_USER_PROMPT
plan_content = f"""A good unit test suite should aim to:
- Test the function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
- Be easy to read and understand, with clean code and descriptive names
- Be deterministic, so that the tests always pass or fail in the same way
- Have tests sorted by difficulty, from easiest to hardest
- Should try not to mock or stub any dependencies by using `{unit_test_package}`.mock or any other similar mocking or stubbing module, so that the testing environment is as close to the production environment as possible
- Include Large Scale Test Cases to assess the function's performance and scalability with large data samples.
To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
posthog_event_suffix = ""
error_context = ""
plan_user_message = {"role": "user", "content": plan_content}
package_comment = ""
# if unit_test_package == "pytest":
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
execute_system_message = {
"role": "system",
"content": execute_system_prompt.format(function_name=ctx.data.qualified_name),
}
execute_messages = [execute_system_message, plan_user_message]
all_notes = ctx.generate_notes_markdown()
note_message = {"role": "user", "content": all_notes}
execute_messages += [note_message]
execute_user_message = {
"role": "user",
"content": execute_user_prompt.format(
unit_test_package=unit_test_package,
function_name=function_name,
function_code=ctx.data.source_code_being_tested,
package_comment=package_comment,
),
}
execute_messages += [execute_user_message]
if IS_PRODUCTION is False:
print_messages(execute_messages)
return execute_messages, posthog_event_suffix, error_context
def instrument_tests(
generated_test_source: str, data: TestGenSchema, python_version: tuple[int, int, int]
) -> tuple[str | None, str | None]:
common_args = {
"test_source": generated_test_source,
"function_to_optimize": data.function_to_optimize,
"helper_function_names": data.helper_function_names,
"module_path": data.module_path,
"test_module_path": data.test_module_path,
"test_framework": data.test_framework,
"test_timeout": data.test_timeout,
"python_version": python_version,
}
# instrument_test_source() already applies isort via format_and_float_to_top()
# No need to apply isort again here (was causing double formatting overhead)
behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR)
perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE)
return behavior_result, perf_result
def parse_and_validate_llm_output(
response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str
) -> str:
try:
if "```python" not in response_content:
sentry_sdk.capture_message("LLM response did not contain a Python code block:\n" + response_content)
raise ValueError("LLM response did not contain a Python code block.")
pattern_res = pattern.search(response_content)
if not pattern_res:
raise ValueError("No Python code block found in the LLM response.")
code = pattern_res.group(1)
cleaned_code = validate_testgen_code(code, python_version[:2], max_lines_to_remove=120)
if ctx.did_generate_ellipsis(cleaned_code, python_version):
msg = f"Ellipsis in generated {error_context}test code, regenerating..."
raise SyntaxError(msg)
return cleaned_code # noqa: TRY300
except Exception as e:
sentry_sdk.capture_exception(e)
raise
@observe_llm_call("test_generation")
async def call_testgen_llm(
trace_id: str,
model: LLM,
messages: list[dict[str, str]],
temperature: float,
user_id: str | None = None,
python_version: str | None = None,
) -> LLMResponse:
"""Call LLM for test generation with automatic observability.
This function is decorated with @observe_llm_call which automatically:
- Records call start (non-blocking)
- Captures timing and token usage
- Records completion (non-blocking)
- Handles errors automatically
All observability runs in the background without blocking the LLM call.
"""
return await call_llm(
model_name=model.name, model_type=model.model_type, messages=messages, temperature=temperature
)
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
async def generate_and_validate_test_code(
messages: list[dict[str, str]],
model: LLM,
temperature: float,
ctx: BaseTestGenContext,
python_version: tuple[int, int, int],
error_context: str,
execute_model: LLM,
cost_tracker: list[float],
user_id: str,
posthog_event_suffix: str,
trace_id: str = "",
) -> str:
# Call LLM with automatic observability (decorator handles everything)
response = await call_testgen_llm(
trace_id=trace_id,
model=model,
messages=messages,
temperature=temperature,
user_id=user_id,
python_version=".".join(str(v) for v in python_version),
)
cost = calculate_llm_cost(response.raw_response, execute_model) or 0.0
cost_tracker.append(cost)
debug_log_sensitive_data(
f"OpenAIClient {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}"
)
if response.raw_response.usage:
ph(
user_id,
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
properties={"model": execute_model.name, "usage": response.raw_response.usage.model_dump_json()},
)
# Parse and validate
validated_code = parse_and_validate_llm_output(
response_content=response.content, ctx=ctx, python_version=python_version, error_context=error_context
)
return validated_code
@stamina.retry(on=TestGenerationFailedError, attempts=2)
async def generate_regression_tests_from_function(
ctx: BaseTestGenContext,
user_id: str,
function_name: str,
python_version: tuple[int, int, int],
data: TestGenSchema,
unit_test_package: str = "pytest",
execute_model: LLM = EXECUTE_MODEL,
temperature: float = 0.4,
is_async: bool = False, # noqa: FBT001, FBT002
trace_id: str = "",
) -> tuple[str, str | None, str | None]:
execute_messages, posthog_event_suffix, error_context = build_prompt(
ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async
)
cost_tracker = []
try:
validated_code = await generate_and_validate_test_code(
messages=execute_messages,
model=execute_model,
temperature=temperature,
ctx=ctx,
python_version=python_version,
error_context=error_context,
execute_model=execute_model,
cost_tracker=cost_tracker,
user_id=user_id,
posthog_event_suffix=posthog_event_suffix,
trace_id=trace_id,
)
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
processed_cst = postprocessing_testgen_pipeline(
parse_module_to_cst(validated_code), data.helper_function_names, data.function_to_optimize, data.module_path
)
generated_test_source = replace_definition_with_import(
processed_cst.code, data.function_to_optimize, data.module_path
)
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
generated_test_source, data, python_version
)
if instrumented_behavior_tests is None or instrumented_perf_tests is None:
msg = "There was an error detected in the function to optimize, is it valid Python code?"
raise TestGenerationFailedError(msg)
sorted_generated_tests = safe_isort(generated_test_source)
try:
parse_module_to_cst(sorted_generated_tests)
generated_test_source = sorted_generated_tests
except Exception: # noqa: BLE001
sentry_sdk.capture_message("isort caused a syntax error in testgen; returning un-sorted code.")
tree = ast.parse(generated_test_source, feature_version=python_version[:2])
if not has_test_functions(tree): # sanity check, after all the processing somehow no test functions
msg = "No test functions were found in the generated test code."
raise TestGenerationFailedError(msg)
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests # noqa: TRY300
except (SyntaxError, ValueError) as e:
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost)
msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries."
raise TestGenerationFailedError(msg) from e
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
if data.test_index == 0:
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_single_article():\n # Single article should return its tags\n articles = [{"tags": ["python", "coding", "tutorial"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_with_common_tags():\n # Multiple articles with common tags should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python", "data"]},\n {"tags": ["python", "machine learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_empty_list_of_articles():\n # Empty list of articles should return an empty set\n articles = []\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_no_common_tags():\n # Articles with no common tags should return an empty set\n articles = [\n {"tags": ["python"]},\n {"tags": ["java"]},\n {"tags": ["c++"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_empty_tag_lists():\n # Articles with some empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": ["python"]},\n {"tags": ["python", "java"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_all_articles_with_empty_tag_lists():\n # All articles with empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": []},\n {"tags": []}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_special_characters():\n # Tags with special characters should be handled correctly\n articles = [\n {"tags": ["python!", "coding"]},\n {"tags": ["python!", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_case_sensitivity():\n # Tags with different cases should not be considered the same\n articles = [\n {"tags": ["Python", "coding"]},\n {"tags": ["python", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_articles():\n # Large number of articles with a common tag should return that tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(1000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_tags():\n # Large number of tags with some common tags should return the common tags\n articles = [\n {"tags": [f"tag{i}" for i in range(1000)]},\n {"tags": [f"tag{i}" for i in range(500, 1500)]}\n ]\n expected = {f"tag{i}" for i in range(500, 1000)}\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_mixed_length_of_tag_lists():\n # Articles with mixed length of tag lists should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python"]},\n {"tags": ["python", "coding", "tutorial"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_different_data_types():\n # Tags with different data types should only consider strings\n articles = [\n {"tags": ["python", 123]},\n {"tags": ["python", "123"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_performance_with_large_data():\n # Performance with large data should return the common tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(10000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_scalability_with_increasing_tags():\n # Scalability with increasing tags should return the common tag\n articles = [{"tags": ["common_tag"] + [f"tag{i}" for i in range(j)]} for j in range(1, 1001)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation'
else:
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_empty_input_list():\n # Test with an empty list\n codeflash_output = find_common_tags([])\n # Outputs were verified to be equal to the original implementation\n\ndef test_single_article():\n # Test with a single article with tags\n codeflash_output = find_common_tags([{"tags": ["python", "coding", "development"]}])\n # Test with a single article with no tags\n codeflash_output = find_common_tags([{"tags": []}])\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_some_common_tags():\n # Test with multiple articles having some common tags\n articles = [\n {"tags": ["python", "coding", "development"]},\n {"tags": ["python", "development", "tutorial"]},\n {"tags": ["python", "development", "guide"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "news"]},\n {"tags": ["tech", "gadgets"]},\n {"tags": ["tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_no_common_tags():\n # Test with multiple articles having no common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["development", "tutorial"]},\n {"tags": ["guide", "learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["apple", "banana"]},\n {"tags": ["orange", "grape"]},\n {"tags": ["melon", "kiwi"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_duplicate_tags():\n # Test with articles having duplicate tags\n articles = [\n {"tags": ["python", "python", "coding"]},\n {"tags": ["python", "development", "python"]},\n {"tags": ["python", "guide", "python"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "tech", "news"]},\n {"tags": ["tech", "tech", "gadgets"]},\n {"tags": ["tech", "tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_mixed_case_tags():\n # Test with articles having mixed case tags\n articles = [\n {"tags": ["Python", "Coding"]},\n {"tags": ["python", "Development"]},\n {"tags": ["PYTHON", "Guide"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n\n articles = [\n {"tags": ["Tech", "News"]},\n {"tags": ["tech", "Gadgets"]},\n {"tags": ["TECH", "Reviews"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_non_string_tags():\n # Test with articles having non-string tags\n articles = [\n {"tags": ["python", 123, "coding"]},\n {"tags": ["python", "development", 123]},\n {"tags": ["python", "guide", 123]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": [None, "news"]},\n {"tags": ["tech", None]},\n {"tags": [None, "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_scale_test_cases():\n # Test with large scale input where all tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(100)\n ]\n expected_output = {"tag" + str(i) for i in range(1000)}\n codeflash_output = find_common_tags(articles)\n\n # Test with large scale input where no tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(50)\n ] + [{"tags": ["unique_tag"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation'
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version)
await asyncio.sleep(5)
return TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
)
async def hack_for_demo_gsq(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
if data.test_index == 0:
generated_test_source = "from datetime import datetime, timedelta\n# function to test\nfrom functools import reduce\n\nimport numpy as np\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n\nclass MqTypeError(TypeError): pass\nclass MqValueError(ValueError): pass\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# unit tests\n\n# --- Basic Test Cases ---\n\ndef test_weighted_sum_simple_two_series():\n # Simple case: two series, same index, weights sum to 1\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n weights = [0.7, 0.3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 0.7 + s2 * 0.3) / (0.7 + 0.3)\n\ndef test_weighted_sum_three_series_equal_weights():\n # Three series, equal weights\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n s3 = pd.Series([7, 8, 9], index=idx)\n weights = [1, 1, 1]\n codeflash_output = weighted_sum([s1, s2, s3], weights); result = codeflash_output\n expected = (s1 + s2 + s3) / 3\n\ndef test_weighted_sum_weights_not_normalized():\n # Weights do not sum to 1\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([10, 20], index=idx)\n s2 = pd.Series([30, 40], index=idx)\n weights = [2, 3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 2 + s2 * 3) / (2 + 3)\n\n# --- Edge Test Cases ---\n\ndef test_weighted_sum_empty_series_list():\n # Empty input series list\n with pytest.raises(MqValueError):\n weighted_sum([], [1, 2])\n\ndef test_weighted_sum_empty_weights_list():\n # Empty weights list\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Weights and series length mismatch\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-pandas Series input\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, [3, 4]], [0.5, 0.5])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, s2], [0.5, \"a\"])\n\ndef test_weighted_sum_disjoint_calendars():\n # Series with disjoint calendars (no overlap)\n idx1 = pd.date_range('2024-01-01', periods=2)\n idx2 = pd.date_range('2024-02-01', periods=2)\n s1 = pd.Series([1, 2], index=idx1)\n s2 = pd.Series([3, 4], index=idx2)\n weights = [0.5, 0.5]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n\ndef test_weighted_sum_partial_overlap_calendars():\n # Series with partial overlap in calendar\n idx1 = pd.date_range('2024-01-01', periods=3)\n idx2 = pd.date_range('2024-01-02', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx1)\n s2 = pd.Series([4, 5, 6], index=idx2)\n weights = [1, 1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n # Only dates present in both should be included\n expected_idx = idx1.intersection(idx2)\n expected = (s1.reindex(expected_idx) + s2.reindex(expected_idx)) / 2\n\ndef test_weighted_sum_zero_weights():\n # Zero weights (should return NaN for all dates)\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n weights = [0, 0]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n\ndef test_weighted_sum_negative_weights():\n # Negative weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([10, 20], index=idx)\n s2 = pd.Series([30, 40], index=idx)\n weights = [1, -1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 1 + s2 * -1) / (1 + -1)\n\ndef test_weighted_sum_nan_values_in_series():\n # Series containing NaN values\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, None, 3], index=idx)\n s2 = pd.Series([4, 5, None], index=idx)\n weights = [1, 1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n # NaN propagation: result should be NaN where any input is NaN\n expected = (s1 * 1 + s2 * 1) / 2\n\ndef test_weighted_sum_mixed_int_float_weights():\n # Mixed int and float weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n weights = [1, 0.5]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 1 + s2 * 0.5) / (1 + 0.5)\n\n# --- Large Scale Test Cases ---\n\ndef test_weighted_sum_large_number_of_series():\n # Large number of series (up to 1000)\n n = 1000\n idx = pd.date_range('2024-01-01', periods=10)\n series_list = [pd.Series([i]*10, index=idx) for i in range(n)]\n weights = [1 for _ in range(n)]\n codeflash_output = weighted_sum(series_list, weights); result = codeflash_output\n # Each date should be the average of 0...999, i.e. (sum(range(n)))/n\n expected_value = sum(range(n)) / n\n\ndef test_weighted_sum_large_series_length():\n # Large series length (up to 1000 dates)\n idx = pd.date_range('2024-01-01', periods=1000)\n s1 = pd.Series(range(1000), index=idx)\n s2 = pd.Series(range(1000, 0, -1), index=idx)\n weights = [0.25, 0.75]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 0.25 + s2 * 0.75) / (0.25 + 0.75)\n\ndef test_weighted_sum_large_partial_overlap():\n # Large series with partial overlap\n idx1 = pd.date_range('2024-01-01', periods=1000)\n idx2 = pd.date_range('2024-01-500', periods=500)\n s1 = pd.Series(range(1000), index=idx1)\n s2 = pd.Series(range(500), index=idx2)\n weights = [1, 2]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected_idx = idx1.intersection(idx2)\n expected = (s1.reindex(expected_idx) * 1 + s2.reindex(expected_idx) * 2) / (1 + 2)\n# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."
else:
generated_test_source = "from datetime import datetime, timedelta\n\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# ---------- Unit Tests ----------\n\n# Helper to create pd.Series with consecutive dates\ndef make_series(start, count, value_func=lambda i: i):\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(count)]\n values = [value_func(i) for i in range(count)]\n return pd.Series(values, index=dates)\n\n# 1. Basic Test Cases\n\ndef test_weighted_sum_simple_two_series_equal_weights():\n # Two series, equal weights\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([3.0]*5, index=s1.index)\n\ndef test_weighted_sum_simple_two_series_unequal_weights():\n # Two series, weights 2:1\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 8)\n codeflash_output = weighted_sum([s1, s2], [2, 1]); result = codeflash_output\n expected = pd.Series([(2*2+1*8)/3]*5, index=s1.index)\n\ndef test_weighted_sum_three_series_varied_weights():\n # Three series, varied weights\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 3)\n s3 = make_series(0, 3, lambda i: 5)\n codeflash_output = weighted_sum([s1, s2, s3], [0.2, 0.3, 0.5]); result = codeflash_output\n expected = pd.Series([1*0.2+3*0.3+5*0.5], index=s1.index)\n expected = pd.Series([1*0.2+3*0.3+5*0.5]*3, index=s1.index)\n\ndef test_weighted_sum_one_series():\n # Single series, should return the series itself\n s1 = make_series(0, 4, lambda i: i+5)\n codeflash_output = weighted_sum([s1], [1]); result = codeflash_output\n\n# 2. Edge Test Cases\n\ndef test_weighted_sum_empty_series_list():\n # Empty series list should raise ValueError\n with pytest.raises(ValueError):\n weighted_sum([], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Series and weights length mismatch\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(ValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-Series in series list\n s1 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, [1, 2, 3]], [1, 1])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric in weights\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, s2], [1, 'a'])\n\ndef test_weighted_sum_zero_weights():\n # All weights zero should return NaN series\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 2)\n codeflash_output = weighted_sum([s1, s2], [0, 0]); result = codeflash_output\n\ndef test_weighted_sum_negative_weights():\n # Negative weights\n s1 = make_series(0, 3, lambda i: 2)\n s2 = make_series(0, 3, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [-1, 2]); result = codeflash_output\n expected = pd.Series([(2*-1+4*2)/(-1+2)]*3, index=s1.index)\n\ndef test_weighted_sum_partial_overlap_indices():\n # Series with partially overlapping indices\n s1 = make_series(0, 5, lambda i: i)\n s2 = make_series(2, 5, lambda i: i+10)\n # Overlap is on days 2,3,4\n overlap_dates = s1.index.intersection(s2.index)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([(s1[d]+s2[d])/2 for d in overlap_dates], index=overlap_dates)\n\ndef test_weighted_sum_series_with_nans():\n # Series with NaNs, intersection should drop NaN dates\n s1 = make_series(0, 4, lambda i: float('nan') if i==2 else i)\n s2 = make_series(0, 4, lambda i: i)\n # Only dates where both have non-NaN should be used\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\ndef test_weighted_sum_series_with_different_freq():\n # Series with different frequencies\n s1 = make_series(0, 5, lambda i: i)\n s2 = pd.Series([10, 20], index=[datetime(2020, 1, 2), datetime(2020, 1, 4)])\n # Only dates 2020-01-02 and 2020-01-04 are in both\n overlap_dates = s1.index.intersection(s2.index)\n codeflash_output = weighted_sum([s1, s2], [1, 3]); result = codeflash_output\n expected = pd.Series([(s1[d]*1+s2[d]*3)/4 for d in overlap_dates], index=overlap_dates)\n\ndef test_weighted_sum_all_nan_series():\n # All series are NaN\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(3)]\n s1 = pd.Series([float('nan')] * 3, index=dates)\n s2 = pd.Series([float('nan')] * 3, index=dates)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\n# 3. Large Scale Test Cases\n\ndef test_weighted_sum_large_number_of_series():\n # 100 series, each with 100 points\n N = 100\n M = 100\n series_list = [make_series(0, M, lambda i: i+j) for j in range(N)]\n weights = [j+1 for j in range(N)]\n codeflash_output = weighted_sum(series_list, weights); result = codeflash_output\n # For each date, expected value is weighted sum of values across series\n dates = series_list[0].index\n expected_values = []\n total_weight = sum(weights)\n for i in range(M):\n # For date i, value in series j is i+j\n weighted_sum_val = sum((i+j)*weights[j] for j in range(N)) / total_weight\n expected_values.append(weighted_sum_val)\n expected = pd.Series(expected_values, index=dates)\n\ndef test_weighted_sum_large_series_length():\n # Two series, each with 1000 points\n M = 1000\n s1 = make_series(0, M, lambda i: i)\n s2 = make_series(0, M, lambda i: 2*i)\n codeflash_output = weighted_sum([s1, s2], [1, 2]); result = codeflash_output\n expected = pd.Series([(i*1+2*i*2)/3 for i in range(M)], index=s1.index)\n\ndef test_weighted_sum_large_sparse_overlap():\n # Series with sparse overlap\n s1 = pd.Series([1]*1000, index=[datetime(2020, 1, 1)+timedelta(days=i) for i in range(0, 1000, 2)])\n s2 = pd.Series([2]*500, index=[datetime(2020, 1, 1)+timedelta(days=i) for i in range(1, 1000, 2)])\n # There should be no overlap, so result should be empty\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\ndef test_weighted_sum_large_all_nan():\n # Large series, all NaN\n M = 1000\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(M)]\n s1 = pd.Series([float('nan')]*M, index=dates)\n s2 = pd.Series([float('nan')]*M, index=dates)\n codeflash_output = weighted_sum([s1, s2], [1, 2]); result = codeflash_output\n# codeflash_output is used to check that the output of the original code is the same as that of the optimized code."
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version)
await asyncio.sleep(2)
return TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
)
def validate_request_data(data: TestGenSchema) -> tuple[tuple[int, int, int], BaseTestGenContext]:
if data.test_framework not in ["unittest", "pytest"]:
raise HttpError(400, "Invalid test framework. We only support unittest and pytest.")
if not data.function_to_optimize:
raise HttpError(400, "Invalid function to optimize. It is empty.")
if not validate_trace_id(data.trace_id):
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
try:
python_version = parse_python_version(data.python_version)
except ValueError:
raise HttpError(400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above.") # noqa: B904
try:
ctx = BaseTestGenContext.get_dynamic_context(
TestGenContextData(
source_code_being_tested=data.source_code_being_tested,
qualified_name=data.function_to_optimize.qualified_name,
)
)
ctx.validate_python_module(feature_version=python_version[:2])
except SyntaxError:
raise HttpError(400, "Invalid source code. It is not valid Python code.") # noqa: B904
return python_version, ctx
@testgen_api.post(
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
)
async def testgen(
request: AuthBearer, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
ph(request.user, "aiservice-testgen-called")
try:
python_version, ctx = validate_request_data(data)
except HttpError as e:
e.add_note(f"Testgen request validation error: {e.status_code} {e.message}")
sentry_sdk.capture_exception(e)
return e.status_code, TestGenErrorResponseSchema(error=e.message)
if should_hack_for_demo(data.source_code_being_tested):
if "find_common_tags" in data.source_code_being_tested:
demo_hack_response = await hack_for_demo(data, python_version)
elif "weighted_sum" in data.source_code_being_tested:
demo_hack_response = await hack_for_demo_gsq(data, python_version)
return 200, demo_hack_response
logging.info("/testgen: Generating tests...")
try:
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
(
generated_test_source,
instrumented_behavior_tests,
instrumented_perf_tests,
) = await generate_regression_tests_from_function(
ctx=ctx,
user_id=request.user,
function_name=data.function_to_optimize.qualified_name,
python_version=python_version,
data=data,
unit_test_package=data.test_framework,
is_async=data.is_async,
trace_id=data.trace_id,
)
ph(request.user, "aiservice-testgen-tests-generated")
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[generated_test_source],
instrumented_generated_tests=[instrumented_behavior_tests],
test_framework=data.test_framework,
metadata={
"test_timeout": data.test_timeout,
"function_to_optimize": data.function_to_optimize.function_name,
},
)
return 200, TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
)
except Exception as e:
logging.exception("Test generation failed")
sentry_sdk.capture_exception(e)
return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.")