fix codeflash optimizing python backend (#2483)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
28c9acc877
commit
387c909c9e
14 changed files with 1349 additions and 1247 deletions
26
.github/workflows/codeflash-aiservice.yaml
vendored
26
.github/workflows/codeflash-aiservice.yaml
vendored
|
|
@ -20,7 +20,6 @@ concurrency:
|
|||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# This job checks if the workflow should run based on file changes
|
||||
check-changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
|
|
@ -36,7 +35,6 @@ jobs:
|
|||
aiservice:
|
||||
- 'django/aiservice/**'
|
||||
|
||||
# This job always runs and succeeds, allowing PRs to be merged when paths don't match
|
||||
no-aiservice-changes:
|
||||
name: No aiservice changes detected
|
||||
needs: check-changes
|
||||
|
|
@ -59,7 +57,6 @@ jobs:
|
|||
CODEFLASH_PR_NUMBER: ${{ github.event.number }}
|
||||
DATABASE_URL: ${{ secrets.DATABASE_URL }}
|
||||
DJANGO_SETTINGS_MODULE: aiservice.settings
|
||||
|
||||
COLUMNS: 110
|
||||
|
||||
steps:
|
||||
|
|
@ -68,27 +65,16 @@ jobs:
|
|||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Python
|
||||
- name: Set up uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
python-version: "3.12"
|
||||
enable-cache: true
|
||||
|
||||
- name: Install Project Dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --refresh --isolated --active
|
||||
|
||||
- name: Install Codeflash in separate venv
|
||||
working-directory: .
|
||||
run: |
|
||||
uv venv .codeflash-venv
|
||||
source .codeflash-venv/bin/activate
|
||||
uv pip install pytest-asyncio black
|
||||
uv sync
|
||||
uv pip install git+https://github.com/codeflash-ai/codeflash@main
|
||||
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code
|
||||
working-directory: .
|
||||
run: |
|
||||
source .codeflash-venv/bin/activate
|
||||
cd django/aiservice
|
||||
codeflash --async --verbose
|
||||
- name: Run Codeflash
|
||||
run: uv run codeflash
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import isort
|
||||
|
||||
ENABLE_DEMO_HACKS = os.getenv("ENABLE_DEMO_HACKS", "").lower() in ("1", "true", "yes")
|
||||
|
||||
def safe_isort(code: str, **kwargs) -> str: # noqa: ANN003
|
||||
|
||||
def safe_isort(code: str, **kwargs: Any) -> str:
|
||||
"""Wrap isort.code to returns the original code if isort fails.
|
||||
|
||||
Args:
|
||||
|
|
@ -73,10 +77,14 @@ def is_codeflash_employee(user_id: str) -> bool:
|
|||
|
||||
|
||||
def should_hack_for_demo(source_code: str) -> bool:
|
||||
if not ENABLE_DEMO_HACKS:
|
||||
return False
|
||||
return bool("def find_common_tags(articles" in source_code) or bool("def weighted_sum(series" in source_code)
|
||||
|
||||
|
||||
def should_hack_for_demo_java(source_code: str) -> bool:
|
||||
if not ENABLE_DEMO_HACKS:
|
||||
return False
|
||||
if "byte[] readFile(File" in source_code and "FileInputStream" in source_code:
|
||||
return True
|
||||
if "class Host" in source_code and "this.name.equals(other.name)" in source_code:
|
||||
|
|
@ -85,4 +93,6 @@ def should_hack_for_demo_java(source_code: str) -> bool:
|
|||
|
||||
|
||||
def is_host_equals_demo(source_code: str) -> bool:
|
||||
if not ENABLE_DEMO_HACKS:
|
||||
return False
|
||||
return "class Host" in source_code and "this.name.equals(other.name)" in source_code
|
||||
|
|
|
|||
1157
django/aiservice/core/languages/java/demo_hacks.py
Normal file
1157
django/aiservice/core/languages/java/demo_hacks.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -15,7 +15,7 @@ from ninja.errors import HttpError
|
|||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
|
|
@ -50,234 +50,6 @@ def is_multi_context_java(source_code: str) -> bool:
|
|||
return source_code.count("```java:") >= 1
|
||||
|
||||
|
||||
def _extract_demo_context(source_code: str) -> tuple[str, str, str, str]:
|
||||
"""Extract package, class name, exception type, and extra imports from the demo source code.
|
||||
|
||||
Returns:
|
||||
Tuple of (package_declaration, class_name, throw_statement_prefix, extra_imports)
|
||||
|
||||
"""
|
||||
# Extract the raw code from markdown block if present
|
||||
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
|
||||
raw_code = code_match.group(1) if code_match else source_code
|
||||
|
||||
# Extract package
|
||||
pkg_match = re.search(r"^\s*package\s+([\w.]+)\s*;", raw_code, re.MULTILINE)
|
||||
package_decl = f"package {pkg_match.group(1)};\n" if pkg_match else ""
|
||||
|
||||
# Extract class name
|
||||
class_match = re.search(r"\bclass\s+(\w+)", raw_code)
|
||||
class_name = class_match.group(1) if class_match else "FileUtils"
|
||||
|
||||
# Extract exception type from the throw statement (e.g., "throw new AerospikeException(...)")
|
||||
throw_match = re.search(r"throw\s+new\s+(\w+)\s*\(", raw_code)
|
||||
exception_type = throw_match.group(1) if throw_match else "RuntimeException"
|
||||
|
||||
# Collect extra imports needed for the exception type (skip standard java/javax)
|
||||
extra_imports = ""
|
||||
if exception_type != "RuntimeException":
|
||||
import_match = re.search(rf"^\s*import\s+([\w.]*\.{re.escape(exception_type)})\s*;", raw_code, re.MULTILINE)
|
||||
if import_match:
|
||||
extra_imports = f"import {import_match.group(1)};\n"
|
||||
|
||||
return package_decl, class_name, exception_type, extra_imports
|
||||
|
||||
|
||||
def _build_demo_optimizations(
|
||||
package_decl: str, class_name: str, exception_type: str, extra_imports: str
|
||||
) -> list[dict[str, str]]:
|
||||
"""Build 2 demo optimization candidates using the extracted class context.
|
||||
|
||||
Candidate 1 (Files.readAllBytes) is the intended winner — it benchmarks fastest.
|
||||
Candidate 2 is a plausible alternative that is functionally correct but
|
||||
benchmarks slightly slower, ensuring Files.readAllBytes wins the speedup critic.
|
||||
"""
|
||||
fmt = dict(
|
||||
package_decl=package_decl, class_name=class_name, exception_type=exception_type, extra_imports=extra_imports
|
||||
)
|
||||
|
||||
return [
|
||||
# Candidate 2: FileInputStream.readAllBytes() (Java 9+)
|
||||
{
|
||||
"source_code": (
|
||||
"{package_decl}"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.io.FileInputStream;\n"
|
||||
"{extra_imports}"
|
||||
"\n"
|
||||
"public final class {class_name} {{\n"
|
||||
" public static byte[] readFile(File file) {{\n"
|
||||
" try (FileInputStream fis = new FileInputStream(file)) {{\n"
|
||||
" return fis.readAllBytes();\n"
|
||||
" }}\n"
|
||||
" catch (Throwable e) {{\n"
|
||||
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
|
||||
" }}\n"
|
||||
" }}\n"
|
||||
"}}"
|
||||
).format(**fmt),
|
||||
"explanation": (
|
||||
"Use FileInputStream.readAllBytes() (Java 9+) to read the entire file in one call. "
|
||||
"This eliminates the manual read loop but still uses FileInputStream internally."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
# Candidate 1: Files.readAllBytes (THE WINNER)
|
||||
{
|
||||
"source_code": (
|
||||
"{package_decl}"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.nio.file.Files;\n"
|
||||
"{extra_imports}"
|
||||
"\n"
|
||||
"public final class {class_name} {{\n"
|
||||
" public static byte[] readFile(File file) {{\n"
|
||||
" try {{\n"
|
||||
" return java.nio.file.Files.readAllBytes(file.toPath());\n"
|
||||
" }}\n"
|
||||
" catch (Throwable e) {{\n"
|
||||
' throw new {exception_type}("Failed to read " + file.getAbsolutePath(), e);\n'
|
||||
" }}\n"
|
||||
" }}\n"
|
||||
"}}"
|
||||
).format(**fmt),
|
||||
"explanation": (
|
||||
"Replace manual FileInputStream read loop with java.nio.file.Files.readAllBytes(). "
|
||||
"This NIO method is optimized at the JDK level for direct file-to-byte-array transfer, "
|
||||
"eliminating manual buffering and loop overhead."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _build_host_equals_demo_optimizations(source_code: str) -> list[dict[str, str]]:
|
||||
"""Build 5 optimization candidates for Host.equals by reordering comparisons.
|
||||
|
||||
Candidate 1 (port-first early return) is the intended winner — comparing the
|
||||
primitive int port before the String name avoids unnecessary method dispatch.
|
||||
"""
|
||||
code_match = re.search(r"```java:[^\n]*\n(.*?)```", source_code, re.DOTALL)
|
||||
raw_code = code_match.group(1) if code_match else source_code
|
||||
|
||||
# Match: return this.name.equals(other.name) && this.port == other.port;
|
||||
original_stmt = re.compile(
|
||||
r"(\s*)return\s+this\.name\.equals\(other\.name\)\s*&&\s*this\.port\s*==\s*other\.port\s*;"
|
||||
)
|
||||
|
||||
match = original_stmt.search(raw_code)
|
||||
if not match:
|
||||
return [
|
||||
{
|
||||
"source_code": raw_code,
|
||||
"explanation": "No optimization applicable.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
indent = match.group(1)
|
||||
inner = indent + " "
|
||||
|
||||
def replace_with(replacement: str) -> str:
|
||||
return original_stmt.sub(replacement, raw_code)
|
||||
|
||||
return [
|
||||
# Candidate 1 (WINNER): Port-first early return
|
||||
{
|
||||
"source_code": replace_with(
|
||||
f"{indent}// Compare primitive port first to avoid unnecessary string equals calls.\n"
|
||||
f"{indent}if (this.port != other.port) {{\n"
|
||||
f"{inner}return false;\n"
|
||||
f"{indent}}}\n"
|
||||
f"{indent}return this.name.equals(other.name);"
|
||||
),
|
||||
"explanation": (
|
||||
"Compare primitive port first to avoid unnecessary string equals calls. "
|
||||
"Integer comparison is a single CPU instruction, while String.equals() "
|
||||
"involves method dispatch and potential character-by-character comparison."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
# Candidate 2: Reordered conjunction (port first in &&)
|
||||
{
|
||||
"source_code": replace_with(f"{indent}return this.port == other.port && this.name.equals(other.name);"),
|
||||
"explanation": (
|
||||
"Reorder the conjunction to evaluate the cheaper primitive int comparison first. "
|
||||
"Short-circuit evaluation skips String.equals() when ports differ."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
# Candidate 3: Port-first with Objects.equals for null safety
|
||||
{
|
||||
"source_code": replace_with(
|
||||
f"{indent}if (this.port != other.port) {{\n"
|
||||
f"{inner}return false;\n"
|
||||
f"{indent}}}\n"
|
||||
f"{indent}return java.util.Objects.equals(this.name, other.name);"
|
||||
),
|
||||
"explanation": (
|
||||
"Check port first (cheap primitive comparison), then use Objects.equals() "
|
||||
"for null-safe name comparison. Adds safety at slight method-call overhead."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
# Candidate 4: Ternary with port-first guard
|
||||
{
|
||||
"source_code": replace_with(
|
||||
f"{indent}return this.port == other.port ? this.name.equals(other.name) : false;"
|
||||
),
|
||||
"explanation": (
|
||||
"Use a ternary to short-circuit on port mismatch. "
|
||||
"Evaluates the cheap int comparison first, only calling String.equals() when ports match."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
# Candidate 5: Explicit null guard + port first
|
||||
{
|
||||
"source_code": replace_with(
|
||||
f"{indent}if (this.port != other.port) {{\n"
|
||||
f"{inner}return false;\n"
|
||||
f"{indent}}}\n"
|
||||
f"{indent}if (this.name == null) {{\n"
|
||||
f"{inner}return other.name == null;\n"
|
||||
f"{indent}}}\n"
|
||||
f"{indent}return this.name.equals(other.name);"
|
||||
),
|
||||
"explanation": (
|
||||
"Guard on port first, then add explicit null handling for the name field "
|
||||
"before delegating to String.equals(). Avoids potential NullPointerException."
|
||||
),
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def hack_for_demo_java(source_code: str) -> OptimizeResponseSchema:
|
||||
# Extract file path from markdown source (```java:path/to/File.java)
|
||||
file_path_match = re.search(r"```java:([^\n]+)", source_code)
|
||||
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
|
||||
|
||||
if is_host_equals_demo(source_code):
|
||||
optimizations = _build_host_equals_demo_optimizations(source_code)
|
||||
else:
|
||||
# Extract class context dynamically from the source code
|
||||
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
|
||||
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
|
||||
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
source_code=group_code({file_name: opt["source_code"]}, language="java"),
|
||||
)
|
||||
for opt in optimizations
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
async def optimize_java_code_single(
|
||||
user_id: str,
|
||||
source_code: str,
|
||||
|
|
@ -481,12 +253,13 @@ async def optimize_java(
|
|||
# Determine Java version
|
||||
language_version = data.language_version or "17"
|
||||
|
||||
# Check for demo mode
|
||||
if should_hack_for_demo_java(data.source_code):
|
||||
response = await hack_for_demo_java(data.source_code)
|
||||
for item in response.optimizations:
|
||||
from core.languages.java.demo_hacks import try_demo_optimize_java # noqa: PLC0415
|
||||
|
||||
demo_response = await try_demo_optimize_java(data.source_code)
|
||||
if demo_response is not None:
|
||||
for item in demo_response.optimizations:
|
||||
item.optimization_event_id = str(optimization_event.id) if optimization_event else None
|
||||
return 200, response
|
||||
return 200, demo_response
|
||||
|
||||
# Run optimization
|
||||
optimization_results, total_cost, code_and_explanations, _optimization_models = await optimize_java_code(
|
||||
|
|
|
|||
|
|
@ -16,17 +16,11 @@ import sentry_sdk
|
|||
from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZE_MODEL
|
||||
from aiservice.validators.java_validator import validate_java_syntax
|
||||
from core.languages.java.optimizer import (
|
||||
_build_demo_optimizations,
|
||||
_build_host_equals_demo_optimizations,
|
||||
_extract_demo_context,
|
||||
is_multi_context_java,
|
||||
)
|
||||
from core.languages.java.optimizer import is_multi_context_java
|
||||
from core.shared.context_helpers import (
|
||||
extract_code_and_explanation,
|
||||
group_code,
|
||||
|
|
@ -34,7 +28,7 @@ from core.shared.context_helpers import (
|
|||
split_markdown_code,
|
||||
)
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_LP_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_schemas import OptimizeResponseItemSchema, OptimizeResponseSchema
|
||||
from core.shared.optimizer_schemas import OptimizeResponseItemSchema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
|
@ -73,29 +67,6 @@ def extract_java_code_and_explanation(content: str, is_multi_file: bool = False)
|
|||
return extract_code_and_explanation(content, JAVA_CODE_PATTERN, JAVA_CODE_WITH_PATH_PATTERN, is_multi_file)
|
||||
|
||||
|
||||
async def hack_for_demo_java_lp(source_code: str) -> OptimizeResponseSchema:
|
||||
"""Return pre-canned line-profiler optimization results for the Java demo function."""
|
||||
file_path_match = re.search(r"```java:([^\n]+)", source_code)
|
||||
file_name = file_path_match.group(1).strip() if file_path_match else "Source.java"
|
||||
|
||||
if is_host_equals_demo(source_code):
|
||||
optimizations = _build_host_equals_demo_optimizations(source_code)
|
||||
else:
|
||||
package_decl, class_name, exception_type, extra_imports = _extract_demo_context(source_code)
|
||||
optimizations = _build_demo_optimizations(package_decl, class_name, exception_type, extra_imports)
|
||||
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=opt["explanation"],
|
||||
optimization_id=opt["optimization_id"],
|
||||
source_code=group_code({file_name: opt["source_code"]}, language="java"),
|
||||
)
|
||||
for opt in optimizations
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
async def optimize_java_code_line_profiler_single(
|
||||
user_id: str,
|
||||
trace_id: str,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ Instrumentation is handled by the codeflash CLI client, not here.
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
|
@ -17,7 +16,7 @@ import sentry_sdk
|
|||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.common_utils import validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL
|
||||
|
|
@ -520,835 +519,6 @@ def _extract_class_from_source(source_code: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
def _build_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
|
||||
"""Build demo test source 0, adapting to the target's package, class, and test framework.
|
||||
|
||||
File creation is in @Before/@BeforeEach so it runs once, outside the instrumentation's
|
||||
inner loop. Test methods only contain readFile calls so every inner iteration succeeds
|
||||
and the benchmark measures pure readFile performance.
|
||||
"""
|
||||
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
||||
test_class_name = f"{class_name}Test"
|
||||
|
||||
if test_framework == "junit4":
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.Before;\n"
|
||||
"import org.junit.Test;\n"
|
||||
"import org.junit.Rule;\n"
|
||||
"import org.junit.rules.TemporaryFolder;\n"
|
||||
"import static org.junit.Assert.*;\n"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.io.FileOutputStream;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"public class {test_class_name} {{\n"
|
||||
"\n"
|
||||
" @Rule\n"
|
||||
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
|
||||
"\n"
|
||||
" private File smallFile;\n"
|
||||
" private byte[] expectedSmall;\n"
|
||||
" private File mediumFile;\n"
|
||||
" private byte[] expectedMedium;\n"
|
||||
" private File largeFile;\n"
|
||||
" private byte[] expectedLarge;\n"
|
||||
"\n"
|
||||
" @Before\n"
|
||||
" public void setUp() throws Exception {\n"
|
||||
' smallFile = tempFolder.newFile("small.txt");\n'
|
||||
' expectedSmall = "Hello, World!".getBytes();\n'
|
||||
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
|
||||
" out.write(expectedSmall);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' mediumFile = tempFolder.newFile("medium.dat");\n'
|
||||
" expectedMedium = new byte[256 * 1024];\n"
|
||||
" for (int i = 0; i < expectedMedium.length; i++) {\n"
|
||||
" expectedMedium[i] = (byte) (i % 251);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
|
||||
" out.write(expectedMedium);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' largeFile = tempFolder.newFile("large.dat");\n'
|
||||
" expectedLarge = new byte[1024 * 1024];\n"
|
||||
" for (int i = 0; i < expectedLarge.length; i++) {\n"
|
||||
" expectedLarge[i] = (byte) (i % 256);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
|
||||
" out.write(expectedLarge);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadSmallFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(smallFile);\n"
|
||||
" assertArrayEquals(expectedSmall, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadMediumFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(mediumFile);\n"
|
||||
" assertArrayEquals(expectedMedium, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadLargeFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(largeFile);\n"
|
||||
" assertArrayEquals(expectedLarge, result);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
# JUnit 5
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.jupiter.api.BeforeEach;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"import org.junit.jupiter.api.DisplayName;\n"
|
||||
"import org.junit.jupiter.api.io.TempDir;\n"
|
||||
"import static org.junit.jupiter.api.Assertions.*;\n"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.io.FileOutputStream;\n"
|
||||
"import java.nio.file.Path;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"class {test_class_name} {{\n"
|
||||
"\n"
|
||||
" @TempDir\n"
|
||||
" Path tempDir;\n"
|
||||
"\n"
|
||||
" private File smallFile;\n"
|
||||
" private byte[] expectedSmall;\n"
|
||||
" private File mediumFile;\n"
|
||||
" private byte[] expectedMedium;\n"
|
||||
" private File largeFile;\n"
|
||||
" private byte[] expectedLarge;\n"
|
||||
"\n"
|
||||
" @BeforeEach\n"
|
||||
" void setUp() throws Exception {\n"
|
||||
' smallFile = tempDir.resolve("small.txt").toFile();\n'
|
||||
' expectedSmall = "Hello, World!".getBytes();\n'
|
||||
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
|
||||
" out.write(expectedSmall);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' mediumFile = tempDir.resolve("medium.dat").toFile();\n'
|
||||
" expectedMedium = new byte[256 * 1024];\n"
|
||||
" for (int i = 0; i < expectedMedium.length; i++) {\n"
|
||||
" expectedMedium[i] = (byte) (i % 251);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
|
||||
" out.write(expectedMedium);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' largeFile = tempDir.resolve("large.dat").toFile();\n'
|
||||
" expectedLarge = new byte[1024 * 1024];\n"
|
||||
" for (int i = 0; i < expectedLarge.length; i++) {\n"
|
||||
" expectedLarge[i] = (byte) (i % 256);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
|
||||
" out.write(expectedLarge);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read a small file")\n'
|
||||
" void testReadSmallFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(smallFile);\n"
|
||||
" assertArrayEquals(expectedSmall, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read a 256KB file")\n'
|
||||
" void testReadMediumFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(mediumFile);\n"
|
||||
" assertArrayEquals(expectedMedium, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read a 1MB file")\n'
|
||||
" void testReadLargeFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(largeFile);\n"
|
||||
" assertArrayEquals(expectedLarge, result);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
|
||||
def _build_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
|
||||
"""Build demo test source 1, adapting to the target's package, class, and test framework.
|
||||
|
||||
Same @Before pattern as source_0: file creation outside the timed test body.
|
||||
Complementary file sizes to source_0.
|
||||
"""
|
||||
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
||||
test_class_name = f"{class_name}Test"
|
||||
|
||||
if test_framework == "junit4":
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.Before;\n"
|
||||
"import org.junit.Test;\n"
|
||||
"import org.junit.Rule;\n"
|
||||
"import org.junit.rules.TemporaryFolder;\n"
|
||||
"import static org.junit.Assert.*;\n"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.io.FileOutputStream;\n"
|
||||
"import java.util.Arrays;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"public class {test_class_name} {{\n"
|
||||
"\n"
|
||||
" @Rule\n"
|
||||
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
|
||||
"\n"
|
||||
" private File patternFile;\n"
|
||||
" private byte[] expectedPattern;\n"
|
||||
" private File halfMegFile;\n"
|
||||
" private byte[] expectedHalfMeg;\n"
|
||||
" private File twoMegFile;\n"
|
||||
" private byte[] expectedTwoMeg;\n"
|
||||
"\n"
|
||||
" @Before\n"
|
||||
" public void setUp() throws Exception {\n"
|
||||
' patternFile = tempFolder.newFile("pattern.dat");\n'
|
||||
" expectedPattern = new byte[128 * 1024];\n"
|
||||
" for (int i = 0; i < expectedPattern.length; i++) {\n"
|
||||
" expectedPattern[i] = (byte) (i % 7);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
|
||||
" out.write(expectedPattern);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' halfMegFile = tempFolder.newFile("half_meg.dat");\n'
|
||||
" expectedHalfMeg = new byte[512 * 1024];\n"
|
||||
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
|
||||
" out.write(expectedHalfMeg);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' twoMegFile = tempFolder.newFile("two_meg.dat");\n'
|
||||
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
|
||||
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
|
||||
" expectedTwoMeg[i] = (byte) (i % 199);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
|
||||
" out.write(expectedTwoMeg);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadBinaryPattern() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(patternFile);\n"
|
||||
" assertArrayEquals(expectedPattern, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadHalfMegFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(halfMegFile);\n"
|
||||
" assertArrayEquals(expectedHalfMeg, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testReadTwoMegFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(twoMegFile);\n"
|
||||
" assertArrayEquals(expectedTwoMeg, result);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
# JUnit 5
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.jupiter.api.BeforeEach;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"import org.junit.jupiter.api.DisplayName;\n"
|
||||
"import org.junit.jupiter.api.io.TempDir;\n"
|
||||
"import static org.junit.jupiter.api.Assertions.*;\n"
|
||||
"\n"
|
||||
"import java.io.File;\n"
|
||||
"import java.io.FileOutputStream;\n"
|
||||
"import java.nio.file.Path;\n"
|
||||
"import java.util.Arrays;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"class {test_class_name} {{\n"
|
||||
"\n"
|
||||
" @TempDir\n"
|
||||
" Path tempDir;\n"
|
||||
"\n"
|
||||
" private File patternFile;\n"
|
||||
" private byte[] expectedPattern;\n"
|
||||
" private File halfMegFile;\n"
|
||||
" private byte[] expectedHalfMeg;\n"
|
||||
" private File twoMegFile;\n"
|
||||
" private byte[] expectedTwoMeg;\n"
|
||||
"\n"
|
||||
" @BeforeEach\n"
|
||||
" void setUp() throws Exception {\n"
|
||||
' patternFile = tempDir.resolve("pattern.dat").toFile();\n'
|
||||
" expectedPattern = new byte[128 * 1024];\n"
|
||||
" for (int i = 0; i < expectedPattern.length; i++) {\n"
|
||||
" expectedPattern[i] = (byte) (i % 7);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
|
||||
" out.write(expectedPattern);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' halfMegFile = tempDir.resolve("half_meg.dat").toFile();\n'
|
||||
" expectedHalfMeg = new byte[512 * 1024];\n"
|
||||
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
|
||||
" out.write(expectedHalfMeg);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
' twoMegFile = tempDir.resolve("two_meg.dat").toFile();\n'
|
||||
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
|
||||
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
|
||||
" expectedTwoMeg[i] = (byte) (i % 199);\n"
|
||||
" }\n"
|
||||
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
|
||||
" out.write(expectedTwoMeg);\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read 128KB binary pattern file")\n'
|
||||
" void testReadBinaryPattern() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(patternFile);\n"
|
||||
" assertArrayEquals(expectedPattern, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read 512KB file")\n'
|
||||
" void testReadHalfMegFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(halfMegFile);\n"
|
||||
" assertArrayEquals(expectedHalfMeg, result);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Read 2MB file")\n'
|
||||
" void testReadTwoMegFile() throws Exception {\n"
|
||||
f" byte[] result = {class_name}.readFile(twoMegFile);\n"
|
||||
" assertArrayEquals(expectedTwoMeg, result);\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
|
||||
def _build_host_equals_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
|
||||
"""Build demo test source 0 for Host.equals, adapting to the target's package, class, and test framework."""
|
||||
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
||||
test_class_name = f"{class_name}Test"
|
||||
|
||||
if test_framework == "junit4":
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.Before;\n"
|
||||
"import org.junit.Test;\n"
|
||||
"import static org.junit.Assert.*;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
"/**\n"
|
||||
f" * Unit tests for {module_path}.equals(...)\n"
|
||||
" */\n"
|
||||
f"public class {test_class_name} {{\n"
|
||||
f" private {class_name} defaultHost;\n"
|
||||
"\n"
|
||||
" @Before\n"
|
||||
" public void setUp() {\n"
|
||||
f' defaultHost = new {class_name}("localhost", 3000);\n'
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_SameInstance_ReturnsTrue() {\n"
|
||||
" assertTrue(defaultHost.equals(defaultHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
|
||||
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
|
||||
" // Both directions should be true (symmetry)\n"
|
||||
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_DifferentPort_ReturnsFalse() {\n"
|
||||
f' {class_name} other = new {class_name}("localhost", 3001);\n'
|
||||
" assertFalse(defaultHost.equals(other));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_DifferentName_ReturnsFalse() {\n"
|
||||
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
|
||||
" assertFalse(defaultHost.equals(other));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_NullArgument_ReturnsFalse() {\n"
|
||||
" assertFalse(defaultHost.equals(null));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_DifferentClass_ReturnsFalse() {\n"
|
||||
' Object notAHost = "I am not a Host";\n'
|
||||
" assertFalse(defaultHost.equals(notAHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("", 0);\n'
|
||||
f' {class_name} b = new {class_name}("", null, 0);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test(expected = NullPointerException.class)\n"
|
||||
" public void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
|
||||
" // When this.name is null, equals calls this.name.equals(...), which throws NPE.\n"
|
||||
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
|
||||
f' {class_name} other = new {class_name}("something", 100);\n'
|
||||
" thisHasNullName.equals(other);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_OtherNameNull_ReturnsFalse() {\n"
|
||||
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
|
||||
f' {class_name} normal = new {class_name}("name", 200);\n'
|
||||
' // "name".equals(null) returns false; no exception expected.\n'
|
||||
" assertFalse(normal.equals(otherHasNullName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_MaxIntPort_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_MinIntPort_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
|
||||
" assertFalse(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEquals_MultipleEqualInstances_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("perf-host", 4000);\n'
|
||||
f' {class_name} b = new {class_name}("perf-host", 4000);\n'
|
||||
f' {class_name} c = new {class_name}("perf-host", "tls-name", 4000);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" assertTrue(b.equals(c));\n"
|
||||
" assertTrue(a.equals(c));\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
# JUnit 5
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.jupiter.api.BeforeEach;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"import org.junit.jupiter.api.DisplayName;\n"
|
||||
"import static org.junit.jupiter.api.Assertions.*;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
"/**\n"
|
||||
f" * Unit tests for {module_path}.equals(...)\n"
|
||||
" */\n"
|
||||
f"class {test_class_name} {{\n"
|
||||
f" private {class_name} defaultHost;\n"
|
||||
"\n"
|
||||
" @BeforeEach\n"
|
||||
" void setUp() {\n"
|
||||
f' defaultHost = new {class_name}("localhost", 3000);\n'
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Same instance returns true")\n'
|
||||
" void testEquals_SameInstance_ReturnsTrue() {\n"
|
||||
" assertTrue(defaultHost.equals(defaultHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Equal name and port ignoring TLS returns true")\n'
|
||||
" void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
|
||||
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
|
||||
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different port returns false")\n'
|
||||
" void testEquals_DifferentPort_ReturnsFalse() {\n"
|
||||
f' {class_name} other = new {class_name}("localhost", 3001);\n'
|
||||
" assertFalse(defaultHost.equals(other));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different name returns false")\n'
|
||||
" void testEquals_DifferentName_ReturnsFalse() {\n"
|
||||
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
|
||||
" assertFalse(defaultHost.equals(other));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Null argument returns false")\n'
|
||||
" void testEquals_NullArgument_ReturnsFalse() {\n"
|
||||
" assertFalse(defaultHost.equals(null));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different class returns false")\n'
|
||||
" void testEquals_DifferentClass_ReturnsFalse() {\n"
|
||||
' Object notAHost = "I am not a Host";\n'
|
||||
" assertFalse(defaultHost.equals(notAHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Empty name both returns true")\n'
|
||||
" void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("", 0);\n'
|
||||
f' {class_name} b = new {class_name}("", null, 0);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Null this.name throws NullPointerException")\n'
|
||||
" void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
|
||||
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
|
||||
f' {class_name} other = new {class_name}("something", 100);\n'
|
||||
" assertThrows(NullPointerException.class, () -> thisHasNullName.equals(other));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Null other.name returns false")\n'
|
||||
" void testEquals_OtherNameNull_ReturnsFalse() {\n"
|
||||
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
|
||||
f' {class_name} normal = new {class_name}("name", 200);\n'
|
||||
" assertFalse(normal.equals(otherHasNullName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Max int port returns true")\n'
|
||||
" void testEquals_MaxIntPort_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Min int port returns true")\n'
|
||||
" void testEquals_MinIntPort_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Max vs different port returns false")\n'
|
||||
" void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
|
||||
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
||||
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
|
||||
" assertFalse(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Multiple equal instances with transitivity")\n'
|
||||
" void testEquals_MultipleEqualInstances_ReturnsTrue() {\n"
|
||||
f' {class_name} a = new {class_name}("perf-host", 4000);\n'
|
||||
f' {class_name} b = new {class_name}("perf-host", 4000);\n'
|
||||
f' {class_name} c = new {class_name}("perf-host", "tls-name", 4000);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" assertTrue(b.equals(c));\n"
|
||||
" assertTrue(a.equals(c));\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
|
||||
def _build_host_equals_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
|
||||
"""Build demo test source 1 for Host.equals, adapting to the target's package, class, and test framework."""
|
||||
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
||||
test_class_name = f"{class_name}Test"
|
||||
|
||||
if test_framework == "junit4":
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.Before;\n"
|
||||
"import org.junit.Test;\n"
|
||||
"import static org.junit.Assert.*;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"public class {test_class_name} {{\n"
|
||||
f" private {class_name} hostSimple;\n"
|
||||
f" private {class_name} hostWithTls;\n"
|
||||
"\n"
|
||||
" @Before\n"
|
||||
" public void setUp() {\n"
|
||||
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
|
||||
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testSameReference_True() {\n"
|
||||
" // same instance should be equal to itself\n"
|
||||
" assertTrue(hostSimple.equals(hostSimple));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEqualNameAndPort_True() {\n"
|
||||
" // two distinct instances with same name and port (tls ignored) are equal\n"
|
||||
f' {class_name} a = new {class_name}("db1", 4000);\n'
|
||||
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testDifferentTlsIgnored_True() {\n"
|
||||
" // tlsName is ignored for equality\n"
|
||||
" assertTrue(hostSimple.equals(hostWithTls));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testDifferentName_False() {\n"
|
||||
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
|
||||
" assertFalse(hostSimple.equals(otherName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testDifferentPort_False() {\n"
|
||||
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
|
||||
" assertFalse(hostSimple.equals(otherPort));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testNullComparison_False() {\n"
|
||||
" // equals should return false when compared to null\n"
|
||||
" assertFalse(hostSimple.equals(null));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testDifferentClass_False() {\n"
|
||||
" // equals should return false when compared to an object of another class\n"
|
||||
' Object notAHost = "server.example.com:3000";\n'
|
||||
" assertFalse(hostSimple.equals(notAHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test(expected = NullPointerException.class)\n"
|
||||
" public void testNameNull_ThrowsNullPointerException() {\n"
|
||||
" // If this.name is null, equals tries to call this.name.equals(...) and will NPE.\n"
|
||||
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
|
||||
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
|
||||
" // This invocation should throw NPE because this.name is null\n"
|
||||
" nullNameHost.equals(otherNullNameHost);\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testOtherNameNull_False() {\n"
|
||||
" // If other.name is null but this.name is non-null, equals should return false\n"
|
||||
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
|
||||
" assertFalse(hostSimple.equals(otherNullName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testPortBoundary_ZeroAndMax_True() {\n"
|
||||
f' {class_name} lowA = new {class_name}("edge", 0);\n'
|
||||
f' {class_name} lowB = new {class_name}("edge", 0);\n'
|
||||
" assertTrue(lowA.equals(lowB));\n"
|
||||
"\n"
|
||||
f' {class_name} highA = new {class_name}("edge", 65535);\n'
|
||||
f' {class_name} highB = new {class_name}("edge", 65535);\n'
|
||||
" assertTrue(highA.equals(highB));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testPortBoundary_DifferentPorts_False() {\n"
|
||||
f' {class_name} low = new {class_name}("edge", 0);\n'
|
||||
f' {class_name} high = new {class_name}("edge", 65535);\n'
|
||||
" assertFalse(low.equals(high));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
" public void testEqualityWithVariousNames() {\n"
|
||||
f' {class_name} a1 = new {class_name}("alpha", 1000);\n'
|
||||
f' {class_name} b1 = new {class_name}("alpha", "tls-alpha", 1000);\n'
|
||||
" assertTrue(a1.equals(b1));\n"
|
||||
"\n"
|
||||
f' {class_name} a2 = new {class_name}("beta", 2000);\n'
|
||||
f' {class_name} b2 = new {class_name}("beta", "tls-beta", 2000);\n'
|
||||
" assertTrue(a2.equals(b2));\n"
|
||||
"\n"
|
||||
f' {class_name} a3 = new {class_name}("gamma", 3000);\n'
|
||||
f' {class_name} b3 = new {class_name}("delta", 3000);\n'
|
||||
" assertFalse(a3.equals(b3));\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
# JUnit 5
|
||||
return (f"package {package_name};\n" if package_name else "") + (
|
||||
"\n"
|
||||
"import org.junit.jupiter.api.BeforeEach;\n"
|
||||
"import org.junit.jupiter.api.Test;\n"
|
||||
"import org.junit.jupiter.api.DisplayName;\n"
|
||||
"import static org.junit.jupiter.api.Assertions.*;\n"
|
||||
"\n"
|
||||
f"import {module_path};\n"
|
||||
"\n"
|
||||
f"class {test_class_name} {{\n"
|
||||
f" private {class_name} hostSimple;\n"
|
||||
f" private {class_name} hostWithTls;\n"
|
||||
"\n"
|
||||
" @BeforeEach\n"
|
||||
" void setUp() {\n"
|
||||
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
|
||||
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Same reference returns true")\n'
|
||||
" void testSameReference_True() {\n"
|
||||
" assertTrue(hostSimple.equals(hostSimple));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Equal name and port with different TLS returns true")\n'
|
||||
" void testEqualNameAndPort_True() {\n"
|
||||
f' {class_name} a = new {class_name}("db1", 4000);\n'
|
||||
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
|
||||
" assertTrue(a.equals(b));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different TLS name is ignored")\n'
|
||||
" void testDifferentTlsIgnored_True() {\n"
|
||||
" assertTrue(hostSimple.equals(hostWithTls));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different name returns false")\n'
|
||||
" void testDifferentName_False() {\n"
|
||||
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
|
||||
" assertFalse(hostSimple.equals(otherName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different port returns false")\n'
|
||||
" void testDifferentPort_False() {\n"
|
||||
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
|
||||
" assertFalse(hostSimple.equals(otherPort));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Null comparison returns false")\n'
|
||||
" void testNullComparison_False() {\n"
|
||||
" assertFalse(hostSimple.equals(null));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different class returns false")\n'
|
||||
" void testDifferentClass_False() {\n"
|
||||
' Object notAHost = "server.example.com:3000";\n'
|
||||
" assertFalse(hostSimple.equals(notAHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Null name throws NullPointerException")\n'
|
||||
" void testNameNull_ThrowsNullPointerException() {\n"
|
||||
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
|
||||
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
|
||||
" assertThrows(NullPointerException.class, () -> nullNameHost.equals(otherNullNameHost));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Other name null returns false")\n'
|
||||
" void testOtherNameNull_False() {\n"
|
||||
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
|
||||
" assertFalse(hostSimple.equals(otherNullName));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Port boundary zero and max")\n'
|
||||
" void testPortBoundary_ZeroAndMax_True() {\n"
|
||||
f' {class_name} lowA = new {class_name}("edge", 0);\n'
|
||||
f' {class_name} lowB = new {class_name}("edge", 0);\n'
|
||||
" assertTrue(lowA.equals(lowB));\n"
|
||||
"\n"
|
||||
f' {class_name} highA = new {class_name}("edge", 65535);\n'
|
||||
f' {class_name} highB = new {class_name}("edge", 65535);\n'
|
||||
" assertTrue(highA.equals(highB));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Different boundary ports returns false")\n'
|
||||
" void testPortBoundary_DifferentPorts_False() {\n"
|
||||
f' {class_name} low = new {class_name}("edge", 0);\n'
|
||||
f' {class_name} high = new {class_name}("edge", 65535);\n'
|
||||
" assertFalse(low.equals(high));\n"
|
||||
" }\n"
|
||||
"\n"
|
||||
" @Test\n"
|
||||
' @DisplayName("Equality with various host names")\n'
|
||||
" void testEqualityWithVariousNames() {\n"
|
||||
f' {class_name} a1 = new {class_name}("alpha", 1000);\n'
|
||||
f' {class_name} b1 = new {class_name}("alpha", "tls-alpha", 1000);\n'
|
||||
" assertTrue(a1.equals(b1));\n"
|
||||
"\n"
|
||||
f' {class_name} a2 = new {class_name}("beta", 2000);\n'
|
||||
f' {class_name} b2 = new {class_name}("beta", "tls-beta", 2000);\n'
|
||||
" assertTrue(a2.equals(b2));\n"
|
||||
"\n"
|
||||
f' {class_name} a3 = new {class_name}("gamma", 3000);\n'
|
||||
f' {class_name} b3 = new {class_name}("delta", 3000);\n'
|
||||
" assertFalse(a3.equals(b3));\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
)
|
||||
|
||||
|
||||
async def hack_for_demo_java_testgen(data: TestGenSchema) -> TestGenResponseSchema:
|
||||
# Extract package and class dynamically from the source code
|
||||
source_code = data.source_code_being_tested
|
||||
package_name = _extract_package_from_source(source_code) or ""
|
||||
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
|
||||
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
|
||||
|
||||
test_index = data.test_index if data.test_index is not None else 0
|
||||
|
||||
if is_host_equals_demo(source_code):
|
||||
if test_index == 0:
|
||||
generated_test_source = _build_host_equals_demo_test_source_0(package_name, class_name, test_framework)
|
||||
else:
|
||||
generated_test_source = _build_host_equals_demo_test_source_1(package_name, class_name, test_framework)
|
||||
elif test_index == 0:
|
||||
generated_test_source = _build_demo_test_source_0(package_name, class_name, test_framework)
|
||||
else:
|
||||
generated_test_source = _build_demo_test_source_1(package_name, class_name, test_framework)
|
||||
|
||||
await asyncio.sleep(5)
|
||||
# For Java, instrumentation is done client-side
|
||||
return TestGenResponseSchema(
|
||||
generated_tests=generated_test_source,
|
||||
instrumented_behavior_tests=generated_test_source,
|
||||
instrumented_perf_tests=generated_test_source,
|
||||
)
|
||||
|
||||
|
||||
async def testgen_java(
|
||||
request: AuthenticatedRequest, data: TestGenSchema
|
||||
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
||||
|
|
@ -1363,9 +533,11 @@ async def testgen_java(
|
|||
|
||||
logging.info("/testgen: Generating Java tests...")
|
||||
|
||||
# Demo hack: intercept before LLM call for demo functions
|
||||
if should_hack_for_demo_java(data.source_code_being_tested):
|
||||
return 200, await hack_for_demo_java_testgen(data)
|
||||
from core.languages.java.demo_hacks import try_demo_testgen_java # noqa: PLC0415
|
||||
|
||||
demo_result = await try_demo_testgen_java(data)
|
||||
if demo_result is not None:
|
||||
return 200, demo_result
|
||||
|
||||
try:
|
||||
debug_log_sensitive_data(f"Generating Java tests for function {data.function_to_optimize.function_name}")
|
||||
|
|
|
|||
|
|
@ -7,13 +7,6 @@ from __future__ import annotations
|
|||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.languages.python.code_repair.code_repair import code_repair
|
||||
from core.languages.python.explanations.explanations import explain_optimizations
|
||||
from core.languages.python.jit_rewrite.jit_rewrite import jit_rewrite
|
||||
from core.languages.python.optimization_review.optimization_review import get_optimization_review
|
||||
from core.languages.python.optimizer.optimizer import optimize_python
|
||||
from core.languages.python.testgen.generate import testgen_python
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiservice.llm_models import LLM
|
||||
from authapp.auth import AuthenticatedRequest
|
||||
|
|
@ -54,18 +47,24 @@ class PythonHandler:
|
|||
self, request: AuthenticatedRequest, data: TestGenSchema
|
||||
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
||||
"""Generate tests for Python code."""
|
||||
from core.languages.python.testgen.generate import testgen_python # noqa: PLC0415
|
||||
|
||||
return await testgen_python(request, data)
|
||||
|
||||
async def optimizer_optimize(
|
||||
self, request: AuthenticatedRequest, data: OptimizeSchema
|
||||
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
|
||||
"""Optimize Python code for performance."""
|
||||
from core.languages.python.optimizer.optimizer import optimize_python # noqa: PLC0415
|
||||
|
||||
return await optimize_python(request, data)
|
||||
|
||||
async def code_repair_repair(
|
||||
self, user_id: str, optimization_id: str, ctx: CodeRepairContext, optimize_model: LLM | None = None
|
||||
) -> CodeRepairIntermediateResponseItemschema | CodeRepairErrorResponseSchema:
|
||||
"""Repair Python code based on error messages."""
|
||||
from core.languages.python.code_repair.code_repair import code_repair # noqa: PLC0415
|
||||
|
||||
kwargs = {}
|
||||
if optimize_model is not None:
|
||||
kwargs["optimize_model"] = optimize_model
|
||||
|
|
@ -75,6 +74,8 @@ class PythonHandler:
|
|||
self, request: AuthenticatedRequest, data: JitRewriteOptimizeSchema
|
||||
) -> tuple[int, OptimizeResponseSchema | OptimizeErrorResponseSchema]:
|
||||
"""Perform JIT rewriting of Python code."""
|
||||
from core.languages.python.jit_rewrite.jit_rewrite import jit_rewrite # noqa: PLC0415
|
||||
|
||||
return await jit_rewrite(request, data)
|
||||
|
||||
async def optimization_review(
|
||||
|
|
@ -84,6 +85,8 @@ class PythonHandler:
|
|||
optimization_review_model: LLM | None = None,
|
||||
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
|
||||
"""Review Python code for optimization opportunities."""
|
||||
from core.languages.python.optimization_review.optimization_review import get_optimization_review # noqa: PLC0415
|
||||
|
||||
kwargs = {}
|
||||
if optimization_review_model is not None:
|
||||
kwargs["optimization_review_model"] = optimization_review_model
|
||||
|
|
@ -93,6 +96,8 @@ class PythonHandler:
|
|||
self, user_id: str, data: ExplanationsSchema, explanations_model: LLM | None = None
|
||||
) -> ExplanationsResponseSchema | ExplanationsErrorResponseSchema:
|
||||
"""Explain optimizations made to Python code."""
|
||||
from core.languages.python.explanations.explanations import explain_optimizations # noqa: PLC0415
|
||||
|
||||
kwargs = {}
|
||||
if explanations_model is not None:
|
||||
kwargs["explanations_model"] = explanations_model
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
import logging
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import sentry_sdk
|
||||
from ninja import NinjaAPI, Schema
|
||||
|
|
@ -57,10 +57,10 @@ class OptimizationReviewSchema(Schema):
|
|||
original_runtime: str
|
||||
optimized_runtime: str
|
||||
speedup: str
|
||||
existing_tests: str | None
|
||||
generated_tests: str | None
|
||||
replay_tests: str | None
|
||||
benchmark_details: str | None
|
||||
existing_tests: str | None = None
|
||||
generated_tests: str | None = None
|
||||
replay_tests: str | None = None
|
||||
benchmark_details: str | None = None
|
||||
coverage_message: str
|
||||
loop_count: int
|
||||
explanation: str
|
||||
|
|
@ -185,7 +185,7 @@ async def get_optimization_review(
|
|||
|
||||
debug_log_sensitive_data(f"{messages[0]}{messages[1]}")
|
||||
|
||||
obs_context: dict = {"speedup": data.speedup}
|
||||
obs_context: dict[str, Any] = {"speedup": data.speedup}
|
||||
if data.call_sequence is not None:
|
||||
obs_context["call_sequence"] = data.call_sequence
|
||||
|
||||
|
|
@ -272,11 +272,11 @@ async def optimization_review(
|
|||
) -> tuple[int, OptimizationReviewResponseSchema | OptimizationReviewErrorSchema]:
|
||||
response_code, output = await get_optimization_review(request, data)
|
||||
try:
|
||||
if response_code == 200:
|
||||
review_event = output.review.value # ty:ignore[unresolved-attribute]
|
||||
review_explanation = output.review_explanation # ty:ignore[unresolved-attribute]
|
||||
if isinstance(output, OptimizationReviewResponseSchema):
|
||||
review_event = output.review.value
|
||||
review_explanation = output.review_explanation
|
||||
else:
|
||||
review_event = output.error # ty:ignore[unresolved-attribute]
|
||||
review_event = output.error
|
||||
review_explanation = ""
|
||||
await update_optimization_features_review(
|
||||
trace_id=data.trace_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,91 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aiservice.common_utils import should_hack_for_demo
|
||||
from core.shared.context_helpers import group_code
|
||||
from core.shared.optimizer_schemas import OptimizeResponseItemSchema, OptimizeResponseSchema
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
|
||||
|
||||
optimizations_json = [
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
|
||||
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
|
||||
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
|
||||
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
|
||||
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
optimizations_json_gsq = [
|
||||
{
|
||||
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
|
||||
"explanation": "\n\n### Key optimizations:\n- **Intersection**: For many series, `.intersection_many` (if available, see pandas 1.5+) is significantly faster than chaining `np.intersect1d`. If not available, this safely falls back to repeated intersection.\n- **Data alignment**: Avoided creating a separate `pd.Series` for each weight. Instead, using NumPy arrays, broadcast the weights and handle the reindexing in an efficient 2D array.\n- **Sum computation**: The product and sums are performed with batched NumPy operations eliminating Python-level for-loops and overhead from generator comprehensions. This avoids repeatedly constructing/interpreting Series objects.\n- **NaN handling**: Handles missing data exactly as in the original by multiplying mask logic accordingly.\n- **Handles empty and no-overlap**: Preserves the original behavior for empty or non-overlapping indexes, returning an all-NaN, correct-indexed Series.\n\n**NB:** This preserves all error handling, exceptions, and the function signature, matching the requirements precisely, while delivering **substantial speedups** for large and/or many input series.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": '''# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n **Usage**\n Calculate a weighted sum e.g. for a basket.\n **Examples**\n Generate price series and get a sum (weights 70%/30%).\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n **See also**\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n # for input series, get the intersection of their calendars\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (\n curve.index\n for curve in series\n ),\n )\n )\n # reindex inputs and calculate using numpy vectorization\n series_arrays = [s.reindex(cal).to_numpy() for s in series]\n weights_array = np.asarray(weights, dtype=float).reshape(-1, 1)\n stacked_series = np.vstack(series_arrays)\n weighted_sum_array = np.sum(stacked_series * weights_array, axis=0)\n sum_weights = np.sum(weights_array)\n return pd.Series(weighted_sum_array / sum_weights, index=cal)\n''',
|
||||
"explanation": "\n**Key optimizations explained:**\n- **Index intersection speedup:** Instead of chaining `np.intersect1d` (which is O(N^2) and re-sorts at each op), use `set.intersection_update` (O(N)) which is drastically faster for longer indices.\n- **Avoid unnecessary repeated object construction:** Instead of building `[pd.Series(w, index=cal) for w in weights]` (which allocates a new Series for every weight/date combination), use vectorized numpy and pandas operations.\n- **Vectorized calculation:** Stack the series (each reindexed to the intersection calendar) as columns in a DataFrame and perform single array multiplications and reductions for both weighted sum and sum of weights.\n- **Memory efficiency:** No large intermediate lists of Series; numpy operations work in-place.\n- **NaN handling:** Uses `np.nansum` and mask-based logic to sum weights only where actual values are present (mimics the original where NaNs would propagate).\n\n**Behavior is identical**: exceptions, input validation, and expected output/NaN-handling all preserved.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # Get the intersection of the calendars for all input series (index labels)\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (curve.index for curve in series),\n )\n )\n\n # Efficiently construct a DataFrame where columns are the series,\n # then multiply by the weights and sum along the columns\n # This avoids allocating an intermediate list of Series and summing pythonically\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # use numpy array for weights: faster than creating Series and avoids sum(weights) recalculation\n weights_arr = np.asarray(weights, dtype=np.float64)\n weighted = df.values * weights_arr[np.newaxis, :]\n weighted_sum_values = np.nansum(weighted, axis=1)\n weights_broadcast = np.broadcast_to(weights_arr, df.shape)\n weights_for_denominator = np.where(~np.isnan(df.values), weights_broadcast, 0.0)\n weights_sum = np.nansum(weights_for_denominator, axis=1)\n # Prevent division by zero; behaves as original since nansum returns 0 for all-nan, so original would return nan\n result_values = np.where(weights_sum != 0, weighted_sum_values / weights_sum, np.nan)\n result = pd.Series(result_values, index=cal)\n\n return result\n',
|
||||
"explanation": "\n**Optimization Explanation**:\n- The calendar intersection now uses Python's set intersection mechanism, which is substantially faster than repeated calls to `np.intersect1d` via `reduce` for typical Pandas index sets.\n- The code now only constructs one output series for weights instead of creating and summing several unnecessary constant-valued Series. It uses NumPy's vectorized operations for arithmetic and summation, which are faster and more memory efficient than repeated Pandas Series arithmetic, especially for large time series.\n- Behavior is preserved for empty input lists and calendars.\n- The input and output types, return values, raised exceptions, and function signature remain unchanged. Comments are preserved and added only for new logic.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
|
||||
file_name = getattr(ctx, "file_name", "source.py")
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=optimization["explanation"],
|
||||
optimization_id=optimization["optimization_id"],
|
||||
source_code=group_code({file_name: optimization["source_code"]}),
|
||||
)
|
||||
for optimization in optimizations_json
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
async def hack_for_demo_gsq(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
|
||||
file_name = getattr(ctx, "file_name", "source.py")
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=optimization["explanation"],
|
||||
optimization_id=optimization["optimization_id"],
|
||||
source_code=group_code({file_name: optimization["source_code"]}),
|
||||
)
|
||||
for optimization in optimizations_json_gsq
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
async def try_demo_optimize(ctx: BaseOptimizerContext) -> OptimizeResponseSchema | None:
|
||||
if not should_hack_for_demo(ctx.source_code):
|
||||
return None
|
||||
if "def find_common_tags(articles" in ctx.source_code:
|
||||
return await hack_for_demo(ctx)
|
||||
if "def weighted_sum(series" in ctx.source_code:
|
||||
return await hack_for_demo_gsq(ctx)
|
||||
return None
|
||||
|
|
@ -2,7 +2,6 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
from pydantic import ValidationError
|
||||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo, validate_trace_id
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import LLM, OPTIMIZE_MODEL
|
||||
|
|
@ -22,7 +21,6 @@ from core.languages.python.optimizer.context_utils.optimizer_context import Base
|
|||
from core.languages.python.optimizer.diff_patches_utils.diff import DiffMethod
|
||||
from core.log_features.log_event import get_or_create_optimization_event
|
||||
from core.log_features.log_features import log_features
|
||||
from core.shared.context_helpers import group_code
|
||||
from core.shared.optimizer_config import MAX_OPTIMIZER_CALLS, get_model_distribution
|
||||
from core.shared.optimizer_models import OptimizedCandidateSource
|
||||
from core.shared.optimizer_schemas import (
|
||||
|
|
@ -37,76 +35,6 @@ if TYPE_CHECKING:
|
|||
from authapp.auth import AuthenticatedRequest
|
||||
from core.shared.optimizer_models import OptimizeSchema
|
||||
|
||||
optimizations_json = [
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n return common_tags\n',
|
||||
"explanation": "The original algorithm repeatedly filters the `common_tags` list for every article, which can be slow. We can use Python sets to improve efficiency, especially with large lists.\n\nHere's the optimized version of your function.\n\n\n\n### Explanation of Optimizations.\n1. **Use of Sets**: Convert the initial list of tags to a set, which allows for more efficient intersection operations compared to list comprehensions.\n2. **Intersection Update**: Use the `intersection_update` method on sets which modifies the set in place, making it more memory efficient and faster than creating new lists and converting them to sets repeatedly.\n\nThis optimized version should perform significantly better, especially as the number of articles and tags increases.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags:\n break\n return common_tags\n',
|
||||
"explanation": "To make the `find_common_tags` function run faster, we can leverage sets, which provide average O(1) time complexity for membership checks and O(n) for intersections. Here\u2019s a refactored version of your program.\n\n\n\nThis version initializes `common_tags` as a set and then iteratively intersects it with the tags of each subsequent article. The `intersection_update` method is used to update `common_tags` in place, which is more efficient. Additionally, it breaks early if `common_tags` becomes empty, which can save unnecessary computation.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags.intersection_update(article.get("tags", []))\n if not common_tags: # Early exit if no common tags left\n break\n return common_tags\n',
|
||||
"explanation": "To optimize the runtime of this function, we can leverage set operations which are generally faster than list comprehensions for membership checks. By converting the tags to sets initially, the intersection operation becomes more efficient. Here's a faster version.\n\n\n\nChanges made.\n1. Convert the tags list of the first article to a set.\n2. Use `intersection_update` method to update the `common_tags` set with the intersection of the current tags and the next article's tags.\n3. Include an early exit condition to break the loop if no common tags remain, further optimizing runtime.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": 'from __future__ import annotations\n\n\ndef find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:\n if not articles:\n return set()\n\n common_tags = set(articles[0].get("tags", []))\n for article in articles[1:]:\n common_tags &= set(article.get("tags", []))\n if not common_tags: # Early exit if no common tags.\n break\n return common_tags\n',
|
||||
"explanation": "To optimize the provided function, we could enhance its efficiency by using set operations which are typically faster for membership checks compared to list comprehensions.\n\nHere\u2019s the optimized version.\n\n\n\nExplanation.\n1. Convert the tags of the first article into a set to take advantage of fast membership checks and intersection operations.\n2. Use the `&=` operation to find the intersection with the tags of each subsequent article.\n3. Introduce an early exit condition: if `common_tags` becomes empty, it's immediately returned since no further intersection can result in common tags.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
optimizations_json_gsq = [
|
||||
{
|
||||
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\n\n\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # For input series, get the intersection of their calendars\n # Instead of reduce(np.intersect1d, ...), use set intersection for better performance\n idx_iter = (curve.index for curve in series)\n idx0 = next(idx_iter)\n cal = set(idx0)\n for idx in idx_iter:\n cal &= set(idx)\n cal = pd.DatetimeIndex(sorted(cal))\n\n # Vectorized calculations using numpy arrays\n if len(series) == 0 or len(cal) == 0:\n # Edge case: empty data\n weights_arr = np.array([], dtype=float)\n values_arr = np.array([], dtype=float)\n weighted_sum_arr = np.array([])\n sum_weights_arr = np.array([])\n return pd.Series(weighted_sum_arr, index=cal)\n\n # Use pd.concat for batch reindex to avoid overhead of python list comprehensions\n series_concat = pd.concat([s.reindex(cal) for s in series], axis=1)\n values_arr = series_concat.values # shape (n_dates, n_series)\n weights_arr = np.asarray(weights, dtype=float).reshape(1, -1) # shape (1, n_series)\n weights_matrix = np.broadcast_to(weights_arr, values_arr.shape) # shape (n_dates, n_series)\n\n # Weighted sum and denominator\n weighted_sum_arr = np.nansum(values_arr * weights_matrix, axis=1)\n sum_weights_arr = np.nansum(weights_matrix * ~np.isnan(values_arr), axis=1)\n\n # Avoid divide-by-zero; if all weights are nan for a row, the sum is nan\n with np.errstate(invalid=\'ignore\', divide=\'ignore\'):\n result_arr = weighted_sum_arr / sum_weights_arr\n result_arr[sum_weights_arr == 0] = np.nan\n\n return pd.Series(result_arr, index=cal)\n',
|
||||
"explanation": "\n\n### Key optimizations:\n- **Intersection**: For many series, `.intersection_many` (if available, see pandas 1.5+) is significantly faster than chaining `np.intersect1d`. If not available, this safely falls back to repeated intersection.\n- **Data alignment**: Avoided creating a separate `pd.Series` for each weight. Instead, using NumPy arrays, broadcast the weights and handle the reindexing in an efficient 2D array.\n- **Sum computation**: The product and sums are performed with batched NumPy operations eliminating Python-level for-loops and overhead from generator comprehensions. This avoids repeatedly constructing/interpreting Series objects.\n- **NaN handling**: Handles missing data exactly as in the original by multiplying mask logic accordingly.\n- **Handles empty and no-overlap**: Preserves the original behavior for empty or non-overlapping indexes, returning an all-NaN, correct-indexed Series.\n\n**NB:** This preserves all error handling, exceptions, and the function signature, matching the requirements precisely, while delivering **substantial speedups** for large and/or many input series.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": '''# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n **Usage**\n Calculate a weighted sum e.g. for a basket.\n **Examples**\n Generate price series and get a sum (weights 70%/30%).\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n **See also**\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n # for input series, get the intersection of their calendars\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (\n curve.index\n for curve in series\n ),\n )\n )\n # reindex inputs and calculate using numpy vectorization\n series_arrays = [s.reindex(cal).to_numpy() for s in series]\n weights_array = np.asarray(weights, dtype=float).reshape(-1, 1)\n stacked_series = np.vstack(series_arrays)\n weighted_sum_array = np.sum(stacked_series * weights_array, axis=0)\n sum_weights = np.sum(weights_array)\n return pd.Series(weighted_sum_array / sum_weights, index=cal)\n''',
|
||||
"explanation": "\n**Key optimizations explained:**\n- **Index intersection speedup:** Instead of chaining `np.intersect1d` (which is O(N^2) and re-sorts at each op), use `set.intersection_update` (O(N)) which is drastically faster for longer indices.\n- **Avoid unnecessary repeated object construction:** Instead of building `[pd.Series(w, index=cal) for w in weights]` (which allocates a new Series for every weight/date combination), use vectorized numpy and pandas operations.\n- **Vectorized calculation:** Stack the series (each reindexed to the intersection calendar) as columns in a DataFrame and perform single array multiplications and reductions for both weighted sum and sum of weights.\n- **Memory efficiency:** No large intermediate lists of Series; numpy operations work in-place.\n- **NaN handling:** Uses `np.nansum` and mask-based logic to sum weights only where actual values are present (mimics the original where NaNs would propagate).\n\n**Behavior is identical**: exceptions, input validation, and expected output/NaN-handling all preserved.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
{
|
||||
"source_code": '# Copyright 2018 Goldman Sachs.\n# Licensed under the Apache License, Version 2.0 (the "License");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n# http://www.apache.org/licenses/LICENSE-2.0\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied. See the License for the\n# specific language governing permissions and limitations\n# under the License.\n#\n#\n# Chart Service will attempt to make public functions (not prefixed with _) from this module available. Such functions\n# should be fully documented: docstrings should describe parameters and the return value, and provide a 1-line\n# description. Type annotations should be provided for parameters.\n\nimport numpy as np\nimport pandas as pd\nfrom functools import reduce\nfrom gs_quant.errors import MqTypeError, MqValueError\nfrom gs_quant.timeseries.helper import plot_function\nfrom typing import List\n\n@plot_function\ndef weighted_sum(series: List[pd.Series], weights: list) -> pd.Series:\n """\n Calculate a weighted sum.\n\n :param series: list of time series\n :param weights: list of weights\n :return: time series of weighted average\n\n **Usage**\n\n Calculate a weighted sum e.g. for a basket.\n\n **Examples**\n\n Generate price series and get a sum (weights 70%/30%).\n\n >>> prices1 = generate_series(100)\n >>> prices2 = generate_series(100)\n >>> mybasket = weighted_sum([prices1, prices2], [0.7, 0.3])\n\n **See also**\n\n :func:`basket`\n """\n if not all(isinstance(x, pd.Series) for x in series):\n raise MqTypeError("expected a list of time series")\n if not all(isinstance(y, (float, int)) for y in weights):\n raise MqTypeError("expected a list of number for weights")\n if len(weights) != len(series):\n raise MqValueError("must have one weight for each time series")\n\n # Get the intersection of the calendars for all input series (index labels)\n cal = pd.DatetimeIndex(\n reduce(\n np.intersect1d,\n (curve.index for curve in series),\n )\n )\n\n # Efficiently construct a DataFrame where columns are the series,\n # then multiply by the weights and sum along the columns\n # This avoids allocating an intermediate list of Series and summing pythonically\n df = pd.concat([s.reindex(cal) for s in series], axis=1)\n # use numpy array for weights: faster than creating Series and avoids sum(weights) recalculation\n weights_arr = np.asarray(weights, dtype=np.float64)\n weighted = df.values * weights_arr[np.newaxis, :]\n weighted_sum_values = np.nansum(weighted, axis=1)\n weights_broadcast = np.broadcast_to(weights_arr, df.shape)\n weights_for_denominator = np.where(~np.isnan(df.values), weights_broadcast, 0.0)\n weights_sum = np.nansum(weights_for_denominator, axis=1)\n # Prevent division by zero; behaves as original since nansum returns 0 for all-nan, so original would return nan\n result_values = np.where(weights_sum != 0, weighted_sum_values / weights_sum, np.nan)\n result = pd.Series(result_values, index=cal)\n\n return result\n',
|
||||
"explanation": "\n**Optimization Explanation**:\n- The calendar intersection now uses Python's set intersection mechanism, which is substantially faster than repeated calls to `np.intersect1d` via `reduce` for typical Pandas index sets.\n- The code now only constructs one output series for weights instead of creating and summing several unnecessary constant-valued Series. It uses NumPy's vectorized operations for arithmetic and summation, which are faster and more memory efficient than repeated Pandas Series arithmetic, especially for large time series.\n- Behavior is preserved for empty input lists and calendars.\n- The input and output types, return values, raised exceptions, and function signature remain unchanged. Comments are preserved and added only for new logic.",
|
||||
"optimization_id": str(uuid.uuid4()),
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
async def hack_for_demo(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
|
||||
file_name = getattr(ctx, "file_name", "source.py")
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=optimization["explanation"],
|
||||
optimization_id=optimization["optimization_id"],
|
||||
source_code=group_code({file_name: optimization["source_code"]}),
|
||||
)
|
||||
for optimization in optimizations_json
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
async def hack_for_demo_gsq(ctx: BaseOptimizerContext) -> OptimizeResponseSchema:
|
||||
file_name = getattr(ctx, "file_name", "source.py")
|
||||
response_list: list[OptimizeResponseItemSchema] = [
|
||||
OptimizeResponseItemSchema(
|
||||
explanation=optimization["explanation"],
|
||||
optimization_id=optimization["optimization_id"],
|
||||
source_code=group_code({file_name: optimization["source_code"]}),
|
||||
)
|
||||
for optimization in optimizations_json_gsq
|
||||
]
|
||||
await asyncio.sleep(5)
|
||||
return OptimizeResponseSchema(optimizations=response_list)
|
||||
|
||||
|
||||
# Get the directory of the current file
|
||||
current_dir = Path(__file__).parent
|
||||
SYSTEM_PROMPT = (current_dir / "system_prompt.md").read_text()
|
||||
|
|
@ -297,12 +225,10 @@ async def optimize_python(
|
|||
sentry_sdk.capture_exception(e)
|
||||
return e.status_code, OptimizeErrorResponseSchema(error=e.message)
|
||||
|
||||
if should_hack_for_demo(ctx.source_code):
|
||||
response_code = 200
|
||||
if "def find_common_tags(articles" in ctx.source_code:
|
||||
response = await hack_for_demo(ctx)
|
||||
elif "def weighted_sum(series" in ctx.source_code:
|
||||
response = await hack_for_demo_gsq(ctx)
|
||||
from core.languages.python.optimizer.demo_hacks import try_demo_optimize # noqa: PLC0415
|
||||
|
||||
demo_response = await try_demo_optimize(ctx)
|
||||
if demo_response is not None:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
event_task = tg.create_task(
|
||||
get_or_create_optimization_event(
|
||||
|
|
@ -322,9 +248,9 @@ async def optimize_python(
|
|||
)
|
||||
)
|
||||
event, _created = event_task.result()
|
||||
for item in response.optimizations:
|
||||
for item in demo_response.optimizations:
|
||||
item.optimization_event_id = str(event.id) if event else None
|
||||
return response_code, response
|
||||
return 200, demo_response
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
|
|
|
|||
|
|
@ -10,12 +10,12 @@ from openai.types.chat import ChatCompletionSystemMessageParam, ChatCompletionUs
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import split_markdown_code
|
||||
from aiservice.common_utils import parse_python_version, should_hack_for_demo_java, validate_trace_id
|
||||
from aiservice.common_utils import parse_python_version, validate_trace_id
|
||||
from aiservice.env_specific import debug_log_sensitive_data, debug_log_sensitive_data_from_callable
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import OPTIMIZE_MODEL
|
||||
from aiservice.validators.javascript_validator import validate_javascript_syntax, validate_typescript_syntax
|
||||
from core.languages.java.optimizer_lp import hack_for_demo_java_lp, optimize_java_code_line_profiler
|
||||
from core.languages.java.optimizer_lp import optimize_java_code_line_profiler
|
||||
from core.languages.js_ts.context_helpers import is_multi_context_js, is_multi_context_ts
|
||||
from core.languages.js_ts.optimizer_lp import optimize_javascript_code_line_profiler
|
||||
from core.languages.python.optimizer.context_utils.optimizer_context import BaseOptimizerContext
|
||||
|
|
@ -274,12 +274,12 @@ async def optimize(request, data: OptimizeSchemaLP) -> tuple[int, OptimizeRespon
|
|||
elif language == "java":
|
||||
# Java path
|
||||
from aiservice.validators.java_validator import validate_java_syntax # noqa: PLC0415
|
||||
from core.languages.java.demo_hacks import try_demo_optimize_java_lp # noqa: PLC0415
|
||||
from core.languages.java.optimizer import is_multi_context_java # noqa: PLC0415
|
||||
|
||||
# Demo hack shortcut
|
||||
if should_hack_for_demo_java(data.source_code):
|
||||
response = await hack_for_demo_java_lp(data.source_code)
|
||||
return 200, response
|
||||
demo_response = await try_demo_optimize_java_lp(data.source_code)
|
||||
if demo_response is not None:
|
||||
return 200, demo_response
|
||||
|
||||
is_multi_file = is_multi_context_java(data.source_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ import logging
|
|||
import re
|
||||
import tokenize
|
||||
from difflib import SequenceMatcher
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import libcst as cst
|
||||
import sentry_sdk
|
||||
from libcst import CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
|
||||
from libcst import BaseStatement, CSTTransformer, CSTVisitor, Expr, IndentedBlock, SimpleStatementLine, SimpleString
|
||||
|
||||
from core.languages.python.cst_utils import compare_unparsed_ast_to_source, parse_module_to_cst, unparse_parse_source
|
||||
from aiservice.common_utils import safe_isort
|
||||
|
|
@ -110,8 +110,8 @@ def cleanup_explanations(
|
|||
|
||||
class DocstringVisitor(CSTVisitor):
|
||||
def __init__(self) -> None:
|
||||
self.original_docstrings = {}
|
||||
self.class_name = None
|
||||
self.original_docstrings: dict[str, str] = {}
|
||||
self.class_name: str | None = None
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.class_name = node.name.value
|
||||
|
|
@ -133,9 +133,9 @@ class DocstringVisitor(CSTVisitor):
|
|||
|
||||
|
||||
class DocstringTransformer(CSTTransformer):
|
||||
def __init__(self, original_docstrings: dict[str, str | None]) -> None:
|
||||
def __init__(self, original_docstrings: dict[str, str]) -> None:
|
||||
self.original_docstrings = original_docstrings
|
||||
self.class_name = None
|
||||
self.class_name: str | None = None
|
||||
|
||||
def visit_ClassDef(self, node: cst.ClassDef) -> bool:
|
||||
self.class_name = node.name.value
|
||||
|
|
@ -145,15 +145,15 @@ class DocstringTransformer(CSTTransformer):
|
|||
original_docstring = self.original_docstrings.get(self.class_name) if self.class_name else None
|
||||
if original_docstring:
|
||||
if not updated_node.get_docstring(clean=False):
|
||||
new_body = [
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*list(updated_node.body.body),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body)),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*list(updated_node.body.body[1:]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
self.class_name = None
|
||||
|
|
@ -165,15 +165,15 @@ class DocstringTransformer(CSTTransformer):
|
|||
original_docstring = self.original_docstrings.get(qualified_name)
|
||||
if original_docstring:
|
||||
if not updated_node.get_docstring(clean=False):
|
||||
new_body = [
|
||||
new_body: list[BaseStatement] = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*list(updated_node.body.body),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body)),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
else:
|
||||
new_body = [
|
||||
SimpleStatementLine(body=[Expr(value=SimpleString(f'"""{original_docstring}"""'))]),
|
||||
*list(updated_node.body.body[1:]),
|
||||
*cast(list[BaseStatement], list(updated_node.body.body[1:])),
|
||||
]
|
||||
updated_node = updated_node.with_changes(body=IndentedBlock(body=new_body))
|
||||
return updated_node
|
||||
|
|
@ -259,12 +259,13 @@ def _strip_comments_from_code(code: str) -> str:
|
|||
The same code with all comments removed, preserving string content
|
||||
|
||||
"""
|
||||
# add comment here
|
||||
try:
|
||||
lines = code.splitlines(keepends=True)
|
||||
tokens = tokenize.generate_tokens(io.StringIO(code).readline)
|
||||
|
||||
# Build a per-line map of comment ranges to remove
|
||||
comments_by_line = [[] for _ in range(len(lines))]
|
||||
comments_by_line: list[list[tuple[int, int]]] = [[] for _ in range(len(lines))]
|
||||
for token in tokens:
|
||||
if token.type == tokenize.COMMENT:
|
||||
line_idx = token.start[0] - 1
|
||||
|
|
@ -445,7 +446,7 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
|
|||
if orig_line_idx is not None and orig_line_idx > 0 and opt_idx not in code_changed_lines:
|
||||
# Look backwards for ALL consecutive comment-only or blank lines
|
||||
# Collect them all, then decide which ones to restore
|
||||
preceding_lines = []
|
||||
preceding_lines: list[str] = []
|
||||
check_idx = orig_line_idx - 1
|
||||
|
||||
while check_idx >= 0:
|
||||
|
|
@ -466,6 +467,7 @@ def clean_extraneous_comments(original_module: cst.Module, optimized_module: cst
|
|||
|
||||
# Use the original line (preserves original comments or lack thereof)
|
||||
result_lines.append(found_orig)
|
||||
if orig_line_idx is not None:
|
||||
restored_orig_indices.add(orig_line_idx)
|
||||
|
||||
# Also check for trailing blank/comment lines after this line that were removed
|
||||
|
|
|
|||
|
|
@ -5,12 +5,25 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aiservice.common_utils import should_hack_for_demo
|
||||
|
||||
from .validate import instrument_tests
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.shared.testgen_models import TestGenResponseSchema, TestGenSchema
|
||||
|
||||
|
||||
async def try_demo_testgen(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema | None:
|
||||
"""Return a canned demo response if source matches a demo function, else None."""
|
||||
if not should_hack_for_demo(data.source_code_being_tested):
|
||||
return None
|
||||
if "find_common_tags" in data.source_code_being_tested:
|
||||
return await hack_for_demo(data, python_version)
|
||||
if "weighted_sum" in data.source_code_being_tested:
|
||||
return await hack_for_demo_gsq(data, python_version)
|
||||
return None
|
||||
|
||||
|
||||
async def hack_for_demo(data: TestGenSchema, python_version: tuple[int, int, int]) -> TestGenResponseSchema:
|
||||
from core.shared.testgen_models import TestGenResponseSchema
|
||||
|
||||
|
|
|
|||
|
|
@ -12,14 +12,14 @@ from ninja.errors import HttpError
|
|||
|
||||
from aiservice.analytics.posthog import ph
|
||||
from aiservice.common.markdown_utils import extract_code_block, split_markdown_code
|
||||
from aiservice.common_utils import safe_isort, should_hack_for_demo
|
||||
from aiservice.common_utils import safe_isort
|
||||
from aiservice.env_specific import debug_log_sensitive_data
|
||||
from aiservice.llm import llm_client
|
||||
from aiservice.llm_models import EXECUTE_MODEL, HAIKU_MODEL, OPENAI_MODEL
|
||||
from aiservice.models.functions_to_optimize import FunctionToOptimize
|
||||
from core.languages.python.cst_utils import any_ellipsis_in_cst, ellipsis_in_cst_not_types, parse_module_to_cst
|
||||
from core.languages.python.optimizer.context_utils.context_helpers import is_multi_context
|
||||
from core.languages.python.testgen.demo_hacks import hack_for_demo, hack_for_demo_gsq
|
||||
from core.languages.python.testgen.demo_hacks import try_demo_testgen
|
||||
from core.languages.python.testgen.models import CostTracker, LLMOutputParseError
|
||||
from core.languages.python.testgen.postprocessing.code_validator import (
|
||||
CodeValidationError,
|
||||
|
|
@ -364,12 +364,8 @@ async def testgen_python(
|
|||
sentry_sdk.capture_exception(e)
|
||||
return e.status_code, TestGenErrorResponseSchema(error=e.message)
|
||||
|
||||
if should_hack_for_demo(data.source_code_being_tested):
|
||||
if "find_common_tags" in data.source_code_being_tested:
|
||||
demo_hack_response = await hack_for_demo(data, python_version)
|
||||
elif "weighted_sum" in data.source_code_being_tested:
|
||||
demo_hack_response = await hack_for_demo_gsq(data, python_version)
|
||||
return 200, demo_hack_response
|
||||
if (demo_result := await try_demo_testgen(data, python_version)) is not None:
|
||||
return 200, demo_result
|
||||
|
||||
logging.info("/testgen: Generating tests...")
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in a new issue