# Derived from https://github.com/openai/openai-cookbook/blob/main/examples/Unit_test_writing_using_a_multi-step_prompt.ipynb from __future__ import annotations import ast import asyncio import logging from pathlib import Path from typing import TYPE_CHECKING, TypedDict import sentry_sdk import stamina from ninja import NinjaAPI from ninja.errors import HttpError from openai import OpenAIError from aiservice.analytics.posthog import ph from aiservice.common.cst_utils import parse_module_to_cst from aiservice.common.markdown_utils import extract_code_block from aiservice.common_utils import parse_python_version, safe_isort, should_hack_for_demo, validate_trace_id from aiservice.env_specific import debug_log_sensitive_data from aiservice.llm import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL, calculate_llm_cost, call_llm from aiservice.models.functions_to_optimize import FunctionToOptimize from authapp.auth import AuthenticatedRequest from log_features.log_event import update_optimization_cost from log_features.log_features import log_features from libcst import parse_module from testgen.instrumentation.edit_generated_test import replace_definition_with_import from testgen.instrumentation.instrument_new_tests import instrument_test_source from testgen.models import ( TestGenDebugInfo, TestGenerationFailedError, TestGenErrorResponseSchema, TestGenResponseSchema, TestGenSchema, TestingMode, ) from testgen.postprocessing.code_validator import CodeValidationError, has_test_functions, validate_testgen_code from testgen.postprocessing.postprocess_pipeline import postprocessing_testgen_pipeline from testgen.testgen_context import BaseTestGenContext, TestGenContextData from languages.js_ts.testgen import testgen_javascript if TYPE_CHECKING: from openai.types.chat import ChatCompletionMessageParam from aiservice.llm import LLM class InstrumentTestSourceArgs(TypedDict): test_source: str function_to_optimize: FunctionToOptimize helper_function_names: list[str] module_path: str test_module_path: str test_framework: str test_timeout: int python_version: tuple[int, int, int] testgen_api = NinjaAPI(urls_namespace="testgen") # Get the directory of the current file current_dir = Path(__file__).parent EXPLAIN_SYSTEM_PROMPT = (current_dir / "explain_system_prompt.md").read_text() EXPLAIN_USER_PROMPT = (current_dir / "explain_user_prompt.md").read_text() EXECUTE_SYSTEM_PROMPT = (current_dir / "execute_system_prompt.md").read_text() EXECUTE_USER_PROMPT = (current_dir / "execute_user_prompt.md").read_text() EXECUTE_ASYNC_SYSTEM_PROMPT = (current_dir / "execute_async_system_prompt.md").read_text() EXECUTE_ASYNC_USER_PROMPT = (current_dir / "execute_async_user_prompt.md").read_text() JIT_INSTRUCTIONS = (current_dir / "jit_system_prompt.md").read_text() def build_prompt( ctx: BaseTestGenContext, function_name: str, unit_test_package: str, *, is_async: bool, is_numerical_code: bool | None = None, ) -> tuple[list[dict[str, str]], str, str]: if is_async: execute_system_prompt = EXECUTE_ASYNC_SYSTEM_PROMPT execute_user_prompt = EXECUTE_ASYNC_USER_PROMPT plan_content = f"""A good unit test suite for an ASYNC function should aim to: - Test the async function's behavior for a wide range of possible inputs - Test edge cases that the author may not have foreseen, including async-specific edge cases - Take advantage of the features of `{unit_test_package}` to make async tests easy to write and maintain - Be easy to read and understand, with clean async code and descriptive names - Be deterministic, so that the async tests always pass or fail in the same way - Have tests sorted by difficulty, from easiest to hardest - Should try not to mock or stub any dependencies, so that the async testing environment is as close to production - Include concurrent execution test cases to assess the function's async performance and behavior - Include throughput test cases to measure the function's performance under load and high-volume scenarios - Test proper async/await patterns and coroutine handling To help unit test the ASYNC function above, list diverse scenarios that the async function should be able to handle (and under each scenario, include a few examples as sub-bullets).""" posthog_event_suffix = "async-" error_context = "async " else: execute_system_prompt = EXECUTE_SYSTEM_PROMPT execute_user_prompt = EXECUTE_USER_PROMPT plan_content = f"""A good unit test suite should aim to: - Test the function's behavior for a wide range of possible inputs - Test edge cases that the author may not have foreseen - Take advantage of the features of `{unit_test_package}` to make the tests easy to write and maintain - Be easy to read and understand, with clean code and descriptive names - Be deterministic, so that the tests always pass or fail in the same way - Have tests sorted by difficulty, from easiest to hardest - Should try not to mock or stub any dependencies by using `{unit_test_package}`.mock or any other similar mocking or stubbing module, so that the testing environment is as close to the production environment as possible - Include Large Scale Test Cases to assess the function's performance and scalability with large data samples. To help unit test the function above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples as sub-bullets).""" posthog_event_suffix = "" error_context = "" plan_user_message = {"role": "user", "content": plan_content} package_comment = "" # if unit_test_package == "pytest": # package_comment = "# below, each test case is represented by a tuple passed to the @pytest.mark.parametrize decorator" system_prompt = execute_system_prompt.format(function_name=ctx.data.qualified_name) if is_numerical_code: system_prompt += f"\n{JIT_INSTRUCTIONS}\n" execute_system_message = {"role": "system", "content": system_prompt} execute_messages = [execute_system_message, plan_user_message] all_notes = ctx.generate_notes_markdown() note_message = {"role": "user", "content": all_notes} execute_messages += [note_message] execute_user_message = { "role": "user", "content": execute_user_prompt.format( unit_test_package=unit_test_package, function_name=function_name, function_code=ctx.data.source_code_being_tested, package_comment=package_comment, ), } execute_messages += [execute_user_message] return execute_messages, posthog_event_suffix, error_context def instrument_tests( generated_test_source: str, data: TestGenSchema, python_version: tuple[int, int, int] ) -> tuple[str | None, str | None]: common_args: InstrumentTestSourceArgs = { "test_source": generated_test_source, "function_to_optimize": data.function_to_optimize, "helper_function_names": data.helper_function_names or [], "module_path": data.module_path, "test_module_path": data.test_module_path, "test_framework": data.test_framework, "test_timeout": data.test_timeout, "python_version": python_version, } # instrument_test_source() already applies isort via format_and_float_to_top() # No need to apply isort again here (was causing double formatting overhead) behavior_result = instrument_test_source(**common_args, mode=TestingMode.BEHAVIOR) perf_result = instrument_test_source(**common_args, mode=TestingMode.PERFORMANCE) return behavior_result, perf_result class LLMOutputParseError(Exception): """Exception for LLM output parsing failures with raw output context.""" def __init__(self, message: str, raw_llm_output: str, code: str | None = None) -> None: super().__init__(message) self.raw_llm_output = raw_llm_output self.code = code # Code extracted from LLM output, if any def parse_and_validate_llm_output( response_content: str, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str, function_to_optimize: FunctionToOptimize | None = None, module_path: str | None = None, ) -> str: code: str | None = None try: if "```python" not in response_content: sentry_sdk.capture_message("LLM response did not contain a Python code block:\n" + response_content) raise LLMOutputParseError( "LLM response did not contain a Python code block.", raw_llm_output=response_content ) code = extract_code_block(response_content) if code is None: raise LLMOutputParseError( "No Python code block found in the LLM response.", raw_llm_output=response_content ) if function_to_optimize is not None and module_path is not None: try: code = replace_definition_with_import(parse_module(code), function_to_optimize, module_path).code except Exception: # noqa: BLE001 # If replacement fails (e.g., parsing error), continue with original code logging.warning("replace_definition_with_import failed, continuing with original code") cleaned_code = validate_testgen_code(code, python_version[:2]) if ctx.did_generate_ellipsis(cleaned_code, python_version): msg = f"Ellipsis in generated {error_context}test code, regenerating..." raise SyntaxError(msg) return cleaned_code # noqa: TRY300 except CodeValidationError as e: # Re-raise with raw LLM output added e.raw_llm_output = response_content # type: ignore[attr-defined] sentry_sdk.capture_exception(e) raise except LLMOutputParseError: raise except Exception as e: sentry_sdk.capture_exception(e) raise @stamina.retry(on=(SyntaxError, ValueError, OpenAIError, CodeValidationError, LLMOutputParseError), attempts=2) async def generate_and_validate_test_code( messages: list[ChatCompletionMessageParam], model: LLM, ctx: BaseTestGenContext, python_version: tuple[int, int, int], error_context: str, execute_model: LLM, cost_tracker: list[float], user_id: str, posthog_event_suffix: str, trace_id: str = "", call_sequence: int | None = None, function_to_optimize: FunctionToOptimize | None = None, module_path: str | None = None, ) -> str: obs_context: dict | None = {"call_sequence": call_sequence} if call_sequence is not None else None response = await call_llm( llm=model, messages=messages, call_type="test_generation", trace_id=trace_id, user_id=user_id, python_version=".".join(str(v) for v in python_version), context=obs_context, ) cost = calculate_llm_cost(response.raw_response, execute_model) cost_tracker.append(cost) debug_log_sensitive_data( f"OpenAIClient {error_context}execute response:\n{response.raw_response.model_dump_json(indent=2)}" ) if response.raw_response.usage: await asyncio.to_thread( ph, user_id, f"aiservice-testgen-{posthog_event_suffix}execute-openai-usage", properties={"model": execute_model.name, "usage": response.raw_response.usage.model_dump_json()}, ) # Parse and validate validated_code = parse_and_validate_llm_output( response_content=response.content, ctx=ctx, python_version=python_version, error_context=error_context, function_to_optimize=function_to_optimize, module_path=module_path, ) return validated_code @stamina.retry(on=TestGenerationFailedError, attempts=2) async def generate_regression_tests_from_function( ctx: BaseTestGenContext, user_id: str, function_name: str, python_version: tuple[int, int, int], data: TestGenSchema, unit_test_package: str = "pytest", execute_model: LLM = EXECUTE_MODEL, is_async: bool = False, # noqa: FBT001, FBT002 trace_id: str = "", call_sequence: int | None = None, ) -> tuple[str, str | None, str | None]: execute_messages, posthog_event_suffix, error_context = build_prompt( ctx=ctx, function_name=function_name, unit_test_package=unit_test_package, is_async=is_async, is_numerical_code=data.is_numerical_code, ) cost_tracker = [] try: validated_code = await generate_and_validate_test_code( messages=execute_messages, model=execute_model, ctx=ctx, python_version=python_version, error_context=error_context, execute_model=execute_model, cost_tracker=cost_tracker, user_id=user_id, posthog_event_suffix=posthog_event_suffix, trace_id=trace_id, call_sequence=call_sequence, function_to_optimize=data.function_to_optimize, module_path=data.module_path, ) total_llm_cost = sum(cost_tracker) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id) processed_cst = postprocessing_testgen_pipeline( parse_module_to_cst(validated_code), data.helper_function_names or [], data.function_to_optimize, data.module_path, data.source_code_being_tested, ) generated_test_source = processed_cst.code instrumented_behavior_tests, instrumented_perf_tests = instrument_tests( generated_test_source, data, python_version ) if instrumented_behavior_tests is None or instrumented_perf_tests is None: msg = ( f"There was an error detected in the function to optimize, is it valid Python code? trace_id={trace_id}" ) logging.error(msg) raise TestGenerationFailedError( msg, debug_info={ "stage": "instrumentation", "final_code": generated_test_source, "validation_error": "Instrumentation failed - function may not be valid Python", }, ) sorted_generated_tests = safe_isort(generated_test_source) try: parse_module_to_cst(sorted_generated_tests) generated_test_source = sorted_generated_tests except Exception: # noqa: BLE001 sentry_sdk.capture_message("isort caused a syntax error in testgen; returning un-sorted code.") tree = ast.parse(generated_test_source, feature_version=python_version[:2]) if not has_test_functions(tree): # sanity check, after all the processing somehow no test functions msg = f"No test functions were found in the generated test code. trace_id={trace_id}" logging.error(msg) raise TestGenerationFailedError( msg, debug_info={ "stage": "postprocessing", "final_code": generated_test_source, "validation_error": "No test functions found after postprocessing", }, ) return generated_test_source, instrumented_behavior_tests, instrumented_perf_tests # noqa: TRY300 except CodeValidationError as e: total_llm_cost = sum(cost_tracker) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id) msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries. trace_id={trace_id}" logging.exception(msg) debug_info = e.to_debug_dict() debug_info["raw_llm_output"] = getattr(e, "raw_llm_output", None) raise TestGenerationFailedError(msg, debug_info=debug_info) from e except LLMOutputParseError as e: total_llm_cost = sum(cost_tracker) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id) msg = f"Failed to parse LLM output for {error_context}test code after {len(cost_tracker)} tries. trace_id={trace_id}" logging.exception(msg) raise TestGenerationFailedError( msg, debug_info={ "stage": "llm_generation", "raw_llm_output": e.raw_llm_output, "initial_code": e.code, "validation_error": str(e), }, ) from e except (SyntaxError, ValueError) as e: total_llm_cost = sum(cost_tracker) await update_optimization_cost(trace_id=trace_id, cost=total_llm_cost, user_id=user_id) msg = f"Failed to generate valid {error_context}test code after {len(cost_tracker)} tries. trace_id={trace_id}" logging.exception(msg) raise TestGenerationFailedError(msg, debug_info={"stage": "unknown", "validation_error": str(e)}) from e async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema: if data.test_index == 0: generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_single_article():\n # Single article should return its tags\n articles = [{"tags": ["python", "coding", "tutorial"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_with_common_tags():\n # Multiple articles with common tags should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python", "data"]},\n {"tags": ["python", "machine learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_empty_list_of_articles():\n # Empty list of articles should return an empty set\n articles = []\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_no_common_tags():\n # Articles with no common tags should return an empty set\n articles = [\n {"tags": ["python"]},\n {"tags": ["java"]},\n {"tags": ["c++"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_empty_tag_lists():\n # Articles with some empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": ["python"]},\n {"tags": ["python", "java"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_all_articles_with_empty_tag_lists():\n # All articles with empty tag lists should return an empty set\n articles = [\n {"tags": []},\n {"tags": []},\n {"tags": []}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_special_characters():\n # Tags with special characters should be handled correctly\n articles = [\n {"tags": ["python!", "coding"]},\n {"tags": ["python!", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_case_sensitivity():\n # Tags with different cases should not be considered the same\n articles = [\n {"tags": ["Python", "coding"]},\n {"tags": ["python", "data"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_articles():\n # Large number of articles with a common tag should return that tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(1000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_number_of_tags():\n # Large number of tags with some common tags should return the common tags\n articles = [\n {"tags": [f"tag{i}" for i in range(1000)]},\n {"tags": [f"tag{i}" for i in range(500, 1500)]}\n ]\n expected = {f"tag{i}" for i in range(500, 1000)}\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_mixed_length_of_tag_lists():\n # Articles with mixed length of tag lists should return the common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["python"]},\n {"tags": ["python", "coding", "tutorial"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_tags_with_different_data_types():\n # Tags with different data types should only consider strings\n articles = [\n {"tags": ["python", 123]},\n {"tags": ["python", "123"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_performance_with_large_data():\n # Performance with large data should return the common tag\n articles = [{"tags": ["common_tag", f"tag{i}"]} for i in range(10000)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_scalability_with_increasing_tags():\n # Scalability with increasing tags should return the common tag\n articles = [{"tags": ["common_tag"] + [f"tag{i}" for i in range(j)]} for j in range(1, 1001)]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation' else: generated_test_source = '# imports\n# function to test\nfrom __future__ import annotations\n\nimport pytest # used for our unit tests\nfrom codeflash.result.common_tags import find_common_tags\n\n# unit tests\n\ndef test_empty_input_list():\n # Test with an empty list\n codeflash_output = find_common_tags([])\n # Outputs were verified to be equal to the original implementation\n\ndef test_single_article():\n # Test with a single article with tags\n codeflash_output = find_common_tags([{"tags": ["python", "coding", "development"]}])\n # Test with a single article with no tags\n codeflash_output = find_common_tags([{"tags": []}])\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_some_common_tags():\n # Test with multiple articles having some common tags\n articles = [\n {"tags": ["python", "coding", "development"]},\n {"tags": ["python", "development", "tutorial"]},\n {"tags": ["python", "development", "guide"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "news"]},\n {"tags": ["tech", "gadgets"]},\n {"tags": ["tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_multiple_articles_no_common_tags():\n # Test with multiple articles having no common tags\n articles = [\n {"tags": ["python", "coding"]},\n {"tags": ["development", "tutorial"]},\n {"tags": ["guide", "learning"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["apple", "banana"]},\n {"tags": ["orange", "grape"]},\n {"tags": ["melon", "kiwi"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_duplicate_tags():\n # Test with articles having duplicate tags\n articles = [\n {"tags": ["python", "python", "coding"]},\n {"tags": ["python", "development", "python"]},\n {"tags": ["python", "guide", "python"]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": ["tech", "tech", "news"]},\n {"tags": ["tech", "tech", "gadgets"]},\n {"tags": ["tech", "tech", "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_mixed_case_tags():\n # Test with articles having mixed case tags\n articles = [\n {"tags": ["Python", "Coding"]},\n {"tags": ["python", "Development"]},\n {"tags": ["PYTHON", "Guide"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n\n articles = [\n {"tags": ["Tech", "News"]},\n {"tags": ["tech", "Gadgets"]},\n {"tags": ["TECH", "Reviews"]}\n ]\n codeflash_output = find_common_tags(articles) # Assuming case sensitivity\n # Outputs were verified to be equal to the original implementation\n\ndef test_articles_with_non_string_tags():\n # Test with articles having non-string tags\n articles = [\n {"tags": ["python", 123, "coding"]},\n {"tags": ["python", "development", 123]},\n {"tags": ["python", "guide", 123]}\n ]\n codeflash_output = find_common_tags(articles)\n\n articles = [\n {"tags": [None, "news"]},\n {"tags": ["tech", None]},\n {"tags": [None, "reviews"]}\n ]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation\n\ndef test_large_scale_test_cases():\n # Test with large scale input where all tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(100)\n ]\n expected_output = {"tag" + str(i) for i in range(1000)}\n codeflash_output = find_common_tags(articles)\n\n # Test with large scale input where no tags should be common\n articles = [\n {"tags": ["tag" + str(i) for i in range(1000)]} for _ in range(50)\n ] + [{"tags": ["unique_tag"]}]\n codeflash_output = find_common_tags(articles)\n # Outputs were verified to be equal to the original implementation' instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version) await asyncio.sleep(5) return TestGenResponseSchema( generated_tests=generated_test_source, instrumented_behavior_tests=instrumented_behavior_tests or "", instrumented_perf_tests=instrumented_perf_tests or "", ) async def hack_for_demo_gsq(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema: if data.test_index == 0: generated_test_source = "from datetime import datetime, timedelta\n# function to test\nfrom functools import reduce\n\nimport numpy as np\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n\nclass MqTypeError(TypeError): pass\nclass MqValueError(ValueError): pass\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# unit tests\n\n# --- Basic Test Cases ---\n\ndef test_weighted_sum_simple_two_series():\n # Simple case: two series, same index, weights sum to 1\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n weights = [0.7, 0.3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 0.7 + s2 * 0.3) / (0.7 + 0.3)\n\ndef test_weighted_sum_three_series_equal_weights():\n # Three series, equal weights\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx)\n s2 = pd.Series([4, 5, 6], index=idx)\n s3 = pd.Series([7, 8, 9], index=idx)\n weights = [1, 1, 1]\n codeflash_output = weighted_sum([s1, s2, s3], weights); result = codeflash_output\n expected = (s1 + s2 + s3) / 3\n\ndef test_weighted_sum_weights_not_normalized():\n # Weights do not sum to 1\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([10, 20], index=idx)\n s2 = pd.Series([30, 40], index=idx)\n weights = [2, 3]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 2 + s2 * 3) / (2 + 3)\n\n# --- Edge Test Cases ---\n\ndef test_weighted_sum_empty_series_list():\n # Empty input series list\n with pytest.raises(MqValueError):\n weighted_sum([], [1, 2])\n\ndef test_weighted_sum_empty_weights_list():\n # Empty weights list\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Weights and series length mismatch\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-pandas Series input\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, [3, 4]], [0.5, 0.5])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n with pytest.raises(MqTypeError):\n weighted_sum([s1, s2], [0.5, \"a\"])\n\ndef test_weighted_sum_disjoint_calendars():\n # Series with disjoint calendars (no overlap)\n idx1 = pd.date_range('2024-01-01', periods=2)\n idx2 = pd.date_range('2024-02-01', periods=2)\n s1 = pd.Series([1, 2], index=idx1)\n s2 = pd.Series([3, 4], index=idx2)\n weights = [0.5, 0.5]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n\ndef test_weighted_sum_partial_overlap_calendars():\n # Series with partial overlap in calendar\n idx1 = pd.date_range('2024-01-01', periods=3)\n idx2 = pd.date_range('2024-01-02', periods=3)\n s1 = pd.Series([1, 2, 3], index=idx1)\n s2 = pd.Series([4, 5, 6], index=idx2)\n weights = [1, 1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n # Only dates present in both should be included\n expected_idx = idx1.intersection(idx2)\n expected = (s1.reindex(expected_idx) + s2.reindex(expected_idx)) / 2\n\ndef test_weighted_sum_zero_weights():\n # Zero weights (should return NaN for all dates)\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n weights = [0, 0]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n\ndef test_weighted_sum_negative_weights():\n # Negative weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([10, 20], index=idx)\n s2 = pd.Series([30, 40], index=idx)\n weights = [1, -1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 1 + s2 * -1) / (1 + -1)\n\ndef test_weighted_sum_nan_values_in_series():\n # Series containing NaN values\n idx = pd.date_range('2024-01-01', periods=3)\n s1 = pd.Series([1, None, 3], index=idx)\n s2 = pd.Series([4, 5, None], index=idx)\n weights = [1, 1]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n # NaN propagation: result should be NaN where any input is NaN\n expected = (s1 * 1 + s2 * 1) / 2\n\ndef test_weighted_sum_mixed_int_float_weights():\n # Mixed int and float weights\n idx = pd.date_range('2024-01-01', periods=2)\n s1 = pd.Series([1, 2], index=idx)\n s2 = pd.Series([3, 4], index=idx)\n weights = [1, 0.5]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 1 + s2 * 0.5) / (1 + 0.5)\n\n# --- Large Scale Test Cases ---\n\ndef test_weighted_sum_large_number_of_series():\n # Large number of series (up to 1000)\n n = 1000\n idx = pd.date_range('2024-01-01', periods=10)\n series_list = [pd.Series([i]*10, index=idx) for i in range(n)]\n weights = [1 for _ in range(n)]\n codeflash_output = weighted_sum(series_list, weights); result = codeflash_output\n # Each date should be the average of 0...999, i.e. (sum(range(n)))/n\n expected_value = sum(range(n)) / n\n\ndef test_weighted_sum_large_series_length():\n # Large series length (up to 1000 dates)\n idx = pd.date_range('2024-01-01', periods=1000)\n s1 = pd.Series(range(1000), index=idx)\n s2 = pd.Series(range(1000, 0, -1), index=idx)\n weights = [0.25, 0.75]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected = (s1 * 0.25 + s2 * 0.75) / (0.25 + 0.75)\n\ndef test_weighted_sum_large_partial_overlap():\n # Large series with partial overlap\n idx1 = pd.date_range('2024-01-01', periods=1000)\n idx2 = pd.date_range('2024-01-500', periods=500)\n s1 = pd.Series(range(1000), index=idx1)\n s2 = pd.Series(range(500), index=idx2)\n weights = [1, 2]\n codeflash_output = weighted_sum([s1, s2], weights); result = codeflash_output\n expected_idx = idx1.intersection(idx2)\n expected = (s1.reindex(expected_idx) * 1 + s2.reindex(expected_idx) * 2) / (1 + 2)\n# codeflash_output is used to check that the output of the original code is the same as that of the optimized code." else: generated_test_source = "from datetime import datetime, timedelta\n\nimport pandas as pd\n# imports\nimport pytest # used for our unit tests\nfrom gs_quant.timeseries.algebra import weighted_sum\n\n# ---------- Unit Tests ----------\n\n# Helper to create pd.Series with consecutive dates\ndef make_series(start, count, value_func=lambda i: i):\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(count)]\n values = [value_func(i) for i in range(count)]\n return pd.Series(values, index=dates)\n\n# 1. Basic Test Cases\n\ndef test_weighted_sum_simple_two_series_equal_weights():\n # Two series, equal weights\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([3.0]*5, index=s1.index)\n\ndef test_weighted_sum_simple_two_series_unequal_weights():\n # Two series, weights 2:1\n s1 = make_series(0, 5, lambda i: 2)\n s2 = make_series(0, 5, lambda i: 8)\n codeflash_output = weighted_sum([s1, s2], [2, 1]); result = codeflash_output\n expected = pd.Series([(2*2+1*8)/3]*5, index=s1.index)\n\ndef test_weighted_sum_three_series_varied_weights():\n # Three series, varied weights\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 3)\n s3 = make_series(0, 3, lambda i: 5)\n codeflash_output = weighted_sum([s1, s2, s3], [0.2, 0.3, 0.5]); result = codeflash_output\n expected = pd.Series([1*0.2+3*0.3+5*0.5], index=s1.index)\n expected = pd.Series([1*0.2+3*0.3+5*0.5]*3, index=s1.index)\n\ndef test_weighted_sum_one_series():\n # Single series, should return the series itself\n s1 = make_series(0, 4, lambda i: i+5)\n codeflash_output = weighted_sum([s1], [1]); result = codeflash_output\n\n# 2. Edge Test Cases\n\ndef test_weighted_sum_empty_series_list():\n # Empty series list should raise ValueError\n with pytest.raises(ValueError):\n weighted_sum([], [])\n\ndef test_weighted_sum_weights_length_mismatch():\n # Series and weights length mismatch\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(ValueError):\n weighted_sum([s1, s2], [1])\n\ndef test_weighted_sum_non_series_input():\n # Non-Series in series list\n s1 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, [1, 2, 3]], [1, 1])\n\ndef test_weighted_sum_non_numeric_weights():\n # Non-numeric in weights\n s1 = make_series(0, 3)\n s2 = make_series(0, 3)\n with pytest.raises(TypeError):\n weighted_sum([s1, s2], [1, 'a'])\n\ndef test_weighted_sum_zero_weights():\n # All weights zero should return NaN series\n s1 = make_series(0, 3, lambda i: 1)\n s2 = make_series(0, 3, lambda i: 2)\n codeflash_output = weighted_sum([s1, s2], [0, 0]); result = codeflash_output\n\ndef test_weighted_sum_negative_weights():\n # Negative weights\n s1 = make_series(0, 3, lambda i: 2)\n s2 = make_series(0, 3, lambda i: 4)\n codeflash_output = weighted_sum([s1, s2], [-1, 2]); result = codeflash_output\n expected = pd.Series([(2*-1+4*2)/(-1+2)]*3, index=s1.index)\n\ndef test_weighted_sum_partial_overlap_indices():\n # Series with partially overlapping indices\n s1 = make_series(0, 5, lambda i: i)\n s2 = make_series(2, 5, lambda i: i+10)\n # Overlap is on days 2,3,4\n overlap_dates = s1.index.intersection(s2.index)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n expected = pd.Series([(s1[d]+s2[d])/2 for d in overlap_dates], index=overlap_dates)\n\ndef test_weighted_sum_series_with_nans():\n # Series with NaNs, intersection should drop NaN dates\n s1 = make_series(0, 4, lambda i: float('nan') if i==2 else i)\n s2 = make_series(0, 4, lambda i: i)\n # Only dates where both have non-NaN should be used\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\ndef test_weighted_sum_series_with_different_freq():\n # Series with different frequencies\n s1 = make_series(0, 5, lambda i: i)\n s2 = pd.Series([10, 20], index=[datetime(2020, 1, 2), datetime(2020, 1, 4)])\n # Only dates 2020-01-02 and 2020-01-04 are in both\n overlap_dates = s1.index.intersection(s2.index)\n codeflash_output = weighted_sum([s1, s2], [1, 3]); result = codeflash_output\n expected = pd.Series([(s1[d]*1+s2[d]*3)/4 for d in overlap_dates], index=overlap_dates)\n\ndef test_weighted_sum_all_nan_series():\n # All series are NaN\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(3)]\n s1 = pd.Series([float('nan')] * 3, index=dates)\n s2 = pd.Series([float('nan')] * 3, index=dates)\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\n# 3. Large Scale Test Cases\n\ndef test_weighted_sum_large_number_of_series():\n # 100 series, each with 100 points\n N = 100\n M = 100\n series_list = [make_series(0, M, lambda i: i+j) for j in range(N)]\n weights = [j+1 for j in range(N)]\n codeflash_output = weighted_sum(series_list, weights); result = codeflash_output\n # For each date, expected value is weighted sum of values across series\n dates = series_list[0].index\n expected_values = []\n total_weight = sum(weights)\n for i in range(M):\n # For date i, value in series j is i+j\n weighted_sum_val = sum((i+j)*weights[j] for j in range(N)) / total_weight\n expected_values.append(weighted_sum_val)\n expected = pd.Series(expected_values, index=dates)\n\ndef test_weighted_sum_large_series_length():\n # Two series, each with 1000 points\n M = 1000\n s1 = make_series(0, M, lambda i: i)\n s2 = make_series(0, M, lambda i: 2*i)\n codeflash_output = weighted_sum([s1, s2], [1, 2]); result = codeflash_output\n expected = pd.Series([(i*1+2*i*2)/3 for i in range(M)], index=s1.index)\n\ndef test_weighted_sum_large_sparse_overlap():\n # Series with sparse overlap\n s1 = pd.Series([1]*1000, index=[datetime(2020, 1, 1)+timedelta(days=i) for i in range(0, 1000, 2)])\n s2 = pd.Series([2]*500, index=[datetime(2020, 1, 1)+timedelta(days=i) for i in range(1, 1000, 2)])\n # There should be no overlap, so result should be empty\n codeflash_output = weighted_sum([s1, s2], [1, 1]); result = codeflash_output\n\ndef test_weighted_sum_large_all_nan():\n # Large series, all NaN\n M = 1000\n dates = [datetime(2020, 1, 1) + timedelta(days=i) for i in range(M)]\n s1 = pd.Series([float('nan')]*M, index=dates)\n s2 = pd.Series([float('nan')]*M, index=dates)\n codeflash_output = weighted_sum([s1, s2], [1, 2]); result = codeflash_output\n# codeflash_output is used to check that the output of the original code is the same as that of the optimized code." instrumented_behavior_tests, instrumented_perf_tests = instrument_tests(generated_test_source, data, python_version) await asyncio.sleep(2) return TestGenResponseSchema( generated_tests=generated_test_source, instrumented_behavior_tests=instrumented_behavior_tests or "", instrumented_perf_tests=instrumented_perf_tests or "", ) def validate_request_data(data: TestGenSchema) -> tuple[tuple[int, int, int], BaseTestGenContext]: if data.test_framework not in ["unittest", "pytest"]: raise HttpError(400, "Invalid test framework. We only support unittest and pytest.") if not data.function_to_optimize: raise HttpError(400, "Invalid function to optimize. It is empty.") if not validate_trace_id(data.trace_id): raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.") if not data.python_version: raise HttpError(400, "Python version is required.") try: python_version = parse_python_version(data.python_version) except ValueError: raise HttpError(400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above.") # noqa: B904 try: ctx = BaseTestGenContext.get_dynamic_context( TestGenContextData( source_code_being_tested=data.source_code_being_tested, qualified_name=data.function_to_optimize.qualified_name, ) ) ctx.validate_python_module(feature_version=python_version[:2]) except SyntaxError: raise HttpError(400, "Invalid source code. It is not valid Python code.") # noqa: B904 return python_version, ctx @testgen_api.post( "/", response={200: TestGenResponseSchema, 400: TestGenErrorResponseSchema, 500: TestGenErrorResponseSchema} ) async def testgen( request: AuthenticatedRequest, data: TestGenSchema ) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]: # Route based on language if data.language in ("javascript", "typescript"): return await testgen_javascript(request, data) # Default: Python test generation return await testgen_python(request, data) async def testgen_python( request: AuthenticatedRequest, data: TestGenSchema ) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]: """Generate Python tests using LLMs.""" await asyncio.to_thread(ph, request.user, "aiservice-testgen-called") try: python_version, ctx = validate_request_data(data) except HttpError as e: e.add_note(f"Testgen request validation error: {e.status_code} {e.message}") sentry_sdk.capture_exception(e) return e.status_code, TestGenErrorResponseSchema(error=e.message) if should_hack_for_demo(data.source_code_being_tested): if "find_common_tags" in data.source_code_being_tested: demo_hack_response = await hack_for_demo(data, python_version) elif "weighted_sum" in data.source_code_being_tested: demo_hack_response = await hack_for_demo_gsq(data, python_version) return 200, demo_hack_response logging.info("/testgen: Generating tests...") try: debug_log_sensitive_data(f"Generating tests for function {data.function_to_optimize.function_name}") # Using different LLMs for different test_index values to get more diverse tests test_index = data.test_index if data.test_index is not None else 0 if test_index % 2 == 0: execute_model = OPENAI_MODEL model_source = "OpenAI" else: execute_model = HAIKU_MODEL model_source = "Anthropic" logging.info( f"Using {model_source} model ({execute_model.name}) for test_index {test_index} to generate diverse tests" ) ( generated_test_source, instrumented_behavior_tests, instrumented_perf_tests, ) = await generate_regression_tests_from_function( ctx=ctx, user_id=request.user, function_name=data.function_to_optimize.qualified_name, python_version=python_version, data=data, unit_test_package=data.test_framework, is_async=data.function_to_optimize.is_async or data.is_async, trace_id=data.trace_id, call_sequence=data.call_sequence, execute_model=execute_model, ) await asyncio.to_thread(ph, request.user, "aiservice-testgen-tests-generated") if hasattr(request, "should_log_features") and request.should_log_features: await log_features( trace_id=data.trace_id, user_id=request.user, generated_tests=[generated_test_source], instrumented_generated_tests=[instrumented_behavior_tests], test_framework=data.test_framework, metadata={ "test_timeout": data.test_timeout, "function_to_optimize": data.function_to_optimize.function_name, }, ) return 200, TestGenResponseSchema( generated_tests=generated_test_source, instrumented_behavior_tests=instrumented_behavior_tests, instrumented_perf_tests=instrumented_perf_tests, ) except TestGenerationFailedError as e: logging.exception(f"Test generation failed. trace_id={data.trace_id}") sentry_sdk.capture_exception(e) # Return detailed debug info for self-healing debug_info = None if e.debug_info: lines_removed = e.debug_info.get("lines_removed") debug_info = TestGenDebugInfo( stage=str(e.debug_info.get("stage", "unknown")), raw_llm_output=str(v) if (v := e.debug_info.get("raw_llm_output")) else None, initial_code=str(v) if (v := e.debug_info.get("initial_code")) else None, fixed_code=str(v) if (v := e.debug_info.get("fixed_code")) else None, final_code=str(v) if (v := e.debug_info.get("final_code")) else None, lines_removed=int(lines_removed) if lines_removed is not None else None, validation_error=str(v) if (v := e.debug_info.get("validation_error")) else None, ) return 500, TestGenErrorResponseSchema(error=str(e), trace_id=data.trace_id, debug_info=debug_info) except Exception as e: logging.exception(f"Test generation failed. trace_id={data.trace_id}") sentry_sdk.capture_exception(e) return 500, TestGenErrorResponseSchema( error="Error generating tests. Internal server error.", trace_id=data.trace_id )