codeflash-internal/django/aiservice/testgen/testgen.py

470 lines
42 KiB
Python
Raw Normal View History

# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb
from __future__ import annotations
2025-10-18 07:55:48 +00:00
import ast
2024-10-28 05:59:24 +00:00
import asyncio
import logging
2025-10-18 05:11:20 +00:00
import re
2024-03-27 02:46:49 +00:00
from pathlib import Path
2025-12-30 07:44:28 +00:00
from typing import TYPE_CHECKING, TypedDict
import sentry_sdk
2025-10-18 05:11:20 +00:00
import stamina
2025-12-22 07:56:59 +00:00
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai import OpenAIError
2024-10-28 05:59:24 +00:00
from aiservice.analytics.posthog import ph
2025-10-23 04:20:20 +00:00
from aiservice.common_utils import parse_python_version, safe_isort, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
2026-01-07 04:54:20 +00:00
from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm
2025-12-30 07:44:28 +00:00
from aiservice.models.functions_to_optimize import FunctionToOptimize
2025-12-30 19:50:29 +00:00
from authapp.auth import AuthenticatedRequest
from log_features.log_event import update_optimization_cost
2025-10-31 10:09:03 +00:00
from log_features.log_features import log_features
2026-01-09 04:29:09 +00:00
from optimizer.context_utils.context_helpers import split_markdown_code
from testgen.instrumentation.edit_generated_test import parse_module_to_cst, replace_definition_with_import
from testgen.instrumentation.instrument_new_tests import instrument_test_source
2025-10-15 08:45:47 +00:00
from testgen.models import (
TestGenerationFailedError,
TestGenErrorResponseSchema,
TestGenResponseSchema,
2025-12-30 06:45:23 +00:00
TestGenSchema,
2025-10-15 08:45:47 +00:00
TestingMode,
)
from testgen.postprocessing.add_missing_imports import add_missing_imports_from_source
2025-10-18 07:55:48 +00:00
from testgen.postprocessing.code_validator import has_test_functions, validate_testgen_code
from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline
from testgen.testgen_context import BaseTestGenContext, TestGenContextData
2026-01-15 06:15:27 +00:00
from testgen.testgen_javascript import testgen_javascript
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
2025-12-23 04:51:05 +00:00
from aiservice.llm import LLM
2025-12-30 07:44:28 +00:00
class InstrumentTestSourceArgs(TypedDict):
test_source: str
function_to_optimize: FunctionToOptimize
helper_function_names: list[str]
module_path: str
test_module_path: str
test_framework: str
test_timeout: int
python_version: tuple[int, int, int]
2025-05-09 01:51:03 +00:00
testgen_api = NinjaAPI(urls_namespace="testgen")
2024-05-14 00:17:47 +00:00
# Get the directory of the current file
current_dir = Path(__file__).parent
EXPLAIN_SYSTEM_PROMPT = (current_dir / "explain_system_prompt.md").read_text()
EXPLAIN_USER_PROMPT = (current_dir / "explain_user_prompt.md").read_text()
2024-05-14 00:17:47 +00:00
EXECUTE_SYSTEM_PROMPT = (current_dir / "execute_system_prompt.md").read_text()
EXECUTE_USER_PROMPT = (current_dir / "execute_user_prompt.md").read_text()
2025-08-28 22:51:33 +00:00
EXECUTE_ASYNC_SYSTEM_PROMPT = (current_dir / "execute_async_system_prompt.md").read_text()
EXECUTE_ASYNC_USER_PROMPT = (current_dir / "execute_async_user_prompt.md").read_text()
2025-10-18 05:11:20 +00:00
pattern = re.compile(r"^```python\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
2025-08-28 22:51:33 +00:00
2024-03-27 02:46:49 +00:00
def build_prompt(
ctx: BaseTestGenContext, function_name: str, unit_test_package: str, *, is_async: bool
) -> tuple[list[dict[str, str]], str, str]:
2025-08-28 22:51:33 +00:00
if is_async:
execute_system_prompt = EXECUTE_ASYNC_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_ASYNC_USER_PROMPT
plan_content = f"""A good unit test suite for an ASYNC function should aim to:
- Test the async function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen, including async-specific edge cases
- Take advantage of the features of `{unit_test_package}` to make async tests easy to write and maintain
- Be easy to read and understand, with clean async code and descriptive names
- Be deterministic, so that the async tests always pass or fail in the same way
- Have tests sorted by difficulty, from easiest to hardest
- Should try not to mock or stub any dependencies, so that the async testing environment is as close to production
- Include concurrent execution test cases to assess the function's async performance and behavior
2025-09-12 21:34:02 +00:00
- Include throughput test cases to measure the function's performance under load and high-volume scenarios
2025-08-28 22:51:33 +00:00
- Test proper async/await patterns and coroutine handling
To help unit test the ASYNC function above, list diverse scenarios that the async function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
posthog_event_suffix = "async-"
error_context = "async "
else:
execute_system_prompt = EXECUTE_SYSTEM_PROMPT
execute_user_prompt = EXECUTE_USER_PROMPT
plan_content = f"""A good unit test suite should aim to:
- Test the function's behavior for a wide range of possible inputs
- Test edge cases that the author may not have foreseen
- Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain
- Be easy to read and understand, with clean code and descriptive names
- Be deterministic, so that the tests always pass or fail in the same way
- Have tests sorted by difficulty, from easiest to hardest
- Should try not to mock or stub any dependencies by using `{unit_test_package}`.mock or any other similar mocking or stubbing module, so that the testing environment is as close to the production environment as possible
2025-08-28 22:51:33 +00:00
- Include Large Scale Test Cases to assess the function's performance and scalability with large data samples.
2025-08-28 22:51:33 +00:00
To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets)."""
posthog_event_suffix = ""
error_context = ""
plan_user_message = {"role": "user", "content": plan_content}
package_comment = ""
# if unit_test_package == "pytest":
# package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator"
execute_system_message = {
"role": "system",
"content": execute_system_prompt.format(function_name=ctx.data.qualified_name),
}
2025-08-07 20:32:55 +00:00
execute_messages = [execute_system_message, plan_user_message]
2025-03-13 00:40:37 +00:00
all_notes = ctx.generate_notes_markdown()
2025-04-21 01:38:16 +00:00
note_message = {"role": "user", "content": all_notes}
2025-03-13 00:40:37 +00:00
execute_messages += [note_message]
execute_user_message = {
"role": "user",
2025-08-28 22:51:33 +00:00
"content": execute_user_prompt.format(
unit_test_package=unit_test_package,
function_name=function_name,
function_code=ctx.data.source_code_being_tested,
package_comment=package_comment,
),
}
2025-03-13 00:40:37 +00:00
execute_messages += [execute_user_message]
return execute_messages, posthog_event_suffix, error_context
def instrument_tests(
generated_test_source: str, data: TestGenSchema, python_version: tuple[int, int, int]
) -> tuple[str | None, str | None]:
2025-12-30 07:44:28 +00:00
common_args: InstrumentTestSourceArgs = {
"test_source": generated_test_source,
"function_to_optimize": data.function_to_optimize,
2025-12-30 07:44:28 +00:00
"helper_function_names": data.helper_function_names or [],
"module_path": data.module_path,
"test_module_path": data.test_module_path,
"test_framework": data.test_framework,
"test_timeout": data.test_timeout,
"python_version": python_version,
}
# instrument_test_source() already applies isort via format_and_float_to_top()
# No need to apply isort again here (was causing double formatting overhead)
behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR)
perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE)
return behavior_result, perf_result
def parse_and_validate_llm_output(
response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str
) -> str:
try:
if "```python" not in response_content:
sentry_sdk.capture_message("LLM response did not contain a Python code block:\n" + response_content)
raise ValueError("LLM response did not contain a Python code block.")
2025-10-18 05:11:20 +00:00
pattern_res = pattern.search(response_content)
if not pattern_res:
raise ValueError("No Python code block found in the LLM response.")
2025-10-18 05:11:20 +00:00
code = pattern_res.group(1)
2025-10-18 05:40:32 +00:00
cleaned_code = validate_testgen_code(code, python_version[:2], max_lines_to_remove=120)
2025-10-18 05:40:32 +00:00
if ctx.did_generate_ellipsis(cleaned_code, python_version):
msg = f"Ellipsis in generated {error_context}test code, regenerating..."
raise SyntaxError(msg)
2025-10-18 05:40:32 +00:00
return cleaned_code # noqa: TRY300
except Exception as e:
sentry_sdk.capture_exception(e)
raise
2025-10-15 09:24:34 +00:00
2025-10-20 22:07:49 +00:00
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
2025-10-18 05:19:11 +00:00
async def generate_and_validate_test_code(
messages: list[ChatCompletionMessageParam],
2025-10-18 05:19:11 +00:00
model: LLM,
ctx: BaseTestGenContext,
python_version: tuple[int, int, int],
error_context: str,
execute_model: LLM,
cost_tracker: list[float],
user_id: str,
posthog_event_suffix: str,
trace_id: str = "",
call_sequence: int | None = None,
2025-10-18 05:19:11 +00:00
) -> str:
2025-12-26 21:06:29 +00:00
obs_context: dict | None = {"call_sequence": call_sequence} if call_sequence is not None else None
2025-12-26 19:15:13 +00:00
response = await call_llm(
2025-12-26 20:02:30 +00:00
llm=model,
messages=messages,
2025-12-26 19:41:21 +00:00
call_type="test_generation",
2025-12-26 19:15:13 +00:00
trace_id=trace_id,
user_id=user_id,
python_version=".".join(str(v) for v in python_version),
context=obs_context,
2025-10-31 10:07:19 +00:00
)
cost = calculate_llm_cost(response.raw_response, execute_model)
2025-10-31 10:07:19 +00:00
cost_tracker.append(cost)
2025-12-23 04:28:29 +00:00
debug_log_sensitive_data(
f"OpenAIClient {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}"
)
2025-12-23 04:28:29 +00:00
if response.raw_response.usage:
2025-10-31 10:07:19 +00:00
ph(
user_id,
f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage",
2025-12-23 04:28:29 +00:00
properties={"model": execute_model.name, "usage": response.raw_response.usage.model_dump_json()},
2025-10-18 05:19:11 +00:00
)
# Parse and validate
validated_code = parse_and_validate_llm_output(
2025-12-23 04:28:29 +00:00
response_content=response.content, ctx=ctx, python_version=python_version, error_context=error_context
2025-10-31 10:07:19 +00:00
)
return validated_code
2025-10-18 05:19:11 +00:00
2025-10-18 06:49:09 +00:00
@stamina.retry(on=TestGenerationFailedError, attempts=2)
async def generate_regression_tests_from_function(
ctx: BaseTestGenContext,
user_id: str,
function_name: str,
python_version: tuple[int, int, int],
data: TestGenSchema,
unit_test_package: str = "pytest",
execute_model: LLM = EXECUTE_MODEL,
2025-10-15 08:45:47 +00:00
is_async: bool = False, # noqa: FBT001, FBT002
trace_id: str = "",
call_sequence: int | None = None,
) -> tuple[str, str | None, str | None]:
execute_messages, posthog_event_suffix, error_context = build_prompt(
ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async
)
2025-10-18 05:19:11 +00:00
cost_tracker = []
try:
validated_code = await generate_and_validate_test_code(
messages=execute_messages,
model=execute_model,
ctx=ctx,
python_version=python_version,
error_context=error_context,
execute_model=execute_model,
cost_tracker=cost_tracker,
user_id=user_id,
posthog_event_suffix=posthog_event_suffix,
trace_id=trace_id,
call_sequence=call_sequence,
2025-08-28 22:51:33 +00:00
)
2025-10-18 05:19:11 +00:00
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
processed_cst = postprocessing_testgen_pipeline(
2025-12-30 19:55:42 +00:00
parse_module_to_cst(validated_code),
data.helper_function_names or [],
data.function_to_optimize,
data.module_path,
)
# Add missing imports for symbols defined in source module but not imported in test.
# This handles cases where the LLM redefines some classes locally but forgets others.
source_code_blocks = split_markdown_code(data.source_code_being_tested)
# Combine all source code blocks to find all available symbols
combined_source = "\n".join(source_code_blocks.values())
processed_cst = add_missing_imports_from_source(processed_cst, combined_source, data.module_path)
generated_test_source = replace_definition_with_import(
processed_cst.code, data.function_to_optimize, data.module_path
)
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(
generated_test_source, data, python_version
)
if instrumented_behavior_tests is None or instrumented_perf_tests is None:
msg = (
f"There was an error detected in the function to optimize, is it valid Python code? trace_id={trace_id}"
)
logging.error(msg)
raise TestGenerationFailedError(msg)
2025-10-23 04:08:06 +00:00
sorted_generated_tests = safe_isort(generated_test_source)
try:
parse_module_to_cst(sorted_generated_tests)
generated_test_source = sorted_generated_tests
except Exception: # noqa: BLE001
sentry_sdk.capture_message("isort caused a syntax error in testgen; returning un-sorted code.")
2025-10-18 07:55:48 +00:00
tree = ast.parse(generated_test_source, feature_version=python_version[:2])
2025-10-20 19:54:49 +00:00
if not has_test_functions(tree): # sanity check, after all the processing somehow no test functions
msg = f"No test functions were found in the generated test code. trace_id={trace_id}"
logging.error(msg)
2025-10-18 07:55:48 +00:00
raise TestGenerationFailedError(msg)
return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests # noqa: TRY300
2025-10-18 05:19:11 +00:00
except (SyntaxError, ValueError) as e:
total_llm_cost = sum(cost_tracker)
await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id)
msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries. trace_id={trace_id}"
logging.error(msg)
2025-10-18 05:19:11 +00:00
raise TestGenerationFailedError(msg) from e
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
if data.test_index == 0:
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_single_article():\n # Single article should return its tags\n articles = [{"tags": ["python", "coding", "tutorial"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_with_common_tags():\n # Multiple articles with common tags should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python", "data"]},\n {"tags": ["python", "machine learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_empty_list_of_articles():\n # Empty list of articles should return an empty set\n articles = []\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_no_common_tags():\n # Articles with no common tags should return an empty set\n articles = [\n {"tags": ["python"]},\n {"tags": ["java"]},\n {"tags": ["c++"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_empty_tag_lists():\n # Articles with some empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": ["python"]},\n {"tags": ["python", "java"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_all_articles_with_empty_tag_lists():\n # All articles with empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": []},\n {"tags": []}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_special_characters():\n # Tags with special characters should be handled correctly\n articles = [\n {"tags": ["python!", "coding"]},\n {"tags": ["python!", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_case_sensitivity():\n # Tags with different cases should not be considered the same\n articles = [\n {"tags": ["Python", "coding"]},\n {"tags": ["python", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_articles():\n # Large number of articles with a common tag should return that tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(1000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_tags():\n # Large number of tags with some common tags should return the common tags\n articles = [\n {"tags": [f"tag{i}" for i in range(1000)]},\n {"tags": [f"tag{i}" for i in range(500, 1500)]}\n ]\n expected = {f"tag{i}" for i in range(500, 1000)}\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_mixed_length_of_tag_lists():\n # Articles with mixed length of tag lists should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python"]},\n {"tags": ["python", "coding", "tutorial"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_different_data_types():\n # Tags with different data types should only consider strings\n articles = [\n {"tags": ["python",
else:
generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_empty_input_list():\n # Test with an empty list\n codeflash_output = find_common_tags([])\n # Outputs were verified to be equal to the original implementation\n\ndef test_single_article():\n # Test with a single article with tags\n codeflash_output = find_common_tags([{"tags": ["python", "coding", "development"]}])\n # Test with a single article with no tags\n codeflash_output = find_common_tags([{"tags": []}])\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_some_common_tags():\n # Test with multiple articles having some common tags\n articles = [\n {"tags": ["python", "coding", "development"]},\n {"tags": ["python", "development", "tutorial"]},\n {"tags": ["python", "development", "guide"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "news"]},\n {"tags": ["tech", "gadgets"]},\n {"tags": ["tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_no_common_tags():\n # Test with multiple articles having no common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["development", "tutorial"]},\n {"tags": ["guide", "learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["apple", "banana"]},\n {"tags": ["orange", "grape"]},\n {"tags": ["melon", "kiwi"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_duplicate_tags():\n # Test with articles having duplicate tags\n articles = [\n {"tags": ["python", "python", "coding"]},\n {"tags": ["python", "development", "python"]},\n {"tags": ["python", "guide", "python"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "tech", "news"]},\n {"tags": ["tech", "tech", "gadgets"]},\n {"tags": ["tech", "tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_mixed_case_tags():\n # Test with articles having mixed case tags\n articles = [\n {"tags": ["Python", "Coding"]},\n {"tags": ["python", "Development"]},\n {"tags": ["PYTHON", "Guide"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n\n articles = [\n {"tags": ["Tech", "News"]},\n {"tags": ["tech", "Gadgets"]},\n {"tags": ["TECH", "Reviews"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_non_string_tags():\n # Test with articles having non-string tags\n articles = [\n {"tags": ["python", 123, "coding"]},\n {"tags": ["python", "development", 123]},\n {"tags": ["python", "guide", 123]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": [None, "news"]},\n {"tags": ["tech", None]},\n {"tags": [None, "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_scale_test_cases():\n # Test with large scale input where all tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(100)\n ]\n expected_output = {"tag" + str(i) for i in range(1000)}\n codeflash_output = find_common_tags(articles)\n\n # Test with large scale input wh
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version)
2025-11-18 18:17:09 +00:00
await asyncio.sleep(5)
return TestGenResponseSchema(
generated_tests=generated_test_source,
2025-12-30 07:44:28 +00:00
instrumented_behavior_tests=instrumented_behavior_tests or "",
instrumented_perf_tests=instrumented_perf_tests or "",
)
2025-11-18 16:30:54 +00:00
async def hack_for_demo_gsq(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
if data.test_index == 0:
generated_test_source = "from datetime import datetime, timedelta\n# function to test\nfrom functools import reduce\n\nimport numpy as np\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n\nclass MqTypeError(TypeError): pass\nclass MqValueError(ValueError): pass\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# unit tests\n\n# --- Basic Test Cases ---\n\ndef test_weighted_sum_simple_two_series():\n # Simple case: two series, same index, weights sum to 1\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n weights = [0.7, 0.3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 0.7 + s2 * 0.3) / (0.7 + 0.3)\n\ndef test_weighted_sum_three_series_equal_weights():\n # Three series, equal weights\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n s3 = pd.Series([7, 8, 9], index=idx)\n weights = [1, 1, 1]\n codeflash_output = weighted_sum([s1, s2, s3], weights); result = codeflash_output\n expected = (s1 + s2 + s3) / 3\n\ndef test_weighted_sum_weights_not_normalized():\n # Weights do not sum to 1\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([10, 20], index=idx)\n s2 = pd.Series([30, 40], index=idx)\n weights = [2, 3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 2 + s2 * 3) / (2 + 3)\n\n# --- Edge Test Cases ---\n\ndef test_weighted_sum_empty_series_list():\n # Empty input series list\n with pytest.raises(MqValueError):\n weighted_sum([], [1, 2])\n\ndef test_weighted_sum_empty_weights_list():\n # Empty weights list\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Weights and series length mismatch\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-pandas Series input\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, [3, 4]], [0.5, 0.5])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, s2], [0.5, \"a\"])\n\ndef test_weighted_sum_disjoint_calendars():\n # Series with disjoint calendars (no overlap)\n idx1 = pd.date_range('2024-01-01', periods=2)\n idx2 = pd.date_range('2024-02-01', periods=2)\n s1 = pd.Series([1, 2], index=idx1)\n s2 = pd.Series([3, 4], index=idx2)\n weights = [0.5, 0.5]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n\ndef test_weighted_sum_partial_overlap_calendars():\n # Series with partial overlap in calendar\n idx1 = pd.date_range('2024-01-01', periods=3)\n idx2 = pd.date_range('2024-01-02', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx1)\n s2 = pd.Series([4, 5, 6], index=idx2)\n weights = [1, 1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n # Only dates present in both should be included\n expected_idx = idx1.intersection(idx2)\n expected = (s1.reindex(expected_idx) + s2.reindex(expected_idx)) / 2\n\ndef test_weighted_sum_zero_weights():\n # Zero weights (should return NaN for all dates)\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n weights = [0, 0]\n codeflash_out
else:
generated_test_source = "from datetime import datetime, timedelta\n\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# ---------- Unit Tests ----------\n\n# Helper to create pd.Series with consecutive dates\ndef make_series(start, count, value_func=lambda i: i):\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(count)]\n values = [value_func(i) for i in range(count)]\n return pd.Series(values, index=dates)\n\n# 1. Basic Test Cases\n\ndef test_weighted_sum_simple_two_series_equal_weights():\n # Two series, equal weights\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([3.0]*5, index=s1.index)\n\ndef test_weighted_sum_simple_two_series_unequal_weights():\n # Two series, weights 2:1\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 8)\n codeflash_output = weighted_sum([s1, s2], [2, 1]); result = codeflash_output\n expected = pd.Series([(2*2+1*8)/3]*5, index=s1.index)\n\ndef test_weighted_sum_three_series_varied_weights():\n # Three series, varied weights\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 3)\n s3 = make_series(0, 3, lambda i: 5)\n codeflash_output = weighted_sum([s1, s2, s3], [0.2, 0.3, 0.5]); result = codeflash_output\n expected = pd.Series([1*0.2+3*0.3+5*0.5], index=s1.index)\n expected = pd.Series([1*0.2+3*0.3+5*0.5]*3, index=s1.index)\n\ndef test_weighted_sum_one_series():\n # Single series, should return the series itself\n s1 = make_series(0, 4, lambda i: i+5)\n codeflash_output = weighted_sum([s1], [1]); result = codeflash_output\n\n# 2. Edge Test Cases\n\ndef test_weighted_sum_empty_series_list():\n # Empty series list should raise ValueError\n with pytest.raises(ValueError):\n weighted_sum([], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Series and weights length mismatch\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(ValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-Series in series list\n s1 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, [1, 2, 3]], [1, 1])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric in weights\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, s2], [1, 'a'])\n\ndef test_weighted_sum_zero_weights():\n # All weights zero should return NaN series\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 2)\n codeflash_output = weighted_sum([s1, s2], [0, 0]); result = codeflash_output\n\ndef test_weighted_sum_negative_weights():\n # Negative weights\n s1 = make_series(0, 3, lambda i: 2)\n s2 = make_series(0, 3, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [-1, 2]); result = codeflash_output\n expected = pd.Series([(2*-1+4*2)/(-1+2)]*3, index=s1.index)\n\ndef test_weighted_sum_partial_overlap_indices():\n # Series with partially overlapping indices\n s1 = make_series(0, 5, lambda i: i)\n s2 = make_series(2, 5, lambda i: i+10)\n # Overlap is on days 2,3,4\n overlap_dates = s1.index.intersection(s2.index)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([(s1[d]+s2[d])/2 for d in overlap_dates], index=overlap_dates)\n\ndef test_weighted_sum_series_with_nans():\n # Series with NaNs, intersection should drop NaN dates\n s1 = make_series(0, 4, lambda i: float('nan') if i==2 else i)\n s2 = make_series(0, 4, lambda i: i)\n # Only dates where both have non-NaN should be used\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\ndef test_weighted_sum_series_with_different_freq():\n # Series with different frequencies\n s1 = make_se
instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version)
2025-11-18 18:17:09 +00:00
await asyncio.sleep(2)
2025-11-18 16:30:54 +00:00
return TestGenResponseSchema(
generated_tests=generated_test_source,
2025-12-30 07:44:28 +00:00
instrumented_behavior_tests=instrumented_behavior_tests or "",
instrumented_perf_tests=instrumented_perf_tests or "",
2025-11-18 16:30:54 +00:00
)
2025-10-18 07:17:16 +00:00
def validate_request_data(data: TestGenSchema) -> tuple[tuple[int, int, int], BaseTestGenContext]:
if data.test_framework not in ["unittest", "pytest"]:
raise HttpError(400, "Invalid test framework. We only support unittest and pytest.")
if not data.function_to_optimize:
raise HttpError(400, "Invalid function to optimize. It is empty.")
if not validate_trace_id(data.trace_id):
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
try:
python_version = parse_python_version(data.python_version)
except ValueError:
raise HttpError(400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above.") # noqa: B904
try:
ctx = BaseTestGenContext.get_dynamic_context(
TestGenContextData(
source_code_being_tested=data.source_code_being_tested,
2025-11-06 01:38:44 +00:00
qualified_name=data.function_to_optimize.qualified_name,
)
)
ctx.validate_python_module(feature_version=python_version[:2])
except SyntaxError:
raise HttpError(400, "Invalid source code. It is not valid Python code.") # noqa: B904
2025-10-18 07:17:16 +00:00
return python_version, ctx
2025-05-09 01:51:03 +00:00
@testgen_api.post(
"/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema}
2023-12-21 01:07:24 +00:00
)
2025-05-09 01:51:03 +00:00
async def testgen(
2025-12-30 07:51:38 +00:00
request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
2026-01-15 06:15:27 +00:00
# Route based on language
if data.language in ("javascript", "typescript"):
return await testgen_javascript(request, data)
# Default: Python test generation
return await testgen_python(request, data)
async def testgen_python(
request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
"""Generate Python tests using LLMs."""
ph(request.user, "aiservice-testgen-called", properties={"language": "python"})
try:
2025-10-18 07:17:16 +00:00
python_version, ctx = validate_request_data(data)
except HttpError as e:
2025-10-15 08:57:46 +00:00
e.add_note(f"Testgen request validation error: {e.status_code} {e.message}")
sentry_sdk.capture_exception(e)
return e.status_code, TestGenErrorResponseSchema(error=e.message)
2025-10-30 17:03:48 +00:00
if should_hack_for_demo(data.source_code_being_tested):
2025-11-18 18:17:09 +00:00
if "find_common_tags" in data.source_code_being_tested:
2025-11-18 16:30:54 +00:00
demo_hack_response = await hack_for_demo(data, python_version)
2025-11-18 18:17:09 +00:00
elif "weighted_sum" in data.source_code_being_tested:
2025-11-18 16:30:54 +00:00
demo_hack_response = await hack_for_demo_gsq(data, python_version)
return 200, demo_hack_response
2025-10-17 07:41:00 +00:00
logging.info("/testgen: Generating tests...")
try:
debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}")
2024-06-27 21:02:19 +00:00
2025-12-29 16:39:44 +00:00
# Using different LLMs for different test_index values to get more diverse tests
2025-12-29 16:35:09 +00:00
test_index = data.test_index if data.test_index is not None else 0
if test_index % 2 == 0:
2026-01-07 04:54:20 +00:00
execute_model = OPENAI_MODEL
2025-12-29 16:35:09 +00:00
model_source = "OpenAI"
else:
2026-01-07 04:54:20 +00:00
execute_model = HAIKU_MODEL
2025-12-29 16:35:09 +00:00
model_source = "Anthropic"
2025-12-30 06:45:23 +00:00
2025-12-30 04:51:23 +00:00
logging.info(
f"Using {model_source} model ({execute_model.name}) for test_index {test_index} to generate diverse tests"
)
2025-12-29 16:35:09 +00:00
(
generated_test_source,
instrumented_behavior_tests,
instrumented_perf_tests,
) = await generate_regression_tests_from_function(
2025-10-17 07:41:00 +00:00
ctx=ctx,
user_id=request.user,
2025-11-06 01:38:44 +00:00
function_name=data.function_to_optimize.qualified_name,
2025-10-17 07:41:00 +00:00
python_version=python_version,
data=data,
unit_test_package=data.test_framework,
2025-10-17 07:41:00 +00:00
is_async=data.is_async,
trace_id=data.trace_id,
call_sequence=data.call_sequence,
2025-12-29 16:35:09 +00:00
execute_model=execute_model,
2025-10-17 07:41:00 +00:00
)
ph(request.user, "aiservice-testgen-tests-generated")
2025-10-31 10:56:25 +00:00
if hasattr(request, "should_log_features") and request.should_log_features:
await log_features(
trace_id=data.trace_id,
user_id=request.user,
generated_tests=[generated_test_source],
instrumented_generated_tests=[instrumented_behavior_tests],
test_framework=data.test_framework,
metadata={
"test_timeout": data.test_timeout,
"function_to_optimize": data.function_to_optimize.function_name,
},
2025-10-20 19:54:49 +00:00
)
return 200, TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=instrumented_behavior_tests,
instrumented_perf_tests=instrumented_perf_tests,
)
except Exception as e:
logging.exception(f"Test generation failed. trace_id={data.trace_id}")
sentry_sdk.capture_exception(e)
return 500, TestGenErrorResponseSchema(error="Error generating tests. Internal server error.")