feat: implement Java assertion removal transformer

Add a robust Java assert removal transformer to convert generated unit
tests into regression tests. This removes assertion statements while
preserving function calls, enabling behavioral verification by comparing
outputs between original and optimized code.

Key features:
- Support for JUnit 5 assertions (assertEquals, assertTrue, assertThrows, etc.)
- Support for JUnit 4 assertions (org.junit.Assert.*)
- Support for AssertJ fluent assertions (assertThat().isEqualTo())
- Support for TestNG and Hamcrest assertions
- Framework auto-detection from imports
- Handles assertAll grouped assertions
- Preserves non-assertion code (setup, Mockito mocks, etc.)
- 57 comprehensive tests with exact string equality assertions

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Saurabh Misra 2026-02-03 09:02:25 +00:00
parent c587c47521
commit 31c90f0391
4 changed files with 1815 additions and 100 deletions

View file

@ -21,10 +21,7 @@ from codeflash.languages.java.build_tools import (
install_codeflash_runtime,
run_maven_tests,
)
from codeflash.languages.java.comparator import (
compare_invocations_directly,
compare_test_results,
)
from codeflash.languages.java.comparator import compare_invocations_directly, compare_test_results
from codeflash.languages.java.config import (
JavaProjectConfig,
detect_java_project,
@ -46,12 +43,7 @@ from codeflash.languages.java.discovery import (
get_class_methods,
get_method_by_name,
)
from codeflash.languages.java.formatter import (
JavaFormatter,
format_java_code,
format_java_file,
normalize_java_code,
)
from codeflash.languages.java.formatter import JavaFormatter, format_java_code, format_java_file, normalize_java_code
from codeflash.languages.java.import_resolver import (
JavaImportResolver,
ResolvedImport,
@ -63,6 +55,7 @@ from codeflash.languages.java.instrumentation import (
instrument_existing_test,
instrument_for_behavior,
instrument_for_benchmarking,
instrument_generated_java_test,
remove_instrumentation,
)
from codeflash.languages.java.parser import (
@ -73,6 +66,11 @@ from codeflash.languages.java.parser import (
JavaMethodNode,
get_java_analyzer,
)
from codeflash.languages.java.remove_asserts import (
JavaAssertTransformer,
remove_assertions_from_test,
transform_java_assertions,
)
from codeflash.languages.java.replacement import (
add_runtime_comments,
insert_method,
@ -81,10 +79,7 @@ from codeflash.languages.java.replacement import (
replace_function,
replace_method_body,
)
from codeflash.languages.java.support import (
JavaSupport,
get_java_support,
)
from codeflash.languages.java.support import JavaSupport, get_java_support
from codeflash.languages.java.test_discovery import (
build_test_mapping_for_project,
discover_all_tests,
@ -106,90 +101,95 @@ from codeflash.languages.java.test_runner import (
)
__all__ = [
# Parser
"JavaAnalyzer",
"JavaClassNode",
"JavaFieldInfo",
"JavaImportInfo",
"JavaMethodNode",
"get_java_analyzer",
# Build tools
"BuildTool",
# Parser
"JavaAnalyzer",
# Assertion removal
"JavaAssertTransformer",
"JavaClassNode",
"JavaFieldInfo",
# Formatter
"JavaFormatter",
"JavaImportInfo",
# Import resolver
"JavaImportResolver",
"JavaMethodNode",
# Config
"JavaProjectConfig",
"JavaProjectInfo",
# Support
"JavaSupport",
# Test runner
"JavaTestRunResult",
"MavenTestResult",
"ResolvedImport",
"add_codeflash_dependency_to_pom",
"compile_maven_project",
"detect_build_tool",
"find_gradle_executable",
"find_maven_executable",
"find_source_root",
"find_test_root",
"get_classpath",
"get_project_info",
"install_codeflash_runtime",
"run_maven_tests",
# Replacement
"add_runtime_comments",
# Test discovery
"build_test_mapping_for_project",
# Comparator
"compare_invocations_directly",
"compare_test_results",
# Config
"JavaProjectConfig",
"compile_maven_project",
# Instrumentation
"create_benchmark_test",
"detect_build_tool",
"detect_java_project",
"get_test_class_pattern",
"get_test_file_pattern",
"is_java_project",
"discover_all_tests",
# Discovery
"discover_functions",
"discover_functions_from_source",
"discover_test_methods",
"discover_tests",
# Context
"extract_class_context",
"extract_code_context",
"extract_function_source",
"extract_read_only_context",
"find_gradle_executable",
"find_helper_files",
"find_helper_functions",
# Discovery
"discover_functions",
"discover_functions_from_source",
"discover_test_methods",
"get_class_methods",
"get_method_by_name",
# Formatter
"JavaFormatter",
"find_maven_executable",
"find_source_root",
"find_test_root",
"find_tests_for_function",
"format_java_code",
"format_java_file",
"normalize_java_code",
# Import resolver
"JavaImportResolver",
"ResolvedImport",
"find_helper_files",
"resolve_imports_for_file",
# Instrumentation
"create_benchmark_test",
"get_class_methods",
"get_classpath",
"get_java_analyzer",
"get_java_support",
"get_method_by_name",
"get_project_info",
"get_test_class_for_source_class",
"get_test_class_pattern",
"get_test_file_pattern",
"get_test_file_suffix",
"get_test_methods_for_class",
"get_test_run_command",
"insert_method",
"install_codeflash_runtime",
"instrument_existing_test",
"instrument_for_behavior",
"instrument_for_benchmarking",
"instrument_generated_java_test",
"is_java_project",
"is_test_file",
"normalize_java_code",
"parse_surefire_results",
"parse_test_results",
"remove_assertions_from_test",
"remove_instrumentation",
# Replacement
"add_runtime_comments",
"insert_method",
"remove_method",
"remove_test_functions",
"replace_function",
"replace_method_body",
# Support
"JavaSupport",
"get_java_support",
# Test discovery
"build_test_mapping_for_project",
"discover_all_tests",
"discover_tests",
"find_tests_for_function",
"get_test_class_for_source_class",
"get_test_file_suffix",
"get_test_methods_for_class",
"is_test_file",
# Test runner
"JavaTestRunResult",
"get_test_run_command",
"parse_surefire_results",
"parse_test_results",
"resolve_imports_for_file",
"run_behavioral_tests",
"run_benchmarking_tests",
"run_maven_tests",
"run_tests",
"transform_java_assertions",
]

View file

@ -55,9 +55,7 @@ def _get_qualified_name(func: Any) -> str:
def instrument_for_behavior(
source: str,
functions: Sequence[FunctionToOptimize],
analyzer: JavaAnalyzer | None = None,
source: str, functions: Sequence[FunctionToOptimize], analyzer: JavaAnalyzer | None = None
) -> str:
"""Add behavior instrumentation to capture inputs/outputs.
@ -83,9 +81,7 @@ def instrument_for_behavior(
def instrument_for_benchmarking(
test_source: str,
target_function: FunctionToOptimize,
analyzer: JavaAnalyzer | None = None,
test_source: str, target_function: FunctionToOptimize, analyzer: JavaAnalyzer | None = None
) -> str:
"""Add timing instrumentation to test code.
@ -168,19 +164,9 @@ def instrument_existing_test(
)
else:
# Behavior mode: add timing instrumentation that also writes to SQLite
modified_source = _add_behavior_instrumentation(
modified_source,
original_class_name,
func_name,
)
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
logger.debug(
"Java %s testing for %s: renamed class %s -> %s",
mode,
func_name,
original_class_name,
new_class_name,
)
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
return True, modified_source
@ -325,8 +311,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
# - new ClassName(args)
# - this
method_call_pattern = re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)",
re.MULTILINE
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*({re.escape(func_name)})\s*\(([^)]*)\)", re.MULTILINE
)
for body_line in body_lines:
@ -346,7 +331,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
full_call = match.group(0) # e.g., "new StringUtils().reverse(\"hello\")"
# Replace this occurrence with the variable
new_line = new_line[:match.start()] + var_name + new_line[match.end():]
new_line = new_line[: match.start()] + var_name + new_line[match.end() :]
# Insert capture line
capture_line = f"{line_indent_str}Object {var_name} = {full_call};"
@ -573,10 +558,7 @@ def _add_timing_instrumentation(source: str, class_name: str, func_name: str) ->
def create_benchmark_test(
target_function: FunctionToOptimize,
test_setup_code: str,
invocation_code: str,
iterations: int = 1000,
target_function: FunctionToOptimize, test_setup_code: str, invocation_code: str, iterations: int = 1000
) -> str:
"""Create a benchmark test for a function.
@ -654,6 +636,11 @@ def instrument_generated_java_test(
) -> str:
"""Instrument a generated Java test for behavior or performance testing.
For generated tests (AI-generated), this function:
1. Removes assertions and captures function return values (for regression testing)
2. Renames the class to include mode suffix
3. Adds timing instrumentation for performance mode
Args:
test_code: The generated test source code.
function_name: Name of the function being tested.
@ -664,6 +651,13 @@ def instrument_generated_java_test(
Instrumented test source code.
"""
from codeflash.languages.java.remove_asserts import transform_java_assertions
# For behavior mode, remove assertions and capture function return values
# This converts the generated test into a regression test that captures outputs
if mode == "behavior":
test_code = transform_java_assertions(test_code, function_name, qualified_name)
# Extract class name from the test code
# Use pattern that starts at beginning of line to avoid matching words in comments
class_match = re.search(r"^(?:public\s+)?class\s+(\w+)", test_code, re.MULTILINE)
@ -681,9 +675,7 @@ def instrument_generated_java_test(
# Rename the class in the source
modified_code = re.sub(
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b",
rf"\1class {new_class_name}",
test_code,
rf"\b(public\s+)?class\s+{re.escape(original_class_name)}\b", rf"\1class {new_class_name}", test_code
)
# For performance mode, add timing instrumentation

View file

@ -0,0 +1,759 @@
"""Java assertion removal transformer for converting tests to regression tests.
This module removes assertion statements from Java test code while preserving
function calls, enabling behavioral verification by comparing outputs between
original and optimized code.
Supported frameworks:
- JUnit 5 (Jupiter): assertEquals, assertTrue, assertThrows, etc.
- JUnit 4: org.junit.Assert.*
- AssertJ: assertThat(...).isEqualTo(...)
- TestNG: org.testng.Assert.*
- Hamcrest: assertThat(actual, is(expected))
- Truth: assertThat(actual).isEqualTo(expected)
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
# JUnit 5 assertion methods that take (expected, actual, ...) or (actual, ...)
JUNIT5_VALUE_ASSERTIONS = frozenset(
{
"assertEquals",
"assertNotEquals",
"assertSame",
"assertNotSame",
"assertArrayEquals",
"assertIterableEquals",
"assertLinesMatch",
}
)
# JUnit 5 assertions that take a single boolean/object argument
JUNIT5_CONDITION_ASSERTIONS = frozenset({"assertTrue", "assertFalse", "assertNull", "assertNotNull"})
# JUnit 5 assertions that handle exceptions (need special treatment)
JUNIT5_EXCEPTION_ASSERTIONS = frozenset({"assertThrows", "assertDoesNotThrow"})
# JUnit 5 timeout assertions
JUNIT5_TIMEOUT_ASSERTIONS = frozenset({"assertTimeout", "assertTimeoutPreemptively"})
# JUnit 5 grouping assertion
JUNIT5_GROUP_ASSERTIONS = frozenset({"assertAll"})
# All JUnit 5 assertions
JUNIT5_ALL_ASSERTIONS = (
JUNIT5_VALUE_ASSERTIONS
| JUNIT5_CONDITION_ASSERTIONS
| JUNIT5_EXCEPTION_ASSERTIONS
| JUNIT5_TIMEOUT_ASSERTIONS
| JUNIT5_GROUP_ASSERTIONS
)
# AssertJ terminal assertions (methods that end the chain)
ASSERTJ_TERMINAL_METHODS = frozenset(
{
"isEqualTo",
"isNotEqualTo",
"isSameAs",
"isNotSameAs",
"isNull",
"isNotNull",
"isTrue",
"isFalse",
"isEmpty",
"isNotEmpty",
"isBlank",
"isNotBlank",
"contains",
"containsOnly",
"containsExactly",
"containsExactlyInAnyOrder",
"doesNotContain",
"startsWith",
"endsWith",
"matches",
"hasSize",
"hasSizeBetween",
"hasSizeGreaterThan",
"hasSizeLessThan",
"isGreaterThan",
"isGreaterThanOrEqualTo",
"isLessThan",
"isLessThanOrEqualTo",
"isBetween",
"isCloseTo",
"isPositive",
"isNegative",
"isZero",
"isNotZero",
"isInstanceOf",
"isNotInstanceOf",
"isIn",
"isNotIn",
"containsKey",
"containsKeys",
"containsValue",
"containsValues",
"containsEntry",
"hasFieldOrPropertyWithValue",
"extracting",
"satisfies",
"doesNotThrow",
}
)
# Hamcrest matcher methods
HAMCREST_MATCHERS = frozenset(
{
"is",
"equalTo",
"not",
"nullValue",
"notNullValue",
"hasItem",
"hasItems",
"hasSize",
"containsString",
"startsWith",
"endsWith",
"greaterThan",
"lessThan",
"closeTo",
"instanceOf",
"anything",
"allOf",
"anyOf",
}
)
@dataclass
class TargetCall:
"""Represents a method call that should be captured."""
receiver: str | None # 'calc', 'algorithms' (None for static)
method_name: str
arguments: str
full_call: str # 'calc.fibonacci(10)'
start_pos: int
end_pos: int
@dataclass
class AssertionMatch:
"""Represents a matched assertion statement."""
start_pos: int
end_pos: int
statement_type: str # 'junit5', 'assertj', 'junit4', 'testng', 'hamcrest'
assertion_method: str
target_calls: list[TargetCall] = field(default_factory=list)
leading_whitespace: str = ""
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
class JavaAssertTransformer:
"""Transforms Java test code by removing assertions and preserving function calls.
This class uses tree-sitter for AST-based analysis and regex for text manipulation.
It handles various Java testing frameworks including JUnit 5, JUnit 4, AssertJ,
TestNG, Hamcrest, and Truth.
"""
def __init__(
self, function_name: str, qualified_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> None:
self.analyzer = analyzer or get_java_analyzer()
self.func_name = function_name
self.qualified_name = qualified_name or function_name
self.invocation_counter = 0
self._detected_framework: str | None = None
def transform(self, source: str) -> str:
"""Remove assertions from source code, preserving target function calls.
Args:
source: Java source code containing test assertions.
Returns:
Transformed source with assertions replaced by captured function calls.
"""
if not source or not source.strip():
return source
# Detect framework from imports
self._detected_framework = self._detect_framework(source)
# Find all assertion statements
assertions = self._find_assertions(source)
if not assertions:
return source
# Filter to only assertions that contain target calls
assertions_with_targets = [a for a in assertions if a.target_calls or a.is_exception_assertion]
if not assertions_with_targets:
return source
# Sort by position (forward order) to assign counter numbers in source order
assertions_with_targets.sort(key=lambda a: a.start_pos)
# Filter out nested assertions (e.g., assertEquals inside assertAll)
# An assertion is nested if it's completely contained within another assertion
non_nested: list[AssertionMatch] = []
for i, assertion in enumerate(assertions_with_targets):
is_nested = False
for j, other in enumerate(assertions_with_targets):
if i != j:
# Check if 'assertion' is nested inside 'other'
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
is_nested = True
break
if not is_nested:
non_nested.append(assertion)
assertions_with_targets = non_nested
# Pre-compute all replacements with correct counter values
replacements: list[tuple[int, int, str]] = []
for assertion in assertions_with_targets:
replacement = self._generate_replacement(assertion)
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
# Apply replacements in reverse order to preserve positions
result = source
for start_pos, end_pos, replacement in reversed(replacements):
result = result[:start_pos] + replacement + result[end_pos:]
return result
def _detect_framework(self, source: str) -> str:
"""Detect which testing framework is being used from imports.
Checks more specific frameworks first (AssertJ, Hamcrest) before
falling back to generic JUnit.
"""
imports = self.analyzer.find_imports(source)
# First pass: check for specific assertion libraries
for imp in imports:
path = imp.import_path.lower()
if "org.assertj" in path:
return "assertj"
if "org.hamcrest" in path:
return "hamcrest"
if "com.google.common.truth" in path:
return "truth"
if "org.testng" in path:
return "testng"
# Second pass: check for JUnit versions
for imp in imports:
path = imp.import_path.lower()
if "org.junit.jupiter" in path or "junit.jupiter" in path:
return "junit5"
if "org.junit" in path:
return "junit4"
# Default to JUnit 5 if no specific imports found
return "junit5"
def _find_assertions(self, source: str) -> list[AssertionMatch]:
"""Find all assertion statements in the source code."""
assertions: list[AssertionMatch] = []
# Find JUnit-style assertions
assertions.extend(self._find_junit_assertions(source))
# Find AssertJ/Truth-style fluent assertions
assertions.extend(self._find_fluent_assertions(source))
# Find Hamcrest assertions
assertions.extend(self._find_hamcrest_assertions(source))
return assertions
def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
"""Find JUnit 4/5 and TestNG style assertions."""
assertions: list[AssertionMatch] = []
# Pattern for JUnit assertions: (Assert.|Assertions.)?assertXxx(...)
# This handles both static imports and qualified calls:
# - assertEquals (static import)
# - Assert.assertEquals (JUnit 4)
# - Assertions.assertEquals (JUnit 5)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
full_method = match.group(2)
assertion_method = match.group(3)
# Find the complete assertion statement (balanced parens)
start_pos = match.start()
paren_start = match.end() - 1 # Position of opening paren
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon after closing paren
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from the arguments
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For assertThrows, extract the lambda body
lambda_body = None
if is_exception and assertion_method == "assertThrows":
lambda_body = self._extract_lambda_body(args_content)
original_text = source[start_pos:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
if "jupiter" in detected or detected == "junit5":
stmt_type = "junit5"
else:
stmt_type = detected
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method=assertion_method,
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
)
)
return assertions
def _find_fluent_assertions(self, source: str) -> list[AssertionMatch]:
"""Find AssertJ and Truth style fluent assertions (assertThat chains)."""
assertions: list[AssertionMatch] = []
# Pattern for fluent assertions: assertThat(...).<chain>
# Handles both org.assertj and com.google.common.truth
pattern = re.compile(r"(\s*)((?:Assertions?\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
# Find assertThat(...) content
args_content, after_paren = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Find the assertion chain (e.g., .isEqualTo(5).hasSize(3))
chain_end = self._find_fluent_chain_end(source, after_paren)
if chain_end == after_paren:
# No chain found, skip
continue
# Check for semicolon
end_pos = chain_end
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from assertThat argument
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "assertj"
stmt_type = "assertj" if "assertj" in detected else "truth"
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]:
"""Find Hamcrest style assertions: assertThat(actual, matcher)."""
assertions: list[AssertionMatch] = []
if self._detected_framework != "hamcrest":
return assertions
# Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher)
pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# For Hamcrest, the first arg (or second if reason given) is the actual value
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type="hamcrest",
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_fluent_chain_end(self, source: str, start_pos: int) -> int:
"""Find the end of a fluent assertion chain."""
pos = start_pos
while pos < len(source):
# Skip whitespace
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
if pos >= len(source) or source[pos] != ".":
break
pos += 1 # Skip dot
# Skip whitespace after dot
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Read method name
method_start = pos
while pos < len(source) and (source[pos].isalnum() or source[pos] == "_"):
pos += 1
if pos == method_start:
break
method_name = source[method_start:pos]
# Skip whitespace before potential parens
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Check for parentheses
if pos < len(source) and source[pos] == "(":
_, new_pos = self._find_balanced_parens(source, pos)
if new_pos == -1:
break
pos = new_pos
# Check if this is a terminal assertion method
if method_name in ASSERTJ_TERMINAL_METHODS:
# Continue looking for chained assertions
continue
return pos
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
"""Extract calls to the target function from assertion arguments."""
target_calls: list[TargetCall] = []
# Pattern to match method calls: (receiver.)?func_name(args)
# Handles: obj.method(args), ClassName.staticMethod(args), method(args)
pattern = re.compile(rf"((?:[a-zA-Z_]\w*\.)*)?({re.escape(self.func_name)})\s*\(", re.MULTILINE)
for match in pattern.finditer(content):
receiver_prefix = match.group(1) or ""
receiver = receiver_prefix.rstrip(".") if receiver_prefix else None
method_name = match.group(2)
# Find the arguments
paren_pos = match.end() - 1
args_content, end_pos = self._find_balanced_parens(content, paren_pos)
if args_content is None:
continue
full_call = content[match.start() : end_pos]
target_calls.append(
TargetCall(
receiver=receiver,
method_name=method_name,
arguments=args_content,
full_call=full_call,
start_pos=base_offset + match.start(),
end_pos=base_offset + end_pos,
)
)
return target_calls
def _extract_lambda_body(self, content: str) -> str | None:
"""Extract the body of a lambda expression from assertThrows arguments.
For assertThrows(Exception.class, () -> code()), we want to extract 'code()'.
For assertThrows(Exception.class, () -> { code(); }), we want 'code();'.
"""
# Look for lambda: () -> expr or () -> { block }
lambda_match = re.search(r"\(\s*\)\s*->\s*", content)
if not lambda_match:
return None
body_start = lambda_match.end()
remaining = content[body_start:].strip()
if remaining.startswith("{"):
# Block lambda: () -> { code }
_, block_end = self._find_balanced_braces(content, body_start + content[body_start:].index("{"))
if block_end != -1:
# Extract content inside braces
brace_content = content[body_start + content[body_start:].index("{") + 1 : block_end - 1]
return brace_content.strip()
else:
# Expression lambda: () -> expr
# Find the end (before the closing paren of assertThrows)
depth = 0
end = body_start
for i, ch in enumerate(content[body_start:]):
if ch == "(":
depth += 1
elif ch == ")":
if depth == 0:
end = body_start + i
break
depth -= 1
return content[body_start:end].strip()
return None
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
"""Find content within balanced parentheses.
Args:
code: The source code.
open_paren_pos: Position of the opening parenthesis.
Returns:
Tuple of (content inside parens, position after closing paren) or (None, -1).
"""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return None, -1
depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None
in_char = False
while pos < len(code) and depth > 0:
char = code[pos]
prev_char = code[pos - 1] if pos > 0 else ""
# Handle character literals
if char == "'" and not in_string and prev_char != "\\":
in_char = not in_char
# Handle string literals (double quotes)
elif char == '"' and not in_char and prev_char != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
pos += 1
if depth != 0:
return None, -1
return code[open_paren_pos + 1 : pos - 1], pos
def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | None, int]:
"""Find content within balanced braces."""
if open_brace_pos >= len(code) or code[open_brace_pos] != "{":
return None, -1
depth = 1
pos = open_brace_pos + 1
in_string = False
string_char = None
in_char = False
while pos < len(code) and depth > 0:
char = code[pos]
prev_char = code[pos - 1] if pos > 0 else ""
if char == "'" and not in_string and prev_char != "\\":
in_char = not in_char
elif char == '"' and not in_char and prev_char != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "{":
depth += 1
elif char == "}":
depth -= 1
pos += 1
if depth != 0:
return None, -1
return code[open_brace_pos + 1 : pos - 1], pos
def _generate_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement code for an assertion.
The replacement captures target function return values and removes assertions.
Args:
assertion: The assertion to replace.
Returns:
Replacement code string.
"""
if assertion.is_exception_assertion:
return self._generate_exception_replacement(assertion)
if not assertion.target_calls:
# No target calls found, just comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertion: no target calls found"
# Generate capture statements for each target call
replacements = []
# For the first replacement, use the full leading whitespace
# For subsequent ones, strip leading newlines to avoid extra blank lines
base_indent = assertion.leading_whitespace.lstrip("\n\r")
for i, call in enumerate(assertion.target_calls):
self.invocation_counter += 1
var_name = f"_cf_result{self.invocation_counter}"
if i == 0:
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
else:
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
return "\n".join(replacements)
def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement for assertThrows/assertDoesNotThrow.
Transforms:
assertThrows(Exception.class, () -> calculator.divide(1, 0));
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
"""
self.invocation_counter += 1
if assertion.lambda_body:
# Extract the actual code from the lambda
code_to_run = assertion.lambda_body
if not code_to_run.endswith(";"):
code_to_run += ";"
return (
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
call = assertion.target_calls[0]
return (
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
"""Transform Java test code by removing assertions and capturing function calls.
This is the main entry point for Java assertion transformation.
Args:
source: The Java test source code.
function_name: Name of the function being tested.
qualified_name: Optional fully qualified name of the function.
Returns:
Transformed source code with assertions replaced by capture statements.
"""
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name)
return transformer.transform(source)
def remove_assertions_from_test(source: str, target_function: FunctionToOptimize) -> str:
"""Remove assertions from test code for the given target function.
This is a convenience wrapper around transform_java_assertions that
takes a FunctionToOptimize object.
Args:
source: The Java test source code.
target_function: The function being optimized.
Returns:
Transformed source code.
"""
return transform_java_assertions(
source=source, function_name=target_function.function_name, qualified_name=target_function.qualified_name
)

View file

@ -0,0 +1,964 @@
"""Tests for Java assertion removal transformer.
This test suite covers the transformation of Java test assertions into
regression test code that captures function return values.
All tests assert for full string equality, no substring matching.
"""
from codeflash.languages.java.remove_asserts import (
JavaAssertTransformer,
transform_java_assertions,
)
class TestBasicJUnit5Assertions:
"""Tests for basic JUnit 5 assertion transformations."""
def test_assert_equals_basic(self):
source = """\
@Test
void testFibonacci() {
assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_equals_with_message(self):
source = """\
@Test
void testFibonacci() {
assertEquals(55, calculator.fibonacci(10), "Fibonacci of 10 should be 55");
}"""
expected = """\
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_true(self):
source = """\
@Test
void testIsValid() {
assertTrue(validator.isValid("test"));
}"""
expected = """\
@Test
void testIsValid() {
Object _cf_result1 = validator.isValid("test");
}"""
result = transform_java_assertions(source, "isValid")
assert result == expected
def test_assert_false(self):
source = """\
@Test
void testIsInvalid() {
assertFalse(validator.isValid(""));
}"""
expected = """\
@Test
void testIsInvalid() {
Object _cf_result1 = validator.isValid("");
}"""
result = transform_java_assertions(source, "isValid")
assert result == expected
def test_assert_null(self):
source = """\
@Test
void testGetNull() {
assertNull(processor.getValue(null));
}"""
expected = """\
@Test
void testGetNull() {
Object _cf_result1 = processor.getValue(null);
}"""
result = transform_java_assertions(source, "getValue")
assert result == expected
def test_assert_not_null(self):
source = """\
@Test
void testGetValue() {
assertNotNull(processor.getValue("key"));
}"""
expected = """\
@Test
void testGetValue() {
Object _cf_result1 = processor.getValue("key");
}"""
result = transform_java_assertions(source, "getValue")
assert result == expected
def test_assert_not_equals(self):
source = """\
@Test
void testDifferent() {
assertNotEquals(0, calculator.add(1, 2));
}"""
expected = """\
@Test
void testDifferent() {
Object _cf_result1 = calculator.add(1, 2);
}"""
result = transform_java_assertions(source, "add")
assert result == expected
def test_assert_same(self):
source = """\
@Test
void testSame() {
assertSame(expected, factory.getInstance());
}"""
expected = """\
@Test
void testSame() {
Object _cf_result1 = factory.getInstance();
}"""
result = transform_java_assertions(source, "getInstance")
assert result == expected
def test_assert_array_equals(self):
source = """\
@Test
void testSort() {
assertArrayEquals(expected, sorter.sort(input));
}"""
expected = """\
@Test
void testSort() {
Object _cf_result1 = sorter.sort(input);
}"""
result = transform_java_assertions(source, "sort")
assert result == expected
class TestJUnit5PrefixedAssertions:
"""Tests for JUnit 5 assertions with Assertions. prefix."""
def test_assertions_prefix(self):
source = """\
@Test
void testFibonacci() {
Assertions.assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_prefix(self):
source = """\
@Test
void testAdd() {
Assert.assertEquals(5, calculator.add(2, 3));
}"""
expected = """\
@Test
void testAdd() {
Object _cf_result1 = calculator.add(2, 3);
}"""
result = transform_java_assertions(source, "add")
assert result == expected
class TestJUnit5ExceptionAssertions:
"""Tests for JUnit 5 exception assertions."""
def test_assert_throws_lambda(self):
source = """\
@Test
void testDivideByZero() {
assertThrows(IllegalArgumentException.class, () -> calculator.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_block_lambda(self):
source = """\
@Test
void testDivideByZero() {
assertThrows(ArithmeticException.class, () -> {
calculator.divide(1, 0);
});
}"""
expected = """\
@Test
void testDivideByZero() {
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_does_not_throw(self):
source = """\
@Test
void testValidDivision() {
assertDoesNotThrow(() -> calculator.divide(10, 2));
}"""
expected = """\
@Test
void testValidDivision() {
try { calculator.divide(10, 2); } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
class TestStaticMethodCalls:
"""Tests for static method call handling."""
def test_static_method_call(self):
source = """\
@Test
void testQuickAdd() {
assertEquals(15.0, Calculator.quickAdd(10.0, 5.0));
}"""
expected = """\
@Test
void testQuickAdd() {
Object _cf_result1 = Calculator.quickAdd(10.0, 5.0);
}"""
result = transform_java_assertions(source, "quickAdd")
assert result == expected
def test_static_method_fully_qualified(self):
source = """\
@Test
void testReverse() {
assertEquals("olleh", com.example.StringUtils.reverse("hello"));
}"""
expected = """\
@Test
void testReverse() {
Object _cf_result1 = com.example.StringUtils.reverse("hello");
}"""
result = transform_java_assertions(source, "reverse")
assert result == expected
class TestMultipleAssertions:
"""Tests for multiple assertions in a single test method."""
def test_multiple_assertions_same_function(self):
source = """\
@Test
void testFibonacciSequence() {
assertEquals(0, calculator.fibonacci(0));
assertEquals(1, calculator.fibonacci(1));
assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
@Test
void testFibonacciSequence() {
Object _cf_result1 = calculator.fibonacci(0);
Object _cf_result2 = calculator.fibonacci(1);
Object _cf_result3 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_multiple_assertions_different_functions(self):
source = """\
@Test
void testCalculator() {
assertEquals(5, calculator.add(2, 3));
assertEquals(6, calculator.multiply(2, 3));
}"""
expected = """\
@Test
void testCalculator() {
Object _cf_result1 = calculator.add(2, 3);
assertEquals(6, calculator.multiply(2, 3));
}"""
result = transform_java_assertions(source, "add")
assert result == expected
class TestAssertJFluentAssertions:
"""Tests for AssertJ fluent assertion transformations."""
def test_assertj_basic(self):
source = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testFibonacci() {
assertThat(calculator.fibonacci(10)).isEqualTo(55);
}"""
expected = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assertj_chained(self):
source = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetList() {
assertThat(processor.getList()).hasSize(5).contains("a", "b");
}"""
expected = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetList() {
Object _cf_result1 = processor.getList();
}"""
result = transform_java_assertions(source, "getList")
assert result == expected
def test_assertj_is_null(self):
source = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetNull() {
assertThat(processor.getValue(null)).isNull();
}"""
expected = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetNull() {
Object _cf_result1 = processor.getValue(null);
}"""
result = transform_java_assertions(source, "getValue")
assert result == expected
def test_assertj_is_not_empty(self):
source = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetList() {
assertThat(processor.getList()).isNotEmpty();
}"""
expected = """\
import static org.assertj.core.api.Assertions.assertThat;
@Test
void testGetList() {
Object _cf_result1 = processor.getList();
}"""
result = transform_java_assertions(source, "getList")
assert result == expected
class TestNestedMethodCalls:
"""Tests for nested method calls in assertions."""
def test_nested_call_in_expected(self):
source = """\
@Test
void testCompare() {
assertEquals(helper.getExpected(), calculator.compute(5));
}"""
expected = """\
@Test
void testCompare() {
Object _cf_result1 = calculator.compute(5);
}"""
result = transform_java_assertions(source, "compute")
assert result == expected
def test_nested_call_as_argument(self):
source = """\
@Test
void testProcess() {
assertEquals(expected, processor.process(helper.getData()));
}"""
expected = """\
@Test
void testProcess() {
Object _cf_result1 = processor.process(helper.getData());
}"""
result = transform_java_assertions(source, "process")
assert result == expected
def test_deeply_nested(self):
source = """\
@Test
void testDeep() {
assertEquals(expected, outer.process(inner.compute(calculator.fibonacci(5))));
}"""
expected = """\
@Test
void testDeep() {
Object _cf_result1 = calculator.fibonacci(5);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestWhitespacePreservation:
"""Tests for whitespace and indentation preservation."""
def test_preserves_indentation(self):
source = """\
@Test
void testFibonacci() {
assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_multiline_assertion(self):
source = """\
@Test
void testLongAssertion() {
assertEquals(
expectedValue,
calculator.computeComplexResult(
arg1,
arg2,
arg3
)
);
}"""
expected = """\
@Test
void testLongAssertion() {
Object _cf_result1 = calculator.computeComplexResult(
arg1,
arg2,
arg3
);
}"""
result = transform_java_assertions(source, "computeComplexResult")
assert result == expected
class TestStringsWithSpecialCharacters:
"""Tests for strings containing special characters."""
def test_string_with_parentheses(self):
source = """\
@Test
void testFormat() {
assertEquals("hello (world)", formatter.format("hello", "world"));
}"""
expected = """\
@Test
void testFormat() {
Object _cf_result1 = formatter.format("hello", "world");
}"""
result = transform_java_assertions(source, "format")
assert result == expected
def test_string_with_quotes(self):
source = """\
@Test
void testEscape() {
assertEquals("hello \\"world\\"", formatter.escape("hello \\"world\\""));
}"""
expected = """\
@Test
void testEscape() {
Object _cf_result1 = formatter.escape("hello \\"world\\"");
}"""
result = transform_java_assertions(source, "escape")
assert result == expected
def test_string_with_newlines(self):
source = """\
@Test
void testMultiline() {
assertEquals("line1\\nline2", processor.join("line1", "line2"));
}"""
expected = """\
@Test
void testMultiline() {
Object _cf_result1 = processor.join("line1", "line2");
}"""
result = transform_java_assertions(source, "join")
assert result == expected
class TestNonAssertionCodePreservation:
"""Tests that non-assertion code is preserved unchanged."""
def test_setup_code_preserved(self):
source = """\
@Test
void testWithSetup() {
Calculator calc = new Calculator(2);
int input = 10;
assertEquals(55, calc.fibonacci(input));
}"""
expected = """\
@Test
void testWithSetup() {
Calculator calc = new Calculator(2);
int input = 10;
Object _cf_result1 = calc.fibonacci(input);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_other_method_calls_preserved(self):
source = """\
@Test
void testWithHelper() {
helper.setup();
assertEquals(55, calculator.fibonacci(10));
helper.cleanup();
}"""
expected = """\
@Test
void testWithHelper() {
helper.setup();
Object _cf_result1 = calculator.fibonacci(10);
helper.cleanup();
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_variable_declarations_preserved(self):
source = """\
@Test
void testWithVariables() {
int expected = 55;
int actual = calculator.fibonacci(10);
assertEquals(expected, actual);
}"""
# fibonacci is assigned to 'actual', not in the assertion - no transformation
expected = source
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestParameterizedTests:
"""Tests for parameterized test handling."""
def test_parameterized_test(self):
source = """\
@ParameterizedTest
@CsvSource({
"0, 0",
"1, 1",
"10, 55"
})
void testFibonacciSequence(int n, long expected) {
assertEquals(expected, calculator.fibonacci(n));
}"""
expected = """\
@ParameterizedTest
@CsvSource({
"0, 0",
"1, 1",
"10, 55"
})
void testFibonacciSequence(int n, long expected) {
Object _cf_result1 = calculator.fibonacci(n);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestNestedTestClasses:
"""Tests for nested test class handling."""
def test_nested_class(self):
source = """\
@Nested
@DisplayName("Fibonacci Tests")
class FibonacciTests {
@Test
void testBasic() {
assertEquals(55, calculator.fibonacci(10));
}
}"""
expected = """\
@Nested
@DisplayName("Fibonacci Tests")
class FibonacciTests {
@Test
void testBasic() {
Object _cf_result1 = calculator.fibonacci(10);
}
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestMockitoPreservation:
"""Tests that Mockito code is not modified."""
def test_mockito_when_preserved(self):
source = """\
@Test
void testWithMock() {
when(mockService.getData()).thenReturn("test");
assertEquals("test", processor.process(mockService));
}"""
expected = """\
@Test
void testWithMock() {
when(mockService.getData()).thenReturn("test");
Object _cf_result1 = processor.process(mockService);
}"""
result = transform_java_assertions(source, "process")
assert result == expected
def test_mockito_verify_preserved(self):
source = """\
@Test
void testWithVerify() {
processor.process(mockService);
verify(mockService).getData();
}"""
# No assertions to transform, source unchanged
expected = source
result = transform_java_assertions(source, "process")
assert result == expected
class TestEdgeCases:
"""Tests for edge cases and boundary conditions."""
def test_empty_source(self):
result = transform_java_assertions("", "fibonacci")
assert result == ""
def test_whitespace_only(self):
source = " \n\t "
result = transform_java_assertions(source, "fibonacci")
assert result == source
def test_no_assertions(self):
source = """\
@Test
void testNoAssertions() {
calculator.fibonacci(10);
}"""
expected = source
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assertion_without_target_function(self):
source = """\
@Test
void testOther() {
assertEquals(5, helper.compute(3));
}"""
# No transformation since target function is not in the assertion
expected = source
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_function_name_in_string(self):
source = """\
@Test
void testWithStringContainingFunctionName() {
assertEquals("fibonacci(10) = 55", formatter.format("fibonacci", 10, 55));
}"""
expected = """\
@Test
void testWithStringContainingFunctionName() {
Object _cf_result1 = formatter.format("fibonacci", 10, 55);
}"""
result = transform_java_assertions(source, "format")
assert result == expected
class TestJUnit4Compatibility:
"""Tests for JUnit 4 style assertions."""
def test_junit4_assert_equals(self):
source = """\
import static org.junit.Assert.*;
@Test
public void testFibonacci() {
assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
import static org.junit.Assert.*;
@Test
public void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_junit4_with_message_first(self):
source = """\
@Test
public void testFibonacci() {
assertEquals("Should be 55", 55, calculator.fibonacci(10));
}"""
expected = """\
@Test
public void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestAssertAll:
"""Tests for assertAll grouped assertions."""
def test_assert_all_basic(self):
source = """\
@Test
void testMultiple() {
assertAll(
() -> assertEquals(0, calculator.fibonacci(0)),
() -> assertEquals(1, calculator.fibonacci(1)),
() -> assertEquals(55, calculator.fibonacci(10))
);
}"""
expected = """\
@Test
void testMultiple() {
Object _cf_result1 = calculator.fibonacci(0);
Object _cf_result2 = calculator.fibonacci(1);
Object _cf_result3 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
class TestTransformerClass:
"""Tests for the JavaAssertTransformer class directly."""
def test_invocation_counter_increments(self):
transformer = JavaAssertTransformer("fibonacci")
source = """\
@Test
void test() {
assertEquals(0, calc.fibonacci(0));
assertEquals(1, calc.fibonacci(1));
}"""
expected = """\
@Test
void test() {
Object _cf_result1 = calc.fibonacci(0);
Object _cf_result2 = calc.fibonacci(1);
}"""
result = transformer.transform(source)
assert result == expected
assert transformer.invocation_counter == 2
def test_qualified_name_support(self):
transformer = JavaAssertTransformer(
function_name="fibonacci",
qualified_name="com.example.Calculator.fibonacci",
)
assert transformer.qualified_name == "com.example.Calculator.fibonacci"
def test_custom_analyzer(self):
from codeflash.languages.java.parser import get_java_analyzer
analyzer = get_java_analyzer()
transformer = JavaAssertTransformer("fibonacci", analyzer=analyzer)
assert transformer.analyzer is analyzer
class TestImportDetection:
"""Tests for framework detection from imports."""
def test_detect_junit5(self):
source = """\
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;"""
transformer = JavaAssertTransformer("test")
transformer._detected_framework = transformer._detect_framework(source)
assert transformer._detected_framework == "junit5"
def test_detect_assertj(self):
source = """\
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;"""
transformer = JavaAssertTransformer("test")
transformer._detected_framework = transformer._detect_framework(source)
assert transformer._detected_framework == "assertj"
def test_detect_testng(self):
source = """\
import org.testng.Assert;
import org.testng.annotations.Test;"""
transformer = JavaAssertTransformer("test")
transformer._detected_framework = transformer._detect_framework(source)
assert transformer._detected_framework == "testng"
def test_detect_hamcrest(self):
source = """\
import org.junit.Test;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.*;"""
transformer = JavaAssertTransformer("test")
transformer._detected_framework = transformer._detect_framework(source)
assert transformer._detected_framework == "hamcrest"
class TestInstrumentGeneratedJavaTest:
"""Tests for the instrument_generated_java_test integration."""
def test_behavior_mode_removes_assertions(self):
from codeflash.languages.java.instrumentation import instrument_generated_java_test
test_code = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testFibonacci() {
Calculator calc = new Calculator();
assertEquals(55, calc.fibonacci(10));
}
}"""
expected = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest__perfinstrumented {
@Test
void testFibonacci() {
Calculator calc = new Calculator();
Object _cf_result1 = calc.fibonacci(10);
}
}"""
result = instrument_generated_java_test(
test_code=test_code,
function_name="fibonacci",
qualified_name="com.example.Calculator.fibonacci",
mode="behavior",
)
assert result == expected
def test_behavior_mode_with_assertj(self):
from codeflash.languages.java.instrumentation import instrument_generated_java_test
test_code = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
public class StringUtilsTest {
@Test
void testReverse() {
assertThat(StringUtils.reverse("hello")).isEqualTo("olleh");
}
}"""
expected = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
public class StringUtilsTest__perfinstrumented {
@Test
void testReverse() {
Object _cf_result1 = StringUtils.reverse("hello");
}
}"""
result = instrument_generated_java_test(
test_code=test_code,
function_name="reverse",
qualified_name="com.example.StringUtils.reverse",
mode="behavior",
)
assert result == expected
class TestComplexRealWorldExamples:
"""Tests based on real-world test patterns."""
def test_calculator_test_pattern(self):
source = """\
@Test
@DisplayName("should calculate compound interest for basic case")
void testBasicCompoundInterest() {
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
assertNotNull(result);
assertTrue(result.contains("."));
}"""
# assertNotNull(result) and assertTrue(result.contains(".")) don't contain the target function
# so they remain unchanged, and the variable assignment is also preserved
expected = source
result = transform_java_assertions(source, "calculateCompoundInterest")
assert result == expected
def test_string_utils_pattern(self):
source = """\
@Test
@DisplayName("should reverse a simple string")
void testReverseSimple() {
assertEquals("olleh", StringUtils.reverse("hello"));
assertEquals("dlrow", StringUtils.reverse("world"));
}"""
expected = """\
@Test
@DisplayName("should reverse a simple string")
void testReverseSimple() {
Object _cf_result1 = StringUtils.reverse("hello");
Object _cf_result2 = StringUtils.reverse("world");
}"""
result = transform_java_assertions(source, "reverse")
assert result == expected
def test_with_before_each_setup(self):
source = """\
private Calculator calculator;
@BeforeEach
void setUp() {
calculator = new Calculator(2);
}
@Test
void testFibonacci() {
assertEquals(55, calculator.fibonacci(10));
}"""
expected = """\
private Calculator calculator;
@BeforeEach
void setUp() {
calculator = new Calculator(2);
}
@Test
void testFibonacci() {
Object _cf_result1 = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected