fix codeflash optimizing python backend (#2483)

Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Kevin Turcios 2026-03-22 03:50:30 -05:00 committed by GitHub
parent 28c9acc877
commit 387c909c9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 1349 additions and 1247 deletions

View file

@ -20,7 +20,6 @@ concurrency:
cancel-in-progress: true
jobs:
# This job checks if the workflow should run based on file changes
check-changes:
runs-on: ubuntu-latest
outputs:
@ -36,7 +35,6 @@ jobs:
aiservice:
- 'django/aiservice/**'
# This job always runs and succeeds, allowing PRs to be merged when paths don't match
no-aiservice-changes:
name: No aiservice changes detected
needs: check-changes
@ -59,7 +57,6 @@ jobs:
CODEFLASH_PR_NUMBER: ${{ github.event.number }}
DATABASE_URL: ${{ secrets.DATABASE_URL }}
DJANGO_SETTINGS_MODULE: aiservice.settings
COLUMNS: 110
steps:
@ -68,27 +65,16 @@ jobs:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Python
- name: Set up uv
uses: astral-sh/setup-uv@v7
with:
python-version: "3.12"
enable-cache: true
- name: Install Project Dependencies
- name: Install dependencies
run: |
uv sync --refresh --isolated --active
- name: Install Codeflash in separate venv
working-directory: .
run: |
uv venv .codeflash-venv
source .codeflash-venv/bin/activate
uv pip install pytest-asyncio black
uv sync
uv pip install git+https://github.com/codeflash-ai/codeflash@main
- name: Run Codeflash to optimize code
id: optimize_code
working-directory: .
run: |
source .codeflash-venv/bin/activate
cd django/aiservice
codeflash --async --verbose
- name: Run Codeflash
run: uv run codeflash

View file

@ -1,12 +1,16 @@
from __future__ import annotations
import os
import re
import uuid
from typing import Any
import isort
ENABLE_DEMO_HACKS = os.getenv("ENABLE_DEMO_HACKS", "").lower() in ("1", "true", "yes")
def safe_isort(code: str, **kwargs) -> str: # noqa: ANN003
def safe_isort(code: str, **kwargs: Any) -> str:
"""Wrap isort.code to returns the original code if isort fails.
Args:
@ -73,10 +77,14 @@ def is_codeflash_employee(user_id: str) -> bool:
def should_hack_for_demo(source_code: str) -> bool:
if not ENABLE_DEMO_HACKS:
return False
return bool("def find_common_tags(articles" in source_code) or bool("def weighted_sum(series" in source_code)
def should_hack_for_demo_java(source_code: str) -> bool:
if not ENABLE_DEMO_HACKS:
return False
if "byte[] readFile(File" in source_code and "FileInputStream" in source_code:
return True
if "class Host" in source_code and "this.name.equals(other.name)" in source_code:
@ -85,4 +93,6 @@ def should_hack_for_demo_java(source_code: str) -> bool:
def is_host_equals_demo(source_code: str) -> bool:
if not ENABLE_DEMO_HACKS:
return False
return "class Host" in source_code and "this.name.equals(other.name)" in source_code

File diff suppressed because it is too large Load diff

View file

@ -15,7 +15,7 @@ from ninja.errors import HttpError
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
@ -50,234 +50,6 @@ def is_multi_context_java(source_code: str) -> bool:
return source_code.count("```java:") >= 1
def _extract_demo_context(source_code: str) -> tuple[str, str, str, str]:
"""Extract package, class name, exception type, and extra imports from the demo source code.
Returns:
Tuple of (package_declaration, class_name, throw_statement_prefix, extra_imports)
"""
# Extract the raw code from markdown block if present
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
raw_code = code_match.group(1) if code_match else source_code
# Extract package
pkg_match = re.search(r"^\s*package\s+([\w.]+)\s*;", raw_code, re.MULTILINE)
package_decl = f"package {pkg_match.group(1)};\n" if pkg_match else ""
# Extract class name
class_match = re.search(r"\bclass\s+(\w+)", raw_code)
class_name = class_match.group(1) if class_match else "FileUtils"
# Extract exception type from the throw statement (e.g., "throw new AerospikeException(...)")
throw_match = re.search(r"throw\s+new\s+(\w+)\s*\(", raw_code)
exception_type = throw_match.group(1) if throw_match else "RuntimeException"
# Collect extra imports needed for the exception type (skip standard java/javax)
extra_imports = ""
if exception_type != "RuntimeException":
import_match = re.search(rf"^\s*import\s+([\w.]*\.{re.escape(exception_type)})\s*;", raw_code, re.MULTILINE)
if import_match:
extra_imports = f"import {import_match.group(1)};\n"
return package_decl, class_name, exception_type, extra_imports
def _build_demo_optimizations(
package_decl: str, class_name: str, exception_type: str, extra_imports: str
) -> list[dict[str, str]]:
"""Build 2 demo optimization candidates using the extracted class context.
Candidate 1 (Files.readAllBytes) is the intended winner it benchmarks fastest.
Candidate 2 is a plausible alternative that is functionally correct but
benchmarks slightly slower, ensuring Files.readAllBytes wins the speedup critic.
"""
fmt = dict(
package_decl=package_decl, class_name=class_name, exception_type=exception_type, extra_imports=extra_imports
)
return [
# Candidate 2: FileInputStream.readAllBytes() (Java 9+)
{
"source_code": (
"{package_decl}"
"\n"
"import java.io.File;\n"
"import java.io.FileInputStream;\n"
"{extra_imports}"
"\n"
"public final class {class_name} {{\n"
" public static byte[] readFile(File file) {{\n"
" try (FileInputStream fis = new FileInputStream(file)) {{\n"
" return fis.readAllBytes();\n"
" }}\n"
" catch (Throwable e) {{\n"
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
" }}\n"
" }}\n"
"}}"
).format(**fmt),
"explanation": (
"Use FileInputStream.readAllBytes() (Java 9+) to read the entire file in one call. "
"This eliminates the manual read loop but still uses FileInputStream internally."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 1: Files.readAllBytes (THE WINNER)
{
"source_code": (
"{package_decl}"
"\n"
"import java.io.File;\n"
"import java.nio.file.Files;\n"
"{extra_imports}"
"\n"
"public final class {class_name} {{\n"
" public static byte[] readFile(File file) {{\n"
" try {{\n"
" return java.nio.file.Files.readAllBytes(file.toPath());\n"
" }}\n"
" catch (Throwable e) {{\n"
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
" }}\n"
" }}\n"
"}}"
).format(**fmt),
"explanation": (
"Replace manual FileInputStream read loop with java.nio.file.Files.readAllBytes(). "
"This NIO method is optimized at the JDK level for direct file-to-byte-array transfer, "
"eliminating manual buffering and loop overhead."
),
"optimization_id": str(uuid.uuid4()),
},
]
def _build_host_equals_demo_optimizations(source_code: str) -> list[dict[str, str]]:
"""Build 5 optimization candidates for Host.equals by reordering comparisons.
Candidate 1 (port-first early return) is the intended winner comparing the
primitive int port before the String name avoids unnecessary method dispatch.
"""
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
raw_code = code_match.group(1) if code_match else source_code
# Match: return this.name.equals(other.name) && this.port == other.port;
original_stmt = re.compile(
r"(\s*)return\s+this\.name\.equals\(other\.name\)\s*&&\s*this\.port\s*==\s*other\.port\s*;"
)
match = original_stmt.search(raw_code)
if not match:
return [
{
"source_code": raw_code,
"explanation": "No optimization applicable.",
"optimization_id": str(uuid.uuid4()),
}
]
indent = match.group(1)
inner = indent + " "
def replace_with(replacement: str) -> str:
return original_stmt.sub(replacement, raw_code)
return [
# Candidate 1 (WINNER): Port-first early return
{
"source_code": replace_with(
f"{indent}// Compare primitive port first to avoid unnecessary string equals calls.\n"
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}return this.name.equals(other.name);"
),
"explanation": (
"Compare primitive port first to avoid unnecessary string equals calls. "
"Integer comparison is a single CPU instruction, while String.equals() "
"involves method dispatch and potential character-by-character comparison."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 2: Reordered conjunction (port first in &&)
{
"source_code": replace_with(f"{indent}return this.port == other.port && this.name.equals(other.name);"),
"explanation": (
"Reorder the conjunction to evaluate the cheaper primitive int comparison first. "
"Short-circuit evaluation skips String.equals() when ports differ."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 3: Port-first with Objects.equals for null safety
{
"source_code": replace_with(
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}return java.util.Objects.equals(this.name, other.name);"
),
"explanation": (
"Check port first (cheap primitive comparison), then use Objects.equals() "
"for null-safe name comparison. Adds safety at slight method-call overhead."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 4: Ternary with port-first guard
{
"source_code": replace_with(
f"{indent}return this.port == other.port ? this.name.equals(other.name) : false;"
),
"explanation": (
"Use a ternary to short-circuit on port mismatch. "
"Evaluates the cheap int comparison first, only calling String.equals() when ports match."
),
"optimization_id": str(uuid.uuid4()),
},
# Candidate 5: Explicit null guard + port first
{
"source_code": replace_with(
f"{indent}if (this.port != other.port) {{\n"
f"{inner}return false;\n"
f"{indent}}}\n"
f"{indent}if (this.name == null) {{\n"
f"{inner}return other.name == null;\n"
f"{indent}}}\n"
f"{indent}return this.name.equals(other.name);"
),
"explanation": (
"Guard on port first, then add explicit null handling for the name field "
"before delegating to String.equals(). Avoids potential NullPointerException."
),
"optimization_id": str(uuid.uuid4()),
},
]
async def hack_for_demo_java(source_code: str) -> OptimizeResponseSchema:
# Extract file path from markdown source (```java:path/to/File.java)
file_path_match = re.search(r"```java:([^\n]+)", source_code)
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
if is_host_equals_demo(source_code):
optimizations = _build_host_equals_demo_optimizations(source_code)
else:
# Extract class context dynamically from the source code
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
source_code=group_code({file_name: opt["source_code"]}, language="java"),
)
for opt in optimizations
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def optimize_java_code_single(
user_id: str,
source_code: str,
@ -481,12 +253,13 @@ async def optimize_java(
# Determine Java version
language_version = data.language_version or "17"
# Check for demo mode
if should_hack_for_demo_java(data.source_code):
response = await hack_for_demo_java(data.source_code)
for item in response.optimizations:
from core.languages.java.demo_hacks import try_demo_optimize_java # noqa: PLC0415
demo_response = await try_demo_optimize_java(data.source_code)
if demo_response is not None:
for item in demo_response.optimizations:
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
return 200, response
return 200, demo_response
# Run optimization
optimization_results, total_cost, code_and_explanations, _optimization_models = await optimize_java_code(

View file

@ -16,17 +16,11 @@ import sentry_sdk
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import is_host_equals_demo
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm_models import OPTIMIZE_MODEL
from aiservice.validators.java_validator import validate_java_syntax
from core.languages.java.optimizer import (
_build_demo_optimizations,
_build_host_equals_demo_optimizations,
_extract_demo_context,
is_multi_context_java,
)
from core.languages.java.optimizer import is_multi_context_java
from core.shared.context_helpers import (
extract_code_and_explanation,
group_code,
@ -34,7 +28,7 @@ from core.shared.context_helpers import (
split_markdown_code,
)
from core.shared.optimizer_config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
from core.shared.optimizer_schemas import OptimizeResponseItemSchema, OptimizeResponseSchema
from core.shared.optimizer_schemas import OptimizeResponseItemSchema
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
@ -73,29 +67,6 @@ def extract_java_code_and_explanation(content: str, is_multi_file: bool = False)
return extract_code_and_explanation(content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file)
async def hack_for_demo_java_lp(source_code: str) -> OptimizeResponseSchema:
"""Return pre-canned line-profiler optimization results for the Java demo function."""
file_path_match = re.search(r"```java:([^\n]+)", source_code)
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
if is_host_equals_demo(source_code):
optimizations = _build_host_equals_demo_optimizations(source_code)
else:
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=opt["explanation"],
optimization_id=opt["optimization_id"],
source_code=group_code({file_name: opt["source_code"]}, language="java"),
)
for opt in optimizations
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def optimize_java_code_line_profiler_single(
user_id: str,
trace_id: str,

View file

@ -6,7 +6,6 @@ Instrumentation is handled by the codeflash CLI client, not here.
from __future__ import annotations
import asyncio
import logging
import os
import re
@ -17,7 +16,7 @@ import sentry_sdk
from openai.types.chat import ChatCompletionMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
from aiservice.common_utils import validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm_models import EXECUTE_MODEL
@ -520,835 +519,6 @@ def _extract_class_from_source(source_code: str) -> str | None:
return None
def _build_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 0, adapting to the target's package, class, and test framework.
File creation is in @Before/@BeforeEach so it runs once, outside the instrumentation's
inner loop. Test methods only contain readFile calls so every inner iteration succeeds
and the benchmark measures pure readFile performance.
"""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import org.junit.Rule;\n"
"import org.junit.rules.TemporaryFolder;\n"
"import static org.junit.Assert.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
"\n"
" @Rule\n"
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
"\n"
" private File smallFile;\n"
" private byte[] expectedSmall;\n"
" private File mediumFile;\n"
" private byte[] expectedMedium;\n"
" private File largeFile;\n"
" private byte[] expectedLarge;\n"
"\n"
" @Before\n"
" public void setUp() throws Exception {\n"
' smallFile = tempFolder.newFile("small.txt");\n'
' expectedSmall = "Hello, World!".getBytes();\n'
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
" out.write(expectedSmall);\n"
" }\n"
"\n"
' mediumFile = tempFolder.newFile("medium.dat");\n'
" expectedMedium = new byte[256 * 1024];\n"
" for (int i = 0; i < expectedMedium.length; i++) {\n"
" expectedMedium[i] = (byte) (i % 251);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
" out.write(expectedMedium);\n"
" }\n"
"\n"
' largeFile = tempFolder.newFile("large.dat");\n'
" expectedLarge = new byte[1024 * 1024];\n"
" for (int i = 0; i < expectedLarge.length; i++) {\n"
" expectedLarge[i] = (byte) (i % 256);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
" out.write(expectedLarge);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadSmallFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(smallFile);\n"
" assertArrayEquals(expectedSmall, result);\n"
" }\n"
"\n"
" @Test\n"
" public void testReadMediumFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(mediumFile);\n"
" assertArrayEquals(expectedMedium, result);\n"
" }\n"
"\n"
" @Test\n"
" public void testReadLargeFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(largeFile);\n"
" assertArrayEquals(expectedLarge, result);\n"
" }\n"
"}\n"
)
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import org.junit.jupiter.api.io.TempDir;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.nio.file.Path;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
"\n"
" @TempDir\n"
" Path tempDir;\n"
"\n"
" private File smallFile;\n"
" private byte[] expectedSmall;\n"
" private File mediumFile;\n"
" private byte[] expectedMedium;\n"
" private File largeFile;\n"
" private byte[] expectedLarge;\n"
"\n"
" @BeforeEach\n"
" void setUp() throws Exception {\n"
' smallFile = tempDir.resolve("small.txt").toFile();\n'
' expectedSmall = "Hello, World!".getBytes();\n'
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
" out.write(expectedSmall);\n"
" }\n"
"\n"
' mediumFile = tempDir.resolve("medium.dat").toFile();\n'
" expectedMedium = new byte[256 * 1024];\n"
" for (int i = 0; i < expectedMedium.length; i++) {\n"
" expectedMedium[i] = (byte) (i % 251);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
" out.write(expectedMedium);\n"
" }\n"
"\n"
' largeFile = tempDir.resolve("large.dat").toFile();\n'
" expectedLarge = new byte[1024 * 1024];\n"
" for (int i = 0; i < expectedLarge.length; i++) {\n"
" expectedLarge[i] = (byte) (i % 256);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
" out.write(expectedLarge);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read a small file")\n'
" void testReadSmallFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(smallFile);\n"
" assertArrayEquals(expectedSmall, result);\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read a 256KB file")\n'
" void testReadMediumFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(mediumFile);\n"
" assertArrayEquals(expectedMedium, result);\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read a 1MB file")\n'
" void testReadLargeFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(largeFile);\n"
" assertArrayEquals(expectedLarge, result);\n"
" }\n"
"}\n"
)
def _build_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 1, adapting to the target's package, class, and test framework.
Same @Before pattern as source_0: file creation outside the timed test body.
Complementary file sizes to source_0.
"""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import org.junit.Rule;\n"
"import org.junit.rules.TemporaryFolder;\n"
"import static org.junit.Assert.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.util.Arrays;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
"\n"
" @Rule\n"
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
"\n"
" private File patternFile;\n"
" private byte[] expectedPattern;\n"
" private File halfMegFile;\n"
" private byte[] expectedHalfMeg;\n"
" private File twoMegFile;\n"
" private byte[] expectedTwoMeg;\n"
"\n"
" @Before\n"
" public void setUp() throws Exception {\n"
' patternFile = tempFolder.newFile("pattern.dat");\n'
" expectedPattern = new byte[128 * 1024];\n"
" for (int i = 0; i < expectedPattern.length; i++) {\n"
" expectedPattern[i] = (byte) (i % 7);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
" out.write(expectedPattern);\n"
" }\n"
"\n"
' halfMegFile = tempFolder.newFile("half_meg.dat");\n'
" expectedHalfMeg = new byte[512 * 1024];\n"
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
" out.write(expectedHalfMeg);\n"
" }\n"
"\n"
' twoMegFile = tempFolder.newFile("two_meg.dat");\n'
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
" expectedTwoMeg[i] = (byte) (i % 199);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
" out.write(expectedTwoMeg);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadBinaryPattern() throws Exception {\n"
f" byte[] result = {class_name}.readFile(patternFile);\n"
" assertArrayEquals(expectedPattern, result);\n"
" }\n"
"\n"
" @Test\n"
" public void testReadHalfMegFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(halfMegFile);\n"
" assertArrayEquals(expectedHalfMeg, result);\n"
" }\n"
"\n"
" @Test\n"
" public void testReadTwoMegFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(twoMegFile);\n"
" assertArrayEquals(expectedTwoMeg, result);\n"
" }\n"
"}\n"
)
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import org.junit.jupiter.api.io.TempDir;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.nio.file.Path;\n"
"import java.util.Arrays;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
"\n"
" @TempDir\n"
" Path tempDir;\n"
"\n"
" private File patternFile;\n"
" private byte[] expectedPattern;\n"
" private File halfMegFile;\n"
" private byte[] expectedHalfMeg;\n"
" private File twoMegFile;\n"
" private byte[] expectedTwoMeg;\n"
"\n"
" @BeforeEach\n"
" void setUp() throws Exception {\n"
' patternFile = tempDir.resolve("pattern.dat").toFile();\n'
" expectedPattern = new byte[128 * 1024];\n"
" for (int i = 0; i < expectedPattern.length; i++) {\n"
" expectedPattern[i] = (byte) (i % 7);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
" out.write(expectedPattern);\n"
" }\n"
"\n"
' halfMegFile = tempDir.resolve("half_meg.dat").toFile();\n'
" expectedHalfMeg = new byte[512 * 1024];\n"
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
" out.write(expectedHalfMeg);\n"
" }\n"
"\n"
' twoMegFile = tempDir.resolve("two_meg.dat").toFile();\n'
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
" expectedTwoMeg[i] = (byte) (i % 199);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
" out.write(expectedTwoMeg);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 128KB binary pattern file")\n'
" void testReadBinaryPattern() throws Exception {\n"
f" byte[] result = {class_name}.readFile(patternFile);\n"
" assertArrayEquals(expectedPattern, result);\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 512KB file")\n'
" void testReadHalfMegFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(halfMegFile);\n"
" assertArrayEquals(expectedHalfMeg, result);\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 2MB file")\n'
" void testReadTwoMegFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(twoMegFile);\n"
" assertArrayEquals(expectedTwoMeg, result);\n"
" }\n"
"}\n"
)
def _build_host_equals_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 0 for Host.equals, adapting to the target's package, class, and test framework."""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import static org.junit.Assert.*;\n"
"\n"
f"import {module_path};\n"
"\n"
"/**\n"
f" * Unit tests for {module_path}.equals(...)\n"
" */\n"
f"public class {test_class_name} {{\n"
f" private {class_name} defaultHost;\n"
"\n"
" @Before\n"
" public void setUp() {\n"
f' defaultHost = new {class_name}("localhost", 3000);\n'
" }\n"
"\n"
" @Test\n"
" public void testEquals_SameInstance_ReturnsTrue() {\n"
" assertTrue(defaultHost.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
" // Both directions should be true (symmetry)\n"
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentPort_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("localhost", 3001);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentName_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_NullArgument_ReturnsFalse() {\n"
" assertFalse(defaultHost.equals(null));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentClass_ReturnsFalse() {\n"
' Object notAHost = "I am not a Host";\n'
" assertFalse(defaultHost.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("", 0);\n'
f' {class_name} b = new {class_name}("", null, 0);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test(expected = NullPointerException.class)\n"
" public void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
" // When this.name is null, equals calls this.name.equals(...), which throws NPE.\n"
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
f' {class_name} other = new {class_name}("something", 100);\n'
" thisHasNullName.equals(other);\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_OtherNameNull_ReturnsFalse() {\n"
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
f' {class_name} normal = new {class_name}("name", 200);\n'
' // "name".equals(null) returns false; no exception expected.\n'
" assertFalse(normal.equals(otherHasNullName));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MaxIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MinIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
" assertFalse(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MultipleEqualInstances_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("perf-host", 4000);\n'
f' {class_name} b = new {class_name}("perf-host", 4000);\n'
f' {class_name} c = new {class_name}("perf-host", "tls-name", 4000);\n'
" assertTrue(a.equals(b));\n"
" assertTrue(b.equals(c));\n"
" assertTrue(a.equals(c));\n"
" }\n"
"}\n"
)
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
f"import {module_path};\n"
"\n"
"/**\n"
f" * Unit tests for {module_path}.equals(...)\n"
" */\n"
f"class {test_class_name} {{\n"
f" private {class_name} defaultHost;\n"
"\n"
" @BeforeEach\n"
" void setUp() {\n"
f' defaultHost = new {class_name}("localhost", 3000);\n'
" }\n"
"\n"
" @Test\n"
' @DisplayName("Same instance returns true")\n'
" void testEquals_SameInstance_ReturnsTrue() {\n"
" assertTrue(defaultHost.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Equal name and port ignoring TLS returns true")\n'
" void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different port returns false")\n'
" void testEquals_DifferentPort_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("localhost", 3001);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different name returns false")\n'
" void testEquals_DifferentName_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null argument returns false")\n'
" void testEquals_NullArgument_ReturnsFalse() {\n"
" assertFalse(defaultHost.equals(null));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different class returns false")\n'
" void testEquals_DifferentClass_ReturnsFalse() {\n"
' Object notAHost = "I am not a Host";\n'
" assertFalse(defaultHost.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Empty name both returns true")\n'
" void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("", 0);\n'
f' {class_name} b = new {class_name}("", null, 0);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null this.name throws NullPointerException")\n'
" void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
f' {class_name} other = new {class_name}("something", 100);\n'
" assertThrows(NullPointerException.class, () -> thisHasNullName.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null other.name returns false")\n'
" void testEquals_OtherNameNull_ReturnsFalse() {\n"
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
f' {class_name} normal = new {class_name}("name", 200);\n'
" assertFalse(normal.equals(otherHasNullName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Max int port returns true")\n'
" void testEquals_MaxIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Min int port returns true")\n'
" void testEquals_MinIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Max vs different port returns false")\n'
" void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
" assertFalse(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Multiple equal instances with transitivity")\n'
" void testEquals_MultipleEqualInstances_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("perf-host", 4000);\n'
f' {class_name} b = new {class_name}("perf-host", 4000);\n'
f' {class_name} c = new {class_name}("perf-host", "tls-name", 4000);\n'
" assertTrue(a.equals(b));\n"
" assertTrue(b.equals(c));\n"
" assertTrue(a.equals(c));\n"
" }\n"
"}\n"
)
def _build_host_equals_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 1 for Host.equals, adapting to the target's package, class, and test framework."""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import static org.junit.Assert.*;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
f" private {class_name} hostSimple;\n"
f" private {class_name} hostWithTls;\n"
"\n"
" @Before\n"
" public void setUp() {\n"
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
" }\n"
"\n"
" @Test\n"
" public void testSameReference_True() {\n"
" // same instance should be equal to itself\n"
" assertTrue(hostSimple.equals(hostSimple));\n"
" }\n"
"\n"
" @Test\n"
" public void testEqualNameAndPort_True() {\n"
" // two distinct instances with same name and port (tls ignored) are equal\n"
f' {class_name} a = new {class_name}("db1", 4000);\n'
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentTlsIgnored_True() {\n"
" // tlsName is ignored for equality\n"
" assertTrue(hostSimple.equals(hostWithTls));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentName_False() {\n"
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
" assertFalse(hostSimple.equals(otherName));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentPort_False() {\n"
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
" assertFalse(hostSimple.equals(otherPort));\n"
" }\n"
"\n"
" @Test\n"
" public void testNullComparison_False() {\n"
" // equals should return false when compared to null\n"
" assertFalse(hostSimple.equals(null));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentClass_False() {\n"
" // equals should return false when compared to an object of another class\n"
' Object notAHost = "server.example.com:3000";\n'
" assertFalse(hostSimple.equals(notAHost));\n"
" }\n"
"\n"
" @Test(expected = NullPointerException.class)\n"
" public void testNameNull_ThrowsNullPointerException() {\n"
" // If this.name is null, equals tries to call this.name.equals(...) and will NPE.\n"
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
" // This invocation should throw NPE because this.name is null\n"
" nullNameHost.equals(otherNullNameHost);\n"
" }\n"
"\n"
" @Test\n"
" public void testOtherNameNull_False() {\n"
" // If other.name is null but this.name is non-null, equals should return false\n"
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
" assertFalse(hostSimple.equals(otherNullName));\n"
" }\n"
"\n"
" @Test\n"
" public void testPortBoundary_ZeroAndMax_True() {\n"
f' {class_name} lowA = new {class_name}("edge", 0);\n'
f' {class_name} lowB = new {class_name}("edge", 0);\n'
" assertTrue(lowA.equals(lowB));\n"
"\n"
f' {class_name} highA = new {class_name}("edge", 65535);\n'
f' {class_name} highB = new {class_name}("edge", 65535);\n'
" assertTrue(highA.equals(highB));\n"
" }\n"
"\n"
" @Test\n"
" public void testPortBoundary_DifferentPorts_False() {\n"
f' {class_name} low = new {class_name}("edge", 0);\n'
f' {class_name} high = new {class_name}("edge", 65535);\n'
" assertFalse(low.equals(high));\n"
" }\n"
"\n"
" @Test\n"
" public void testEqualityWithVariousNames() {\n"
f' {class_name} a1 = new {class_name}("alpha", 1000);\n'
f' {class_name} b1 = new {class_name}("alpha", "tls-alpha", 1000);\n'
" assertTrue(a1.equals(b1));\n"
"\n"
f' {class_name} a2 = new {class_name}("beta", 2000);\n'
f' {class_name} b2 = new {class_name}("beta", "tls-beta", 2000);\n'
" assertTrue(a2.equals(b2));\n"
"\n"
f' {class_name} a3 = new {class_name}("gamma", 3000);\n'
f' {class_name} b3 = new {class_name}("delta", 3000);\n'
" assertFalse(a3.equals(b3));\n"
" }\n"
"}\n"
)
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
f" private {class_name} hostSimple;\n"
f" private {class_name} hostWithTls;\n"
"\n"
" @BeforeEach\n"
" void setUp() {\n"
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
" }\n"
"\n"
" @Test\n"
' @DisplayName("Same reference returns true")\n'
" void testSameReference_True() {\n"
" assertTrue(hostSimple.equals(hostSimple));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Equal name and port with different TLS returns true")\n'
" void testEqualNameAndPort_True() {\n"
f' {class_name} a = new {class_name}("db1", 4000);\n'
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different TLS name is ignored")\n'
" void testDifferentTlsIgnored_True() {\n"
" assertTrue(hostSimple.equals(hostWithTls));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different name returns false")\n'
" void testDifferentName_False() {\n"
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
" assertFalse(hostSimple.equals(otherName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different port returns false")\n'
" void testDifferentPort_False() {\n"
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
" assertFalse(hostSimple.equals(otherPort));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null comparison returns false")\n'
" void testNullComparison_False() {\n"
" assertFalse(hostSimple.equals(null));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different class returns false")\n'
" void testDifferentClass_False() {\n"
' Object notAHost = "server.example.com:3000";\n'
" assertFalse(hostSimple.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null name throws NullPointerException")\n'
" void testNameNull_ThrowsNullPointerException() {\n"
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
" assertThrows(NullPointerException.class, () -> nullNameHost.equals(otherNullNameHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Other name null returns false")\n'
" void testOtherNameNull_False() {\n"
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
" assertFalse(hostSimple.equals(otherNullName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Port boundary zero and max")\n'
" void testPortBoundary_ZeroAndMax_True() {\n"
f' {class_name} lowA = new {class_name}("edge", 0);\n'
f' {class_name} lowB = new {class_name}("edge", 0);\n'
" assertTrue(lowA.equals(lowB));\n"
"\n"
f' {class_name} highA = new {class_name}("edge", 65535);\n'
f' {class_name} highB = new {class_name}("edge", 65535);\n'
" assertTrue(highA.equals(highB));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different boundary ports returns false")\n'
" void testPortBoundary_DifferentPorts_False() {\n"
f' {class_name} low = new {class_name}("edge", 0);\n'
f' {class_name} high = new {class_name}("edge", 65535);\n'
" assertFalse(low.equals(high));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Equality with various host names")\n'
" void testEqualityWithVariousNames() {\n"
f' {class_name} a1 = new {class_name}("alpha", 1000);\n'
f' {class_name} b1 = new {class_name}("alpha", "tls-alpha", 1000);\n'
" assertTrue(a1.equals(b1));\n"
"\n"
f' {class_name} a2 = new {class_name}("beta", 2000);\n'
f' {class_name} b2 = new {class_name}("beta", "tls-beta", 2000);\n'
" assertTrue(a2.equals(b2));\n"
"\n"
f' {class_name} a3 = new {class_name}("gamma", 3000);\n'
f' {class_name} b3 = new {class_name}("delta", 3000);\n'
" assertFalse(a3.equals(b3));\n"
" }\n"
"}\n"
)
async def hack_for_demo_java_testgen(data: TestGenSchema) -> TestGenResponseSchema:
# Extract package and class dynamically from the source code
source_code = data.source_code_being_tested
package_name = _extract_package_from_source(source_code) or ""
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
test_index = data.test_index if data.test_index is not None else 0
if is_host_equals_demo(source_code):
if test_index == 0:
generated_test_source = _build_host_equals_demo_test_source_0(package_name, class_name, test_framework)
else:
generated_test_source = _build_host_equals_demo_test_source_1(package_name, class_name, test_framework)
elif test_index == 0:
generated_test_source = _build_demo_test_source_0(package_name, class_name, test_framework)
else:
generated_test_source = _build_demo_test_source_1(package_name, class_name, test_framework)
await asyncio.sleep(5)
# For Java, instrumentation is done client-side
return TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=generated_test_source,
instrumented_perf_tests=generated_test_source,
)
async def testgen_java(
request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
@ -1363,9 +533,11 @@ async def testgen_java(
logging.info("/testgen: Generating Java tests...")
# Demo hack: intercept before LLM call for demo functions
if should_hack_for_demo_java(data.source_code_being_tested):
return 200, await hack_for_demo_java_testgen(data)
from core.languages.java.demo_hacks import try_demo_testgen_java # noqa: PLC0415
demo_result = await try_demo_testgen_java(data)
if demo_result is not None:
return 200, demo_result
try:
debug_log_sensitive_data(f"Generating Java tests for function {data.function_to_optimize.function_name}")

View file

@ -7,13 +7,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from core.languages.python.code_repair.code_repair import code_repair
from core.languages.python.explanations.explanations import explain_optimizations
from core.languages.python.jit_rewrite.jit_rewrite import jit_rewrite
from core.languages.python.optimization_review.optimization_review import get_optimization_review
from core.languages.python.optimizer.optimizer import optimize_python
from core.languages.python.testgen.generate import testgen_python
if TYPE_CHECKING:
from aiservice.llm_models import LLM
from authapp.auth import AuthenticatedRequest
@ -54,18 +47,24 @@ class PythonHandler:
self, request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
"""Generate tests for Python code."""
from core.languages.python.testgen.generate import testgen_python # noqa: PLC0415
return await testgen_python(request, data)
async def optimizer_optimize(
self, request: AuthenticatedRequest, data: OptimizeSchema
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
"""Optimize Python code for performance."""
from core.languages.python.optimizer.optimizer import optimize_python # noqa: PLC0415
return await optimize_python(request, data)
async def code_repair_repair(
self, user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM | None = None
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
"""Repair Python code based on error messages."""
from core.languages.python.code_repair.code_repair import code_repair # noqa: PLC0415
kwargs = {}
if optimize_model is not None:
kwargs["optimize_model"] = optimize_model
@ -75,6 +74,8 @@ class PythonHandler:
self, request: AuthenticatedRequest, data: JitRewriteOptimizeSchema
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
"""Perform JIT rewriting of Python code."""
from core.languages.python.jit_rewrite.jit_rewrite import jit_rewrite # noqa: PLC0415
return await jit_rewrite(request, data)
async def optimization_review(
@ -84,6 +85,8 @@ class PythonHandler:
optimization_review_model: LLM | None = None,
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
"""Review Python code for optimization opportunities."""
from core.languages.python.optimization_review.optimization_review import get_optimization_review # noqa: PLC0415
kwargs = {}
if optimization_review_model is not None:
kwargs["optimization_review_model"] = optimization_review_model
@ -93,6 +96,8 @@ class PythonHandler:
self, user_id: str, data: ExplanationsSchema, explanations_model: LLM | None = None
) -> ExplanationsResponseSchema | ExplanationsErrorResponseSchema:
"""Explain optimizations made to Python code."""
from core.languages.python.explanations.explanations import explain_optimizations # noqa: PLC0415
kwargs = {}
if explanations_model is not None:
kwargs["explanations_model"] = explanations_model

View file

@ -4,7 +4,7 @@ import json
import logging
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
import sentry_sdk
from ninja import NinjaAPI, Schema
@ -57,10 +57,10 @@ class OptimizationReviewSchema(Schema):
original_runtime: str
optimized_runtime: str
speedup: str
existing_tests: str | None
generated_tests: str | None
replay_tests: str | None
benchmark_details: str | None
existing_tests: str | None = None
generated_tests: str | None = None
replay_tests: str | None = None
benchmark_details: str | None = None
coverage_message: str
loop_count: int
explanation: str
@ -185,7 +185,7 @@ async def get_optimization_review(
debug_log_sensitive_data(f"{messages[0]}{messages[1]}")
obs_context: dict = {"speedup": data.speedup}
obs_context: dict[str, Any] = {"speedup": data.speedup}
if data.call_sequence is not None:
obs_context["call_sequence"] = data.call_sequence
@ -272,11 +272,11 @@ async def optimization_review(
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
response_code, output = await get_optimization_review(request, data)
try:
if response_code == 200:
review_event = output.review.value # ty:ignore[unresolved-attribute]
review_explanation = output.review_explanation # ty:ignore[unresolved-attribute]
if isinstance(output, OptimizationReviewResponseSchema):
review_event = output.review.value
review_explanation = output.review_explanation
else:
review_event = output.error # ty:ignore[unresolved-attribute]
review_event = output.error
review_explanation = ""
await update_optimization_features_review(
trace_id=data.trace_id,

View file

@ -0,0 +1,91 @@
from __future__ import annotations
import asyncio
import uuid
from typing import TYPE_CHECKING
from aiservice.common_utils import should_hack_for_demo
from core.shared.context_helpers import group_code
from core.shared.optimizer_schemas import OptimizeResponseItemSchema, OptimizeResponseSchema
if TYPE_CHECKING:
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
optimizations_json = [
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
"optimization_id": str(uuid.uuid4()),
},
]
optimizations_json_gsq = [
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
"explanation": "\n\n### Key optimizations:\n- **Intersection**: For many series, `.intersection_many` (if available, see pandas 1.5+) is significantly faster than chaining `np.intersect1d`. If not available, this safely falls back to repeated intersection.\n- **Data alignment**: Avoided creating a separate `pd.Series` for each weight. Instead, using NumPy arrays, broadcast the weights and handle the reindexing in an efficient 2D array.\n- **Sum computation**: The product and sums are performed with batched NumPy operations eliminating Python-level for-loops and overhead from generator comprehensions. This avoids repeatedly constructing/interpreting Series objects.\n- **NaN handling**: Handles missing data exactly as in the original by multiplying mask logic accordingly.\n- **Handles empty and no-overlap**: Preserves the original behavior for empty or non-overlapping indexes, returning an all-NaN, correct-indexed Series.\n\n**NB:** This preserves all error handling, exceptions, and the function signature, matching the requirements precisely, while delivering **substantial speedups** for large and/or many input series.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '''# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n **Usage**\n Calculate a weighted sum e.g. for a basket.\n **Examples**\n Generate price series and get a sum (weights 70%/30%).\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n **See also**\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n # for input series, get the intersection of their calendars\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (\n curve.index\n for curve in series\n ),\n )\n )\n # reindex inputs and calculate using numpy vectorization\n series_arrays = [s.reindex(cal).to_numpy() for s in series]\n weights_array = np.asarray(weights, dtype=float).reshape(-1, 1)\n stacked_series = np.vstack(series_arrays)\n weighted_sum_array = np.sum(stacked_series * weights_array, axis=0)\n sum_weights = np.sum(weights_array)\n return pd.Series(weighted_sum_array / sum_weights, index=cal)\n''',
"explanation": "\n**Key optimizations explained:**\n- **Index intersection speedup:** Instead of chaining `np.intersect1d` (which is O(N^2) and re-sorts at each op), use `set.intersection_update` (O(N)) which is drastically faster for longer indices.\n- **Avoid unnecessary repeated object construction:** Instead of building `[pd.Series(w, index=cal) for w in weights]` (which allocates a new Series for every weight/date combination), use vectorized numpy and pandas operations.\n- **Vectorized calculation:** Stack the series (each reindexed to the intersection calendar) as columns in a DataFrame and perform single array multiplications and reductions for both weighted sum and sum of weights.\n- **Memory efficiency:** No large intermediate lists of Series; numpy operations work in-place.\n- **NaN handling:** Uses `np.nansum` and mask-based logic to sum weights only where actual values are present (mimics the original where NaNs would propagate).\n\n**Behavior is identical**: exceptions, input validation, and expected output/NaN-handling all preserved.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # Get the intersection of the calendars for all input series (index labels)\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (curve.index for curve in series),\n )\n )\n\n # Efficiently construct a DataFrame where columns are the series,\n # then multiply by the weights and sum along the columns\n # This avoids allocating an intermediate list of Series and summing pythonically\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # use numpy array for weights: faster than creating Series and avoids sum(weights) recalculation\n weights_arr = np.asarray(weights, dtype=np.float64)\n weighted = df.values * weights_arr[np.newaxis, :]\n weighted_sum_values = np.nansum(weighted, axis=1)\n weights_broadcast = np.broadcast_to(weights_arr, df.shape)\n weights_for_denominator = np.where(~np.isnan(df.values), weights_broadcast, 0.0)\n weights_sum = np.nansum(weights_for_denominator, axis=1)\n # Prevent division by zero; behaves as original since nansum returns 0 for all-nan, so original would return nan\n result_values = np.where(weights_sum != 0, weighted_sum_values / weights_sum, np.nan)\n result = pd.Series(result_values, index=cal)\n\n return result\n',
"explanation": "\n**Optimization Explanation**:\n- The calendar intersection now uses Python's set intersection mechanism, which is substantially faster than repeated calls to `np.intersect1d` via `reduce` for typical Pandas index sets.\n- The code now only constructs one output series for weights instead of creating and summing several unnecessary constant-valued Series. It uses NumPy's vectorized operations for arithmetic and summation, which are faster and more memory efficient than repeated Pandas Series arithmetic, especially for large time series.\n- Behavior is preserved for empty input lists and calendars.\n- The input and output types, return values, raised exceptions, and function signature remain unchanged. Comments are preserved and added only for new logic.",
"optimization_id": str(uuid.uuid4()),
},
]
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
file_name = getattr(ctx, "file_name", "source.py")
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({file_name: optimization["source_code"]}),
)
for optimization in optimizations_json
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def hack_for_demo_gsq(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
file_name = getattr(ctx, "file_name", "source.py")
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({file_name: optimization["source_code"]}),
)
for optimization in optimizations_json_gsq
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def try_demo_optimize(ctx: BaseOptimizerContext) -> OptimizeResponseSchema | None:
if not should_hack_for_demo(ctx.source_code):
return None
if "def find_common_tags(articles" in ctx.source_code:
return await hack_for_demo(ctx)
if "def weighted_sum(series" in ctx.source_code:
return await hack_for_demo_gsq(ctx)
return None

View file

@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
import logging
import uuid
from pathlib import Path
from typing import TYPE_CHECKING, Any
@ -13,7 +12,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from pydantic import ValidationError
from aiservice.analytics.posthog import ph
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
@ -22,7 +21,6 @@ from core.languages.python.optimizer.context_utils.optimizer_context import Base
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
from core.log_features.log_event import get_or_create_optimization_event
from core.log_features.log_features import log_features
from core.shared.context_helpers import group_code
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
from core.shared.optimizer_models import OptimizedCandidateSource
from core.shared.optimizer_schemas import (
@ -37,76 +35,6 @@ if TYPE_CHECKING:
from authapp.auth import AuthenticatedRequest
from core.shared.optimizer_models import OptimizeSchema
optimizations_json = [
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
"optimization_id": str(uuid.uuid4()),
},
]
optimizations_json_gsq = [
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
"explanation": "\n\n### Key optimizations:\n- **Intersection**: For many series, `.intersection_many` (if available, see pandas 1.5+) is significantly faster than chaining `np.intersect1d`. If not available, this safely falls back to repeated intersection.\n- **Data alignment**: Avoided creating a separate `pd.Series` for each weight. Instead, using NumPy arrays, broadcast the weights and handle the reindexing in an efficient 2D array.\n- **Sum computation**: The product and sums are performed with batched NumPy operations eliminating Python-level for-loops and overhead from generator comprehensions. This avoids repeatedly constructing/interpreting Series objects.\n- **NaN handling**: Handles missing data exactly as in the original by multiplying mask logic accordingly.\n- **Handles empty and no-overlap**: Preserves the original behavior for empty or non-overlapping indexes, returning an all-NaN, correct-indexed Series.\n\n**NB:** This preserves all error handling, exceptions, and the function signature, matching the requirements precisely, while delivering **substantial speedups** for large and/or many input series.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '''# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n **Usage**\n Calculate a weighted sum e.g. for a basket.\n **Examples**\n Generate price series and get a sum (weights 70%/30%).\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n **See also**\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n # for input series, get the intersection of their calendars\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (\n curve.index\n for curve in series\n ),\n )\n )\n # reindex inputs and calculate using numpy vectorization\n series_arrays = [s.reindex(cal).to_numpy() for s in series]\n weights_array = np.asarray(weights, dtype=float).reshape(-1, 1)\n stacked_series = np.vstack(series_arrays)\n weighted_sum_array = np.sum(stacked_series * weights_array, axis=0)\n sum_weights = np.sum(weights_array)\n return pd.Series(weighted_sum_array / sum_weights, index=cal)\n''',
"explanation": "\n**Key optimizations explained:**\n- **Index intersection speedup:** Instead of chaining `np.intersect1d` (which is O(N^2) and re-sorts at each op), use `set.intersection_update` (O(N)) which is drastically faster for longer indices.\n- **Avoid unnecessary repeated object construction:** Instead of building `[pd.Series(w, index=cal) for w in weights]` (which allocates a new Series for every weight/date combination), use vectorized numpy and pandas operations.\n- **Vectorized calculation:** Stack the series (each reindexed to the intersection calendar) as columns in a DataFrame and perform single array multiplications and reductions for both weighted sum and sum of weights.\n- **Memory efficiency:** No large intermediate lists of Series; numpy operations work in-place.\n- **NaN handling:** Uses `np.nansum` and mask-based logic to sum weights only where actual values are present (mimics the original where NaNs would propagate).\n\n**Behavior is identical**: exceptions, input validation, and expected output/NaN-handling all preserved.",
"optimization_id": str(uuid.uuid4()),
},
{
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # Get the intersection of the calendars for all input series (index labels)\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (curve.index for curve in series),\n )\n )\n\n # Efficiently construct a DataFrame where columns are the series,\n # then multiply by the weights and sum along the columns\n # This avoids allocating an intermediate list of Series and summing pythonically\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # use numpy array for weights: faster than creating Series and avoids sum(weights) recalculation\n weights_arr = np.asarray(weights, dtype=np.float64)\n weighted = df.values * weights_arr[np.newaxis, :]\n weighted_sum_values = np.nansum(weighted, axis=1)\n weights_broadcast = np.broadcast_to(weights_arr, df.shape)\n weights_for_denominator = np.where(~np.isnan(df.values), weights_broadcast, 0.0)\n weights_sum = np.nansum(weights_for_denominator, axis=1)\n # Prevent division by zero; behaves as original since nansum returns 0 for all-nan, so original would return nan\n result_values = np.where(weights_sum != 0, weighted_sum_values / weights_sum, np.nan)\n result = pd.Series(result_values, index=cal)\n\n return result\n',
"explanation": "\n**Optimization Explanation**:\n- The calendar intersection now uses Python's set intersection mechanism, which is substantially faster than repeated calls to `np.intersect1d` via `reduce` for typical Pandas index sets.\n- The code now only constructs one output series for weights instead of creating and summing several unnecessary constant-valued Series. It uses NumPy's vectorized operations for arithmetic and summation, which are faster and more memory efficient than repeated Pandas Series arithmetic, especially for large time series.\n- Behavior is preserved for empty input lists and calendars.\n- The input and output types, return values, raised exceptions, and function signature remain unchanged. Comments are preserved and added only for new logic.",
"optimization_id": str(uuid.uuid4()),
},
]
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
file_name = getattr(ctx, "file_name", "source.py")
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({file_name: optimization["source_code"]}),
)
for optimization in optimizations_json
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
async def hack_for_demo_gsq(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
file_name = getattr(ctx, "file_name", "source.py")
response_list: list[OptimizeResponseItemSchema] = [
OptimizeResponseItemSchema(
explanation=optimization["explanation"],
optimization_id=optimization["optimization_id"],
source_code=group_code({file_name: optimization["source_code"]}),
)
for optimization in optimizations_json_gsq
]
await asyncio.sleep(5)
return OptimizeResponseSchema(optimizations=response_list)
# Get the directory of the current file
current_dir = Path(__file__).parent
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
@ -297,12 +225,10 @@ async def optimize_python(
sentry_sdk.capture_exception(e)
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
if should_hack_for_demo(ctx.source_code):
response_code = 200
if "def find_common_tags(articles" in ctx.source_code:
response = await hack_for_demo(ctx)
elif "def weighted_sum(series" in ctx.source_code:
response = await hack_for_demo_gsq(ctx)
from core.languages.python.optimizer.demo_hacks import try_demo_optimize # noqa: PLC0415
demo_response = await try_demo_optimize(ctx)
if demo_response is not None:
async with asyncio.TaskGroup() as tg:
event_task = tg.create_task(
get_or_create_optimization_event(
@ -322,9 +248,9 @@ async def optimize_python(
)
)
event, _created = event_task.result()
for item in response.optimizations:
for item in demo_response.optimizations:
item.optimization_event_id = str(event.id) if event else None
return response_code, response
return 200, demo_response
try:
async with asyncio.TaskGroup() as tg:

View file

@ -10,12 +10,12 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import split_markdown_code
from aiservice.common_utils import parse_python_version, should_hack_for_demo_java, validate_trace_id
from aiservice.common_utils import parse_python_version, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
from aiservice.llm import llm_client
from aiservice.llm_models import OPTIMIZE_MODEL
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
from core.languages.java.optimizer_lp import hack_for_demo_java_lp, optimize_java_code_line_profiler
from core.languages.java.optimizer_lp import optimize_java_code_line_profiler
from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_context_ts
from core.languages.js_ts.optimizer_lp import optimize_javascript_code_line_profiler
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
@ -274,12 +274,12 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
elif language == "java":
# Java path
from aiservice.validators.java_validator import validate_java_syntax # noqa: PLC0415
from core.languages.java.demo_hacks import try_demo_optimize_java_lp # noqa: PLC0415
from core.languages.java.optimizer import is_multi_context_java # noqa: PLC0415
# Demo hack shortcut
if should_hack_for_demo_java(data.source_code):
response = await hack_for_demo_java_lp(data.source_code)
return 200, response
demo_response = await try_demo_optimize_java_lp(data.source_code)
if demo_response is not None:
return 200, demo_response
is_multi_file = is_multi_context_java(data.source_code)

View file

@ -6,11 +6,11 @@ import logging
import re
import tokenize
from difflib import SequenceMatcher
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
import libcst as cst
import sentry_sdk
from libcst import CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
from libcst import BaseStatement, CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
from aiservice.common_utils import safe_isort
@ -110,8 +110,8 @@ def cleanup_explanations(
class DocstringVisitor(CSTVisitor):
def __init__(self) -> None:
self.original_docstrings = {}
self.class_name = None
self.original_docstrings: dict[str, str] = {}
self.class_name: str | None = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
@ -133,9 +133,9 @@ class DocstringVisitor(CSTVisitor):
class DocstringTransformer(CSTTransformer):
def __init__(self, original_docstrings: dict[str, str | None]) -> None:
def __init__(self, original_docstrings: dict[str, str]) -> None:
self.original_docstrings = original_docstrings
self.class_name = None
self.class_name: str | None = None
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self.class_name = node.name.value
@ -145,15 +145,15 @@ class DocstringTransformer(CSTTransformer):
original_docstring = self.original_docstrings.get(self.class_name) if self.class_name else None
if original_docstring:
if not updated_node.get_docstring(clean=False):
new_body = [
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*list(updated_node.body.body),
*cast(list[BaseStatement], list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*list(updated_node.body.body[1:]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
self.class_name = None
@ -165,15 +165,15 @@ class DocstringTransformer(CSTTransformer):
original_docstring = self.original_docstrings.get(qualified_name)
if original_docstring:
if not updated_node.get_docstring(clean=False):
new_body = [
new_body: list[BaseStatement] = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*list(updated_node.body.body),
*cast(list[BaseStatement], list(updated_node.body.body)),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
else:
new_body = [
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
*list(updated_node.body.body[1:]),
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
]
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
return updated_node
@ -259,12 +259,13 @@ def _strip_comments_from_code(code: str) -> str:
The same code with all comments removed, preserving string content
"""
# add comment here
try:
lines = code.splitlines(keepends=True)
tokens = tokenize.generate_tokens(io.StringIO(code).readline)
# Build a per-line map of comment ranges to remove
comments_by_line = [[] for _ in range(len(lines))]
comments_by_line: list[list[tuple[int, int]]] = [[] for _ in range(len(lines))]
for token in tokens:
if token.type == tokenize.COMMENT:
line_idx = token.start[0] - 1
@ -445,7 +446,7 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
if orig_line_idx is not None and orig_line_idx > 0 and opt_idx not in code_changed_lines:
# Look backwards for ALL consecutive comment-only or blank lines
# Collect them all, then decide which ones to restore
preceding_lines = []
preceding_lines: list[str] = []
check_idx = orig_line_idx - 1
while check_idx >= 0:
@ -466,7 +467,8 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
# Use the original line (preserves original comments or lack thereof)
result_lines.append(found_orig)
restored_orig_indices.add(orig_line_idx)
if orig_line_idx is not None:
restored_orig_indices.add(orig_line_idx)
# Also check for trailing blank/comment lines after this line that were removed
# BUT: only restore them if this code line is UNCHANGED and they don't come before a changed line

View file

@ -5,12 +5,25 @@ from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from aiservice.common_utils import should_hack_for_demo
from .validate import instrument_tests
if TYPE_CHECKING:
from core.shared.testgen_models import TestGenResponseSchema, TestGenSchema
async def try_demo_testgen(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema | None:
"""Return a canned demo response if source matches a demo function, else None."""
if not should_hack_for_demo(data.source_code_being_tested):
return None
if "find_common_tags" in data.source_code_being_tested:
return await hack_for_demo(data, python_version)
if "weighted_sum" in data.source_code_being_tested:
return await hack_for_demo_gsq(data, python_version)
return None
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
from core.shared.testgen_models import TestGenResponseSchema

View file

@ -12,14 +12,14 @@ from ninja.errors import HttpError
from aiservice.analytics.posthog import ph
from aiservice.common.markdown_utils import extract_code_block, split_markdown_code
from aiservice.common_utils import safe_isort, should_hack_for_demo
from aiservice.common_utils import safe_isort
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import llm_client
from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
from aiservice.models.functions_to_optimize import FunctionToOptimize
from core.languages.python.cst_utils import any_ellipsis_in_cst, ellipsis_in_cst_not_types, parse_module_to_cst
from core.languages.python.optimizer.context_utils.context_helpers import is_multi_context
from core.languages.python.testgen.demo_hacks import hack_for_demo, hack_for_demo_gsq
from core.languages.python.testgen.demo_hacks import try_demo_testgen
from core.languages.python.testgen.models import CostTracker, LLMOutputParseError
from core.languages.python.testgen.postprocessing.code_validator import (
CodeValidationError,
@ -364,12 +364,8 @@ async def testgen_python(
sentry_sdk.capture_exception(e)
return e.status_code, TestGenErrorResponseSchema(error=e.message)
if should_hack_for_demo(data.source_code_being_tested):
if "find_common_tags" in data.source_code_being_tested:
demo_hack_response = await hack_for_demo(data, python_version)
elif "weighted_sum" in data.source_code_being_tested:
demo_hack_response = await hack_for_demo_gsq(data, python_version)
return 200, demo_hack_response
if (demo_result := await try_demo_testgen(data, python_version)) is not None:
return 200, demo_result
logging.info("/testgen: Generating tests...")
try: