Merge pull request #2342 from codeflash-ai/java-junit4-support
feat: add JUnit 4 test generation support for Java
This commit is contained in:
commit
ea84d3a944
1 changed files with 138 additions and 14 deletions
|
|
@ -39,8 +39,8 @@ JAVA_PROMPTS_DIR = current_dir / "prompts" / "testgen"
|
|||
# Ensure prompts directory exists
|
||||
JAVA_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Java system prompt for test generation
|
||||
JAVA_SYSTEM_PROMPT = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
|
||||
# Java system prompt for test generation - JUnit 5 (default)
|
||||
JAVA_SYSTEM_PROMPT_JUNIT5 = """You are an expert Java developer specializing in writing comprehensive JUnit 5 tests.
|
||||
Your task is to generate high-quality unit tests for the given Java function.
|
||||
|
||||
Guidelines:
|
||||
|
|
@ -56,7 +56,25 @@ Guidelines:
|
|||
Function to test: {function_name}
|
||||
"""
|
||||
|
||||
JAVA_USER_PROMPT = """Generate JUnit 5 tests for the following Java function.
|
||||
# Java system prompt for JUnit 4
|
||||
JAVA_SYSTEM_PROMPT_JUNIT4 = """You are an expert Java developer specializing in writing comprehensive JUnit 4 tests.
|
||||
Your task is to generate high-quality unit tests for the given Java function.
|
||||
|
||||
Guidelines:
|
||||
1. Use JUnit 4 test annotations (@Test, @Before, etc.) - NOT JUnit 5
|
||||
2. Include imports for org.junit.* (NOT org.junit.jupiter.api.*)
|
||||
3. Use static imports: import static org.junit.Assert.*
|
||||
4. Use assertEquals, assertTrue, assertFalse as appropriate. For exceptions, use @Test(expected=ExceptionType.class)
|
||||
5. Test edge cases, boundary conditions, and typical use cases
|
||||
6. Use descriptive test method names following the pattern: test<Scenario>_<ExpectedBehavior>
|
||||
7. Include a test class with the pattern: <ClassName>Test
|
||||
8. Do NOT use mocks unless absolutely necessary
|
||||
9. Keep tests simple and focused on one assertion per test when possible
|
||||
|
||||
Function to test: {function_name}
|
||||
"""
|
||||
|
||||
JAVA_USER_PROMPT_JUNIT5 = """Generate JUnit 5 tests for the following Java function.
|
||||
|
||||
Function name: {function_name}
|
||||
Class name: {class_name}
|
||||
|
|
@ -106,6 +124,58 @@ public class {class_name}Test {{
|
|||
```
|
||||
"""
|
||||
|
||||
JAVA_USER_PROMPT_JUNIT4 = """Generate JUnit 4 tests for the following Java function.
|
||||
|
||||
Function name: {function_name}
|
||||
Class name: {class_name}
|
||||
Full qualified name: {module_path}
|
||||
Package: {package_name}
|
||||
|
||||
Source code:
|
||||
```java
|
||||
{function_code}
|
||||
```
|
||||
|
||||
Generate comprehensive unit tests that cover:
|
||||
1. Basic functionality with typical inputs
|
||||
2. Edge cases (empty inputs, null values, boundary conditions)
|
||||
3. Error conditions (if the function can throw exceptions)
|
||||
4. Large-scale inputs for performance verification
|
||||
|
||||
IMPORTANT REQUIREMENTS:
|
||||
1. The package declaration MUST be: package {package_name};
|
||||
2. You MUST import the class under test: import {module_path};
|
||||
3. The test class MUST be named {class_name}Test
|
||||
4. Use JUnit 4 annotations and imports (NOT JUnit 5/Jupiter)
|
||||
5. Create an instance of {class_name} in @Before or directly in test methods
|
||||
|
||||
Wrap your response in a Java code block (```java ... ```).
|
||||
|
||||
Example structure:
|
||||
```java
|
||||
package {package_name};
|
||||
|
||||
import org.junit.Test;
|
||||
import org.junit.Before;
|
||||
import static org.junit.Assert.*;
|
||||
import {module_path};
|
||||
|
||||
public class {class_name}Test {{
|
||||
private {class_name} instance;
|
||||
|
||||
@Before
|
||||
public void setUp() {{
|
||||
instance = new {class_name}();
|
||||
}}
|
||||
|
||||
@Test
|
||||
public void testBasicFunctionality() {{
|
||||
// Test code here
|
||||
}}
|
||||
}}
|
||||
```
|
||||
"""
|
||||
|
||||
# Pattern to extract Java code blocks
|
||||
JAVA_PATTERN = re.compile(r"^```(?:java)?\s*\n(.*?)\n```", re.MULTILINE | re.DOTALL)
|
||||
|
||||
|
|
@ -116,6 +186,7 @@ def build_java_prompt(
|
|||
module_path: str,
|
||||
class_name: str,
|
||||
package_name: str,
|
||||
test_framework: str = "junit5",
|
||||
) -> tuple[list[ChatCompletionMessageParam], str]:
|
||||
"""Build the prompt messages for Java test generation.
|
||||
|
||||
|
|
@ -125,19 +196,28 @@ def build_java_prompt(
|
|||
module_path: Import path for the module
|
||||
class_name: Name of the class containing the function
|
||||
package_name: Package name for the test class
|
||||
test_framework: Test framework to use ("junit5" or "junit4")
|
||||
|
||||
Returns:
|
||||
Tuple of (messages, posthog_event_suffix)
|
||||
|
||||
"""
|
||||
# Select prompts based on test framework
|
||||
if test_framework == "junit4":
|
||||
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT4
|
||||
user_prompt = JAVA_USER_PROMPT_JUNIT4
|
||||
else:
|
||||
system_prompt = JAVA_SYSTEM_PROMPT_JUNIT5
|
||||
user_prompt = JAVA_USER_PROMPT_JUNIT5
|
||||
|
||||
system_message: ChatCompletionMessageParam = {
|
||||
"role": "system",
|
||||
"content": JAVA_SYSTEM_PROMPT.format(function_name=function_name),
|
||||
"content": system_prompt.format(function_name=function_name),
|
||||
}
|
||||
|
||||
user_message: ChatCompletionMessageParam = {
|
||||
"role": "user",
|
||||
"content": JAVA_USER_PROMPT.format(
|
||||
"content": user_prompt.format(
|
||||
function_name=function_name,
|
||||
function_code=function_code,
|
||||
module_path=module_path,
|
||||
|
|
@ -147,7 +227,7 @@ def build_java_prompt(
|
|||
}
|
||||
|
||||
messages: list[ChatCompletionMessageParam] = [system_message, user_message]
|
||||
return messages, "java-"
|
||||
return messages, f"java-{test_framework}-"
|
||||
|
||||
|
||||
def parse_and_validate_java_output(response_content: str) -> str:
|
||||
|
|
@ -281,10 +361,11 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
Package name (e.g., "com.example"), or None if not found
|
||||
|
||||
"""
|
||||
# First try: package declaration in source
|
||||
# First try: package declaration in source (most reliable)
|
||||
package_pattern = re.compile(r'^\s*package\s+([\w.]+)\s*;', re.MULTILINE)
|
||||
match = package_pattern.search(source_code)
|
||||
if match:
|
||||
logging.debug(f"Extracted package from declaration: {match.group(1)}")
|
||||
return match.group(1)
|
||||
|
||||
# Second try: extract from markdown code block header (e.g., "```java:src/main/java/com/example/Algorithms.java")
|
||||
|
|
@ -294,17 +375,39 @@ def _extract_package_from_source(source_code: str) -> str | None:
|
|||
file_path = markdown_match.group(1).strip()
|
||||
package = _extract_package_from_path(file_path)
|
||||
if package:
|
||||
logging.debug(f"Extracted package from markdown header: {package}")
|
||||
return package
|
||||
|
||||
# Third try: extract from file path comment (e.g., "// file: src/main/java/com/example/Algorithms.java")
|
||||
# Also handle "// file: src/com/example/Algorithms.java" (non-standard Maven)
|
||||
file_comment_pattern = re.compile(r'//\s*file:\s*([^\n]+\.java)', re.IGNORECASE)
|
||||
file_match = file_comment_pattern.search(source_code)
|
||||
if file_match:
|
||||
file_path = file_match.group(1).strip()
|
||||
logging.debug(f"Found file comment: {file_path}")
|
||||
package = _extract_package_from_path(file_path)
|
||||
if package:
|
||||
logging.debug(f"Extracted package from file comment: {package}")
|
||||
return package
|
||||
|
||||
# Fourth try: infer package from import statements (last resort)
|
||||
# Look for imports that might indicate the package structure
|
||||
import_pattern = re.compile(r'^\s*import\s+([\w.]+)\.[\w*]+\s*;', re.MULTILINE)
|
||||
imports = import_pattern.findall(source_code)
|
||||
if imports:
|
||||
# Find common prefix among imports that look like internal packages
|
||||
# Exclude common library packages
|
||||
internal_imports = [
|
||||
imp for imp in imports
|
||||
if not imp.startswith(('java.', 'javax.', 'org.junit', 'org.apache', 'com.google'))
|
||||
]
|
||||
if internal_imports:
|
||||
# Use the shortest import path as a hint
|
||||
shortest = min(internal_imports, key=len)
|
||||
logging.debug(f"Inferred package from imports: {shortest}")
|
||||
return shortest
|
||||
|
||||
logging.warning("Could not extract package name from source code")
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -312,25 +415,42 @@ def _extract_package_from_path(file_path: str) -> str | None:
|
|||
"""Extract Java package from a file path.
|
||||
|
||||
Args:
|
||||
file_path: e.g., "src/main/java/com/example/Algorithms.java"
|
||||
file_path: e.g., "src/main/java/com/example/Algorithms.java" or "src/com/example/Algorithms.java"
|
||||
|
||||
Returns:
|
||||
Package name (e.g., "com.example"), or None if not found
|
||||
|
||||
"""
|
||||
# Look for "src/main/java/" or "src/test/java/" patterns
|
||||
# Normalize slashes
|
||||
file_path = file_path.replace('\\', '/')
|
||||
|
||||
# Standard Maven patterns (highest priority)
|
||||
java_src_patterns = ['/src/main/java/', '/src/test/java/', 'src/main/java/', 'src/test/java/']
|
||||
for pattern in java_src_patterns:
|
||||
if pattern in file_path:
|
||||
# Get the part after src/main/java/
|
||||
idx = file_path.find(pattern)
|
||||
remaining = file_path[idx + len(pattern):]
|
||||
# Convert path to package (e.g., "com/example/Algorithms.java" -> "com.example")
|
||||
parts = remaining.split('/')
|
||||
if len(parts) > 1:
|
||||
package_parts = parts[:-1] # Remove the filename
|
||||
return '.'.join(package_parts)
|
||||
break
|
||||
|
||||
# Non-standard paths: look for "src/" followed by what looks like a package structure
|
||||
# e.g., "src/com/aerospike/client/util/Crypto.java" -> "com.aerospike.client.util"
|
||||
# or "client/src/com/aerospike/client/util/Crypto.java"
|
||||
src_patterns = ['/src/', 'src/']
|
||||
for pattern in src_patterns:
|
||||
if pattern in file_path:
|
||||
idx = file_path.find(pattern)
|
||||
remaining = file_path[idx + len(pattern):]
|
||||
parts = remaining.split('/')
|
||||
if len(parts) > 1:
|
||||
# Check if first part looks like a package (lowercase, not 'main', 'test', 'java', 'resources')
|
||||
first_part = parts[0].lower()
|
||||
if first_part not in ('main', 'test', 'java', 'resources') and first_part.isalpha():
|
||||
package_parts = parts[:-1] # Remove the filename
|
||||
return '.'.join(package_parts)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -405,8 +525,11 @@ async def testgen_java(
|
|||
else:
|
||||
module_path = class_name
|
||||
|
||||
logging.info(f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}")
|
||||
debug_log_sensitive_data(f"Extracted: package={package_name}, class={class_name}, module_path={module_path}")
|
||||
# Determine test framework (default to junit5 if not specified)
|
||||
test_framework = data.test_framework if data.test_framework in ("junit4", "junit5") else "junit5"
|
||||
|
||||
logging.info(f"Java testgen: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}")
|
||||
debug_log_sensitive_data(f"Extracted: package={package_name}, class={class_name}, module_path={module_path}, framework={test_framework}")
|
||||
|
||||
# Build prompt
|
||||
messages, posthog_event_suffix = build_java_prompt(
|
||||
|
|
@ -415,6 +538,7 @@ async def testgen_java(
|
|||
module_path=module_path,
|
||||
class_name=class_name,
|
||||
package_name=package_name,
|
||||
test_framework=test_framework,
|
||||
)
|
||||
|
||||
# Track costs
|
||||
|
|
|
|||
Loading…
Reference in a new issue