codeflash-internal/django/aiservice/core/languages/java/testgen.py
Saurabh Misra 198c0c1a4e
codeflash-omni-java (#2335)
# Pull Request Checklist

## Description
- [ ] **Breaking Changes**: Document any breaking changes (if
applicable)
- [ ] **Description of PR**: Clear and concise description of what this
PR accomplishes
- [ ] **Related Issues**: Link to any related issues or tickets

## Testing
- [ ] **Test cases Attached**: All relevant test cases have been
added/updated
- [ ] **Manual Testing**: Manual testing completed for the changes

## Monitoring & Debugging
- [ ] **Logging in place**: Appropriate logging has been added for
debugging user issues
- [ ] **Sentry will be able to catch errors**: Error handling ensures
Sentry can capture and report errors
- [ ] **Avoid Dev based/Prisma logging**: No development-only or
Prisma-specific logging in production code

## Configuration
- [ ] **Env variables newly added**: Any new environment variables are
documented in .env.example file or mentioned in description
---

## Additional Notes
<!-- Add any additional context, screenshots, or notes for reviewers
here -->

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com>
Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-200.ec2.internal>
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
Co-authored-by: Kevin Turcios <turcioskevinr@gmail.com>
Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
2026-02-13 23:26:55 +05:30

1482 lines
62 KiB
Python

"""Java test generation module.
This module generates JUnit 5 tests for Java functions.
Instrumentation is handled by the codeflash CLI client, not here.
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING
import sentry_sdk
import stamina
from openai import OpenAIError
from openai.types.chat import ChatCompletionMessageParam
from aiservice.analytics.posthog import ph
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
from aiservice.env_specific import debug_log_sensitive_data
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
from authapp.auth import AuthenticatedRequest
from core.shared.testgen_models import (
TestGenerationFailedError,
TestGenErrorResponseSchema,
TestGenResponseSchema,
TestGenSchema,
)
from log_features.log_event import update_optimization_cost
if TYPE_CHECKING:
from aiservice.llm import LLM
from aiservice.validators.java_validator import validate_java_syntax
_TEST_FUNC_RE = re.compile(r"@Test\s*\n\s*(?:public\s+)?void\s+\w+")
# Get the directory of the current file
current_dir = Path(__file__).parent
JAVA_PROMPTS_DIR = current_dir / "prompts" / "testgen"
# Ensure prompts directory exists
JAVA_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
# Java system prompt for test generation - JUnit 5 (default)
JAVA_SYSTEM_PROMPT_JUNIT5 = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
Your task is to generate high-quality unit tests for the given Java function.
Guidelines:
1. Use JUnit 5 (Jupiter) test annotations (@Test, @BeforeEach, etc.)
2. Include imports for org.junit.jupiter.api.*
3. Use assertEquals, assertTrue, assertFalse, assertThrows as appropriate
4. Test edge cases, boundary conditions, and typical use cases
5. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
6. Include a test class with the pattern: <ClassName>Test
7. Do NOT use mocks unless absolutely necessary
8. Keep tests simple and focused on one assertion per test when possible
CRITICAL - Java Syntax Rules:
- Java does NOT support import aliasing (e.g., "import X as Y" is INVALID)
- Use fully qualified names or regular imports only
- Example: Use "java.util.Base64.getEncoder()" or "import java.util.Base64;" NOT "import java.util.Base64 as Foo;"
CRITICAL - Handling Complex Parameter Types:
- If a function parameter is an abstract class or interface (e.g., Value, Key), use the REAL factory methods or concrete implementations from the library, NOT custom mock classes
- NEVER create a custom class that extends an abstract class unless you implement ALL abstract methods
- Prefer using static factory methods like Value.get("string"), Value.get(123), etc.
- If the real class has a builder or factory pattern, use it
- Check if there are concrete implementations you can instantiate directly
Function to test: {function_name}
"""
# Java system prompt for JUnit 4
JAVA_SYSTEM_PROMPT_JUNIT4 = """You are an expert Java developer specializing in writing comprehensive JUnit 4 tests.
Your task is to generate high-quality unit tests for the given Java function.
Guidelines:
1. Use JUnit 4 test annotations (@Test, @Before, etc.) - NOT JUnit 5
2. Include imports for org.junit.* (NOT org.junit.jupiter.api.*)
3. Use static imports: import static org.junit.Assert.*
4. Use assertEquals, assertTrue, assertFalse as appropriate. For exceptions, use @Test(expected=ExceptionType.class)
5. Test edge cases, boundary conditions, and typical use cases
6. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
7. Include a test class with the pattern: <ClassName>Test
8. Do NOT use mocks unless absolutely necessary
9. Keep tests simple and focused on one assertion per test when possible
CRITICAL - Java Syntax Rules:
- Java does NOT support import aliasing (e.g., "import X as Y" is INVALID)
- Use fully qualified names or regular imports only
- Example: Use "java.util.Base64.getEncoder()" or "import java.util.Base64;" NOT "import java.util.Base64 as Foo;"
CRITICAL - Handling Complex Parameter Types:
- If a function parameter is an abstract class or interface (e.g., Value, Key), use the REAL factory methods or concrete implementations from the library, NOT custom mock classes
- NEVER create a custom class that extends an abstract class unless you implement ALL abstract methods
- Prefer using static factory methods like Value.get("string"), Value.get(123), etc.
- If the real class has a builder or factory pattern, use it
- Check if there are concrete implementations you can instantiate directly
Function to test: {function_name}
"""
JAVA_USER_PROMPT_JUNIT5 = """Generate JUnit 5 tests for the following Java function.
Function name: {function_name}
Class name: {class_name}
Full qualified name: {module_path}
Package: {package_name}
Source code:
```java
{function_code}
```
Generate comprehensive unit tests that cover:
1. Basic functionality with typical inputs
2. Edge cases (empty inputs, null values, boundary conditions)
3. Error conditions (if the function can throw exceptions)
4. Large-scale inputs for performance verification
IMPORTANT REQUIREMENTS:
1. The package declaration MUST be: package {package_name};
2. You MUST import the class under test: import {module_path};
3. The test class MUST be named {class_name}Test
4. Create an instance of {class_name} in @BeforeEach or directly in test methods
5. CRITICAL: Do NOT create custom classes that extend abstract classes or interfaces.
Instead, use factory methods or concrete implementations from the library.
Example: For com.aerospike.client.Value, use Value.get("string"), Value.get(123), etc.
Wrap your response in a Java code block (```java ... ```).
Example structure:
```java
package {package_name};
import org.junit.jupiter.api.*;
import static org.junit.jupiter.api.Assertions.*;
import {module_path};
public class {class_name}Test {{
private {class_name} instance;
@BeforeEach
void setUp() {{
instance = new {class_name}();
}}
@Test
void testBasicFunctionality() {{
// Test code here
}}
}}
```
"""
JAVA_USER_PROMPT_JUNIT4 = """Generate JUnit 4 tests for the following Java function.
Function name: {function_name}
Class name: {class_name}
Full qualified name: {module_path}
Package: {package_name}
Source code:
```java
{function_code}
```
Generate comprehensive unit tests that cover:
1. Basic functionality with typical inputs
2. Edge cases (empty inputs, null values, boundary conditions)
3. Error conditions (if the function can throw exceptions)
4. Large-scale inputs for performance verification
IMPORTANT REQUIREMENTS:
1. The package declaration MUST be: package {package_name};
2. You MUST import the class under test: import {module_path};
3. The test class MUST be named {class_name}Test
4. Use JUnit 4 annotations and imports (NOT JUnit 5/Jupiter)
5. Create an instance of {class_name} in @Before or directly in test methods
6. CRITICAL: Do NOT create custom classes that extend abstract classes or interfaces.
Instead, use factory methods or concrete implementations from the library.
Example: For com.aerospike.client.Value, use Value.get("string"), Value.get(123), etc.
Wrap your response in a Java code block (```java ... ```).
Example structure:
```java
package {package_name};
import org.junit.Test;
import org.junit.Before;
import static org.junit.Assert.*;
import {module_path};
public class {class_name}Test {{
private {class_name} instance;
@Before
public void setUp() {{
instance = new {class_name}();
}}
@Test
public void testBasicFunctionality() {{
// Test code here
}}
}}
```
"""
# Pattern to extract Java code blocks
JAVA_PATTERN = re.compile(r"^```(?:java)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
def build_java_prompt(
function_name: str,
function_code: str,
module_path: str,
class_name: str,
package_name: str,
test_framework: str = "junit5",
) -> tuple[list[ChatCompletionMessageParam], str]:
"""Build the prompt messages for Java test generation.
Args:
function_name: Name of the function to test
function_code: Source code of the function
module_path: Import path for the module
class_name: Name of the class containing the function
package_name: Package name for the test class
test_framework: Test framework to use ("junit5" or "junit4")
Returns:
Tuple of (messages, posthog_event_suffix)
"""
# Select prompts based on test framework
if test_framework == "junit4":
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT4
user_prompt = JAVA_USER_PROMPT_JUNIT4
else:
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT5
user_prompt = JAVA_USER_PROMPT_JUNIT5
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": system_prompt.format(function_name=function_name),
}
user_message: ChatCompletionMessageParam = {
"role": "user",
"content": user_prompt.format(
function_name=function_name,
function_code=function_code,
module_path=module_path,
class_name=class_name,
package_name=package_name,
),
}
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
return messages, f"java-{test_framework}-"
def parse_and_validate_java_output(response_content: str) -> str:
"""Parse and validate the LLM response for Java code.
Args:
response_content: Raw LLM response
Returns:
Validated Java code
Raises:
ValueError: If no valid code block found
SyntaxError: If code has syntax errors
"""
# Check for code block
if "```" not in response_content:
sentry_sdk.capture_message("LLM response did not contain a code block:\n" + response_content[:500])
raise ValueError("LLM response did not contain a code block.")
pattern_res = JAVA_PATTERN.search(response_content)
if not pattern_res:
raise ValueError("No Java code block found in the LLM response.")
code = pattern_res.group(1).strip()
# Syntax validation using tree-sitter
is_valid, error = validate_java_syntax(code)
if not is_valid:
raise SyntaxError(f"Invalid Java code: {error}")
# Check for test functions
if not _has_test_functions(code):
raise ValueError("Generated code does not contain any @Test annotated methods.")
return code
def _has_test_functions(code: str) -> bool:
"""Check if the code contains JUnit test functions."""
return _TEST_FUNC_RE.search(code) is not None
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
async def generate_and_validate_java_test_code(
messages: list[ChatCompletionMessageParam],
model: LLM,
cost_tracker: list[float],
user_id: str,
posthog_event_suffix: str,
trace_id: str = "",
) -> str:
"""Generate and validate Java test code using an LLM.
Args:
messages: Prompt messages for the LLM
model: LLM model to use
cost_tracker: List to track costs
user_id: User ID for analytics
posthog_event_suffix: Suffix for analytics events
trace_id: Trace ID for logging
Returns:
Validated Java test code
Raises:
ValueError: If code generation fails
SyntaxError: If generated code has syntax errors
"""
try:
output = await call_llm(
llm=model,
messages=messages,
call_type="testgen",
trace_id=trace_id,
user_id=user_id,
python_version="N/A", # Not applicable for Java
)
except Exception as e:
logging.exception("LLM Code Generation error")
sentry_sdk.capture_exception(e)
raise
llm_cost = calculate_llm_cost(output.raw_response, model)
cost_tracker.append(llm_cost)
debug_log_sensitive_data(f"LLM testgen response:\n{output.content}")
return parse_and_validate_java_output(output.content)
def _extract_class_and_package(module_path: str) -> tuple[str, str]:
"""Extract class name and package from module path.
Args:
module_path: e.g., "com.example.Algorithms"
Returns:
Tuple of (class_name, package_name)
"""
parts = module_path.rsplit(".", 1)
if len(parts) == 2:
return parts[1], parts[0] # class_name, package_name
return parts[0], "" # class_name only, no package
def _extract_package_from_source(source_code: str) -> str | None:
"""Extract package name from Java source code.
Args:
source_code: Java source code
Returns:
Package name (e.g., "com.example"), or None if not found
"""
# First try: package declaration in source (most reliable)
package_pattern = re.compile(r"^\s*package\s+([\w.]+)\s*;", re.MULTILINE)
match = package_pattern.search(source_code)
if match:
logging.debug(f"Extracted package from declaration: {match.group(1)}")
return match.group(1)
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
markdown_pattern = re.compile(r"```java:([^\n`]+\.java)", re.IGNORECASE)
markdown_match = markdown_pattern.search(source_code)
if markdown_match:
file_path = markdown_match.group(1).strip()
package = _extract_package_from_path(file_path)
if package:
logging.debug(f"Extracted package from markdown header: {package}")
return package
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
# Also handle "// file: src/com/example/Algorithms.java" (non-standard Maven)
file_comment_pattern = re.compile(r"//\s*file:\s*([^\n]+\.java)", re.IGNORECASE)
file_match = file_comment_pattern.search(source_code)
if file_match:
file_path = file_match.group(1).strip()
logging.debug(f"Found file comment: {file_path}")
package = _extract_package_from_path(file_path)
if package:
logging.debug(f"Extracted package from file comment: {package}")
return package
# Fourth try: infer package from import statements (last resort)
# Look for imports that might indicate the package structure
import_pattern = re.compile(r"^\s*import\s+([\w.]+)\.[\w*]+\s*;", re.MULTILINE)
imports = import_pattern.findall(source_code)
if imports:
# Find common prefix among imports that look like internal packages
# Exclude common library packages
internal_imports: list[str] = [
imp for imp in imports if not imp.startswith(("java.", "javax.", "org.junit", "org.apache", "com.google"))
]
if internal_imports:
# Use the shortest import path as a hint
internal_imports.sort(key=len)
logging.debug(f"Inferred package from imports: {internal_imports[0]}")
return internal_imports[0]
logging.warning("Could not extract package name from source code")
return None
def _extract_package_from_path(file_path: str) -> str | None:
"""Extract Java package from a file path.
Args:
file_path: e.g., "src/main/java/com/example/Algorithms.java" or "src/com/example/Algorithms.java"
Returns:
Package name (e.g., "com.example"), or None if not found
"""
# Normalize slashes
file_path = file_path.replace("\\", "/")
# Standard Maven patterns (highest priority)
java_src_patterns = ["/src/main/java/", "/src/test/java/", "src/main/java/", "src/test/java/"]
for pattern in java_src_patterns:
if pattern in file_path:
idx = file_path.find(pattern)
remaining = file_path[idx + len(pattern) :]
parts = remaining.split("/")
if len(parts) > 1:
package_parts = parts[:-1] # Remove the filename
return ".".join(package_parts)
# Non-standard paths: look for "src/" followed by what looks like a package structure
# e.g., "src/com/aerospike/client/util/Crypto.java" -> "com.aerospike.client.util"
# or "client/src/com/aerospike/client/util/Crypto.java"
src_patterns = ["/src/", "src/"]
for pattern in src_patterns:
if pattern in file_path:
idx = file_path.find(pattern)
remaining = file_path[idx + len(pattern) :]
parts = remaining.split("/")
if len(parts) > 1:
# Check if first part looks like a package (lowercase, not 'main', 'test', 'java', 'resources')
first_part = parts[0].lower()
if first_part not in ("main", "test", "java", "resources") and first_part.isalpha():
package_parts = parts[:-1] # Remove the filename
return ".".join(package_parts)
return None
def _extract_class_from_source(source_code: str) -> str | None:
"""Extract class name from Java source code.
Args:
source_code: Java source code
Returns:
Class name, or None if not found
"""
# First try: class declaration in source
class_pattern = re.compile(r"\bclass\s+(\w+)")
match = class_pattern.search(source_code)
if match:
return match.group(1)
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
markdown_pattern = re.compile(r"```java:([^\n`]+\.java)", re.IGNORECASE)
markdown_match = markdown_pattern.search(source_code)
if markdown_match:
file_path = markdown_match.group(1).strip()
filename = os.path.basename(file_path)
if filename.endswith(".java"):
return filename[:-5] # Remove .java extension
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
file_comment_pattern = re.compile(r"//\s*file:\s*([^\n]+\.java)", re.IGNORECASE)
file_match = file_comment_pattern.search(source_code)
if file_match:
file_path = file_match.group(1).strip()
# Extract class name from file path (e.g., "Algorithms.java" -> "Algorithms")
filename = os.path.basename(file_path)
if filename.endswith(".java"):
return filename[:-5] # Remove .java extension
return None
def _build_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 0, adapting to the target's package, class, and test framework.
File creation is in @Before/@BeforeEach so it runs once, outside the instrumentation's
inner loop. Test methods only contain readFile calls so every inner iteration succeeds
and the benchmark measures pure readFile performance.
"""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import org.junit.Rule;\n"
"import org.junit.rules.TemporaryFolder;\n"
"import static org.junit.Assert.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
"\n"
" @Rule\n"
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
"\n"
" private File smallFile;\n"
" private byte[] expectedSmall;\n"
" private File mediumFile;\n"
" private byte[] expectedMedium;\n"
" private File largeFile;\n"
" private byte[] expectedLarge;\n"
"\n"
" @Before\n"
" public void setUp() throws Exception {\n"
' smallFile = tempFolder.newFile("small.txt");\n'
' expectedSmall = "Hello, World!".getBytes();\n'
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
" out.write(expectedSmall);\n"
" }\n"
"\n"
' mediumFile = tempFolder.newFile("medium.dat");\n'
" expectedMedium = new byte[256 * 1024];\n"
" for (int i = 0; i < expectedMedium.length; i++) {\n"
" expectedMedium[i] = (byte) (i % 251);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
" out.write(expectedMedium);\n"
" }\n"
"\n"
' largeFile = tempFolder.newFile("large.dat");\n'
" expectedLarge = new byte[1024 * 1024];\n"
" for (int i = 0; i < expectedLarge.length; i++) {\n"
" expectedLarge[i] = (byte) (i % 256);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
" out.write(expectedLarge);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadSmallFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(smallFile);\n"
" assertArrayEquals(expectedSmall, result);\n"
" }\n"
"\n"
" @Test\n"
" public void testReadMediumFileRepeated() throws Exception {\n"
f" assertArrayEquals(expectedMedium, {class_name}.readFile(mediumFile));\n"
" for (int i = 0; i < 300; i++) {\n"
f" {class_name}.readFile(mediumFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadLargeFileRepeated() throws Exception {\n"
f" assertArrayEquals(expectedLarge, {class_name}.readFile(largeFile));\n"
" for (int i = 0; i < 2000; i++) {\n"
f" {class_name}.readFile(largeFile);\n"
" }\n"
" }\n"
"}\n"
)
else:
# JUnit 5
return (
f"package {package_name};\n" if package_name else ""
) + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import org.junit.jupiter.api.io.TempDir;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.nio.file.Path;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
"\n"
" @TempDir\n"
" Path tempDir;\n"
"\n"
" private File smallFile;\n"
" private byte[] expectedSmall;\n"
" private File mediumFile;\n"
" private byte[] expectedMedium;\n"
" private File largeFile;\n"
" private byte[] expectedLarge;\n"
"\n"
" @BeforeEach\n"
" void setUp() throws Exception {\n"
' smallFile = tempDir.resolve("small.txt").toFile();\n'
' expectedSmall = "Hello, World!".getBytes();\n'
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
" out.write(expectedSmall);\n"
" }\n"
"\n"
' mediumFile = tempDir.resolve("medium.dat").toFile();\n'
" expectedMedium = new byte[256 * 1024];\n"
" for (int i = 0; i < expectedMedium.length; i++) {\n"
" expectedMedium[i] = (byte) (i % 251);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
" out.write(expectedMedium);\n"
" }\n"
"\n"
' largeFile = tempDir.resolve("large.dat").toFile();\n'
" expectedLarge = new byte[1024 * 1024];\n"
" for (int i = 0; i < expectedLarge.length; i++) {\n"
" expectedLarge[i] = (byte) (i % 256);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
" out.write(expectedLarge);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read a small file")\n'
" void testReadSmallFile() throws Exception {\n"
f" byte[] result = {class_name}.readFile(smallFile);\n"
" assertArrayEquals(expectedSmall, result);\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 256KB file 301 times")\n'
" void testReadMediumFileRepeated() throws Exception {\n"
f" assertArrayEquals(expectedMedium, {class_name}.readFile(mediumFile));\n"
" for (int i = 0; i < 300; i++) {\n"
f" {class_name}.readFile(mediumFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 1MB file 501 times")\n'
" void testReadLargeFileRepeated() throws Exception {\n"
f" assertArrayEquals(expectedLarge, {class_name}.readFile(largeFile));\n"
" for (int i = 0; i < 2000; i++) {\n"
f" {class_name}.readFile(largeFile);\n"
" }\n"
" }\n"
"}\n"
)
def _build_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 1, adapting to the target's package, class, and test framework.
Same @Before pattern as source_0: file creation outside the timed test body.
Complementary file sizes to source_0.
"""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import org.junit.Rule;\n"
"import org.junit.rules.TemporaryFolder;\n"
"import static org.junit.Assert.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.util.Arrays;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
"\n"
" @Rule\n"
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
"\n"
" private File patternFile;\n"
" private byte[] expectedPattern;\n"
" private File halfMegFile;\n"
" private byte[] expectedHalfMeg;\n"
" private File twoMegFile;\n"
" private byte[] expectedTwoMeg;\n"
"\n"
" @Before\n"
" public void setUp() throws Exception {\n"
' patternFile = tempFolder.newFile("pattern.dat");\n'
" expectedPattern = new byte[128 * 1024];\n"
" for (int i = 0; i < expectedPattern.length; i++) {\n"
" expectedPattern[i] = (byte) (i % 7);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
" out.write(expectedPattern);\n"
" }\n"
"\n"
' halfMegFile = tempFolder.newFile("half_meg.dat");\n'
" expectedHalfMeg = new byte[512 * 1024];\n"
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
" out.write(expectedHalfMeg);\n"
" }\n"
"\n"
' twoMegFile = tempFolder.newFile("two_meg.dat");\n'
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
" expectedTwoMeg[i] = (byte) (i % 199);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
" out.write(expectedTwoMeg);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadBinaryPattern() throws Exception {\n"
f" assertArrayEquals(expectedPattern, {class_name}.readFile(patternFile));\n"
" for (int i = 0; i < 400; i++) {\n"
f" {class_name}.readFile(patternFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadHalfMegRepeated() throws Exception {\n"
f" assertArrayEquals(expectedHalfMeg, {class_name}.readFile(halfMegFile));\n"
" for (int i = 0; i < 300; i++) {\n"
f" {class_name}.readFile(halfMegFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
" public void testReadTwoMegRepeated() throws Exception {\n"
f" assertArrayEquals(expectedTwoMeg, {class_name}.readFile(twoMegFile));\n"
" for (int i = 0; i < 200; i++) {\n"
f" {class_name}.readFile(twoMegFile);\n"
" }\n"
" }\n"
"}\n"
)
else:
# JUnit 5
return (
f"package {package_name};\n" if package_name else ""
) + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import org.junit.jupiter.api.io.TempDir;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
"import java.io.File;\n"
"import java.io.FileOutputStream;\n"
"import java.nio.file.Path;\n"
"import java.util.Arrays;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
"\n"
" @TempDir\n"
" Path tempDir;\n"
"\n"
" private File patternFile;\n"
" private byte[] expectedPattern;\n"
" private File halfMegFile;\n"
" private byte[] expectedHalfMeg;\n"
" private File twoMegFile;\n"
" private byte[] expectedTwoMeg;\n"
"\n"
" @BeforeEach\n"
" void setUp() throws Exception {\n"
' patternFile = tempDir.resolve("pattern.dat").toFile();\n'
" expectedPattern = new byte[128 * 1024];\n"
" for (int i = 0; i < expectedPattern.length; i++) {\n"
" expectedPattern[i] = (byte) (i % 7);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
" out.write(expectedPattern);\n"
" }\n"
"\n"
' halfMegFile = tempDir.resolve("half_meg.dat").toFile();\n'
" expectedHalfMeg = new byte[512 * 1024];\n"
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
" out.write(expectedHalfMeg);\n"
" }\n"
"\n"
' twoMegFile = tempDir.resolve("two_meg.dat").toFile();\n'
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
" expectedTwoMeg[i] = (byte) (i % 199);\n"
" }\n"
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
" out.write(expectedTwoMeg);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 128KB binary pattern 401 times")\n'
" void testReadBinaryPattern() throws Exception {\n"
f" assertArrayEquals(expectedPattern, {class_name}.readFile(patternFile));\n"
" for (int i = 0; i < 400; i++) {\n"
f" {class_name}.readFile(patternFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 512KB file 301 times")\n'
" void testReadHalfMegRepeated() throws Exception {\n"
f" assertArrayEquals(expectedHalfMeg, {class_name}.readFile(halfMegFile));\n"
" for (int i = 0; i < 300; i++) {\n"
f" {class_name}.readFile(halfMegFile);\n"
" }\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Read 2MB file 201 times")\n'
" void testReadTwoMegRepeated() throws Exception {\n"
f" assertArrayEquals(expectedTwoMeg, {class_name}.readFile(twoMegFile));\n"
" for (int i = 0; i < 200; i++) {\n"
f" {class_name}.readFile(twoMegFile);\n"
" }\n"
" }\n"
"}\n"
)
def _build_host_equals_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 0 for Host.equals, adapting to the target's package, class, and test framework."""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import static org.junit.Assert.*;\n"
"\n"
f"import {module_path};\n"
"\n"
"/**\n"
f" * Unit tests for {module_path}.equals(...)\n"
" */\n"
f"public class {test_class_name} {{\n"
f" private {class_name} defaultHost;\n"
"\n"
" @Before\n"
" public void setUp() {\n"
f' defaultHost = new {class_name}("localhost", 3000);\n'
" }\n"
"\n"
" @Test\n"
" public void testEquals_SameInstance_ReturnsTrue() {\n"
" assertTrue(defaultHost.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
" // Both directions should be true (symmetry)\n"
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentPort_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("localhost", 3001);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentName_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_NullArgument_ReturnsFalse() {\n"
" assertFalse(defaultHost.equals(null));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_DifferentClass_ReturnsFalse() {\n"
' Object notAHost = "I am not a Host";\n'
" assertFalse(defaultHost.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("", 0);\n'
f' {class_name} b = new {class_name}("", null, 0);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test(expected = NullPointerException.class)\n"
" public void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
" // When this.name is null, equals calls this.name.equals(...), which throws NPE.\n"
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
f' {class_name} other = new {class_name}("something", 100);\n'
" thisHasNullName.equals(other);\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_OtherNameNull_ReturnsFalse() {\n"
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
f' {class_name} normal = new {class_name}("name", 200);\n'
' // "name".equals(null) returns false; no exception expected.\n'
" assertFalse(normal.equals(otherHasNullName));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MaxIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MinIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
" assertFalse(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testEquals_LargeScale_AllEqualInstances_ReturnsTrue() {\n"
f' {class_name} reference = new {class_name}("perf-host", 4000);\n'
" boolean allEqual = true;\n"
" final int iterations = 10000;\n"
" for (int i = 0; i < iterations; i++) {\n"
f' {class_name} h = new {class_name}("perf-host", 4000);\n'
" if (!reference.equals(h)) {\n"
" allEqual = false;\n"
" break;\n"
" }\n"
" }\n"
" assertTrue(allEqual);\n"
" }\n"
"}\n"
)
else:
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
f"import {module_path};\n"
"\n"
"/**\n"
f" * Unit tests for {module_path}.equals(...)\n"
" */\n"
f"class {test_class_name} {{\n"
f" private {class_name} defaultHost;\n"
"\n"
" @BeforeEach\n"
" void setUp() {\n"
f' defaultHost = new {class_name}("localhost", 3000);\n'
" }\n"
"\n"
" @Test\n"
' @DisplayName("Same instance returns true")\n'
" void testEquals_SameInstance_ReturnsTrue() {\n"
" assertTrue(defaultHost.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Equal name and port ignoring TLS returns true")\n'
" void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different port returns false")\n'
" void testEquals_DifferentPort_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("localhost", 3001);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different name returns false")\n'
" void testEquals_DifferentName_ReturnsFalse() {\n"
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
" assertFalse(defaultHost.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null argument returns false")\n'
" void testEquals_NullArgument_ReturnsFalse() {\n"
" assertFalse(defaultHost.equals(null));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different class returns false")\n'
" void testEquals_DifferentClass_ReturnsFalse() {\n"
' Object notAHost = "I am not a Host";\n'
" assertFalse(defaultHost.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Empty name both returns true")\n'
" void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("", 0);\n'
f' {class_name} b = new {class_name}("", null, 0);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null this.name throws NullPointerException")\n'
" void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
f' {class_name} other = new {class_name}("something", 100);\n'
" assertThrows(NullPointerException.class, () -> thisHasNullName.equals(other));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null other.name returns false")\n'
" void testEquals_OtherNameNull_ReturnsFalse() {\n"
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
f' {class_name} normal = new {class_name}("name", 200);\n'
" assertFalse(normal.equals(otherHasNullName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Max int port returns true")\n'
" void testEquals_MaxIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Min int port returns true")\n'
" void testEquals_MinIntPort_ReturnsTrue() {\n"
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Max vs different port returns false")\n'
" void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
" assertFalse(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Large scale equality check")\n'
" void testEquals_LargeScale_AllEqualInstances_ReturnsTrue() {\n"
f' {class_name} reference = new {class_name}("perf-host", 4000);\n'
" boolean allEqual = true;\n"
" final int iterations = 10000;\n"
" for (int i = 0; i < iterations; i++) {\n"
f' {class_name} h = new {class_name}("perf-host", 4000);\n'
" if (!reference.equals(h)) {\n"
" allEqual = false;\n"
" break;\n"
" }\n"
" }\n"
" assertTrue(allEqual);\n"
" }\n"
"}\n"
)
def _build_host_equals_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
"""Build demo test source 1 for Host.equals, adapting to the target's package, class, and test framework."""
module_path = f"{package_name}.{class_name}" if package_name else class_name
test_class_name = f"{class_name}Test"
if test_framework == "junit4":
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.Before;\n"
"import org.junit.Test;\n"
"import static org.junit.Assert.*;\n"
"\n"
f"import {module_path};\n"
"\n"
f"public class {test_class_name} {{\n"
f" private {class_name} hostSimple;\n"
f" private {class_name} hostWithTls;\n"
"\n"
" @Before\n"
" public void setUp() {\n"
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
" }\n"
"\n"
" @Test\n"
" public void testSameReference_True() {\n"
" // same instance should be equal to itself\n"
" assertTrue(hostSimple.equals(hostSimple));\n"
" }\n"
"\n"
" @Test\n"
" public void testEqualNameAndPort_True() {\n"
" // two distinct instances with same name and port (tls ignored) are equal\n"
f' {class_name} a = new {class_name}("db1", 4000);\n'
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentTlsIgnored_True() {\n"
" // tlsName is ignored for equality\n"
" assertTrue(hostSimple.equals(hostWithTls));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentName_False() {\n"
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
" assertFalse(hostSimple.equals(otherName));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentPort_False() {\n"
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
" assertFalse(hostSimple.equals(otherPort));\n"
" }\n"
"\n"
" @Test\n"
" public void testNullComparison_False() {\n"
" // equals should return false when compared to null\n"
" assertFalse(hostSimple.equals(null));\n"
" }\n"
"\n"
" @Test\n"
" public void testDifferentClass_False() {\n"
" // equals should return false when compared to an object of another class\n"
' Object notAHost = "server.example.com:3000";\n'
" assertFalse(hostSimple.equals(notAHost));\n"
" }\n"
"\n"
" @Test(expected = NullPointerException.class)\n"
" public void testNameNull_ThrowsNullPointerException() {\n"
" // If this.name is null, equals tries to call this.name.equals(...) and will NPE.\n"
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
" // This invocation should throw NPE because this.name is null\n"
" nullNameHost.equals(otherNullNameHost);\n"
" }\n"
"\n"
" @Test\n"
" public void testOtherNameNull_False() {\n"
" // If other.name is null but this.name is non-null, equals should return false\n"
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
" assertFalse(hostSimple.equals(otherNullName));\n"
" }\n"
"\n"
" @Test\n"
" public void testPortBoundary_ZeroAndMax_True() {\n"
f' {class_name} lowA = new {class_name}("edge", 0);\n'
f' {class_name} lowB = new {class_name}("edge", 0);\n'
" assertTrue(lowA.equals(lowB));\n"
"\n"
f' {class_name} highA = new {class_name}("edge", 65535);\n'
f' {class_name} highB = new {class_name}("edge", 65535);\n'
" assertTrue(highA.equals(highB));\n"
" }\n"
"\n"
" @Test\n"
" public void testPortBoundary_DifferentPorts_False() {\n"
f' {class_name} low = new {class_name}("edge", 0);\n'
f' {class_name} high = new {class_name}("edge", 65535);\n'
" assertFalse(low.equals(high));\n"
" }\n"
"\n"
" @Test\n"
" public void testLargeScale_Equality_BulkCompare() {\n"
" // Create many hosts and verify equality behavior (tls ignored) in a loop.\n"
" final int COUNT = 10000;\n"
" for (int i = 0; i < COUNT; i++) {\n"
' String name = "bulk-" + i;\n'
" int port = 1000 + (i % 1000);\n"
f" {class_name} a = new {class_name}(name, port);\n"
f' {class_name} b = new {class_name}(name, "tls-" + i, port);\n'
' assertTrue("Failed equality at index " + i, a.equals(b));\n'
" }\n"
" }\n"
"}\n"
)
else:
# JUnit 5
return (f"package {package_name};\n" if package_name else "") + (
"\n"
"import org.junit.jupiter.api.BeforeEach;\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"import static org.junit.jupiter.api.Assertions.*;\n"
"\n"
f"import {module_path};\n"
"\n"
f"class {test_class_name} {{\n"
f" private {class_name} hostSimple;\n"
f" private {class_name} hostWithTls;\n"
"\n"
" @BeforeEach\n"
" void setUp() {\n"
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
" }\n"
"\n"
" @Test\n"
' @DisplayName("Same reference returns true")\n'
" void testSameReference_True() {\n"
" assertTrue(hostSimple.equals(hostSimple));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Equal name and port with different TLS returns true")\n'
" void testEqualNameAndPort_True() {\n"
f' {class_name} a = new {class_name}("db1", 4000);\n'
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
" assertTrue(a.equals(b));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different TLS name is ignored")\n'
" void testDifferentTlsIgnored_True() {\n"
" assertTrue(hostSimple.equals(hostWithTls));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different name returns false")\n'
" void testDifferentName_False() {\n"
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
" assertFalse(hostSimple.equals(otherName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different port returns false")\n'
" void testDifferentPort_False() {\n"
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
" assertFalse(hostSimple.equals(otherPort));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null comparison returns false")\n'
" void testNullComparison_False() {\n"
" assertFalse(hostSimple.equals(null));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different class returns false")\n'
" void testDifferentClass_False() {\n"
' Object notAHost = "server.example.com:3000";\n'
" assertFalse(hostSimple.equals(notAHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Null name throws NullPointerException")\n'
" void testNameNull_ThrowsNullPointerException() {\n"
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
" assertThrows(NullPointerException.class, () -> nullNameHost.equals(otherNullNameHost));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Other name null returns false")\n'
" void testOtherNameNull_False() {\n"
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
" assertFalse(hostSimple.equals(otherNullName));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Port boundary zero and max")\n'
" void testPortBoundary_ZeroAndMax_True() {\n"
f' {class_name} lowA = new {class_name}("edge", 0);\n'
f' {class_name} lowB = new {class_name}("edge", 0);\n'
" assertTrue(lowA.equals(lowB));\n"
"\n"
f' {class_name} highA = new {class_name}("edge", 65535);\n'
f' {class_name} highB = new {class_name}("edge", 65535);\n'
" assertTrue(highA.equals(highB));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Different boundary ports returns false")\n'
" void testPortBoundary_DifferentPorts_False() {\n"
f' {class_name} low = new {class_name}("edge", 0);\n'
f' {class_name} high = new {class_name}("edge", 65535);\n'
" assertFalse(low.equals(high));\n"
" }\n"
"\n"
" @Test\n"
' @DisplayName("Large scale bulk equality comparison")\n'
" void testLargeScale_Equality_BulkCompare() {\n"
" final int COUNT = 10000;\n"
" for (int i = 0; i < COUNT; i++) {\n"
' String name = "bulk-" + i;\n'
" int port = 1000 + (i % 1000);\n"
f" {class_name} a = new {class_name}(name, port);\n"
f' {class_name} b = new {class_name}(name, "tls-" + i, port);\n'
' assertTrue(a.equals(b), "Failed equality at index " + i);\n'
" }\n"
" }\n"
"}\n"
)
async def hack_for_demo_java_testgen(data: TestGenSchema) -> TestGenResponseSchema:
# Extract package and class dynamically from the source code
source_code = data.source_code_being_tested
package_name = _extract_package_from_source(source_code) or ""
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
test_index = data.test_index if data.test_index is not None else 0
if is_host_equals_demo(source_code):
if test_index == 0:
generated_test_source = _build_host_equals_demo_test_source_0(package_name, class_name, test_framework)
else:
generated_test_source = _build_host_equals_demo_test_source_1(package_name, class_name, test_framework)
else:
if test_index == 0:
generated_test_source = _build_demo_test_source_0(package_name, class_name, test_framework)
else:
generated_test_source = _build_demo_test_source_1(package_name, class_name, test_framework)
await asyncio.sleep(5)
# For Java, instrumentation is done client-side
return TestGenResponseSchema(
generated_tests=generated_test_source,
instrumented_behavior_tests=generated_test_source,
instrumented_perf_tests=generated_test_source,
)
async def testgen_java(
request: AuthenticatedRequest, data: TestGenSchema
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
"""Generate Java tests using LLMs."""
await asyncio.to_thread(ph, request.user, "aiservice-testgen-java-called")
# Validate request
if not data.function_to_optimize:
return 400, TestGenErrorResponseSchema(error="Invalid function to optimize. It is empty.")
if not validate_trace_id(data.trace_id):
return 400, TestGenErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
logging.info("/testgen: Generating Java tests...")
# Demo hack: intercept before LLM call for demo functions
if should_hack_for_demo_java(data.source_code_being_tested):
return 200, await hack_for_demo_java_testgen(data)
try:
debug_log_sensitive_data(f"Generating Java tests for function {data.function_to_optimize.function_name}")
logging.info(f"Generating Java tests for function {data.function_to_optimize.function_name}")
# Extract class and package info from source code (more reliable than qualified_name)
source_code = data.source_code_being_tested
debug_log_sensitive_data(f"Source code first 200 chars: {source_code[:200]}")
package_name = _extract_package_from_source(source_code) or ""
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
# Build the full module path (package.ClassName)
if package_name:
module_path = f"{package_name}.{class_name}"
else:
module_path = class_name
# Determine test framework (default to junit5 if not specified)
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
logging.info(
f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
)
debug_log_sensitive_data(
f"Extracted: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
)
# Build prompt
messages, posthog_event_suffix = build_java_prompt(
function_name=data.function_to_optimize.function_name,
function_code=data.source_code_being_tested,
module_path=module_path,
class_name=class_name,
package_name=package_name,
test_framework=test_framework,
)
# Track costs
cost_tracker: list[float] = []
# Generate tests
generated_test_code = await generate_and_validate_java_test_code(
messages=messages,
model=EXECUTE_MODEL,
cost_tracker=cost_tracker,
user_id=request.user,
posthog_event_suffix=posthog_event_suffix,
trace_id=data.trace_id,
)
# Track analytics
total_cost = sum(cost_tracker)
await asyncio.to_thread(
ph,
request.user,
f"aiservice-testgen-{posthog_event_suffix}success",
properties={
"trace_id": data.trace_id,
"total_cost": total_cost,
"test_count": len(_TEST_FUNC_RE.findall(generated_test_code)),
},
)
# Update cost tracking
await update_optimization_cost(data.trace_id, total_cost, request.user)
# For Java, instrumentation is done client-side
# Return the generated tests without server-side instrumentation
return 200, TestGenResponseSchema(
generated_tests=generated_test_code,
instrumented_behavior_tests=generated_test_code, # Client will instrument
instrumented_perf_tests=generated_test_code, # Client will instrument
)
except TestGenerationFailedError as e:
logging.warning(f"Java test generation failed: {e}")
sentry_sdk.capture_exception(e)
return 400, TestGenErrorResponseSchema(error=str(e))
except (ValueError, SyntaxError) as e:
logging.warning(f"Java test generation error: {e}")
sentry_sdk.capture_exception(e)
return 400, TestGenErrorResponseSchema(error=f"Failed to generate valid tests: {e}")
except Exception as e:
logging.exception("Unexpected error in Java test generation")
sentry_sdk.capture_exception(e)
return 500, TestGenErrorResponseSchema(error=f"Internal error: {e}")