codeflash-internal/django/aiservice/optimizer/optimizer.py
Aseem Saxena 2896cde219 wip
2025-11-18 08:30:54 -08:00

352 lines
46 KiB
Python

from __future__ import annotations
import asyncio
import logging
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import libcst as cst
import sentry_sdk
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable, llm_clients
from aiservice.models.aimodels import OPTIMIZE_MODEL, calculate_llm_cost
from authapp.user import get_user_by_id
from log_features.log_event import log_optimization_event
from log_features.log_features import log_features
from ninja import NinjaAPI
from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from pydantic import ValidationError
from optimizer.context_utils.context_helpers import group_code
from optimizer.context_utils.optimizer_context import (
BaseOptimizerContext,
OptimizeErrorResponseSchema,
OptimizeResponseItemSchema,
OptimizeResponseSchema,
)
from optimizer.models import OptimizeSchema # noqa: TC001
if TYPE_CHECKING:
from aiservice.models.aimodels import LLM
from django.http import HttpRequest
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionFunctionMessageParam,
ChatCompletionToolMessageParam,
)
optimizations_json = [
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
"optimization_id": str(uuid.uuid4()),
},
]
# '104b5ff9-a3df-49f7-94a5-7f0a34b6137b': '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
# '4eda9cc4-e50c-482a-8bc3-b3b786117fc7': '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # for input series, get the intersection of their calendars\n # Optimization: Use set.intersection for efficiency with large indices\n calendars = [s.index for s in series]\n if not calendars:\n # fallback for empty input\n cal = pd.DatetimeIndex([])\n else:\n cal_set = set(calendars[0])\n for idx in calendars[1:]:\n cal_set.intersection_update(idx)\n if len(cal_set) == 0:\n cal = pd.DatetimeIndex([])\n else:\n # Sorted for reproducible behavior, like DatetimeIndex from intersect1d\n cal = pd.DatetimeIndex(sorted(cal_set))\n\n # Vectorize reindexing & multiplication via DataFrame and numpy\n # Stack series as columns\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # Faster to construct array and multiply broadcast than make a list of Series\n w_arr = np.array(weights, dtype=float)\n # Avoid unnecessary construction of Series for each weight\n weighted = df.values * w_arr # shape (n_dates, n_series) * (n_series,) => (n_dates, n_series)\n weighted_sum_arr = np.nansum(weighted, axis=1)\n weight_sum_arr = np.nansum(~np.isnan(df.values) * w_arr, axis=1) # Only count where data is present\n\n # Division, preserving locations where sum of weights is zero as NaN\n with np.errstate(invalid="ignore", divide="ignore"):\n result = weighted_sum_arr / weight_sum_arr\n res_series = pd.Series(result, index=cal)\n return res_series\n',
# '867ff797-a2b1-42f8-97fd-b16285a07d6f': '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # for input series, get the intersection of their calendars\n calendars = [curve.index for curve in series]\n if not calendars:\n # handle empty series input (preserve original error propagation)\n cal = pd.DatetimeIndex([])\n else:\n # Use set intersection to accelerate intersection for many inputs\n cal = pd.DatetimeIndex(calendars[0].intersection_many(calendars[1:])) if len(calendars) > 1 else pd.DatetimeIndex(calendars[0])\n\n # reindex all series & build a numpy 2D array for fast weighted computation\n # Only create DataFrame once for all series at once (columns=series)\n if len(cal) == 0:\n # If no overlapping calendar, result is empty index. \n # Reproduce original behavior (by construction, all output is NaN so division is fine).\n result = pd.Series(index=cal, dtype=float)\n return result\n\n data = np.stack([s.reindex(cal).values for s in series], axis=1)\n weights_arr = np.asarray(weights, dtype=float)\n # Broadcast weights across all rows for calculation\n weighted_data = data * weights_arr\n\n # Calculate weighted sum and sum of weights (identical for all rows, but must support missing data)\n # When there\'s missing data, both weighted sum and sum of weights should ignore nan\n # (original code: NaN * weight yields NaN, sum ignores NaN values, denominator is sum of weights corresponding to non-NaN)\n mask = ~np.isnan(data)\n numerator = np.nansum(weighted_data, axis=1)\n denominator = np.sum(weights_arr * mask, axis=1)\n with np.errstate(divide=\'ignore\', invalid=\'ignore\'):\n values = numerator / denominator\n\n return pd.Series(values, index=cal)\n',
optimizations_json_gsq = [
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # for input series, get the intersection of their calendars\n calendars = [curve.index for curve in series]\n if not calendars:\n # handle empty series input (preserve original error propagation)\n cal = pd.DatetimeIndex([])\n else:\n # Use set intersection to accelerate intersection for many inputs\n cal = pd.DatetimeIndex(calendars[0].intersection_many(calendars[1:])) if len(calendars) > 1 else pd.DatetimeIndex(calendars[0])\n\n # reindex all series & build a numpy 2D array for fast weighted computation\n # Only create DataFrame once for all series at once (columns=series)\n if len(cal) == 0:\n # If no overlapping calendar, result is empty index. \n # Reproduce original behavior (by construction, all output is NaN so division is fine).\n result = pd.Series(index=cal, dtype=float)\n return result\n\n data = np.stack([s.reindex(cal).values for s in series], axis=1)\n weights_arr = np.asarray(weights, dtype=float)\n # Broadcast weights across all rows for calculation\n weighted_data = data * weights_arr\n\n # Calculate weighted sum and sum of weights (identical for all rows, but must support missing data)\n # When there\'s missing data, both weighted sum and sum of weights should ignore nan\n # (original code: NaN * weight yields NaN, sum ignores NaN values, denominator is sum of weights corresponding to non-NaN)\n mask = ~np.isnan(data)\n numerator = np.nansum(weighted_data, axis=1)\n denominator = np.sum(weights_arr * mask, axis=1)\n with np.errstate(divide=\'ignore\', invalid=\'ignore\'):\n values = numerator / denominator\n\n return pd.Series(values, index=cal)\n',
"explanation": "\n\n### Key optimizations:\n- **Intersection**: For many series, `.intersection_many` (if available, see pandas 1.5+) is significantly faster than chaining `np.intersect1d`. If not available, this safely falls back to repeated intersection.\n- **Data alignment**: Avoided creating a separate `pd.Series` for each weight. Instead, using NumPy arrays, broadcast the weights and handle the reindexing in an efficient 2D array.\n- **Sum computation**: The product and sums are performed with batched NumPy operations eliminating Python-level for-loops and overhead from generator comprehensions. This avoids repeatedly constructing/interpreting Series objects.\n- **NaN handling**: Handles missing data exactly as in the original by multiplying mask logic accordingly.\n- **Handles empty and no-overlap**: Preserves the original behavior for empty or non-overlapping indexes, returning an all-NaN, correct-indexed Series.\n\n**NB:** This preserves all error handling, exceptions, and the function signature, matching the requirements precisely, while delivering **substantial speedups** for large and/or many input series.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # Get the intersection of the calendars for all input series (index labels)\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (curve.index for curve in series),\n )\n )\n\n # Efficiently construct a DataFrame where columns are the series,\n # then multiply by the weights and sum along the columns\n # This avoids allocating an intermediate list of Series and summing pythonically\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # use numpy array for weights: faster than creating Series and avoids sum(weights) recalculation\n weights_arr = np.asarray(weights, dtype=np.float64)\n weighted = df.values * weights_arr[np.newaxis, :]\n weighted_sum_values = np.nansum(weighted, axis=1)\n weights_broadcast = np.broadcast_to(weights_arr, df.shape)\n weights_for_denominator = np.where(~np.isnan(df.values), weights_broadcast, 0.0)\n weights_sum = np.nansum(weights_for_denominator, axis=1)\n # Prevent division by zero; behaves as original since nansum returns 0 for all-nan, so original would return nan\n result_values = np.where(weights_sum != 0, weighted_sum_values / weights_sum, np.nan)\n result = pd.Series(result_values, index=cal)\n\n return result\n',
"explanation": "\n**Optimization Explanation**:\n- The calendar intersection now uses Python's set intersection mechanism, which is substantially faster than repeated calls to `np.intersect1d` via `reduce` for typical Pandas index sets.\n- The code now only constructs one output series for weights instead of creating and summing several unnecessary constant-valued Series. It uses NumPy's vectorized operations for arithmetic and summation, which are faster and more memory efficient than repeated Pandas Series arithmetic, especially for large time series.\n- Behavior is preserved for empty input lists and calendars.\n- The input and output types, return values, raised exceptions, and function signature remain unchanged. Comments are preserved and added only for new logic.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # for input series, get the intersection of their calendars\n # Optimization: Use set.intersection for efficiency with large indices\n calendars = [s.index for s in series]\n if not calendars:\n # fallback for empty input\n cal = pd.DatetimeIndex([])\n else:\n cal_set = set(calendars[0])\n for idx in calendars[1:]:\n cal_set.intersection_update(idx)\n if len(cal_set) == 0:\n cal = pd.DatetimeIndex([])\n else:\n # Sorted for reproducible behavior, like DatetimeIndex from intersect1d\n cal = pd.DatetimeIndex(sorted(cal_set))\n\n # Vectorize reindexing & multiplication via DataFrame and numpy\n # Stack series as columns\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # Faster to construct array and multiply broadcast than make a list of Series\n w_arr = np.array(weights, dtype=float)\n # Avoid unnecessary construction of Series for each weight\n weighted = df.values * w_arr # shape (n_dates, n_series) * (n_series,) => (n_dates, n_series)\n weighted_sum_arr = np.nansum(weighted, axis=1)\n weight_sum_arr = np.nansum(~np.isnan(df.values) * w_arr, axis=1) # Only count where data is present\n\n # Division, preserving locations where sum of weights is zero as NaN\n with np.errstate(invalid="ignore", divide="ignore"):\n result = weighted_sum_arr / weight_sum_arr\n res_series = pd.Series(result, index=cal)\n return res_series\n',
"explanation": "\n**Key optimizations explained:**\n- **Index intersection speedup:** Instead of chaining `np.intersect1d` (which is O(N^2) and re-sorts at each op), use `set.intersection_update` (O(N)) which is drastically faster for longer indices.\n- **Avoid unnecessary repeated object construction:** Instead of building `[pd.Series(w, index=cal) for w in weights]` (which allocates a new Series for every weight/date combination), use vectorized numpy and pandas operations.\n- **Vectorized calculation:** Stack the series (each reindexed to the intersection calendar) as columns in a DataFrame and perform single array multiplications and reductions for both weighted sum and sum of weights.\n- **Memory efficiency:** No large intermediate lists of Series; numpy operations work in-place.\n- **NaN handling:** Uses `np.nansum` and mask-based logic to sum weights only where actual values are present (mimics the original where NaNs would propagate).\n\n**Behavior is identical**: exceptions, input validation, and expected output/NaN-handling all preserved.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
"explanation": "\n**Key optimizations:**\n- **Calendar intersection**: Switched from `reduce(np.intersect1d, ...)` to iterative `set` intersection, which is much faster for this specific purpose and uses less temporary memory.\n- **Batch reindex and vectorized calculation**: Used `pd.concat([...], axis=1)` followed by vectorized numpy operations. This avoids slow python-level for-loops/comprehensions for reindexing, multiplying, and summing.\n- **Weights vectorization**: Constructed a weights matrix using `np.broadcast_to` instead of creating multiple identical `pd.Series` objects.\n- **NaN and edge case handling**: Preserves the original exception and NaN propagation logic while avoiding O(N^2) python expression cost.\n\n**All interface, exceptions, nan edge-cases, and type annotations are unchanged.**",
"optimization_id": str(uuid.uuid4()),
},
]
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({ctx.file_name: optimization["source_code"]}),
)
for optimization in optimizations_json
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def hack_for_demo_gsq(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({ctx.file_name: optimization["source_code"]}),
)
for optimization in optimizations_json_gsq
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
optimize_api = NinjaAPI(urls_namespace="optimize")
# Get the directory of the current file
current_dir = Path(__file__).parent
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
USER_PROMPT = (current_dir / "user_prompt.md").read_text()
ASYNC_SYSTEM_PROMPT = (current_dir / "async_system_prompt.md").read_text()
ASYNC_USER_PROMPT = (current_dir / "async_user_prompt.md").read_text()
async def optimize_python_code(
user_id: str,
ctx: BaseOptimizerContext,
dependency_code: str | None = None,
n: int = 1,
optimize_model: LLM = OPTIMIZE_MODEL,
python_version: tuple[int, int, int] = (3, 12, 9),
) -> tuple[list[OptimizeResponseItemSchema], float | None]:
"""Optimize the given python code for performance using LLMs.
Parameters
----------
user_id : str
The ID of the user requesting the optimization.
ctx : BaseOptimizerContext
The optimizer context containing source code and configuration.
dependency_code : str | None, optional
Additional dependency code for context. Default is None.
n : int, optional
Number of optimization variants to generate. Default is 1.
optimize_model : LLM, optional
The LLM model to use for optimization. Default is OPTIMIZE_MODEL.
python_version : tuple[int, int, int], optional
The python version to use. Default is (3, 12, 9).
Returns
-------
tuple[list[OptimizeResponseItemSchema], float | None]
A tuple containing a list of optimization response items and the LLM cost.
"""
logging.info("/optimize: Optimizing python code.")
debug_log_sensitive_data(f"Optimizing python code for user {user_id}:\n{ctx.source_code}")
# TODO: Experiment with iterative approaches to optimization. Take the learnings from the testing phase into the
# next optimization iteration
# TODO: Experiment with iterative chain-of-thought generation. ask what is the
# function doing and then ask it to describe how to speed it up and then generate optimization
python_version_str = ".".join(str(x) for x in python_version)
system_prompt = ctx.get_system_prompt(python_version_str)
user_prompt = ctx.get_user_prompt(dependency_code, None)
system_message = ChatCompletionSystemMessageParam(role="system", content=system_prompt)
user_message = ChatCompletionUserMessageParam(role="user", content=user_prompt)
messages: list[
ChatCompletionSystemMessageParam
| ChatCompletionUserMessageParam
| ChatCompletionAssistantMessageParam
| ChatCompletionToolMessageParam
| ChatCompletionFunctionMessageParam
] = [system_message, user_message]
llm_client = llm_clients[optimize_model.model_type]
try:
output = await llm_client.with_options(max_retries=3).chat.completions.create(
model=optimize_model.name, messages=messages, n=n
)
except Exception as e:
logging.exception("OpenAI Code Generation error in optimizer")
sentry_sdk.capture_exception(e)
debug_log_sensitive_data(f"Failed to generate code for source:\n{ctx.source_code}")
return []
llm_cost = calculate_llm_cost(output, optimize_model)
debug_log_sensitive_data(f"OpenAIClient optimization response:\n{output.model_dump_json(indent=2)}")
if output.usage is not None:
ph(
user_id,
"aiservice-optimize-openai-usage",
properties={"model": optimize_model.name, "n": n, "usage": output.usage.json()},
)
results = [content for op in output.choices if (content := op.message.content)]
optimization_response_items: list[OptimizeResponseItemSchema] = []
for result in results:
ctx.extract_code_and_explanation_from_llm_res(result)
try:
res = ctx.parse_and_generate_candidate_schema()
if res is not None and ctx.is_valid_code():
optimization_response_items.append(res)
ctx.extracted_code_and_expl = None
ctx.parsed_code_and_explanation = None
except (ValueError, ValidationError, cst.ParserSyntaxError) as e:
sentry_sdk.capture_message(f"Error parsing optimization result: {e}")
debug_log_sensitive_data(f"error for source:\n{ctx.source_code}")
debug_log_sensitive_data(f"Traceback: {e}")
continue
return optimization_response_items, llm_cost
def validate_request_data(data: OptimizeSchema, ctx: BaseOptimizerContext) -> tuple[int, int, int]:
if not data.source_code:
raise HttpError(400, "Source code cannot be empty.")
if not validate_trace_id(data.trace_id):
raise HttpError(400, "Invalid trace ID. Please provide a valid UUIDv4.")
try:
python_version = parse_python_version(data.python_version)
except ValueError as e:
raise HttpError(
400, "Invalid Python version, it should look like 3.x.x. We only support Python 3.9 and above."
) from e
try:
ctx.validate_and_parse_source_code(data.source_code, feature_version=python_version[:2])
except SyntaxError as e:
raise HttpError(
400, "Invalid source code. It is not valid Python code. Please check syntax of your code."
) from e
return python_version
@optimize_api.post(
"/", response={200: OptimizeResponseSchema, 400: OptimizeErrorResponseSchema, 500: OptimizeErrorResponseSchema}
)
async def optimize(
request: HttpRequest, data: OptimizeSchema
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
system_prompt = ASYNC_SYSTEM_PROMPT if data.is_async else SYSTEM_PROMPT
user_prompt = ASYNC_USER_PROMPT if data.is_async else USER_PROMPT
ctx: BaseOptimizerContext = BaseOptimizerContext.get_dynamic_context(system_prompt, user_prompt, data.source_code)
ph(request.user, "aiservice-optimize-called")
try:
python_version = validate_request_data(data, ctx)
except HttpError as e:
e.add_note(f"Optimizer request validation error: {e.status_code} {e.message}")
sentry_sdk.capture_exception(e)
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
if should_hack_for_demo(ctx.source_code):
if "def find_common_tags(articles" in ctx.source_code:
return 200, await hack_for_demo(ctx)
return 200, await hack_for_demo_gsq(ctx)
try:
async with asyncio.TaskGroup() as tg:
optimize_task = tg.create_task(
optimize_python_code(
request.user,
ctx,
data.dependency_code,
n=min(data.n_candidates or 5, 5),
python_version=python_version,
)
)
user_task = None
if data.current_username is None:
user_task = tg.create_task(get_user_by_id(request.user))
except Exception as e: # noqa: BLE001
e.add_note("Error during optimization task or user retrieval.")
sentry_sdk.capture_exception(e)
return 500, OptimizeErrorResponseSchema(error="Error generating optimizations. Internal server error.")
optimization_response_items, llm_cost = optimize_task.result()
if user_task:
user = await user_task
if user:
data.current_username = user.github_username
if len(optimization_response_items) == 0:
ph(request.user, "aiservice-optimize-no-optimizations-found")
debug_log_sensitive_data(f"No optimizations found for source:\n{data.source_code}")
return 500, OptimizeErrorResponseSchema(error="Could not generate any optimizations. Please try again.")
ph(
request.user,
"aiservice-optimize-optimizations-found",
properties={"num_optimizations": len(optimization_response_items)},
)
async with asyncio.TaskGroup() as tg:
event_task = tg.create_task(
log_optimization_event(
event_type="no-pr",
user_id=request.user,
current_username=data.current_username,
repo_owner=data.repo_owner,
repo_name=data.repo_name,
trace_id=data.trace_id,
api_key_id=request.api_key_id,
metadata={
"codeflash_version": data.codeflash_version,
"num_optimizations": len(optimization_response_items),
"experiment_metadata": data.experiment_metadata,
},
llm_cost=llm_cost,
)
)
tg.create_task(
log_features(
trace_id=data.trace_id,
user_id=request.user,
original_code=data.source_code,
dependency_code=data.dependency_code,
optimizations_raw={
op_id: cei.code for op_id, cei in ctx.code_and_explanation_before_post_processing.items()
},
optimizations_post={cei.optimization_id: cei.source_code for cei in optimization_response_items},
explanations_raw={
op_id: cei.explanation for op_id, cei in ctx.code_and_explanation_before_post_processing.items()
},
explanations_post={cei.optimization_id: cei.explanation for cei in optimization_response_items},
experiment_metadata=data.experiment_metadata if data.experiment_metadata else None,
# request=request,
)
)
event = event_task.result()
for item in optimization_response_items:
item.optimization_event_id = str(event.id) if event else None
response = OptimizeResponseSchema(optimizations=optimization_response_items)
def log_response() -> None:
debug_log_sensitive_data(f"Response:\n{response.json()}")
for opt in response.optimizations:
debug_log_sensitive_data(f"Optimized source:\n{opt.source_code}")
debug_log_sensitive_data(f"Optimization explanation:\n{opt.explanation}")
debug_log_sensitive_data_from_callable(log_response)
ph(request.user, "aiservice-optimize-successful")
return 200, response