mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
feat: implement Java assertion removal transformer
Add a robust Java assert removal transformer to convert generated unit tests into regression tests. This removes assertion statements while preserving function calls, enabling behavioral verification by comparing outputs between original and optimized code. Key features: - Support for JUnit 5 assertions (assertEquals, assertTrue, assertThrows, etc.) - Support for JUnit 4 assertions (org.junit.Assert.*) - Support for AssertJ fluent assertions (assertThat().isEqualTo()) - Support for TestNG and Hamcrest assertions - Framework auto-detection from imports - Handles assertAll grouped assertions - Preserves non-assertion code (setup, Mockito mocks, etc.) - 57 comprehensive tests with exact string equality assertions Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c587c47521
commit
31c90f0391
4 changed files with 1815 additions and 100 deletions
|
|
@ -21,10 +21,7 @@ from codeflash.languages.java.build_tools import (
|
|||
install_codeflash_runtime,
|
||||
run_maven_tests,
|
||||
)
|
||||
from codeflash.languages.java.comparator import (
|
||||
compare_invocations_directly,
|
||||
compare_test_results,
|
||||
)
|
||||
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results
|
||||
from codeflash.languages.java.config import (
|
||||
JavaProjectConfig,
|
||||
detect_java_project,
|
||||
|
|
@ -46,12 +43,7 @@ from codeflash.languages.java.discovery import (
|
|||
get_class_methods,
|
||||
get_method_by_name,
|
||||
)
|
||||
from codeflash.languages.java.formatter import (
|
||||
JavaFormatter,
|
||||
format_java_code,
|
||||
format_java_file,
|
||||
normalize_java_code,
|
||||
)
|
||||
from codeflash.languages.java.formatter import JavaFormatter, format_java_code, format_java_file, normalize_java_code
|
||||
from codeflash.languages.java.import_resolver import (
|
||||
JavaImportResolver,
|
||||
ResolvedImport,
|
||||
|
|
@ -63,6 +55,7 @@ from codeflash.languages.java.instrumentation import (
|
|||
instrument_existing_test,
|
||||
instrument_for_behavior,
|
||||
instrument_for_benchmarking,
|
||||
instrument_generated_java_test,
|
||||
remove_instrumentation,
|
||||
)
|
||||
from codeflash.languages.java.parser import (
|
||||
|
|
@ -73,6 +66,11 @@ from codeflash.languages.java.parser import (
|
|||
JavaMethodNode,
|
||||
get_java_analyzer,
|
||||
)
|
||||
from codeflash.languages.java.remove_asserts import (
|
||||
JavaAssertTransformer,
|
||||
remove_assertions_from_test,
|
||||
transform_java_assertions,
|
||||
)
|
||||
from codeflash.languages.java.replacement import (
|
||||
add_runtime_comments,
|
||||
insert_method,
|
||||
|
|
@ -81,10 +79,7 @@ from codeflash.languages.java.replacement import (
|
|||
replace_function,
|
||||
replace_method_body,
|
||||
)
|
||||
from codeflash.languages.java.support import (
|
||||
JavaSupport,
|
||||
get_java_support,
|
||||
)
|
||||
from codeflash.languages.java.support import JavaSupport, get_java_support
|
||||
from codeflash.languages.java.test_discovery import (
|
||||
build_test_mapping_for_project,
|
||||
discover_all_tests,
|
||||
|
|
@ -106,90 +101,95 @@ from codeflash.languages.java.test_runner import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
# Parser
|
||||
"JavaAnalyzer",
|
||||
"JavaClassNode",
|
||||
"JavaFieldInfo",
|
||||
"JavaImportInfo",
|
||||
"JavaMethodNode",
|
||||
"get_java_analyzer",
|
||||
# Build tools
|
||||
"BuildTool",
|
||||
# Parser
|
||||
"JavaAnalyzer",
|
||||
# Assertion removal
|
||||
"JavaAssertTransformer",
|
||||
"JavaClassNode",
|
||||
"JavaFieldInfo",
|
||||
# Formatter
|
||||
"JavaFormatter",
|
||||
"JavaImportInfo",
|
||||
# Import resolver
|
||||
"JavaImportResolver",
|
||||
"JavaMethodNode",
|
||||
# Config
|
||||
"JavaProjectConfig",
|
||||
"JavaProjectInfo",
|
||||
# Support
|
||||
"JavaSupport",
|
||||
# Test runner
|
||||
"JavaTestRunResult",
|
||||
"MavenTestResult",
|
||||
"ResolvedImport",
|
||||
"add_codeflash_dependency_to_pom",
|
||||
"compile_maven_project",
|
||||
"detect_build_tool",
|
||||
"find_gradle_executable",
|
||||
"find_maven_executable",
|
||||
"find_source_root",
|
||||
"find_test_root",
|
||||
"get_classpath",
|
||||
"get_project_info",
|
||||
"install_codeflash_runtime",
|
||||
"run_maven_tests",
|
||||
# Replacement
|
||||
"add_runtime_comments",
|
||||
# Test discovery
|
||||
"build_test_mapping_for_project",
|
||||
# Comparator
|
||||
"compare_invocations_directly",
|
||||
"compare_test_results",
|
||||
# Config
|
||||
"JavaProjectConfig",
|
||||
"compile_maven_project",
|
||||
# Instrumentation
|
||||
"create_benchmark_test",
|
||||
"detect_build_tool",
|
||||
"detect_java_project",
|
||||
"get_test_class_pattern",
|
||||
"get_test_file_pattern",
|
||||
"is_java_project",
|
||||
"discover_all_tests",
|
||||
# Discovery
|
||||
"discover_functions",
|
||||
"discover_functions_from_source",
|
||||
"discover_test_methods",
|
||||
"discover_tests",
|
||||
# Context
|
||||
"extract_class_context",
|
||||
"extract_code_context",
|
||||
"extract_function_source",
|
||||
"extract_read_only_context",
|
||||
"find_gradle_executable",
|
||||
"find_helper_files",
|
||||
"find_helper_functions",
|
||||
# Discovery
|
||||
"discover_functions",
|
||||
"discover_functions_from_source",
|
||||
"discover_test_methods",
|
||||
"get_class_methods",
|
||||
"get_method_by_name",
|
||||
# Formatter
|
||||
"JavaFormatter",
|
||||
"find_maven_executable",
|
||||
"find_source_root",
|
||||
"find_test_root",
|
||||
"find_tests_for_function",
|
||||
"format_java_code",
|
||||
"format_java_file",
|
||||
"normalize_java_code",
|
||||
# Import resolver
|
||||
"JavaImportResolver",
|
||||
"ResolvedImport",
|
||||
"find_helper_files",
|
||||
"resolve_imports_for_file",
|
||||
# Instrumentation
|
||||
"create_benchmark_test",
|
||||
"get_class_methods",
|
||||
"get_classpath",
|
||||
"get_java_analyzer",
|
||||
"get_java_support",
|
||||
"get_method_by_name",
|
||||
"get_project_info",
|
||||
"get_test_class_for_source_class",
|
||||
"get_test_class_pattern",
|
||||
"get_test_file_pattern",
|
||||
"get_test_file_suffix",
|
||||
"get_test_methods_for_class",
|
||||
"get_test_run_command",
|
||||
"insert_method",
|
||||
"install_codeflash_runtime",
|
||||
"instrument_existing_test",
|
||||
"instrument_for_behavior",
|
||||
"instrument_for_benchmarking",
|
||||
"instrument_generated_java_test",
|
||||
"is_java_project",
|
||||
"is_test_file",
|
||||
"normalize_java_code",
|
||||
"parse_surefire_results",
|
||||
"parse_test_results",
|
||||
"remove_assertions_from_test",
|
||||
"remove_instrumentation",
|
||||
# Replacement
|
||||
"add_runtime_comments",
|
||||
"insert_method",
|
||||
"remove_method",
|
||||
"remove_test_functions",
|
||||
"replace_function",
|
||||
"replace_method_body",
|
||||
# Support
|
||||
"JavaSupport",
|
||||
"get_java_support",
|
||||
# Test discovery
|
||||
"build_test_mapping_for_project",
|
||||
"discover_all_tests",
|
||||
"discover_tests",
|
||||
"find_tests_for_function",
|
||||
"get_test_class_for_source_class",
|
||||
"get_test_file_suffix",
|
||||
"get_test_methods_for_class",
|
||||
"is_test_file",
|
||||
# Test runner
|
||||
"JavaTestRunResult",
|
||||
"get_test_run_command",
|
||||
"parse_surefire_results",
|
||||
"parse_test_results",
|
||||
"resolve_imports_for_file",
|
||||
"run_behavioral_tests",
|
||||
"run_benchmarking_tests",
|
||||
"run_maven_tests",
|
||||
"run_tests",
|
||||
"transform_java_assertions",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -55,9 +55,7 @@ def _get_qualified_name(func: Any) -> str:
|
|||
|
||||
|
||||
def instrument_for_behavior(
|
||||
source: str,
|
||||
functions: Sequence[FunctionToOptimize],
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
source: str, functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None
|
||||
) -> str:
|
||||
"""Add behavior instrumentation to capture inputs/outputs.
|
||||
|
||||
|
|
@ -83,9 +81,7 @@ def instrument_for_behavior(
|
|||
|
||||
|
||||
def instrument_for_benchmarking(
|
||||
test_source: str,
|
||||
target_function: FunctionToOptimize,
|
||||
analyzer: JavaAnalyzer | None = None,
|
||||
test_source: str, target_function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None
|
||||
) -> str:
|
||||
"""Add timing instrumentation to test code.
|
||||
|
||||
|
|
@ -168,19 +164,9 @@ def instrument_existing_test(
|
|||
)
|
||||
else:
|
||||
# Behavior mode: add timing instrumentation that also writes to SQLite
|
||||
modified_source = _add_behavior_instrumentation(
|
||||
modified_source,
|
||||
original_class_name,
|
||||
func_name,
|
||||
)
|
||||
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
|
||||
|
||||
logger.debug(
|
||||
"Java %s testing for %s: renamed class %s -> %s",
|
||||
mode,
|
||||
func_name,
|
||||
original_class_name,
|
||||
new_class_name,
|
||||
)
|
||||
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
|
||||
|
||||
return True, modified_source
|
||||
|
||||
|
|
@ -325,8 +311,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
# - new ClassName(args)
|
||||
# - this
|
||||
method_call_pattern = re.compile(
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
|
||||
re.MULTILINE
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
|
||||
)
|
||||
|
||||
for body_line in body_lines:
|
||||
|
|
@ -346,7 +331,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
full_call = match.group(0) # e.g., "new StringUtils().reverse(\"hello\")"
|
||||
|
||||
# Replace this occurrence with the variable
|
||||
new_line = new_line[:match.start()] + var_name + new_line[match.end():]
|
||||
new_line = new_line[: match.start()] + var_name + new_line[match.end() :]
|
||||
|
||||
# Insert capture line
|
||||
capture_line = f"{line_indent_str}Object {var_name} = {full_call};"
|
||||
|
|
@ -573,10 +558,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
|
|||
|
||||
|
||||
def create_benchmark_test(
|
||||
target_function: FunctionToOptimize,
|
||||
test_setup_code: str,
|
||||
invocation_code: str,
|
||||
iterations: int = 1000,
|
||||
target_function: FunctionToOptimize, test_setup_code: str, invocation_code: str, iterations: int = 1000
|
||||
) -> str:
|
||||
"""Create a benchmark test for a function.
|
||||
|
||||
|
|
@ -654,6 +636,11 @@ def instrument_generated_java_test(
|
|||
) -> str:
|
||||
"""Instrument a generated Java test for behavior or performance testing.
|
||||
|
||||
For generated tests (AI-generated), this function:
|
||||
1. Removes assertions and captures function return values (for regression testing)
|
||||
2. Renames the class to include mode suffix
|
||||
3. Adds timing instrumentation for performance mode
|
||||
|
||||
Args:
|
||||
test_code: The generated test source code.
|
||||
function_name: Name of the function being tested.
|
||||
|
|
@ -664,6 +651,13 @@ def instrument_generated_java_test(
|
|||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
from codeflash.languages.java.remove_asserts import transform_java_assertions
|
||||
|
||||
# For behavior mode, remove assertions and capture function return values
|
||||
# This converts the generated test into a regression test that captures outputs
|
||||
if mode == "behavior":
|
||||
test_code = transform_java_assertions(test_code, function_name, qualified_name)
|
||||
|
||||
# Extract class name from the test code
|
||||
# Use pattern that starts at beginning of line to avoid matching words in comments
|
||||
class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", test_code, re.MULTILINE)
|
||||
|
|
@ -681,9 +675,7 @@ def instrument_generated_java_test(
|
|||
|
||||
# Rename the class in the source
|
||||
modified_code = re.sub(
|
||||
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b",
|
||||
rf"\1class {new_class_name}",
|
||||
test_code,
|
||||
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b", rf"\1class {new_class_name}", test_code
|
||||
)
|
||||
|
||||
# For performance mode, add timing instrumentation
|
||||
|
|
|
|||
759
codeflash/languages/java/remove_asserts.py
Normal file
759
codeflash/languages/java/remove_asserts.py
Normal file
|
|
@ -0,0 +1,759 @@
|
|||
"""Java assertion removal transformer for converting tests to regression tests.
|
||||
|
||||
This module removes assertion statements from Java test code while preserving
|
||||
function calls, enabling behavioral verification by comparing outputs between
|
||||
original and optimized code.
|
||||
|
||||
Supported frameworks:
|
||||
- JUnit 5 (Jupiter): assertEquals, assertTrue, assertThrows, etc.
|
||||
- JUnit 4: org.junit.Assert.*
|
||||
- AssertJ: assertThat(...).isEqualTo(...)
|
||||
- TestNG: org.testng.Assert.*
|
||||
- Hamcrest: assertThat(actual, is(expected))
|
||||
- Truth: assertThat(actual).isEqualTo(expected)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.parser import JavaAnalyzer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# JUnit 5 assertion methods that take (expected, actual, ...) or (actual, ...)
|
||||
JUNIT5_VALUE_ASSERTIONS = frozenset(
|
||||
{
|
||||
"assertEquals",
|
||||
"assertNotEquals",
|
||||
"assertSame",
|
||||
"assertNotSame",
|
||||
"assertArrayEquals",
|
||||
"assertIterableEquals",
|
||||
"assertLinesMatch",
|
||||
}
|
||||
)
|
||||
|
||||
# JUnit 5 assertions that take a single boolean/object argument
|
||||
JUNIT5_CONDITION_ASSERTIONS = frozenset({"assertTrue", "assertFalse", "assertNull", "assertNotNull"})
|
||||
|
||||
# JUnit 5 assertions that handle exceptions (need special treatment)
|
||||
JUNIT5_EXCEPTION_ASSERTIONS = frozenset({"assertThrows", "assertDoesNotThrow"})
|
||||
|
||||
# JUnit 5 timeout assertions
|
||||
JUNIT5_TIMEOUT_ASSERTIONS = frozenset({"assertTimeout", "assertTimeoutPreemptively"})
|
||||
|
||||
# JUnit 5 grouping assertion
|
||||
JUNIT5_GROUP_ASSERTIONS = frozenset({"assertAll"})
|
||||
|
||||
# All JUnit 5 assertions
|
||||
JUNIT5_ALL_ASSERTIONS = (
|
||||
JUNIT5_VALUE_ASSERTIONS
|
||||
| JUNIT5_CONDITION_ASSERTIONS
|
||||
| JUNIT5_EXCEPTION_ASSERTIONS
|
||||
| JUNIT5_TIMEOUT_ASSERTIONS
|
||||
| JUNIT5_GROUP_ASSERTIONS
|
||||
)
|
||||
|
||||
# AssertJ terminal assertions (methods that end the chain)
|
||||
ASSERTJ_TERMINAL_METHODS = frozenset(
|
||||
{
|
||||
"isEqualTo",
|
||||
"isNotEqualTo",
|
||||
"isSameAs",
|
||||
"isNotSameAs",
|
||||
"isNull",
|
||||
"isNotNull",
|
||||
"isTrue",
|
||||
"isFalse",
|
||||
"isEmpty",
|
||||
"isNotEmpty",
|
||||
"isBlank",
|
||||
"isNotBlank",
|
||||
"contains",
|
||||
"containsOnly",
|
||||
"containsExactly",
|
||||
"containsExactlyInAnyOrder",
|
||||
"doesNotContain",
|
||||
"startsWith",
|
||||
"endsWith",
|
||||
"matches",
|
||||
"hasSize",
|
||||
"hasSizeBetween",
|
||||
"hasSizeGreaterThan",
|
||||
"hasSizeLessThan",
|
||||
"isGreaterThan",
|
||||
"isGreaterThanOrEqualTo",
|
||||
"isLessThan",
|
||||
"isLessThanOrEqualTo",
|
||||
"isBetween",
|
||||
"isCloseTo",
|
||||
"isPositive",
|
||||
"isNegative",
|
||||
"isZero",
|
||||
"isNotZero",
|
||||
"isInstanceOf",
|
||||
"isNotInstanceOf",
|
||||
"isIn",
|
||||
"isNotIn",
|
||||
"containsKey",
|
||||
"containsKeys",
|
||||
"containsValue",
|
||||
"containsValues",
|
||||
"containsEntry",
|
||||
"hasFieldOrPropertyWithValue",
|
||||
"extracting",
|
||||
"satisfies",
|
||||
"doesNotThrow",
|
||||
}
|
||||
)
|
||||
|
||||
# Hamcrest matcher methods
|
||||
HAMCREST_MATCHERS = frozenset(
|
||||
{
|
||||
"is",
|
||||
"equalTo",
|
||||
"not",
|
||||
"nullValue",
|
||||
"notNullValue",
|
||||
"hasItem",
|
||||
"hasItems",
|
||||
"hasSize",
|
||||
"containsString",
|
||||
"startsWith",
|
||||
"endsWith",
|
||||
"greaterThan",
|
||||
"lessThan",
|
||||
"closeTo",
|
||||
"instanceOf",
|
||||
"anything",
|
||||
"allOf",
|
||||
"anyOf",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TargetCall:
|
||||
"""Represents a method call that should be captured."""
|
||||
|
||||
receiver: str | None # 'calc', 'algorithms' (None for static)
|
||||
method_name: str
|
||||
arguments: str
|
||||
full_call: str # 'calc.fibonacci(10)'
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssertionMatch:
|
||||
"""Represents a matched assertion statement."""
|
||||
|
||||
start_pos: int
|
||||
end_pos: int
|
||||
statement_type: str # 'junit5', 'assertj', 'junit4', 'testng', 'hamcrest'
|
||||
assertion_method: str
|
||||
target_calls: list[TargetCall] = field(default_factory=list)
|
||||
leading_whitespace: str = ""
|
||||
original_text: str = ""
|
||||
is_exception_assertion: bool = False
|
||||
lambda_body: str | None = None # For assertThrows lambda content
|
||||
|
||||
|
||||
class JavaAssertTransformer:
|
||||
"""Transforms Java test code by removing assertions and preserving function calls.
|
||||
|
||||
This class uses tree-sitter for AST-based analysis and regex for text manipulation.
|
||||
It handles various Java testing frameworks including JUnit 5, JUnit 4, AssertJ,
|
||||
TestNG, Hamcrest, and Truth.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, function_name: str, qualified_name: str | None = None, analyzer: JavaAnalyzer | None = None
|
||||
) -> None:
|
||||
self.analyzer = analyzer or get_java_analyzer()
|
||||
self.func_name = function_name
|
||||
self.qualified_name = qualified_name or function_name
|
||||
self.invocation_counter = 0
|
||||
self._detected_framework: str | None = None
|
||||
|
||||
def transform(self, source: str) -> str:
|
||||
"""Remove assertions from source code, preserving target function calls.
|
||||
|
||||
Args:
|
||||
source: Java source code containing test assertions.
|
||||
|
||||
Returns:
|
||||
Transformed source with assertions replaced by captured function calls.
|
||||
|
||||
"""
|
||||
if not source or not source.strip():
|
||||
return source
|
||||
|
||||
# Detect framework from imports
|
||||
self._detected_framework = self._detect_framework(source)
|
||||
|
||||
# Find all assertion statements
|
||||
assertions = self._find_assertions(source)
|
||||
|
||||
if not assertions:
|
||||
return source
|
||||
|
||||
# Filter to only assertions that contain target calls
|
||||
assertions_with_targets = [a for a in assertions if a.target_calls or a.is_exception_assertion]
|
||||
|
||||
if not assertions_with_targets:
|
||||
return source
|
||||
|
||||
# Sort by position (forward order) to assign counter numbers in source order
|
||||
assertions_with_targets.sort(key=lambda a: a.start_pos)
|
||||
|
||||
# Filter out nested assertions (e.g., assertEquals inside assertAll)
|
||||
# An assertion is nested if it's completely contained within another assertion
|
||||
non_nested: list[AssertionMatch] = []
|
||||
for i, assertion in enumerate(assertions_with_targets):
|
||||
is_nested = False
|
||||
for j, other in enumerate(assertions_with_targets):
|
||||
if i != j:
|
||||
# Check if 'assertion' is nested inside 'other'
|
||||
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
|
||||
is_nested = True
|
||||
break
|
||||
if not is_nested:
|
||||
non_nested.append(assertion)
|
||||
|
||||
assertions_with_targets = non_nested
|
||||
|
||||
# Pre-compute all replacements with correct counter values
|
||||
replacements: list[tuple[int, int, str]] = []
|
||||
for assertion in assertions_with_targets:
|
||||
replacement = self._generate_replacement(assertion)
|
||||
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
|
||||
|
||||
# Apply replacements in reverse order to preserve positions
|
||||
result = source
|
||||
for start_pos, end_pos, replacement in reversed(replacements):
|
||||
result = result[:start_pos] + replacement + result[end_pos:]
|
||||
|
||||
return result
|
||||
|
||||
def _detect_framework(self, source: str) -> str:
|
||||
"""Detect which testing framework is being used from imports.
|
||||
|
||||
Checks more specific frameworks first (AssertJ, Hamcrest) before
|
||||
falling back to generic JUnit.
|
||||
"""
|
||||
imports = self.analyzer.find_imports(source)
|
||||
|
||||
# First pass: check for specific assertion libraries
|
||||
for imp in imports:
|
||||
path = imp.import_path.lower()
|
||||
if "org.assertj" in path:
|
||||
return "assertj"
|
||||
if "org.hamcrest" in path:
|
||||
return "hamcrest"
|
||||
if "com.google.common.truth" in path:
|
||||
return "truth"
|
||||
if "org.testng" in path:
|
||||
return "testng"
|
||||
|
||||
# Second pass: check for JUnit versions
|
||||
for imp in imports:
|
||||
path = imp.import_path.lower()
|
||||
if "org.junit.jupiter" in path or "junit.jupiter" in path:
|
||||
return "junit5"
|
||||
if "org.junit" in path:
|
||||
return "junit4"
|
||||
|
||||
# Default to JUnit 5 if no specific imports found
|
||||
return "junit5"
|
||||
|
||||
def _find_assertions(self, source: str) -> list[AssertionMatch]:
|
||||
"""Find all assertion statements in the source code."""
|
||||
assertions: list[AssertionMatch] = []
|
||||
|
||||
# Find JUnit-style assertions
|
||||
assertions.extend(self._find_junit_assertions(source))
|
||||
|
||||
# Find AssertJ/Truth-style fluent assertions
|
||||
assertions.extend(self._find_fluent_assertions(source))
|
||||
|
||||
# Find Hamcrest assertions
|
||||
assertions.extend(self._find_hamcrest_assertions(source))
|
||||
|
||||
return assertions
|
||||
|
||||
def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
|
||||
"""Find JUnit 4/5 and TestNG style assertions."""
|
||||
assertions: list[AssertionMatch] = []
|
||||
|
||||
# Pattern for JUnit assertions: (Assert.|Assertions.)?assertXxx(...)
|
||||
# This handles both static imports and qualified calls:
|
||||
# - assertEquals (static import)
|
||||
# - Assert.assertEquals (JUnit 4)
|
||||
# - Assertions.assertEquals (JUnit 5)
|
||||
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
|
||||
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
|
||||
|
||||
for match in pattern.finditer(source):
|
||||
leading_ws = match.group(1)
|
||||
full_method = match.group(2)
|
||||
assertion_method = match.group(3)
|
||||
|
||||
# Find the complete assertion statement (balanced parens)
|
||||
start_pos = match.start()
|
||||
paren_start = match.end() - 1 # Position of opening paren
|
||||
|
||||
args_content, end_pos = self._find_balanced_parens(source, paren_start)
|
||||
if args_content is None:
|
||||
continue
|
||||
|
||||
# Check for semicolon after closing paren
|
||||
while end_pos < len(source) and source[end_pos] in " \t\n\r":
|
||||
end_pos += 1
|
||||
if end_pos < len(source) and source[end_pos] == ";":
|
||||
end_pos += 1
|
||||
|
||||
# Extract target calls from the arguments
|
||||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
|
||||
|
||||
# For assertThrows, extract the lambda body
|
||||
lambda_body = None
|
||||
if is_exception and assertion_method == "assertThrows":
|
||||
lambda_body = self._extract_lambda_body(args_content)
|
||||
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
# Determine statement type based on detected framework
|
||||
detected = self._detected_framework or "junit5"
|
||||
if "jupiter" in detected or detected == "junit5":
|
||||
stmt_type = "junit5"
|
||||
else:
|
||||
stmt_type = detected
|
||||
|
||||
assertions.append(
|
||||
AssertionMatch(
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
statement_type=stmt_type,
|
||||
assertion_method=assertion_method,
|
||||
target_calls=target_calls,
|
||||
leading_whitespace=leading_ws,
|
||||
original_text=original_text,
|
||||
is_exception_assertion=is_exception,
|
||||
lambda_body=lambda_body,
|
||||
)
|
||||
)
|
||||
|
||||
return assertions
|
||||
|
||||
def _find_fluent_assertions(self, source: str) -> list[AssertionMatch]:
|
||||
"""Find AssertJ and Truth style fluent assertions (assertThat chains)."""
|
||||
assertions: list[AssertionMatch] = []
|
||||
|
||||
# Pattern for fluent assertions: assertThat(...).<chain>
|
||||
# Handles both org.assertj and com.google.common.truth
|
||||
pattern = re.compile(r"(\s*)((?:Assertions?\.)?assertThat)\s*\(", re.MULTILINE)
|
||||
|
||||
for match in pattern.finditer(source):
|
||||
leading_ws = match.group(1)
|
||||
start_pos = match.start()
|
||||
paren_start = match.end() - 1
|
||||
|
||||
# Find assertThat(...) content
|
||||
args_content, after_paren = self._find_balanced_parens(source, paren_start)
|
||||
if args_content is None:
|
||||
continue
|
||||
|
||||
# Find the assertion chain (e.g., .isEqualTo(5).hasSize(3))
|
||||
chain_end = self._find_fluent_chain_end(source, after_paren)
|
||||
if chain_end == after_paren:
|
||||
# No chain found, skip
|
||||
continue
|
||||
|
||||
# Check for semicolon
|
||||
end_pos = chain_end
|
||||
while end_pos < len(source) and source[end_pos] in " \t\n\r":
|
||||
end_pos += 1
|
||||
if end_pos < len(source) and source[end_pos] == ";":
|
||||
end_pos += 1
|
||||
|
||||
# Extract target calls from assertThat argument
|
||||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
# Determine statement type based on detected framework
|
||||
detected = self._detected_framework or "assertj"
|
||||
stmt_type = "assertj" if "assertj" in detected else "truth"
|
||||
|
||||
assertions.append(
|
||||
AssertionMatch(
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
statement_type=stmt_type,
|
||||
assertion_method="assertThat",
|
||||
target_calls=target_calls,
|
||||
leading_whitespace=leading_ws,
|
||||
original_text=original_text,
|
||||
)
|
||||
)
|
||||
|
||||
return assertions
|
||||
|
||||
def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]:
|
||||
"""Find Hamcrest style assertions: assertThat(actual, matcher)."""
|
||||
assertions: list[AssertionMatch] = []
|
||||
|
||||
if self._detected_framework != "hamcrest":
|
||||
return assertions
|
||||
|
||||
# Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher)
|
||||
pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)
|
||||
|
||||
for match in pattern.finditer(source):
|
||||
leading_ws = match.group(1)
|
||||
start_pos = match.start()
|
||||
paren_start = match.end() - 1
|
||||
|
||||
args_content, end_pos = self._find_balanced_parens(source, paren_start)
|
||||
if args_content is None:
|
||||
continue
|
||||
|
||||
# Check for semicolon
|
||||
while end_pos < len(source) and source[end_pos] in " \t\n\r":
|
||||
end_pos += 1
|
||||
if end_pos < len(source) and source[end_pos] == ";":
|
||||
end_pos += 1
|
||||
|
||||
# For Hamcrest, the first arg (or second if reason given) is the actual value
|
||||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
assertions.append(
|
||||
AssertionMatch(
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
statement_type="hamcrest",
|
||||
assertion_method="assertThat",
|
||||
target_calls=target_calls,
|
||||
leading_whitespace=leading_ws,
|
||||
original_text=original_text,
|
||||
)
|
||||
)
|
||||
|
||||
return assertions
|
||||
|
||||
def _find_fluent_chain_end(self, source: str, start_pos: int) -> int:
|
||||
"""Find the end of a fluent assertion chain."""
|
||||
pos = start_pos
|
||||
|
||||
while pos < len(source):
|
||||
# Skip whitespace
|
||||
while pos < len(source) and source[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
if pos >= len(source) or source[pos] != ".":
|
||||
break
|
||||
|
||||
pos += 1 # Skip dot
|
||||
|
||||
# Skip whitespace after dot
|
||||
while pos < len(source) and source[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
# Read method name
|
||||
method_start = pos
|
||||
while pos < len(source) and (source[pos].isalnum() or source[pos] == "_"):
|
||||
pos += 1
|
||||
|
||||
if pos == method_start:
|
||||
break
|
||||
|
||||
method_name = source[method_start:pos]
|
||||
|
||||
# Skip whitespace before potential parens
|
||||
while pos < len(source) and source[pos] in " \t\n\r":
|
||||
pos += 1
|
||||
|
||||
# Check for parentheses
|
||||
if pos < len(source) and source[pos] == "(":
|
||||
_, new_pos = self._find_balanced_parens(source, pos)
|
||||
if new_pos == -1:
|
||||
break
|
||||
pos = new_pos
|
||||
|
||||
# Check if this is a terminal assertion method
|
||||
if method_name in ASSERTJ_TERMINAL_METHODS:
|
||||
# Continue looking for chained assertions
|
||||
continue
|
||||
|
||||
return pos
|
||||
|
||||
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
|
||||
"""Extract calls to the target function from assertion arguments."""
|
||||
target_calls: list[TargetCall] = []
|
||||
|
||||
# Pattern to match method calls: (receiver.)?func_name(args)
|
||||
# Handles: obj.method(args), ClassName.staticMethod(args), method(args)
|
||||
pattern = re.compile(rf"((?:[a-zA-Z_]\w*\.)*)?({re.escape(self.func_name)})\s*\(", re.MULTILINE)
|
||||
|
||||
for match in pattern.finditer(content):
|
||||
receiver_prefix = match.group(1) or ""
|
||||
receiver = receiver_prefix.rstrip(".") if receiver_prefix else None
|
||||
method_name = match.group(2)
|
||||
|
||||
# Find the arguments
|
||||
paren_pos = match.end() - 1
|
||||
args_content, end_pos = self._find_balanced_parens(content, paren_pos)
|
||||
if args_content is None:
|
||||
continue
|
||||
|
||||
full_call = content[match.start() : end_pos]
|
||||
|
||||
target_calls.append(
|
||||
TargetCall(
|
||||
receiver=receiver,
|
||||
method_name=method_name,
|
||||
arguments=args_content,
|
||||
full_call=full_call,
|
||||
start_pos=base_offset + match.start(),
|
||||
end_pos=base_offset + end_pos,
|
||||
)
|
||||
)
|
||||
|
||||
return target_calls
|
||||
|
||||
def _extract_lambda_body(self, content: str) -> str | None:
|
||||
"""Extract the body of a lambda expression from assertThrows arguments.
|
||||
|
||||
For assertThrows(Exception.class, () -> code()), we want to extract 'code()'.
|
||||
For assertThrows(Exception.class, () -> { code(); }), we want 'code();'.
|
||||
"""
|
||||
# Look for lambda: () -> expr or () -> { block }
|
||||
lambda_match = re.search(r"\(\s*\)\s*->\s*", content)
|
||||
if not lambda_match:
|
||||
return None
|
||||
|
||||
body_start = lambda_match.end()
|
||||
remaining = content[body_start:].strip()
|
||||
|
||||
if remaining.startswith("{"):
|
||||
# Block lambda: () -> { code }
|
||||
_, block_end = self._find_balanced_braces(content, body_start + content[body_start:].index("{"))
|
||||
if block_end != -1:
|
||||
# Extract content inside braces
|
||||
brace_content = content[body_start + content[body_start:].index("{") + 1 : block_end - 1]
|
||||
return brace_content.strip()
|
||||
else:
|
||||
# Expression lambda: () -> expr
|
||||
# Find the end (before the closing paren of assertThrows)
|
||||
depth = 0
|
||||
end = body_start
|
||||
for i, ch in enumerate(content[body_start:]):
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
elif ch == ")":
|
||||
if depth == 0:
|
||||
end = body_start + i
|
||||
break
|
||||
depth -= 1
|
||||
return content[body_start:end].strip()
|
||||
|
||||
return None
|
||||
|
||||
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
|
||||
"""Find content within balanced parentheses.
|
||||
|
||||
Args:
|
||||
code: The source code.
|
||||
open_paren_pos: Position of the opening parenthesis.
|
||||
|
||||
Returns:
|
||||
Tuple of (content inside parens, position after closing paren) or (None, -1).
|
||||
|
||||
"""
|
||||
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
|
||||
return None, -1
|
||||
|
||||
depth = 1
|
||||
pos = open_paren_pos + 1
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_char = False
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
char = code[pos]
|
||||
prev_char = code[pos - 1] if pos > 0 else ""
|
||||
|
||||
# Handle character literals
|
||||
if char == "'" and not in_string and prev_char != "\\":
|
||||
in_char = not in_char
|
||||
# Handle string literals (double quotes)
|
||||
elif char == '"' and not in_char and prev_char != "\\":
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string and not in_char:
|
||||
if char == "(":
|
||||
depth += 1
|
||||
elif char == ")":
|
||||
depth -= 1
|
||||
|
||||
pos += 1
|
||||
|
||||
if depth != 0:
|
||||
return None, -1
|
||||
|
||||
return code[open_paren_pos + 1 : pos - 1], pos
|
||||
|
||||
def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | None, int]:
|
||||
"""Find content within balanced braces."""
|
||||
if open_brace_pos >= len(code) or code[open_brace_pos] != "{":
|
||||
return None, -1
|
||||
|
||||
depth = 1
|
||||
pos = open_brace_pos + 1
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_char = False
|
||||
|
||||
while pos < len(code) and depth > 0:
|
||||
char = code[pos]
|
||||
prev_char = code[pos - 1] if pos > 0 else ""
|
||||
|
||||
if char == "'" and not in_string and prev_char != "\\":
|
||||
in_char = not in_char
|
||||
elif char == '"' and not in_char and prev_char != "\\":
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string and not in_char:
|
||||
if char == "{":
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
|
||||
pos += 1
|
||||
|
||||
if depth != 0:
|
||||
return None, -1
|
||||
|
||||
return code[open_brace_pos + 1 : pos - 1], pos
|
||||
|
||||
def _generate_replacement(self, assertion: AssertionMatch) -> str:
|
||||
"""Generate replacement code for an assertion.
|
||||
|
||||
The replacement captures target function return values and removes assertions.
|
||||
|
||||
Args:
|
||||
assertion: The assertion to replace.
|
||||
|
||||
Returns:
|
||||
Replacement code string.
|
||||
|
||||
"""
|
||||
if assertion.is_exception_assertion:
|
||||
return self._generate_exception_replacement(assertion)
|
||||
|
||||
if not assertion.target_calls:
|
||||
# No target calls found, just comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertion: no target calls found"
|
||||
|
||||
# Generate capture statements for each target call
|
||||
replacements = []
|
||||
# For the first replacement, use the full leading whitespace
|
||||
# For subsequent ones, strip leading newlines to avoid extra blank lines
|
||||
base_indent = assertion.leading_whitespace.lstrip("\n\r")
|
||||
for i, call in enumerate(assertion.target_calls):
|
||||
self.invocation_counter += 1
|
||||
var_name = f"_cf_result{self.invocation_counter}"
|
||||
if i == 0:
|
||||
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
|
||||
else:
|
||||
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
|
||||
|
||||
return "\n".join(replacements)
|
||||
|
||||
def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
|
||||
"""Generate replacement for assertThrows/assertDoesNotThrow.
|
||||
|
||||
Transforms:
|
||||
assertThrows(Exception.class, () -> calculator.divide(1, 0));
|
||||
To:
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
|
||||
"""
|
||||
self.invocation_counter += 1
|
||||
|
||||
if assertion.lambda_body:
|
||||
# Extract the actual code from the lambda
|
||||
code_to_run = assertion.lambda_body
|
||||
if not code_to_run.endswith(";"):
|
||||
code_to_run += ";"
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
|
||||
# If no lambda body found, try to extract from target calls
|
||||
if assertion.target_calls:
|
||||
call = assertion.target_calls[0]
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
|
||||
# Fallback: comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
|
||||
|
||||
|
||||
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
|
||||
"""Transform Java test code by removing assertions and capturing function calls.
|
||||
|
||||
This is the main entry point for Java assertion transformation.
|
||||
|
||||
Args:
|
||||
source: The Java test source code.
|
||||
function_name: Name of the function being tested.
|
||||
qualified_name: Optional fully qualified name of the function.
|
||||
|
||||
Returns:
|
||||
Transformed source code with assertions replaced by capture statements.
|
||||
|
||||
"""
|
||||
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name)
|
||||
return transformer.transform(source)
|
||||
|
||||
|
||||
def remove_assertions_from_test(source: str, target_function: FunctionToOptimize) -> str:
|
||||
"""Remove assertions from test code for the given target function.
|
||||
|
||||
This is a convenience wrapper around transform_java_assertions that
|
||||
takes a FunctionToOptimize object.
|
||||
|
||||
Args:
|
||||
source: The Java test source code.
|
||||
target_function: The function being optimized.
|
||||
|
||||
Returns:
|
||||
Transformed source code.
|
||||
|
||||
"""
|
||||
return transform_java_assertions(
|
||||
source=source, function_name=target_function.function_name, qualified_name=target_function.qualified_name
|
||||
)
|
||||
964
tests/test_java_assertion_removal.py
Normal file
964
tests/test_java_assertion_removal.py
Normal file
|
|
@ -0,0 +1,964 @@
|
|||
"""Tests for Java assertion removal transformer.
|
||||
|
||||
This test suite covers the transformation of Java test assertions into
|
||||
regression test code that captures function return values.
|
||||
|
||||
All tests assert for full string equality, no substring matching.
|
||||
"""
|
||||
|
||||
from codeflash.languages.java.remove_asserts import (
|
||||
JavaAssertTransformer,
|
||||
transform_java_assertions,
|
||||
)
|
||||
|
||||
|
||||
class TestBasicJUnit5Assertions:
|
||||
"""Tests for basic JUnit 5 assertion transformations."""
|
||||
|
||||
def test_assert_equals_basic(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_equals_with_message(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
assertEquals(55, calculator.fibonacci(10), "Fibonacci of 10 should be 55");
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_true(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testIsValid() {
|
||||
assertTrue(validator.isValid("test"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testIsValid() {
|
||||
Object _cf_result1 = validator.isValid("test");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "isValid")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_false(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testIsInvalid() {
|
||||
assertFalse(validator.isValid(""));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testIsInvalid() {
|
||||
Object _cf_result1 = validator.isValid("");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "isValid")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_null(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testGetNull() {
|
||||
assertNull(processor.getValue(null));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testGetNull() {
|
||||
Object _cf_result1 = processor.getValue(null);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getValue")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_not_null(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testGetValue() {
|
||||
assertNotNull(processor.getValue("key"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testGetValue() {
|
||||
Object _cf_result1 = processor.getValue("key");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getValue")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_not_equals(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDifferent() {
|
||||
assertNotEquals(0, calculator.add(1, 2));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDifferent() {
|
||||
Object _cf_result1 = calculator.add(1, 2);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_same(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testSame() {
|
||||
assertSame(expected, factory.getInstance());
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testSame() {
|
||||
Object _cf_result1 = factory.getInstance();
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getInstance")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_array_equals(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testSort() {
|
||||
assertArrayEquals(expected, sorter.sort(input));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testSort() {
|
||||
Object _cf_result1 = sorter.sort(input);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "sort")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestJUnit5PrefixedAssertions:
|
||||
"""Tests for JUnit 5 assertions with Assertions. prefix."""
|
||||
|
||||
def test_assertions_prefix(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Assertions.assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_prefix(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
Assert.assertEquals(5, calculator.add(2, 3));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
Object _cf_result1 = calculator.add(2, 3);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestJUnit5ExceptionAssertions:
|
||||
"""Tests for JUnit 5 exception assertions."""
|
||||
|
||||
def test_assert_throws_lambda(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
assertThrows(IllegalArgumentException.class, () -> calculator.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_block_lambda(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
assertThrows(ArithmeticException.class, () -> {
|
||||
calculator.divide(1, 0);
|
||||
});
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_does_not_throw(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testValidDivision() {
|
||||
assertDoesNotThrow(() -> calculator.divide(10, 2));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testValidDivision() {
|
||||
try { calculator.divide(10, 2); } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestStaticMethodCalls:
|
||||
"""Tests for static method call handling."""
|
||||
|
||||
def test_static_method_call(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testQuickAdd() {
|
||||
assertEquals(15.0, Calculator.quickAdd(10.0, 5.0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testQuickAdd() {
|
||||
Object _cf_result1 = Calculator.quickAdd(10.0, 5.0);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "quickAdd")
|
||||
assert result == expected
|
||||
|
||||
def test_static_method_fully_qualified(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testReverse() {
|
||||
assertEquals("olleh", com.example.StringUtils.reverse("hello"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testReverse() {
|
||||
Object _cf_result1 = com.example.StringUtils.reverse("hello");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "reverse")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestMultipleAssertions:
|
||||
"""Tests for multiple assertions in a single test method."""
|
||||
|
||||
def test_multiple_assertions_same_function(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFibonacciSequence() {
|
||||
assertEquals(0, calculator.fibonacci(0));
|
||||
assertEquals(1, calculator.fibonacci(1));
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFibonacciSequence() {
|
||||
Object _cf_result1 = calculator.fibonacci(0);
|
||||
Object _cf_result2 = calculator.fibonacci(1);
|
||||
Object _cf_result3 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_assertions_different_functions(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testCalculator() {
|
||||
assertEquals(5, calculator.add(2, 3));
|
||||
assertEquals(6, calculator.multiply(2, 3));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testCalculator() {
|
||||
Object _cf_result1 = calculator.add(2, 3);
|
||||
assertEquals(6, calculator.multiply(2, 3));
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAssertJFluentAssertions:
|
||||
"""Tests for AssertJ fluent assertion transformations."""
|
||||
|
||||
def test_assertj_basic(self):
|
||||
source = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
assertThat(calculator.fibonacci(10)).isEqualTo(55);
|
||||
}"""
|
||||
expected = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assertj_chained(self):
|
||||
source = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetList() {
|
||||
assertThat(processor.getList()).hasSize(5).contains("a", "b");
|
||||
}"""
|
||||
expected = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetList() {
|
||||
Object _cf_result1 = processor.getList();
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getList")
|
||||
assert result == expected
|
||||
|
||||
def test_assertj_is_null(self):
|
||||
source = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetNull() {
|
||||
assertThat(processor.getValue(null)).isNull();
|
||||
}"""
|
||||
expected = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetNull() {
|
||||
Object _cf_result1 = processor.getValue(null);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getValue")
|
||||
assert result == expected
|
||||
|
||||
def test_assertj_is_not_empty(self):
|
||||
source = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetList() {
|
||||
assertThat(processor.getList()).isNotEmpty();
|
||||
}"""
|
||||
expected = """\
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
@Test
|
||||
void testGetList() {
|
||||
Object _cf_result1 = processor.getList();
|
||||
}"""
|
||||
result = transform_java_assertions(source, "getList")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestNestedMethodCalls:
|
||||
"""Tests for nested method calls in assertions."""
|
||||
|
||||
def test_nested_call_in_expected(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testCompare() {
|
||||
assertEquals(helper.getExpected(), calculator.compute(5));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testCompare() {
|
||||
Object _cf_result1 = calculator.compute(5);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "compute")
|
||||
assert result == expected
|
||||
|
||||
def test_nested_call_as_argument(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testProcess() {
|
||||
assertEquals(expected, processor.process(helper.getData()));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testProcess() {
|
||||
Object _cf_result1 = processor.process(helper.getData());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "process")
|
||||
assert result == expected
|
||||
|
||||
def test_deeply_nested(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDeep() {
|
||||
assertEquals(expected, outer.process(inner.compute(calculator.fibonacci(5))));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDeep() {
|
||||
Object _cf_result1 = calculator.fibonacci(5);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestWhitespacePreservation:
|
||||
"""Tests for whitespace and indentation preservation."""
|
||||
|
||||
def test_preserves_indentation(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_multiline_assertion(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testLongAssertion() {
|
||||
assertEquals(
|
||||
expectedValue,
|
||||
calculator.computeComplexResult(
|
||||
arg1,
|
||||
arg2,
|
||||
arg3
|
||||
)
|
||||
);
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testLongAssertion() {
|
||||
Object _cf_result1 = calculator.computeComplexResult(
|
||||
arg1,
|
||||
arg2,
|
||||
arg3
|
||||
);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "computeComplexResult")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestStringsWithSpecialCharacters:
|
||||
"""Tests for strings containing special characters."""
|
||||
|
||||
def test_string_with_parentheses(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testFormat() {
|
||||
assertEquals("hello (world)", formatter.format("hello", "world"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testFormat() {
|
||||
Object _cf_result1 = formatter.format("hello", "world");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "format")
|
||||
assert result == expected
|
||||
|
||||
def test_string_with_quotes(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testEscape() {
|
||||
assertEquals("hello \\"world\\"", formatter.escape("hello \\"world\\""));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testEscape() {
|
||||
Object _cf_result1 = formatter.escape("hello \\"world\\"");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "escape")
|
||||
assert result == expected
|
||||
|
||||
def test_string_with_newlines(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testMultiline() {
|
||||
assertEquals("line1\\nline2", processor.join("line1", "line2"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testMultiline() {
|
||||
Object _cf_result1 = processor.join("line1", "line2");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "join")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestNonAssertionCodePreservation:
|
||||
"""Tests that non-assertion code is preserved unchanged."""
|
||||
|
||||
def test_setup_code_preserved(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithSetup() {
|
||||
Calculator calc = new Calculator(2);
|
||||
int input = 10;
|
||||
assertEquals(55, calc.fibonacci(input));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testWithSetup() {
|
||||
Calculator calc = new Calculator(2);
|
||||
int input = 10;
|
||||
Object _cf_result1 = calc.fibonacci(input);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_other_method_calls_preserved(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithHelper() {
|
||||
helper.setup();
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
helper.cleanup();
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testWithHelper() {
|
||||
helper.setup();
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
helper.cleanup();
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_variable_declarations_preserved(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithVariables() {
|
||||
int expected = 55;
|
||||
int actual = calculator.fibonacci(10);
|
||||
assertEquals(expected, actual);
|
||||
}"""
|
||||
# fibonacci is assigned to 'actual', not in the assertion - no transformation
|
||||
expected = source
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestParameterizedTests:
|
||||
"""Tests for parameterized test handling."""
|
||||
|
||||
def test_parameterized_test(self):
|
||||
source = """\
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"0, 0",
|
||||
"1, 1",
|
||||
"10, 55"
|
||||
})
|
||||
void testFibonacciSequence(int n, long expected) {
|
||||
assertEquals(expected, calculator.fibonacci(n));
|
||||
}"""
|
||||
expected = """\
|
||||
@ParameterizedTest
|
||||
@CsvSource({
|
||||
"0, 0",
|
||||
"1, 1",
|
||||
"10, 55"
|
||||
})
|
||||
void testFibonacciSequence(int n, long expected) {
|
||||
Object _cf_result1 = calculator.fibonacci(n);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestNestedTestClasses:
|
||||
"""Tests for nested test class handling."""
|
||||
|
||||
def test_nested_class(self):
|
||||
source = """\
|
||||
@Nested
|
||||
@DisplayName("Fibonacci Tests")
|
||||
class FibonacciTests {
|
||||
@Test
|
||||
void testBasic() {
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}
|
||||
}"""
|
||||
expected = """\
|
||||
@Nested
|
||||
@DisplayName("Fibonacci Tests")
|
||||
class FibonacciTests {
|
||||
@Test
|
||||
void testBasic() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestMockitoPreservation:
|
||||
"""Tests that Mockito code is not modified."""
|
||||
|
||||
def test_mockito_when_preserved(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithMock() {
|
||||
when(mockService.getData()).thenReturn("test");
|
||||
assertEquals("test", processor.process(mockService));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testWithMock() {
|
||||
when(mockService.getData()).thenReturn("test");
|
||||
Object _cf_result1 = processor.process(mockService);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "process")
|
||||
assert result == expected
|
||||
|
||||
def test_mockito_verify_preserved(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithVerify() {
|
||||
processor.process(mockService);
|
||||
verify(mockService).getData();
|
||||
}"""
|
||||
# No assertions to transform, source unchanged
|
||||
expected = source
|
||||
result = transform_java_assertions(source, "process")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Tests for edge cases and boundary conditions."""
|
||||
|
||||
def test_empty_source(self):
|
||||
result = transform_java_assertions("", "fibonacci")
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_only(self):
|
||||
source = " \n\t "
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == source
|
||||
|
||||
def test_no_assertions(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testNoAssertions() {
|
||||
calculator.fibonacci(10);
|
||||
}"""
|
||||
expected = source
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assertion_without_target_function(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testOther() {
|
||||
assertEquals(5, helper.compute(3));
|
||||
}"""
|
||||
# No transformation since target function is not in the assertion
|
||||
expected = source
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_function_name_in_string(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testWithStringContainingFunctionName() {
|
||||
assertEquals("fibonacci(10) = 55", formatter.format("fibonacci", 10, 55));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testWithStringContainingFunctionName() {
|
||||
Object _cf_result1 = formatter.format("fibonacci", 10, 55);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "format")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestJUnit4Compatibility:
|
||||
"""Tests for JUnit 4 style assertions."""
|
||||
|
||||
def test_junit4_assert_equals(self):
|
||||
source = """\
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Test
|
||||
public void testFibonacci() {
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
@Test
|
||||
public void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_junit4_with_message_first(self):
|
||||
source = """\
|
||||
@Test
|
||||
public void testFibonacci() {
|
||||
assertEquals("Should be 55", 55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
public void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAssertAll:
|
||||
"""Tests for assertAll grouped assertions."""
|
||||
|
||||
def test_assert_all_basic(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testMultiple() {
|
||||
assertAll(
|
||||
() -> assertEquals(0, calculator.fibonacci(0)),
|
||||
() -> assertEquals(1, calculator.fibonacci(1)),
|
||||
() -> assertEquals(55, calculator.fibonacci(10))
|
||||
);
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testMultiple() {
|
||||
Object _cf_result1 = calculator.fibonacci(0);
|
||||
Object _cf_result2 = calculator.fibonacci(1);
|
||||
Object _cf_result3 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestTransformerClass:
|
||||
"""Tests for the JavaAssertTransformer class directly."""
|
||||
|
||||
def test_invocation_counter_increments(self):
|
||||
transformer = JavaAssertTransformer("fibonacci")
|
||||
source = """\
|
||||
@Test
|
||||
void test() {
|
||||
assertEquals(0, calc.fibonacci(0));
|
||||
assertEquals(1, calc.fibonacci(1));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = calc.fibonacci(0);
|
||||
Object _cf_result2 = calc.fibonacci(1);
|
||||
}"""
|
||||
result = transformer.transform(source)
|
||||
assert result == expected
|
||||
assert transformer.invocation_counter == 2
|
||||
|
||||
def test_qualified_name_support(self):
|
||||
transformer = JavaAssertTransformer(
|
||||
function_name="fibonacci",
|
||||
qualified_name="com.example.Calculator.fibonacci",
|
||||
)
|
||||
assert transformer.qualified_name == "com.example.Calculator.fibonacci"
|
||||
|
||||
def test_custom_analyzer(self):
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
transformer = JavaAssertTransformer("fibonacci", analyzer=analyzer)
|
||||
assert transformer.analyzer is analyzer
|
||||
|
||||
|
||||
class TestImportDetection:
|
||||
"""Tests for framework detection from imports."""
|
||||
|
||||
def test_detect_junit5(self):
|
||||
source = """\
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;"""
|
||||
transformer = JavaAssertTransformer("test")
|
||||
transformer._detected_framework = transformer._detect_framework(source)
|
||||
assert transformer._detected_framework == "junit5"
|
||||
|
||||
def test_detect_assertj(self):
|
||||
source = """\
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.assertThat;"""
|
||||
transformer = JavaAssertTransformer("test")
|
||||
transformer._detected_framework = transformer._detect_framework(source)
|
||||
assert transformer._detected_framework == "assertj"
|
||||
|
||||
def test_detect_testng(self):
|
||||
source = """\
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;"""
|
||||
transformer = JavaAssertTransformer("test")
|
||||
transformer._detected_framework = transformer._detect_framework(source)
|
||||
assert transformer._detected_framework == "testng"
|
||||
|
||||
def test_detect_hamcrest(self):
|
||||
source = """\
|
||||
import org.junit.Test;
|
||||
import static org.hamcrest.MatcherAssert.assertThat;
|
||||
import static org.hamcrest.Matchers.*;"""
|
||||
transformer = JavaAssertTransformer("test")
|
||||
transformer._detected_framework = transformer._detect_framework(source)
|
||||
assert transformer._detected_framework == "hamcrest"
|
||||
|
||||
|
||||
class TestInstrumentGeneratedJavaTest:
|
||||
"""Tests for the instrument_generated_java_test integration."""
|
||||
|
||||
def test_behavior_mode_removes_assertions(self):
|
||||
from codeflash.languages.java.instrumentation import instrument_generated_java_test
|
||||
|
||||
test_code = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class FibonacciTest {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(55, calc.fibonacci(10));
|
||||
}
|
||||
}"""
|
||||
expected = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class FibonacciTest__perfinstrumented {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Calculator calc = new Calculator();
|
||||
Object _cf_result1 = calc.fibonacci(10);
|
||||
}
|
||||
}"""
|
||||
result = instrument_generated_java_test(
|
||||
test_code=test_code,
|
||||
function_name="fibonacci",
|
||||
qualified_name="com.example.Calculator.fibonacci",
|
||||
mode="behavior",
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_behavior_mode_with_assertj(self):
|
||||
from codeflash.languages.java.instrumentation import instrument_generated_java_test
|
||||
|
||||
test_code = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class StringUtilsTest {
|
||||
@Test
|
||||
void testReverse() {
|
||||
assertThat(StringUtils.reverse("hello")).isEqualTo("olleh");
|
||||
}
|
||||
}"""
|
||||
expected = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class StringUtilsTest__perfinstrumented {
|
||||
@Test
|
||||
void testReverse() {
|
||||
Object _cf_result1 = StringUtils.reverse("hello");
|
||||
}
|
||||
}"""
|
||||
result = instrument_generated_java_test(
|
||||
test_code=test_code,
|
||||
function_name="reverse",
|
||||
qualified_name="com.example.StringUtils.reverse",
|
||||
mode="behavior",
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestComplexRealWorldExamples:
|
||||
"""Tests based on real-world test patterns."""
|
||||
|
||||
def test_calculator_test_pattern(self):
|
||||
source = """\
|
||||
@Test
|
||||
@DisplayName("should calculate compound interest for basic case")
|
||||
void testBasicCompoundInterest() {
|
||||
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
|
||||
assertNotNull(result);
|
||||
assertTrue(result.contains("."));
|
||||
}"""
|
||||
# assertNotNull(result) and assertTrue(result.contains(".")) don't contain the target function
|
||||
# so they remain unchanged, and the variable assignment is also preserved
|
||||
expected = source
|
||||
result = transform_java_assertions(source, "calculateCompoundInterest")
|
||||
assert result == expected
|
||||
|
||||
def test_string_utils_pattern(self):
|
||||
source = """\
|
||||
@Test
|
||||
@DisplayName("should reverse a simple string")
|
||||
void testReverseSimple() {
|
||||
assertEquals("olleh", StringUtils.reverse("hello"));
|
||||
assertEquals("dlrow", StringUtils.reverse("world"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
@DisplayName("should reverse a simple string")
|
||||
void testReverseSimple() {
|
||||
Object _cf_result1 = StringUtils.reverse("hello");
|
||||
Object _cf_result2 = StringUtils.reverse("world");
|
||||
}"""
|
||||
result = transform_java_assertions(source, "reverse")
|
||||
assert result == expected
|
||||
|
||||
def test_with_before_each_setup(self):
|
||||
source = """\
|
||||
private Calculator calculator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
calculator = new Calculator(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
assertEquals(55, calculator.fibonacci(10));
|
||||
}"""
|
||||
expected = """\
|
||||
private Calculator calculator;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
calculator = new Calculator(2);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
Loading…
Reference in a new issue