Merge pull request #2342 from codeflash-ai/java-junit4-support

feat: add JUnit 4 test generation support for Java
This commit is contained in:
Saurabh Misra 2026-01-31 02:00:55 -08:00 committed by GitHub
commit ea84d3a944
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -39,8 +39,8 @@ JAVA_PROMPTS_DIR = current_dir / "prompts" / "testgen"
# Ensure prompts directory exists
JAVA_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
# Java system prompt for test generation
JAVA_SYSTEM_PROMPT = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
# Java system prompt for test generation - JUnit 5 (default)
JAVA_SYSTEM_PROMPT_JUNIT5 = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
Your task is to generate high-quality unit tests for the given Java function.
Guidelines:
@ -56,7 +56,25 @@ Guidelines:
Function to test: {function_name}
"""
JAVA_USER_PROMPT = """Generate JUnit 5 tests for the following Java function.
# Java system prompt for JUnit 4
JAVA_SYSTEM_PROMPT_JUNIT4 = """You are an expert Java developer specializing in writing comprehensive JUnit 4 tests.
Your task is to generate high-quality unit tests for the given Java function.
Guidelines:
1. Use JUnit 4 test annotations (@Test, @Before, etc.) - NOT JUnit 5
2. Include imports for org.junit.* (NOT org.junit.jupiter.api.*)
3. Use static imports: import static org.junit.Assert.*
4. Use assertEquals, assertTrue, assertFalse as appropriate. For exceptions, use @Test(expected=ExceptionType.class)
5. Test edge cases, boundary conditions, and typical use cases
6. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
7. Include a test class with the pattern: <ClassName>Test
8. Do NOT use mocks unless absolutely necessary
9. Keep tests simple and focused on one assertion per test when possible
Function to test: {function_name}
"""
JAVA_USER_PROMPT_JUNIT5 = """Generate JUnit 5 tests for the following Java function.
Function name: {function_name}
Class name: {class_name}
@ -106,6 +124,58 @@ public class {class_name}Test {{
```
"""
JAVA_USER_PROMPT_JUNIT4 = """Generate JUnit 4 tests for the following Java function.
Function name: {function_name}
Class name: {class_name}
Full qualified name: {module_path}
Package: {package_name}
Source code:
```java
{function_code}
```
Generate comprehensive unit tests that cover:
1. Basic functionality with typical inputs
2. Edge cases (empty inputs, null values, boundary conditions)
3. Error conditions (if the function can throw exceptions)
4. Large-scale inputs for performance verification
IMPORTANT REQUIREMENTS:
1. The package declaration MUST be: package {package_name};
2. You MUST import the class under test: import {module_path};
3. The test class MUST be named {class_name}Test
4. Use JUnit 4 annotations and imports (NOT JUnit 5/Jupiter)
5. Create an instance of {class_name} in @Before or directly in test methods
Wrap your response in a Java code block (```java ... ```).
Example structure:
```java
package {package_name};
import org.junit.Test;
import org.junit.Before;
import static org.junit.Assert.*;
import {module_path};
public class {class_name}Test {{
private {class_name} instance;
@Before
public void setUp() {{
instance = new {class_name}();
}}
@Test
public void testBasicFunctionality() {{
// Test code here
}}
}}
```
"""
# Pattern to extract Java code blocks
JAVA_PATTERN = re.compile(r"^```(?:java)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
@ -116,6 +186,7 @@ def build_java_prompt(
module_path: str,
class_name: str,
package_name: str,
test_framework: str = "junit5",
) -> tuple[list[ChatCompletionMessageParam], str]:
"""Build the prompt messages for Java test generation.
@ -125,19 +196,28 @@ def build_java_prompt(
module_path: Import path for the module
class_name: Name of the class containing the function
package_name: Package name for the test class
test_framework: Test framework to use ("junit5" or "junit4")
Returns:
Tuple of (messages, posthog_event_suffix)
"""
# Select prompts based on test framework
if test_framework == "junit4":
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT4
user_prompt = JAVA_USER_PROMPT_JUNIT4
else:
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT5
user_prompt = JAVA_USER_PROMPT_JUNIT5
system_message: ChatCompletionMessageParam = {
"role": "system",
"content": JAVA_SYSTEM_PROMPT.format(function_name=function_name),
"content": system_prompt.format(function_name=function_name),
}
user_message: ChatCompletionMessageParam = {
"role": "user",
"content": JAVA_USER_PROMPT.format(
"content": user_prompt.format(
function_name=function_name,
function_code=function_code,
module_path=module_path,
@ -147,7 +227,7 @@ def build_java_prompt(
}
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
return messages, "java-"
return messages, f"java-{test_framework}-"
def parse_and_validate_java_output(response_content: str) -> str:
@ -281,10 +361,11 @@ def _extract_package_from_source(source_code: str) -> str | None:
Package name (e.g., "com.example"), or None if not found
"""
# First try: package declaration in source
# First try: package declaration in source (most reliable)
package_pattern = re.compile(r'^\s*package\s+([\w.]+)\s*;', re.MULTILINE)
match = package_pattern.search(source_code)
if match:
logging.debug(f"Extracted package from declaration: {match.group(1)}")
return match.group(1)
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
@ -294,17 +375,39 @@ def _extract_package_from_source(source_code: str) -> str | None:
file_path = markdown_match.group(1).strip()
package = _extract_package_from_path(file_path)
if package:
logging.debug(f"Extracted package from markdown header: {package}")
return package
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
# Also handle "// file: src/com/example/Algorithms.java" (non-standard Maven)
file_comment_pattern = re.compile(r'//\s*file:\s*([^\n]+\.java)', re.IGNORECASE)
file_match = file_comment_pattern.search(source_code)
if file_match:
file_path = file_match.group(1).strip()
logging.debug(f"Found file comment: {file_path}")
package = _extract_package_from_path(file_path)
if package:
logging.debug(f"Extracted package from file comment: {package}")
return package
# Fourth try: infer package from import statements (last resort)
# Look for imports that might indicate the package structure
import_pattern = re.compile(r'^\s*import\s+([\w.]+)\.[\w*]+\s*;', re.MULTILINE)
imports = import_pattern.findall(source_code)
if imports:
# Find common prefix among imports that look like internal packages
# Exclude common library packages
internal_imports = [
imp for imp in imports
if not imp.startswith(('java.', 'javax.', 'org.junit', 'org.apache', 'com.google'))
]
if internal_imports:
# Use the shortest import path as a hint
shortest = min(internal_imports, key=len)
logging.debug(f"Inferred package from imports: {shortest}")
return shortest
logging.warning("Could not extract package name from source code")
return None
@ -312,25 +415,42 @@ def _extract_package_from_path(file_path: str) -> str | None:
"""Extract Java package from a file path.
Args:
file_path: e.g., "src/main/java/com/example/Algorithms.java"
file_path: e.g., "src/main/java/com/example/Algorithms.java" or "src/com/example/Algorithms.java"
Returns:
Package name (e.g., "com.example"), or None if not found
"""
# Look for "src/main/java/" or "src/test/java/" patterns
# Normalize slashes
file_path = file_path.replace('\\', '/')
# Standard Maven patterns (highest priority)
java_src_patterns = ['/src/main/java/', '/src/test/java/', 'src/main/java/', 'src/test/java/']
for pattern in java_src_patterns:
if pattern in file_path:
# Get the part after src/main/java/
idx = file_path.find(pattern)
remaining = file_path[idx + len(pattern):]
# Convert path to package (e.g., "com/example/Algorithms.java" -> "com.example")
parts = remaining.split('/')
if len(parts) > 1:
package_parts = parts[:-1] # Remove the filename
return '.'.join(package_parts)
break
# Non-standard paths: look for "src/" followed by what looks like a package structure
# e.g., "src/com/aerospike/client/util/Crypto.java" -> "com.aerospike.client.util"
# or "client/src/com/aerospike/client/util/Crypto.java"
src_patterns = ['/src/', 'src/']
for pattern in src_patterns:
if pattern in file_path:
idx = file_path.find(pattern)
remaining = file_path[idx + len(pattern):]
parts = remaining.split('/')
if len(parts) > 1:
# Check if first part looks like a package (lowercase, not 'main', 'test', 'java', 'resources')
first_part = parts[0].lower()
if first_part not in ('main', 'test', 'java', 'resources') and first_part.isalpha():
package_parts = parts[:-1] # Remove the filename
return '.'.join(package_parts)
return None
@ -405,8 +525,11 @@ async def testgen_java(
else:
module_path = class_name
logging.info(f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}")
debug_log_sensitive_data(f"Extracted: package={package_name}, class={class_name}, module_path={module_path}")
# Determine test framework (default to junit5 if not specified)
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
logging.info(f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}")
debug_log_sensitive_data(f"Extracted: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}")
# Build prompt
messages, posthog_event_suffix = build_java_prompt(
@ -415,6 +538,7 @@ async def testgen_java(
module_path=module_path,
class_name=class_name,
package_name=package_name,
test_framework=test_framework,
)
# Track costs