mirror of
https://github.com/codeflash-ai/codeflash-internal.git
synced 2026-05-04 18:25:18 +00:00
# Pull Request Checklist ## Description - [ ] **Breaking Changes**: Document any breaking changes (if applicable) - [ ] **Description of PR**: Clear and concise description of what this PR accomplishes - [ ] **Related Issues**: Link to any related issues or tickets ## Testing - [ ] **Test cases Attached**: All relevant test cases have been added/updated - [ ] **Manual Testing**: Manual testing completed for the changes ## Monitoring & Debugging - [ ] **Logging in place**: Appropriate logging has been added for debugging user issues - [ ] **Sentry will be able to catch errors**: Error handling ensures Sentry can capture and report errors - [ ] **Avoid Dev based/Prisma logging**: No development-only or Prisma-specific logging in production code ## Configuration - [ ] **Env variables newly added**: Any new environment variables are documented in .env.example file or mentioned in description --- ## Additional Notes <!-- Add any additional context, screenshots, or notes for reviewers here --> --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: HeshamHM28 <HeshamMohamedFathy@outlook.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-39-200.ec2.internal> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Kevin Turcios <turcioskevinr@gmail.com> Co-authored-by: Kevin Turcios <106575910+KRRT7@users.noreply.github.com>
1482 lines
62 KiB
Python
1482 lines
62 KiB
Python
"""Java test generation module.
|
|
|
|
This module generates JUnit 5 tests for Java functions.
|
|
Instrumentation is handled by the codeflash CLI client, not here.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import re
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import sentry_sdk
|
|
import stamina
|
|
from openai import OpenAIError
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
|
|
from aiservice.analytics.posthog import ph
|
|
from aiservice.common_utils import is_host_equals_demo, should_hack_for_demo_java, validate_trace_id
|
|
from aiservice.env_specific import debug_log_sensitive_data
|
|
from aiservice.llm import EXECUTE_MODEL, calculate_llm_cost, call_llm
|
|
from authapp.auth import AuthenticatedRequest
|
|
from core.shared.testgen_models import (
|
|
TestGenerationFailedError,
|
|
TestGenErrorResponseSchema,
|
|
TestGenResponseSchema,
|
|
TestGenSchema,
|
|
)
|
|
from log_features.log_event import update_optimization_cost
|
|
|
|
if TYPE_CHECKING:
|
|
from aiservice.llm import LLM
|
|
|
|
from aiservice.validators.java_validator import validate_java_syntax
|
|
|
|
_TEST_FUNC_RE = re.compile(r"@Test\s*\n\s*(?:public\s+)?void\s+\w+")
|
|
|
|
# Get the directory of the current file
|
|
current_dir = Path(__file__).parent
|
|
JAVA_PROMPTS_DIR = current_dir / "prompts" / "testgen"
|
|
|
|
# Ensure prompts directory exists
|
|
JAVA_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Java system prompt for test generation - JUnit 5 (default)
|
|
JAVA_SYSTEM_PROMPT_JUNIT5 = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
|
|
Your task is to generate high-quality unit tests for the given Java function.
|
|
|
|
Guidelines:
|
|
1. Use JUnit 5 (Jupiter) test annotations (@Test, @BeforeEach, etc.)
|
|
2. Include imports for org.junit.jupiter.api.*
|
|
3. Use assertEquals, assertTrue, assertFalse, assertThrows as appropriate
|
|
4. Test edge cases, boundary conditions, and typical use cases
|
|
5. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
|
|
6. Include a test class with the pattern: <ClassName>Test
|
|
7. Do NOT use mocks unless absolutely necessary
|
|
8. Keep tests simple and focused on one assertion per test when possible
|
|
|
|
CRITICAL - Java Syntax Rules:
|
|
- Java does NOT support import aliasing (e.g., "import X as Y" is INVALID)
|
|
- Use fully qualified names or regular imports only
|
|
- Example: Use "java.util.Base64.getEncoder()" or "import java.util.Base64;" NOT "import java.util.Base64 as Foo;"
|
|
|
|
CRITICAL - Handling Complex Parameter Types:
|
|
- If a function parameter is an abstract class or interface (e.g., Value, Key), use the REAL factory methods or concrete implementations from the library, NOT custom mock classes
|
|
- NEVER create a custom class that extends an abstract class unless you implement ALL abstract methods
|
|
- Prefer using static factory methods like Value.get("string"), Value.get(123), etc.
|
|
- If the real class has a builder or factory pattern, use it
|
|
- Check if there are concrete implementations you can instantiate directly
|
|
|
|
Function to test: {function_name}
|
|
"""
|
|
|
|
# Java system prompt for JUnit 4
|
|
JAVA_SYSTEM_PROMPT_JUNIT4 = """You are an expert Java developer specializing in writing comprehensive JUnit 4 tests.
|
|
Your task is to generate high-quality unit tests for the given Java function.
|
|
|
|
Guidelines:
|
|
1. Use JUnit 4 test annotations (@Test, @Before, etc.) - NOT JUnit 5
|
|
2. Include imports for org.junit.* (NOT org.junit.jupiter.api.*)
|
|
3. Use static imports: import static org.junit.Assert.*
|
|
4. Use assertEquals, assertTrue, assertFalse as appropriate. For exceptions, use @Test(expected=ExceptionType.class)
|
|
5. Test edge cases, boundary conditions, and typical use cases
|
|
6. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
|
|
7. Include a test class with the pattern: <ClassName>Test
|
|
8. Do NOT use mocks unless absolutely necessary
|
|
9. Keep tests simple and focused on one assertion per test when possible
|
|
|
|
CRITICAL - Java Syntax Rules:
|
|
- Java does NOT support import aliasing (e.g., "import X as Y" is INVALID)
|
|
- Use fully qualified names or regular imports only
|
|
- Example: Use "java.util.Base64.getEncoder()" or "import java.util.Base64;" NOT "import java.util.Base64 as Foo;"
|
|
|
|
CRITICAL - Handling Complex Parameter Types:
|
|
- If a function parameter is an abstract class or interface (e.g., Value, Key), use the REAL factory methods or concrete implementations from the library, NOT custom mock classes
|
|
- NEVER create a custom class that extends an abstract class unless you implement ALL abstract methods
|
|
- Prefer using static factory methods like Value.get("string"), Value.get(123), etc.
|
|
- If the real class has a builder or factory pattern, use it
|
|
- Check if there are concrete implementations you can instantiate directly
|
|
|
|
Function to test: {function_name}
|
|
"""
|
|
|
|
JAVA_USER_PROMPT_JUNIT5 = """Generate JUnit 5 tests for the following Java function.
|
|
|
|
Function name: {function_name}
|
|
Class name: {class_name}
|
|
Full qualified name: {module_path}
|
|
Package: {package_name}
|
|
|
|
Source code:
|
|
```java
|
|
{function_code}
|
|
```
|
|
|
|
Generate comprehensive unit tests that cover:
|
|
1. Basic functionality with typical inputs
|
|
2. Edge cases (empty inputs, null values, boundary conditions)
|
|
3. Error conditions (if the function can throw exceptions)
|
|
4. Large-scale inputs for performance verification
|
|
|
|
IMPORTANT REQUIREMENTS:
|
|
1. The package declaration MUST be: package {package_name};
|
|
2. You MUST import the class under test: import {module_path};
|
|
3. The test class MUST be named {class_name}Test
|
|
4. Create an instance of {class_name} in @BeforeEach or directly in test methods
|
|
5. CRITICAL: Do NOT create custom classes that extend abstract classes or interfaces.
|
|
Instead, use factory methods or concrete implementations from the library.
|
|
Example: For com.aerospike.client.Value, use Value.get("string"), Value.get(123), etc.
|
|
|
|
Wrap your response in a Java code block (```java ... ```).
|
|
|
|
Example structure:
|
|
```java
|
|
package {package_name};
|
|
|
|
import org.junit.jupiter.api.*;
|
|
import static org.junit.jupiter.api.Assertions.*;
|
|
import {module_path};
|
|
|
|
public class {class_name}Test {{
|
|
private {class_name} instance;
|
|
|
|
@BeforeEach
|
|
void setUp() {{
|
|
instance = new {class_name}();
|
|
}}
|
|
|
|
@Test
|
|
void testBasicFunctionality() {{
|
|
// Test code here
|
|
}}
|
|
}}
|
|
```
|
|
"""
|
|
|
|
JAVA_USER_PROMPT_JUNIT4 = """Generate JUnit 4 tests for the following Java function.
|
|
|
|
Function name: {function_name}
|
|
Class name: {class_name}
|
|
Full qualified name: {module_path}
|
|
Package: {package_name}
|
|
|
|
Source code:
|
|
```java
|
|
{function_code}
|
|
```
|
|
|
|
Generate comprehensive unit tests that cover:
|
|
1. Basic functionality with typical inputs
|
|
2. Edge cases (empty inputs, null values, boundary conditions)
|
|
3. Error conditions (if the function can throw exceptions)
|
|
4. Large-scale inputs for performance verification
|
|
|
|
IMPORTANT REQUIREMENTS:
|
|
1. The package declaration MUST be: package {package_name};
|
|
2. You MUST import the class under test: import {module_path};
|
|
3. The test class MUST be named {class_name}Test
|
|
4. Use JUnit 4 annotations and imports (NOT JUnit 5/Jupiter)
|
|
5. Create an instance of {class_name} in @Before or directly in test methods
|
|
6. CRITICAL: Do NOT create custom classes that extend abstract classes or interfaces.
|
|
Instead, use factory methods or concrete implementations from the library.
|
|
Example: For com.aerospike.client.Value, use Value.get("string"), Value.get(123), etc.
|
|
|
|
Wrap your response in a Java code block (```java ... ```).
|
|
|
|
Example structure:
|
|
```java
|
|
package {package_name};
|
|
|
|
import org.junit.Test;
|
|
import org.junit.Before;
|
|
import static org.junit.Assert.*;
|
|
import {module_path};
|
|
|
|
public class {class_name}Test {{
|
|
private {class_name} instance;
|
|
|
|
@Before
|
|
public void setUp() {{
|
|
instance = new {class_name}();
|
|
}}
|
|
|
|
@Test
|
|
public void testBasicFunctionality() {{
|
|
// Test code here
|
|
}}
|
|
}}
|
|
```
|
|
"""
|
|
|
|
# Pattern to extract Java code blocks
|
|
JAVA_PATTERN = re.compile(r"^```(?:java)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
|
|
|
|
|
|
def build_java_prompt(
|
|
function_name: str,
|
|
function_code: str,
|
|
module_path: str,
|
|
class_name: str,
|
|
package_name: str,
|
|
test_framework: str = "junit5",
|
|
) -> tuple[list[ChatCompletionMessageParam], str]:
|
|
"""Build the prompt messages for Java test generation.
|
|
|
|
Args:
|
|
function_name: Name of the function to test
|
|
function_code: Source code of the function
|
|
module_path: Import path for the module
|
|
class_name: Name of the class containing the function
|
|
package_name: Package name for the test class
|
|
test_framework: Test framework to use ("junit5" or "junit4")
|
|
|
|
Returns:
|
|
Tuple of (messages, posthog_event_suffix)
|
|
|
|
"""
|
|
# Select prompts based on test framework
|
|
if test_framework == "junit4":
|
|
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT4
|
|
user_prompt = JAVA_USER_PROMPT_JUNIT4
|
|
else:
|
|
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT5
|
|
user_prompt = JAVA_USER_PROMPT_JUNIT5
|
|
|
|
system_message: ChatCompletionMessageParam = {
|
|
"role": "system",
|
|
"content": system_prompt.format(function_name=function_name),
|
|
}
|
|
|
|
user_message: ChatCompletionMessageParam = {
|
|
"role": "user",
|
|
"content": user_prompt.format(
|
|
function_name=function_name,
|
|
function_code=function_code,
|
|
module_path=module_path,
|
|
class_name=class_name,
|
|
package_name=package_name,
|
|
),
|
|
}
|
|
|
|
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
|
return messages, f"java-{test_framework}-"
|
|
|
|
|
|
def parse_and_validate_java_output(response_content: str) -> str:
|
|
"""Parse and validate the LLM response for Java code.
|
|
|
|
Args:
|
|
response_content: Raw LLM response
|
|
|
|
Returns:
|
|
Validated Java code
|
|
|
|
Raises:
|
|
ValueError: If no valid code block found
|
|
SyntaxError: If code has syntax errors
|
|
|
|
"""
|
|
# Check for code block
|
|
if "```" not in response_content:
|
|
sentry_sdk.capture_message("LLM response did not contain a code block:\n" + response_content[:500])
|
|
raise ValueError("LLM response did not contain a code block.")
|
|
|
|
pattern_res = JAVA_PATTERN.search(response_content)
|
|
if not pattern_res:
|
|
raise ValueError("No Java code block found in the LLM response.")
|
|
|
|
code = pattern_res.group(1).strip()
|
|
|
|
# Syntax validation using tree-sitter
|
|
is_valid, error = validate_java_syntax(code)
|
|
if not is_valid:
|
|
raise SyntaxError(f"Invalid Java code: {error}")
|
|
|
|
# Check for test functions
|
|
if not _has_test_functions(code):
|
|
raise ValueError("Generated code does not contain any @Test annotated methods.")
|
|
|
|
return code
|
|
|
|
|
|
def _has_test_functions(code: str) -> bool:
|
|
"""Check if the code contains JUnit test functions."""
|
|
return _TEST_FUNC_RE.search(code) is not None
|
|
|
|
|
|
@stamina.retry(on=(SyntaxError, ValueError, OpenAIError), attempts=2)
|
|
async def generate_and_validate_java_test_code(
|
|
messages: list[ChatCompletionMessageParam],
|
|
model: LLM,
|
|
cost_tracker: list[float],
|
|
user_id: str,
|
|
posthog_event_suffix: str,
|
|
trace_id: str = "",
|
|
) -> str:
|
|
"""Generate and validate Java test code using an LLM.
|
|
|
|
Args:
|
|
messages: Prompt messages for the LLM
|
|
model: LLM model to use
|
|
cost_tracker: List to track costs
|
|
user_id: User ID for analytics
|
|
posthog_event_suffix: Suffix for analytics events
|
|
trace_id: Trace ID for logging
|
|
|
|
Returns:
|
|
Validated Java test code
|
|
|
|
Raises:
|
|
ValueError: If code generation fails
|
|
SyntaxError: If generated code has syntax errors
|
|
|
|
"""
|
|
try:
|
|
output = await call_llm(
|
|
llm=model,
|
|
messages=messages,
|
|
call_type="testgen",
|
|
trace_id=trace_id,
|
|
user_id=user_id,
|
|
python_version="N/A", # Not applicable for Java
|
|
)
|
|
except Exception as e:
|
|
logging.exception("LLM Code Generation error")
|
|
sentry_sdk.capture_exception(e)
|
|
raise
|
|
|
|
llm_cost = calculate_llm_cost(output.raw_response, model)
|
|
cost_tracker.append(llm_cost)
|
|
|
|
debug_log_sensitive_data(f"LLM testgen response:\n{output.content}")
|
|
|
|
return parse_and_validate_java_output(output.content)
|
|
|
|
|
|
def _extract_class_and_package(module_path: str) -> tuple[str, str]:
|
|
"""Extract class name and package from module path.
|
|
|
|
Args:
|
|
module_path: e.g., "com.example.Algorithms"
|
|
|
|
Returns:
|
|
Tuple of (class_name, package_name)
|
|
|
|
"""
|
|
parts = module_path.rsplit(".", 1)
|
|
if len(parts) == 2:
|
|
return parts[1], parts[0] # class_name, package_name
|
|
return parts[0], "" # class_name only, no package
|
|
|
|
|
|
def _extract_package_from_source(source_code: str) -> str | None:
|
|
"""Extract package name from Java source code.
|
|
|
|
Args:
|
|
source_code: Java source code
|
|
|
|
Returns:
|
|
Package name (e.g., "com.example"), or None if not found
|
|
|
|
"""
|
|
# First try: package declaration in source (most reliable)
|
|
package_pattern = re.compile(r"^\s*package\s+([\w.]+)\s*;", re.MULTILINE)
|
|
match = package_pattern.search(source_code)
|
|
if match:
|
|
logging.debug(f"Extracted package from declaration: {match.group(1)}")
|
|
return match.group(1)
|
|
|
|
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
|
|
markdown_pattern = re.compile(r"```java:([^\n`]+\.java)", re.IGNORECASE)
|
|
markdown_match = markdown_pattern.search(source_code)
|
|
if markdown_match:
|
|
file_path = markdown_match.group(1).strip()
|
|
package = _extract_package_from_path(file_path)
|
|
if package:
|
|
logging.debug(f"Extracted package from markdown header: {package}")
|
|
return package
|
|
|
|
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
|
|
# Also handle "// file: src/com/example/Algorithms.java" (non-standard Maven)
|
|
file_comment_pattern = re.compile(r"//\s*file:\s*([^\n]+\.java)", re.IGNORECASE)
|
|
file_match = file_comment_pattern.search(source_code)
|
|
if file_match:
|
|
file_path = file_match.group(1).strip()
|
|
logging.debug(f"Found file comment: {file_path}")
|
|
package = _extract_package_from_path(file_path)
|
|
if package:
|
|
logging.debug(f"Extracted package from file comment: {package}")
|
|
return package
|
|
|
|
# Fourth try: infer package from import statements (last resort)
|
|
# Look for imports that might indicate the package structure
|
|
import_pattern = re.compile(r"^\s*import\s+([\w.]+)\.[\w*]+\s*;", re.MULTILINE)
|
|
imports = import_pattern.findall(source_code)
|
|
if imports:
|
|
# Find common prefix among imports that look like internal packages
|
|
# Exclude common library packages
|
|
internal_imports: list[str] = [
|
|
imp for imp in imports if not imp.startswith(("java.", "javax.", "org.junit", "org.apache", "com.google"))
|
|
]
|
|
if internal_imports:
|
|
# Use the shortest import path as a hint
|
|
internal_imports.sort(key=len)
|
|
logging.debug(f"Inferred package from imports: {internal_imports[0]}")
|
|
return internal_imports[0]
|
|
|
|
logging.warning("Could not extract package name from source code")
|
|
return None
|
|
|
|
|
|
def _extract_package_from_path(file_path: str) -> str | None:
|
|
"""Extract Java package from a file path.
|
|
|
|
Args:
|
|
file_path: e.g., "src/main/java/com/example/Algorithms.java" or "src/com/example/Algorithms.java"
|
|
|
|
Returns:
|
|
Package name (e.g., "com.example"), or None if not found
|
|
|
|
"""
|
|
# Normalize slashes
|
|
file_path = file_path.replace("\\", "/")
|
|
|
|
# Standard Maven patterns (highest priority)
|
|
java_src_patterns = ["/src/main/java/", "/src/test/java/", "src/main/java/", "src/test/java/"]
|
|
for pattern in java_src_patterns:
|
|
if pattern in file_path:
|
|
idx = file_path.find(pattern)
|
|
remaining = file_path[idx + len(pattern) :]
|
|
parts = remaining.split("/")
|
|
if len(parts) > 1:
|
|
package_parts = parts[:-1] # Remove the filename
|
|
return ".".join(package_parts)
|
|
|
|
# Non-standard paths: look for "src/" followed by what looks like a package structure
|
|
# e.g., "src/com/aerospike/client/util/Crypto.java" -> "com.aerospike.client.util"
|
|
# or "client/src/com/aerospike/client/util/Crypto.java"
|
|
src_patterns = ["/src/", "src/"]
|
|
for pattern in src_patterns:
|
|
if pattern in file_path:
|
|
idx = file_path.find(pattern)
|
|
remaining = file_path[idx + len(pattern) :]
|
|
parts = remaining.split("/")
|
|
if len(parts) > 1:
|
|
# Check if first part looks like a package (lowercase, not 'main', 'test', 'java', 'resources')
|
|
first_part = parts[0].lower()
|
|
if first_part not in ("main", "test", "java", "resources") and first_part.isalpha():
|
|
package_parts = parts[:-1] # Remove the filename
|
|
return ".".join(package_parts)
|
|
|
|
return None
|
|
|
|
|
|
def _extract_class_from_source(source_code: str) -> str | None:
|
|
"""Extract class name from Java source code.
|
|
|
|
Args:
|
|
source_code: Java source code
|
|
|
|
Returns:
|
|
Class name, or None if not found
|
|
|
|
"""
|
|
# First try: class declaration in source
|
|
class_pattern = re.compile(r"\bclass\s+(\w+)")
|
|
match = class_pattern.search(source_code)
|
|
if match:
|
|
return match.group(1)
|
|
|
|
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
|
|
markdown_pattern = re.compile(r"```java:([^\n`]+\.java)", re.IGNORECASE)
|
|
markdown_match = markdown_pattern.search(source_code)
|
|
if markdown_match:
|
|
file_path = markdown_match.group(1).strip()
|
|
filename = os.path.basename(file_path)
|
|
if filename.endswith(".java"):
|
|
return filename[:-5] # Remove .java extension
|
|
|
|
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
|
|
file_comment_pattern = re.compile(r"//\s*file:\s*([^\n]+\.java)", re.IGNORECASE)
|
|
file_match = file_comment_pattern.search(source_code)
|
|
if file_match:
|
|
file_path = file_match.group(1).strip()
|
|
# Extract class name from file path (e.g., "Algorithms.java" -> "Algorithms")
|
|
filename = os.path.basename(file_path)
|
|
if filename.endswith(".java"):
|
|
return filename[:-5] # Remove .java extension
|
|
|
|
return None
|
|
|
|
|
|
def _build_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
|
|
"""Build demo test source 0, adapting to the target's package, class, and test framework.
|
|
|
|
File creation is in @Before/@BeforeEach so it runs once, outside the instrumentation's
|
|
inner loop. Test methods only contain readFile calls so every inner iteration succeeds
|
|
and the benchmark measures pure readFile performance.
|
|
"""
|
|
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
|
test_class_name = f"{class_name}Test"
|
|
|
|
if test_framework == "junit4":
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.Before;\n"
|
|
"import org.junit.Test;\n"
|
|
"import org.junit.Rule;\n"
|
|
"import org.junit.rules.TemporaryFolder;\n"
|
|
"import static org.junit.Assert.*;\n"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.io.FileOutputStream;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"public class {test_class_name} {{\n"
|
|
"\n"
|
|
" @Rule\n"
|
|
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
|
|
"\n"
|
|
" private File smallFile;\n"
|
|
" private byte[] expectedSmall;\n"
|
|
" private File mediumFile;\n"
|
|
" private byte[] expectedMedium;\n"
|
|
" private File largeFile;\n"
|
|
" private byte[] expectedLarge;\n"
|
|
"\n"
|
|
" @Before\n"
|
|
" public void setUp() throws Exception {\n"
|
|
' smallFile = tempFolder.newFile("small.txt");\n'
|
|
' expectedSmall = "Hello, World!".getBytes();\n'
|
|
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
|
|
" out.write(expectedSmall);\n"
|
|
" }\n"
|
|
"\n"
|
|
' mediumFile = tempFolder.newFile("medium.dat");\n'
|
|
" expectedMedium = new byte[256 * 1024];\n"
|
|
" for (int i = 0; i < expectedMedium.length; i++) {\n"
|
|
" expectedMedium[i] = (byte) (i % 251);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
|
|
" out.write(expectedMedium);\n"
|
|
" }\n"
|
|
"\n"
|
|
' largeFile = tempFolder.newFile("large.dat");\n'
|
|
" expectedLarge = new byte[1024 * 1024];\n"
|
|
" for (int i = 0; i < expectedLarge.length; i++) {\n"
|
|
" expectedLarge[i] = (byte) (i % 256);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
|
|
" out.write(expectedLarge);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadSmallFile() throws Exception {\n"
|
|
f" byte[] result = {class_name}.readFile(smallFile);\n"
|
|
" assertArrayEquals(expectedSmall, result);\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadMediumFileRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedMedium, {class_name}.readFile(mediumFile));\n"
|
|
" for (int i = 0; i < 300; i++) {\n"
|
|
f" {class_name}.readFile(mediumFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadLargeFileRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedLarge, {class_name}.readFile(largeFile));\n"
|
|
" for (int i = 0; i < 2000; i++) {\n"
|
|
f" {class_name}.readFile(largeFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
else:
|
|
# JUnit 5
|
|
return (
|
|
f"package {package_name};\n" if package_name else ""
|
|
) + (
|
|
"\n"
|
|
"import org.junit.jupiter.api.BeforeEach;\n"
|
|
"import org.junit.jupiter.api.Test;\n"
|
|
"import org.junit.jupiter.api.DisplayName;\n"
|
|
"import org.junit.jupiter.api.io.TempDir;\n"
|
|
"import static org.junit.jupiter.api.Assertions.*;\n"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.io.FileOutputStream;\n"
|
|
"import java.nio.file.Path;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"class {test_class_name} {{\n"
|
|
"\n"
|
|
" @TempDir\n"
|
|
" Path tempDir;\n"
|
|
"\n"
|
|
" private File smallFile;\n"
|
|
" private byte[] expectedSmall;\n"
|
|
" private File mediumFile;\n"
|
|
" private byte[] expectedMedium;\n"
|
|
" private File largeFile;\n"
|
|
" private byte[] expectedLarge;\n"
|
|
"\n"
|
|
" @BeforeEach\n"
|
|
" void setUp() throws Exception {\n"
|
|
' smallFile = tempDir.resolve("small.txt").toFile();\n'
|
|
' expectedSmall = "Hello, World!".getBytes();\n'
|
|
" try (FileOutputStream out = new FileOutputStream(smallFile)) {\n"
|
|
" out.write(expectedSmall);\n"
|
|
" }\n"
|
|
"\n"
|
|
' mediumFile = tempDir.resolve("medium.dat").toFile();\n'
|
|
" expectedMedium = new byte[256 * 1024];\n"
|
|
" for (int i = 0; i < expectedMedium.length; i++) {\n"
|
|
" expectedMedium[i] = (byte) (i % 251);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(mediumFile)) {\n"
|
|
" out.write(expectedMedium);\n"
|
|
" }\n"
|
|
"\n"
|
|
' largeFile = tempDir.resolve("large.dat").toFile();\n'
|
|
" expectedLarge = new byte[1024 * 1024];\n"
|
|
" for (int i = 0; i < expectedLarge.length; i++) {\n"
|
|
" expectedLarge[i] = (byte) (i % 256);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(largeFile)) {\n"
|
|
" out.write(expectedLarge);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read a small file")\n'
|
|
" void testReadSmallFile() throws Exception {\n"
|
|
f" byte[] result = {class_name}.readFile(smallFile);\n"
|
|
" assertArrayEquals(expectedSmall, result);\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read 256KB file 301 times")\n'
|
|
" void testReadMediumFileRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedMedium, {class_name}.readFile(mediumFile));\n"
|
|
" for (int i = 0; i < 300; i++) {\n"
|
|
f" {class_name}.readFile(mediumFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read 1MB file 501 times")\n'
|
|
" void testReadLargeFileRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedLarge, {class_name}.readFile(largeFile));\n"
|
|
" for (int i = 0; i < 2000; i++) {\n"
|
|
f" {class_name}.readFile(largeFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
|
|
|
|
def _build_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
|
|
"""Build demo test source 1, adapting to the target's package, class, and test framework.
|
|
|
|
Same @Before pattern as source_0: file creation outside the timed test body.
|
|
Complementary file sizes to source_0.
|
|
"""
|
|
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
|
test_class_name = f"{class_name}Test"
|
|
|
|
if test_framework == "junit4":
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.Before;\n"
|
|
"import org.junit.Test;\n"
|
|
"import org.junit.Rule;\n"
|
|
"import org.junit.rules.TemporaryFolder;\n"
|
|
"import static org.junit.Assert.*;\n"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.io.FileOutputStream;\n"
|
|
"import java.util.Arrays;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"public class {test_class_name} {{\n"
|
|
"\n"
|
|
" @Rule\n"
|
|
" public TemporaryFolder tempFolder = new TemporaryFolder();\n"
|
|
"\n"
|
|
" private File patternFile;\n"
|
|
" private byte[] expectedPattern;\n"
|
|
" private File halfMegFile;\n"
|
|
" private byte[] expectedHalfMeg;\n"
|
|
" private File twoMegFile;\n"
|
|
" private byte[] expectedTwoMeg;\n"
|
|
"\n"
|
|
" @Before\n"
|
|
" public void setUp() throws Exception {\n"
|
|
' patternFile = tempFolder.newFile("pattern.dat");\n'
|
|
" expectedPattern = new byte[128 * 1024];\n"
|
|
" for (int i = 0; i < expectedPattern.length; i++) {\n"
|
|
" expectedPattern[i] = (byte) (i % 7);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
|
|
" out.write(expectedPattern);\n"
|
|
" }\n"
|
|
"\n"
|
|
' halfMegFile = tempFolder.newFile("half_meg.dat");\n'
|
|
" expectedHalfMeg = new byte[512 * 1024];\n"
|
|
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
|
|
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
|
|
" out.write(expectedHalfMeg);\n"
|
|
" }\n"
|
|
"\n"
|
|
' twoMegFile = tempFolder.newFile("two_meg.dat");\n'
|
|
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
|
|
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
|
|
" expectedTwoMeg[i] = (byte) (i % 199);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
|
|
" out.write(expectedTwoMeg);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadBinaryPattern() throws Exception {\n"
|
|
f" assertArrayEquals(expectedPattern, {class_name}.readFile(patternFile));\n"
|
|
" for (int i = 0; i < 400; i++) {\n"
|
|
f" {class_name}.readFile(patternFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadHalfMegRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedHalfMeg, {class_name}.readFile(halfMegFile));\n"
|
|
" for (int i = 0; i < 300; i++) {\n"
|
|
f" {class_name}.readFile(halfMegFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testReadTwoMegRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedTwoMeg, {class_name}.readFile(twoMegFile));\n"
|
|
" for (int i = 0; i < 200; i++) {\n"
|
|
f" {class_name}.readFile(twoMegFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
else:
|
|
# JUnit 5
|
|
return (
|
|
f"package {package_name};\n" if package_name else ""
|
|
) + (
|
|
"\n"
|
|
"import org.junit.jupiter.api.BeforeEach;\n"
|
|
"import org.junit.jupiter.api.Test;\n"
|
|
"import org.junit.jupiter.api.DisplayName;\n"
|
|
"import org.junit.jupiter.api.io.TempDir;\n"
|
|
"import static org.junit.jupiter.api.Assertions.*;\n"
|
|
"\n"
|
|
"import java.io.File;\n"
|
|
"import java.io.FileOutputStream;\n"
|
|
"import java.nio.file.Path;\n"
|
|
"import java.util.Arrays;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"class {test_class_name} {{\n"
|
|
"\n"
|
|
" @TempDir\n"
|
|
" Path tempDir;\n"
|
|
"\n"
|
|
" private File patternFile;\n"
|
|
" private byte[] expectedPattern;\n"
|
|
" private File halfMegFile;\n"
|
|
" private byte[] expectedHalfMeg;\n"
|
|
" private File twoMegFile;\n"
|
|
" private byte[] expectedTwoMeg;\n"
|
|
"\n"
|
|
" @BeforeEach\n"
|
|
" void setUp() throws Exception {\n"
|
|
' patternFile = tempDir.resolve("pattern.dat").toFile();\n'
|
|
" expectedPattern = new byte[128 * 1024];\n"
|
|
" for (int i = 0; i < expectedPattern.length; i++) {\n"
|
|
" expectedPattern[i] = (byte) (i % 7);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(patternFile)) {\n"
|
|
" out.write(expectedPattern);\n"
|
|
" }\n"
|
|
"\n"
|
|
' halfMegFile = tempDir.resolve("half_meg.dat").toFile();\n'
|
|
" expectedHalfMeg = new byte[512 * 1024];\n"
|
|
" Arrays.fill(expectedHalfMeg, (byte) 0xCD);\n"
|
|
" try (FileOutputStream out = new FileOutputStream(halfMegFile)) {\n"
|
|
" out.write(expectedHalfMeg);\n"
|
|
" }\n"
|
|
"\n"
|
|
' twoMegFile = tempDir.resolve("two_meg.dat").toFile();\n'
|
|
" expectedTwoMeg = new byte[2 * 1024 * 1024];\n"
|
|
" for (int i = 0; i < expectedTwoMeg.length; i++) {\n"
|
|
" expectedTwoMeg[i] = (byte) (i % 199);\n"
|
|
" }\n"
|
|
" try (FileOutputStream out = new FileOutputStream(twoMegFile)) {\n"
|
|
" out.write(expectedTwoMeg);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read 128KB binary pattern 401 times")\n'
|
|
" void testReadBinaryPattern() throws Exception {\n"
|
|
f" assertArrayEquals(expectedPattern, {class_name}.readFile(patternFile));\n"
|
|
" for (int i = 0; i < 400; i++) {\n"
|
|
f" {class_name}.readFile(patternFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read 512KB file 301 times")\n'
|
|
" void testReadHalfMegRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedHalfMeg, {class_name}.readFile(halfMegFile));\n"
|
|
" for (int i = 0; i < 300; i++) {\n"
|
|
f" {class_name}.readFile(halfMegFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Read 2MB file 201 times")\n'
|
|
" void testReadTwoMegRepeated() throws Exception {\n"
|
|
f" assertArrayEquals(expectedTwoMeg, {class_name}.readFile(twoMegFile));\n"
|
|
" for (int i = 0; i < 200; i++) {\n"
|
|
f" {class_name}.readFile(twoMegFile);\n"
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
|
|
|
|
def _build_host_equals_demo_test_source_0(package_name: str, class_name: str, test_framework: str) -> str:
|
|
"""Build demo test source 0 for Host.equals, adapting to the target's package, class, and test framework."""
|
|
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
|
test_class_name = f"{class_name}Test"
|
|
|
|
if test_framework == "junit4":
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.Before;\n"
|
|
"import org.junit.Test;\n"
|
|
"import static org.junit.Assert.*;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
"/**\n"
|
|
f" * Unit tests for {module_path}.equals(...)\n"
|
|
" */\n"
|
|
f"public class {test_class_name} {{\n"
|
|
f" private {class_name} defaultHost;\n"
|
|
"\n"
|
|
" @Before\n"
|
|
" public void setUp() {\n"
|
|
f' defaultHost = new {class_name}("localhost", 3000);\n'
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_SameInstance_ReturnsTrue() {\n"
|
|
" assertTrue(defaultHost.equals(defaultHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
|
|
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
|
|
" // Both directions should be true (symmetry)\n"
|
|
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_DifferentPort_ReturnsFalse() {\n"
|
|
f' {class_name} other = new {class_name}("localhost", 3001);\n'
|
|
" assertFalse(defaultHost.equals(other));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_DifferentName_ReturnsFalse() {\n"
|
|
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
|
|
" assertFalse(defaultHost.equals(other));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_NullArgument_ReturnsFalse() {\n"
|
|
" assertFalse(defaultHost.equals(null));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_DifferentClass_ReturnsFalse() {\n"
|
|
' Object notAHost = "I am not a Host";\n'
|
|
" assertFalse(defaultHost.equals(notAHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("", 0);\n'
|
|
f' {class_name} b = new {class_name}("", null, 0);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test(expected = NullPointerException.class)\n"
|
|
" public void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
|
|
" // When this.name is null, equals calls this.name.equals(...), which throws NPE.\n"
|
|
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
|
|
f' {class_name} other = new {class_name}("something", 100);\n'
|
|
" thisHasNullName.equals(other);\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_OtherNameNull_ReturnsFalse() {\n"
|
|
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
|
|
f' {class_name} normal = new {class_name}("name", 200);\n'
|
|
' // "name".equals(null) returns false; no exception expected.\n'
|
|
" assertFalse(normal.equals(otherHasNullName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_MaxIntPort_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_MinIntPort_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
|
|
" assertFalse(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEquals_LargeScale_AllEqualInstances_ReturnsTrue() {\n"
|
|
f' {class_name} reference = new {class_name}("perf-host", 4000);\n'
|
|
" boolean allEqual = true;\n"
|
|
" final int iterations = 10000;\n"
|
|
" for (int i = 0; i < iterations; i++) {\n"
|
|
f' {class_name} h = new {class_name}("perf-host", 4000);\n'
|
|
" if (!reference.equals(h)) {\n"
|
|
" allEqual = false;\n"
|
|
" break;\n"
|
|
" }\n"
|
|
" }\n"
|
|
" assertTrue(allEqual);\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
else:
|
|
# JUnit 5
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.jupiter.api.BeforeEach;\n"
|
|
"import org.junit.jupiter.api.Test;\n"
|
|
"import org.junit.jupiter.api.DisplayName;\n"
|
|
"import static org.junit.jupiter.api.Assertions.*;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
"/**\n"
|
|
f" * Unit tests for {module_path}.equals(...)\n"
|
|
" */\n"
|
|
f"class {test_class_name} {{\n"
|
|
f" private {class_name} defaultHost;\n"
|
|
"\n"
|
|
" @BeforeEach\n"
|
|
" void setUp() {\n"
|
|
f' defaultHost = new {class_name}("localhost", 3000);\n'
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Same instance returns true")\n'
|
|
" void testEquals_SameInstance_ReturnsTrue() {\n"
|
|
" assertTrue(defaultHost.equals(defaultHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Equal name and port ignoring TLS returns true")\n'
|
|
" void testEquals_EqualNameAndPort_IgnoringTls_ReturnsTrue() {\n"
|
|
f' {class_name} withTls = new {class_name}("localhost", "server-cert", 3000);\n'
|
|
" assertTrue(defaultHost.equals(withTls) && withTls.equals(defaultHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different port returns false")\n'
|
|
" void testEquals_DifferentPort_ReturnsFalse() {\n"
|
|
f' {class_name} other = new {class_name}("localhost", 3001);\n'
|
|
" assertFalse(defaultHost.equals(other));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different name returns false")\n'
|
|
" void testEquals_DifferentName_ReturnsFalse() {\n"
|
|
f' {class_name} other = new {class_name}("otherhost", 3000);\n'
|
|
" assertFalse(defaultHost.equals(other));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Null argument returns false")\n'
|
|
" void testEquals_NullArgument_ReturnsFalse() {\n"
|
|
" assertFalse(defaultHost.equals(null));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different class returns false")\n'
|
|
" void testEquals_DifferentClass_ReturnsFalse() {\n"
|
|
' Object notAHost = "I am not a Host";\n'
|
|
" assertFalse(defaultHost.equals(notAHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Empty name both returns true")\n'
|
|
" void testEquals_EmptyNameBoth_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("", 0);\n'
|
|
f' {class_name} b = new {class_name}("", null, 0);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Null this.name throws NullPointerException")\n'
|
|
" void testEquals_ThisNameNull_ThrowsNullPointerException() {\n"
|
|
f" {class_name} thisHasNullName = new {class_name}(null, 100);\n"
|
|
f' {class_name} other = new {class_name}("something", 100);\n'
|
|
" assertThrows(NullPointerException.class, () -> thisHasNullName.equals(other));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Null other.name returns false")\n'
|
|
" void testEquals_OtherNameNull_ReturnsFalse() {\n"
|
|
f" {class_name} otherHasNullName = new {class_name}(null, 200);\n"
|
|
f' {class_name} normal = new {class_name}("name", 200);\n'
|
|
" assertFalse(normal.equals(otherHasNullName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Max int port returns true")\n'
|
|
" void testEquals_MaxIntPort_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Min int port returns true")\n'
|
|
" void testEquals_MinIntPort_ReturnsTrue() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MIN_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MIN_VALUE);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Max vs different port returns false")\n'
|
|
" void testEquals_MaxAndDifferentPort_ReturnsFalse() {\n"
|
|
f' {class_name} a = new {class_name}("host", Integer.MAX_VALUE);\n'
|
|
f' {class_name} b = new {class_name}("host", Integer.MAX_VALUE - 1);\n'
|
|
" assertFalse(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Large scale equality check")\n'
|
|
" void testEquals_LargeScale_AllEqualInstances_ReturnsTrue() {\n"
|
|
f' {class_name} reference = new {class_name}("perf-host", 4000);\n'
|
|
" boolean allEqual = true;\n"
|
|
" final int iterations = 10000;\n"
|
|
" for (int i = 0; i < iterations; i++) {\n"
|
|
f' {class_name} h = new {class_name}("perf-host", 4000);\n'
|
|
" if (!reference.equals(h)) {\n"
|
|
" allEqual = false;\n"
|
|
" break;\n"
|
|
" }\n"
|
|
" }\n"
|
|
" assertTrue(allEqual);\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
|
|
|
|
def _build_host_equals_demo_test_source_1(package_name: str, class_name: str, test_framework: str) -> str:
|
|
"""Build demo test source 1 for Host.equals, adapting to the target's package, class, and test framework."""
|
|
module_path = f"{package_name}.{class_name}" if package_name else class_name
|
|
test_class_name = f"{class_name}Test"
|
|
|
|
if test_framework == "junit4":
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.Before;\n"
|
|
"import org.junit.Test;\n"
|
|
"import static org.junit.Assert.*;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"public class {test_class_name} {{\n"
|
|
f" private {class_name} hostSimple;\n"
|
|
f" private {class_name} hostWithTls;\n"
|
|
"\n"
|
|
" @Before\n"
|
|
" public void setUp() {\n"
|
|
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
|
|
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testSameReference_True() {\n"
|
|
" // same instance should be equal to itself\n"
|
|
" assertTrue(hostSimple.equals(hostSimple));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testEqualNameAndPort_True() {\n"
|
|
" // two distinct instances with same name and port (tls ignored) are equal\n"
|
|
f' {class_name} a = new {class_name}("db1", 4000);\n'
|
|
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testDifferentTlsIgnored_True() {\n"
|
|
" // tlsName is ignored for equality\n"
|
|
" assertTrue(hostSimple.equals(hostWithTls));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testDifferentName_False() {\n"
|
|
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
|
|
" assertFalse(hostSimple.equals(otherName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testDifferentPort_False() {\n"
|
|
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
|
|
" assertFalse(hostSimple.equals(otherPort));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testNullComparison_False() {\n"
|
|
" // equals should return false when compared to null\n"
|
|
" assertFalse(hostSimple.equals(null));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testDifferentClass_False() {\n"
|
|
" // equals should return false when compared to an object of another class\n"
|
|
' Object notAHost = "server.example.com:3000";\n'
|
|
" assertFalse(hostSimple.equals(notAHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test(expected = NullPointerException.class)\n"
|
|
" public void testNameNull_ThrowsNullPointerException() {\n"
|
|
" // If this.name is null, equals tries to call this.name.equals(...) and will NPE.\n"
|
|
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
|
|
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
|
|
" // This invocation should throw NPE because this.name is null\n"
|
|
" nullNameHost.equals(otherNullNameHost);\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testOtherNameNull_False() {\n"
|
|
" // If other.name is null but this.name is non-null, equals should return false\n"
|
|
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
|
|
" assertFalse(hostSimple.equals(otherNullName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testPortBoundary_ZeroAndMax_True() {\n"
|
|
f' {class_name} lowA = new {class_name}("edge", 0);\n'
|
|
f' {class_name} lowB = new {class_name}("edge", 0);\n'
|
|
" assertTrue(lowA.equals(lowB));\n"
|
|
"\n"
|
|
f' {class_name} highA = new {class_name}("edge", 65535);\n'
|
|
f' {class_name} highB = new {class_name}("edge", 65535);\n'
|
|
" assertTrue(highA.equals(highB));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testPortBoundary_DifferentPorts_False() {\n"
|
|
f' {class_name} low = new {class_name}("edge", 0);\n'
|
|
f' {class_name} high = new {class_name}("edge", 65535);\n'
|
|
" assertFalse(low.equals(high));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
" public void testLargeScale_Equality_BulkCompare() {\n"
|
|
" // Create many hosts and verify equality behavior (tls ignored) in a loop.\n"
|
|
" final int COUNT = 10000;\n"
|
|
" for (int i = 0; i < COUNT; i++) {\n"
|
|
' String name = "bulk-" + i;\n'
|
|
" int port = 1000 + (i % 1000);\n"
|
|
f" {class_name} a = new {class_name}(name, port);\n"
|
|
f' {class_name} b = new {class_name}(name, "tls-" + i, port);\n'
|
|
' assertTrue("Failed equality at index " + i, a.equals(b));\n'
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
else:
|
|
# JUnit 5
|
|
return (f"package {package_name};\n" if package_name else "") + (
|
|
"\n"
|
|
"import org.junit.jupiter.api.BeforeEach;\n"
|
|
"import org.junit.jupiter.api.Test;\n"
|
|
"import org.junit.jupiter.api.DisplayName;\n"
|
|
"import static org.junit.jupiter.api.Assertions.*;\n"
|
|
"\n"
|
|
f"import {module_path};\n"
|
|
"\n"
|
|
f"class {test_class_name} {{\n"
|
|
f" private {class_name} hostSimple;\n"
|
|
f" private {class_name} hostWithTls;\n"
|
|
"\n"
|
|
" @BeforeEach\n"
|
|
" void setUp() {\n"
|
|
f' hostSimple = new {class_name}("server.example.com", 3000);\n'
|
|
f' hostWithTls = new {class_name}("server.example.com", "tls.server.example.com", 3000);\n'
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Same reference returns true")\n'
|
|
" void testSameReference_True() {\n"
|
|
" assertTrue(hostSimple.equals(hostSimple));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Equal name and port with different TLS returns true")\n'
|
|
" void testEqualNameAndPort_True() {\n"
|
|
f' {class_name} a = new {class_name}("db1", 4000);\n'
|
|
f' {class_name} b = new {class_name}("db1", "tlsNameDifferent", 4000);\n'
|
|
" assertTrue(a.equals(b));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different TLS name is ignored")\n'
|
|
" void testDifferentTlsIgnored_True() {\n"
|
|
" assertTrue(hostSimple.equals(hostWithTls));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different name returns false")\n'
|
|
" void testDifferentName_False() {\n"
|
|
f' {class_name} otherName = new {class_name}("other.example.com", 3000);\n'
|
|
" assertFalse(hostSimple.equals(otherName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different port returns false")\n'
|
|
" void testDifferentPort_False() {\n"
|
|
f' {class_name} otherPort = new {class_name}("server.example.com", 3001);\n'
|
|
" assertFalse(hostSimple.equals(otherPort));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Null comparison returns false")\n'
|
|
" void testNullComparison_False() {\n"
|
|
" assertFalse(hostSimple.equals(null));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different class returns false")\n'
|
|
" void testDifferentClass_False() {\n"
|
|
' Object notAHost = "server.example.com:3000";\n'
|
|
" assertFalse(hostSimple.equals(notAHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Null name throws NullPointerException")\n'
|
|
" void testNameNull_ThrowsNullPointerException() {\n"
|
|
f" {class_name} nullNameHost = new {class_name}(null, 3000);\n"
|
|
f" {class_name} otherNullNameHost = new {class_name}(null, 3000);\n"
|
|
" assertThrows(NullPointerException.class, () -> nullNameHost.equals(otherNullNameHost));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Other name null returns false")\n'
|
|
" void testOtherNameNull_False() {\n"
|
|
f" {class_name} otherNullName = new {class_name}(null, 3000);\n"
|
|
" assertFalse(hostSimple.equals(otherNullName));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Port boundary zero and max")\n'
|
|
" void testPortBoundary_ZeroAndMax_True() {\n"
|
|
f' {class_name} lowA = new {class_name}("edge", 0);\n'
|
|
f' {class_name} lowB = new {class_name}("edge", 0);\n'
|
|
" assertTrue(lowA.equals(lowB));\n"
|
|
"\n"
|
|
f' {class_name} highA = new {class_name}("edge", 65535);\n'
|
|
f' {class_name} highB = new {class_name}("edge", 65535);\n'
|
|
" assertTrue(highA.equals(highB));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Different boundary ports returns false")\n'
|
|
" void testPortBoundary_DifferentPorts_False() {\n"
|
|
f' {class_name} low = new {class_name}("edge", 0);\n'
|
|
f' {class_name} high = new {class_name}("edge", 65535);\n'
|
|
" assertFalse(low.equals(high));\n"
|
|
" }\n"
|
|
"\n"
|
|
" @Test\n"
|
|
' @DisplayName("Large scale bulk equality comparison")\n'
|
|
" void testLargeScale_Equality_BulkCompare() {\n"
|
|
" final int COUNT = 10000;\n"
|
|
" for (int i = 0; i < COUNT; i++) {\n"
|
|
' String name = "bulk-" + i;\n'
|
|
" int port = 1000 + (i % 1000);\n"
|
|
f" {class_name} a = new {class_name}(name, port);\n"
|
|
f' {class_name} b = new {class_name}(name, "tls-" + i, port);\n'
|
|
' assertTrue(a.equals(b), "Failed equality at index " + i);\n'
|
|
" }\n"
|
|
" }\n"
|
|
"}\n"
|
|
)
|
|
|
|
|
|
async def hack_for_demo_java_testgen(data: TestGenSchema) -> TestGenResponseSchema:
|
|
# Extract package and class dynamically from the source code
|
|
source_code = data.source_code_being_tested
|
|
package_name = _extract_package_from_source(source_code) or ""
|
|
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
|
|
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
|
|
|
|
test_index = data.test_index if data.test_index is not None else 0
|
|
|
|
if is_host_equals_demo(source_code):
|
|
if test_index == 0:
|
|
generated_test_source = _build_host_equals_demo_test_source_0(package_name, class_name, test_framework)
|
|
else:
|
|
generated_test_source = _build_host_equals_demo_test_source_1(package_name, class_name, test_framework)
|
|
else:
|
|
if test_index == 0:
|
|
generated_test_source = _build_demo_test_source_0(package_name, class_name, test_framework)
|
|
else:
|
|
generated_test_source = _build_demo_test_source_1(package_name, class_name, test_framework)
|
|
|
|
await asyncio.sleep(5)
|
|
# For Java, instrumentation is done client-side
|
|
return TestGenResponseSchema(
|
|
generated_tests=generated_test_source,
|
|
instrumented_behavior_tests=generated_test_source,
|
|
instrumented_perf_tests=generated_test_source,
|
|
)
|
|
|
|
|
|
async def testgen_java(
|
|
request: AuthenticatedRequest, data: TestGenSchema
|
|
) -> tuple[int, TestGenResponseSchema | TestGenErrorResponseSchema]:
|
|
"""Generate Java tests using LLMs."""
|
|
await asyncio.to_thread(ph, request.user, "aiservice-testgen-java-called")
|
|
|
|
# Validate request
|
|
if not data.function_to_optimize:
|
|
return 400, TestGenErrorResponseSchema(error="Invalid function to optimize. It is empty.")
|
|
if not validate_trace_id(data.trace_id):
|
|
return 400, TestGenErrorResponseSchema(error="Invalid trace ID. Please provide a valid UUIDv4.")
|
|
|
|
logging.info("/testgen: Generating Java tests...")
|
|
|
|
# Demo hack: intercept before LLM call for demo functions
|
|
if should_hack_for_demo_java(data.source_code_being_tested):
|
|
return 200, await hack_for_demo_java_testgen(data)
|
|
|
|
try:
|
|
debug_log_sensitive_data(f"Generating Java tests for function {data.function_to_optimize.function_name}")
|
|
logging.info(f"Generating Java tests for function {data.function_to_optimize.function_name}")
|
|
|
|
# Extract class and package info from source code (more reliable than qualified_name)
|
|
source_code = data.source_code_being_tested
|
|
debug_log_sensitive_data(f"Source code first 200 chars: {source_code[:200]}")
|
|
|
|
package_name = _extract_package_from_source(source_code) or ""
|
|
class_name = _extract_class_from_source(source_code) or data.function_to_optimize.function_name
|
|
|
|
# Build the full module path (package.ClassName)
|
|
if package_name:
|
|
module_path = f"{package_name}.{class_name}"
|
|
else:
|
|
module_path = class_name
|
|
|
|
# Determine test framework (default to junit5 if not specified)
|
|
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
|
|
|
|
logging.info(
|
|
f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
|
|
)
|
|
debug_log_sensitive_data(
|
|
f"Extracted: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}"
|
|
)
|
|
|
|
# Build prompt
|
|
messages, posthog_event_suffix = build_java_prompt(
|
|
function_name=data.function_to_optimize.function_name,
|
|
function_code=data.source_code_being_tested,
|
|
module_path=module_path,
|
|
class_name=class_name,
|
|
package_name=package_name,
|
|
test_framework=test_framework,
|
|
)
|
|
|
|
# Track costs
|
|
cost_tracker: list[float] = []
|
|
|
|
# Generate tests
|
|
generated_test_code = await generate_and_validate_java_test_code(
|
|
messages=messages,
|
|
model=EXECUTE_MODEL,
|
|
cost_tracker=cost_tracker,
|
|
user_id=request.user,
|
|
posthog_event_suffix=posthog_event_suffix,
|
|
trace_id=data.trace_id,
|
|
)
|
|
|
|
# Track analytics
|
|
total_cost = sum(cost_tracker)
|
|
await asyncio.to_thread(
|
|
ph,
|
|
request.user,
|
|
f"aiservice-testgen-{posthog_event_suffix}success",
|
|
properties={
|
|
"trace_id": data.trace_id,
|
|
"total_cost": total_cost,
|
|
"test_count": len(_TEST_FUNC_RE.findall(generated_test_code)),
|
|
},
|
|
)
|
|
|
|
# Update cost tracking
|
|
await update_optimization_cost(data.trace_id, total_cost, request.user)
|
|
|
|
# For Java, instrumentation is done client-side
|
|
# Return the generated tests without server-side instrumentation
|
|
return 200, TestGenResponseSchema(
|
|
generated_tests=generated_test_code,
|
|
instrumented_behavior_tests=generated_test_code, # Client will instrument
|
|
instrumented_perf_tests=generated_test_code, # Client will instrument
|
|
)
|
|
|
|
except TestGenerationFailedError as e:
|
|
logging.warning(f"Java test generation failed: {e}")
|
|
sentry_sdk.capture_exception(e)
|
|
return 400, TestGenErrorResponseSchema(error=str(e))
|
|
except (ValueError, SyntaxError) as e:
|
|
logging.warning(f"Java test generation error: {e}")
|
|
sentry_sdk.capture_exception(e)
|
|
return 400, TestGenErrorResponseSchema(error=f"Failed to generate valid tests: {e}")
|
|
except Exception as e:
|
|
logging.exception("Unexpected error in Java test generation")
|
|
sentry_sdk.capture_exception(e)
|
|
return 500, TestGenErrorResponseSchema(error=f"Internal error: {e}")
|