codeflash/codeflash/languages/java/remove_asserts.py
2026-02-16 08:32:55 +02:00

939 lines
34 KiB
Python

"""Java assertion removal transformer for converting tests to regression tests.
This module removes assertion statements from Java test code while preserving
function calls, enabling behavioral verification by comparing outputs between
original and optimized code.
Supported frameworks:
- JUnit 5 (Jupiter): assertEquals, assertTrue, assertThrows, etc.
- JUnit 4: org.junit.Assert.*
- AssertJ: assertThat(...).isEqualTo(...)
- TestNG: org.testng.Assert.*
- Hamcrest: assertThat(actual, is(expected))
- Truth: assertThat(actual).isEqualTo(expected)
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
logger = logging.getLogger(__name__)
# JUnit 5 assertion methods that take (expected, actual, ...) or (actual, ...)
JUNIT5_VALUE_ASSERTIONS = frozenset(
{
"assertEquals",
"assertNotEquals",
"assertSame",
"assertNotSame",
"assertArrayEquals",
"assertIterableEquals",
"assertLinesMatch",
}
)
# JUnit 5 assertions that take a single boolean/object argument
JUNIT5_CONDITION_ASSERTIONS = frozenset({"assertTrue", "assertFalse", "assertNull", "assertNotNull"})
# JUnit 5 assertions that handle exceptions (need special treatment)
JUNIT5_EXCEPTION_ASSERTIONS = frozenset({"assertThrows", "assertDoesNotThrow"})
# JUnit 5 timeout assertions
JUNIT5_TIMEOUT_ASSERTIONS = frozenset({"assertTimeout", "assertTimeoutPreemptively"})
# JUnit 5 grouping assertion
JUNIT5_GROUP_ASSERTIONS = frozenset({"assertAll"})
# All JUnit 5 assertions
JUNIT5_ALL_ASSERTIONS = (
JUNIT5_VALUE_ASSERTIONS
| JUNIT5_CONDITION_ASSERTIONS
| JUNIT5_EXCEPTION_ASSERTIONS
| JUNIT5_TIMEOUT_ASSERTIONS
| JUNIT5_GROUP_ASSERTIONS
)
# AssertJ terminal assertions (methods that end the chain)
ASSERTJ_TERMINAL_METHODS = frozenset(
{
"isEqualTo",
"isNotEqualTo",
"isSameAs",
"isNotSameAs",
"isNull",
"isNotNull",
"isTrue",
"isFalse",
"isEmpty",
"isNotEmpty",
"isBlank",
"isNotBlank",
"contains",
"containsOnly",
"containsExactly",
"containsExactlyInAnyOrder",
"doesNotContain",
"startsWith",
"endsWith",
"matches",
"hasSize",
"hasSizeBetween",
"hasSizeGreaterThan",
"hasSizeLessThan",
"isGreaterThan",
"isGreaterThanOrEqualTo",
"isLessThan",
"isLessThanOrEqualTo",
"isBetween",
"isCloseTo",
"isPositive",
"isNegative",
"isZero",
"isNotZero",
"isInstanceOf",
"isNotInstanceOf",
"isIn",
"isNotIn",
"containsKey",
"containsKeys",
"containsValue",
"containsValues",
"containsEntry",
"hasFieldOrPropertyWithValue",
"extracting",
"satisfies",
"doesNotThrow",
}
)
# Hamcrest matcher methods
HAMCREST_MATCHERS = frozenset(
{
"is",
"equalTo",
"not",
"nullValue",
"notNullValue",
"hasItem",
"hasItems",
"hasSize",
"containsString",
"startsWith",
"endsWith",
"greaterThan",
"lessThan",
"closeTo",
"instanceOf",
"anything",
"allOf",
"anyOf",
}
)
@dataclass
class TargetCall:
"""Represents a method call that should be captured."""
receiver: str | None # 'calc', 'algorithms' (None for static)
method_name: str
arguments: str
full_call: str # 'calc.fibonacci(10)'
start_pos: int
end_pos: int
@dataclass
class AssertionMatch:
"""Represents a matched assertion statement."""
start_pos: int
end_pos: int
statement_type: str # 'junit5', 'assertj', 'junit4', 'testng', 'hamcrest'
assertion_method: str
target_calls: list[TargetCall] = field(default_factory=list)
leading_whitespace: str = ""
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
class JavaAssertTransformer:
"""Transforms Java test code by removing assertions and preserving function calls.
This class uses tree-sitter for AST-based analysis and regex for text manipulation.
It handles various Java testing frameworks including JUnit 5, JUnit 4, AssertJ,
TestNG, Hamcrest, and Truth.
"""
def __init__(
self, function_name: str, qualified_name: str | None = None, analyzer: JavaAnalyzer | None = None
) -> None:
self.analyzer = analyzer or get_java_analyzer()
self.func_name = function_name
self.qualified_name = qualified_name or function_name
self.invocation_counter = 0
self._detected_framework: str | None = None
# Precompile the assignment-detection regex to avoid recompiling on each call.
self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
def transform(self, source: str) -> str:
"""Remove assertions from source code, preserving target function calls.
Args:
source: Java source code containing test assertions.
Returns:
Transformed source with assertions replaced by captured function calls.
"""
if not source or not source.strip():
return source
# Detect framework from imports
self._detected_framework = self._detect_framework(source)
# Find all assertion statements
assertions = self._find_assertions(source)
if not assertions:
return source
# Filter to only assertions that contain target calls
assertions_with_targets = [a for a in assertions if a.target_calls or a.is_exception_assertion]
if not assertions_with_targets:
return source
# Sort by position (forward order) to assign counter numbers in source order
assertions_with_targets.sort(key=lambda a: a.start_pos)
# Filter out nested assertions (e.g., assertEquals inside assertAll)
# An assertion is nested if it's completely contained within another assertion
non_nested: list[AssertionMatch] = []
for i, assertion in enumerate(assertions_with_targets):
is_nested = False
for j, other in enumerate(assertions_with_targets):
if i != j:
# Check if 'assertion' is nested inside 'other'
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
is_nested = True
break
if not is_nested:
non_nested.append(assertion)
assertions_with_targets = non_nested
# Pre-compute all replacements with correct counter values
replacements: list[tuple[int, int, str]] = []
for assertion in assertions_with_targets:
replacement = self._generate_replacement(assertion)
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
# Apply replacements in reverse order to preserve positions
result = source
for start_pos, end_pos, replacement in reversed(replacements):
result = result[:start_pos] + replacement + result[end_pos:]
return result
def _detect_framework(self, source: str) -> str:
"""Detect which testing framework is being used from imports.
Checks more specific frameworks first (AssertJ, Hamcrest) before
falling back to generic JUnit.
"""
imports = self.analyzer.find_imports(source)
# First pass: check for specific assertion libraries
for imp in imports:
path = imp.import_path.lower()
if "org.assertj" in path:
return "assertj"
if "org.hamcrest" in path:
return "hamcrest"
if "com.google.common.truth" in path:
return "truth"
if "org.testng" in path:
return "testng"
# Second pass: check for JUnit versions
for imp in imports:
path = imp.import_path.lower()
if "org.junit.jupiter" in path or "junit.jupiter" in path:
return "junit5"
if "org.junit" in path:
return "junit4"
# Default to JUnit 5 if no specific imports found
return "junit5"
def _find_assertions(self, source: str) -> list[AssertionMatch]:
"""Find all assertion statements in the source code."""
assertions: list[AssertionMatch] = []
# Find JUnit-style assertions
assertions.extend(self._find_junit_assertions(source))
# Find AssertJ/Truth-style fluent assertions
assertions.extend(self._find_fluent_assertions(source))
# Find Hamcrest assertions
assertions.extend(self._find_hamcrest_assertions(source))
return assertions
def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
"""Find JUnit 4/5 and TestNG style assertions."""
assertions: list[AssertionMatch] = []
# Pattern for JUnit assertions: (Assert.|Assertions.)?assertXxx(...)
# This handles both static imports and qualified calls:
# - assertEquals (static import)
# - Assert.assertEquals (JUnit 4)
# - Assertions.assertEquals (JUnit 5)
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile(
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
)
for match in pattern.finditer(source):
leading_ws = match.group(1)
full_method = match.group(2)
assertion_method = match.group(3)
# Find the complete assertion statement (balanced parens)
start_pos = match.start()
paren_start = match.end() - 1 # Position of opening paren
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon after closing paren
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from the arguments
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For exception assertions, extract the lambda body
lambda_body = None
exception_class = None
if is_exception:
lambda_body = self._extract_lambda_body(args_content)
# Extract exception class specifically for assertThrows
if assertion_method == "assertThrows":
exception_class = self._extract_exception_class(args_content)
# Check if assertion is assigned to a variable
# Detect variable assignment: Type var = assertXxx(...)
# This applies to all assertions (assertThrows, assertTimeout, etc.)
assigned_var_type = None
assigned_var_name = None
original_text = source[start_pos:end_pos]
before = source[:start_pos]
last_nl_idx = before.rfind("\n")
if last_nl_idx >= 0:
line_prefix = source[last_nl_idx + 1 : start_pos]
else:
line_prefix = source[:start_pos]
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
if var_match:
if last_nl_idx >= 0:
start_pos = last_nl_idx
leading_ws = "\n" + var_match.group(1)
else:
start_pos = 0
leading_ws = var_match.group(1)
assigned_var_type = var_match.group(2)
assigned_var_name = var_match.group(3)
original_text = source[start_pos:end_pos] # Update with adjusted range
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
if "jupiter" in detected or detected == "junit5":
stmt_type = "junit5"
else:
stmt_type = detected
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method=assertion_method,
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
assigned_var_type=assigned_var_type,
assigned_var_name=assigned_var_name,
exception_class=exception_class,
)
)
return assertions
def _find_fluent_assertions(self, source: str) -> list[AssertionMatch]:
"""Find AssertJ and Truth style fluent assertions (assertThat chains)."""
assertions: list[AssertionMatch] = []
# Pattern for fluent assertions: assertThat(...).<chain>
# Handles both org.assertj and com.google.common.truth
pattern = re.compile(r"(\s*)((?:Assertions?\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
# Find assertThat(...) content
args_content, after_paren = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Find the assertion chain (e.g., .isEqualTo(5).hasSize(3))
chain_end = self._find_fluent_chain_end(source, after_paren)
if chain_end == after_paren:
# No chain found, skip
continue
# Check for semicolon
end_pos = chain_end
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from assertThat argument
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "assertj"
stmt_type = "assertj" if "assertj" in detected else "truth"
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]:
"""Find Hamcrest style assertions: assertThat(actual, matcher)."""
assertions: list[AssertionMatch] = []
if self._detected_framework != "hamcrest":
return assertions
# Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher)
pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# For Hamcrest, the first arg (or second if reason given) is the actual value
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type="hamcrest",
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_fluent_chain_end(self, source: str, start_pos: int) -> int:
"""Find the end of a fluent assertion chain."""
pos = start_pos
while pos < len(source):
# Skip whitespace
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
if pos >= len(source) or source[pos] != ".":
break
pos += 1 # Skip dot
# Skip whitespace after dot
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Read method name
method_start = pos
while pos < len(source) and (source[pos].isalnum() or source[pos] == "_"):
pos += 1
if pos == method_start:
break
method_name = source[method_start:pos]
# Skip whitespace before potential parens
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Check for parentheses
if pos < len(source) and source[pos] == "(":
_, new_pos = self._find_balanced_parens(source, pos)
if new_pos == -1:
break
pos = new_pos
# Check if this is a terminal assertion method
if method_name in ASSERTJ_TERMINAL_METHODS:
# Continue looking for chained assertions
continue
return pos
# Wrapper template to make assertion argument fragments parseable by tree-sitter.
# e.g. content "55, obj.fibonacci(10)" becomes "class _D { void _m() { _d(55, obj.fibonacci(10)); } }"
_TS_WRAPPER_PREFIX = "class _D { void _m() { _d("
_TS_WRAPPER_SUFFIX = "); } }"
_TS_WRAPPER_PREFIX_BYTES = _TS_WRAPPER_PREFIX.encode("utf8")
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
"""Find all calls to the target function within assertion argument text using tree-sitter."""
if not content or not content.strip():
return []
content_bytes = content.encode("utf8")
wrapper_bytes = self._TS_WRAPPER_PREFIX_BYTES + content_bytes + self._TS_WRAPPER_SUFFIX.encode("utf8")
tree = self.analyzer.parse(wrapper_bytes)
results: list[TargetCall] = []
self._collect_target_invocations(tree.root_node, wrapper_bytes, content_bytes, base_offset, results)
return results
def _collect_target_invocations(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
base_offset: int, out: list[TargetCall],
) -> None:
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name."""
prefix_len = len(self._TS_WRAPPER_PREFIX_BYTES)
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.analyzer.get_node_text(name_node, wrapper_bytes) == self.func_name:
start = node.start_byte - prefix_len
end = node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes):
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
for child in node.children:
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out)
def _build_target_call(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
start_byte: int, end_byte: int, base_offset: int,
) -> TargetCall:
"""Build a TargetCall from a tree-sitter method_invocation node."""
get_text = self.analyzer.get_node_text
object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
# argument_list node includes parens, strip them
if args_text.startswith("(") and args_text.endswith(")"):
args_text = args_text[1:-1]
# Byte offsets -> char offsets for correct Python string indexing
start_char = len(content_bytes[:start_byte].decode("utf8"))
end_char = len(content_bytes[:end_byte].decode("utf8"))
return TargetCall(
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
method_name=self.func_name,
arguments=args_text,
full_call=get_text(node, wrapper_bytes),
start_pos=base_offset + start_char,
end_pos=base_offset + end_char,
)
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
"""Check if assertion is assigned to a variable.
Detects patterns like:
IllegalArgumentException exception = assertThrows(...)
Exception ex = assertThrows(...)
Args:
source: The full source code.
assertion_start: Start position of the assertion.
Returns:
Tuple of (variable_type, variable_name) or (None, None).
"""
# Look backwards from assertion_start to beginning of line
line_start = source.rfind("\n", 0, assertion_start)
if line_start == -1:
line_start = 0
else:
line_start += 1
# Pattern: Type varName = assertXxx(...)
# Handle generic types: Type<Generic> varName = ...
match = self._assign_re.search(source, line_start, assertion_start)
if match:
var_type = match.group(1).strip()
var_name = match.group(2).strip()
return var_type, var_name
return None, None
def _extract_exception_class(self, args_content: str) -> str | None:
"""Extract exception class from assertThrows arguments.
Args:
args_content: Content inside assertThrows parentheses.
Returns:
Exception class name (e.g., "IllegalArgumentException") or None.
Example:
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
"""
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
# Split by comma, but respect nested parentheses and generics
depth = 0
current = []
parts = []
for char in args_content:
if char in "(<":
depth += 1
current.append(char)
elif char in ")>":
depth -= 1
current.append(char)
elif char == "," and depth == 0:
parts.append("".join(current).strip())
current = []
else:
current.append(char)
if current:
parts.append("".join(current).strip())
if parts:
exception_arg = parts[0].strip()
# Remove .class suffix
if exception_arg.endswith(".class"):
return exception_arg[:-6].strip()
return None
def _extract_lambda_body(self, content: str) -> str | None:
"""Extract the body of a lambda expression from assertThrows arguments.
For assertThrows(Exception.class, () -> code()), we want to extract 'code()'.
For assertThrows(Exception.class, () -> { code(); }), we want 'code();'.
"""
# Look for lambda: () -> expr or () -> { block }
lambda_match = re.search(r"\(\s*\)\s*->\s*", content)
if not lambda_match:
return None
body_start = lambda_match.end()
remaining = content[body_start:].strip()
if remaining.startswith("{"):
# Block lambda: () -> { code }
_, block_end = self._find_balanced_braces(content, body_start + content[body_start:].index("{"))
if block_end != -1:
# Extract content inside braces
brace_content = content[body_start + content[body_start:].index("{") + 1 : block_end - 1]
return brace_content.strip()
else:
# Expression lambda: () -> expr
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
depth = 0
end = len(content)
for i, ch in enumerate(content[body_start:]):
if ch == "(":
depth += 1
elif ch == ")":
if depth == 0:
end = body_start + i
break
depth -= 1
elif ch == "," and depth == 0:
end = body_start + i
break
return content[body_start:end].strip()
return None
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
"""Find content within balanced parentheses.
Args:
code: The source code.
open_paren_pos: Position of the opening parenthesis.
Returns:
Tuple of (content inside parens, position after closing paren) or (None, -1).
"""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return None, -1
depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None
in_char = False
while pos < len(code) and depth > 0:
char = code[pos]
prev_char = code[pos - 1] if pos > 0 else ""
# Handle character literals
if char == "'" and not in_string and prev_char != "\\":
in_char = not in_char
# Handle string literals (double quotes)
elif char == '"' and not in_char and prev_char != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
pos += 1
if depth != 0:
return None, -1
return code[open_paren_pos + 1 : pos - 1], pos
def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | None, int]:
"""Find content within balanced braces."""
if open_brace_pos >= len(code) or code[open_brace_pos] != "{":
return None, -1
depth = 1
pos = open_brace_pos + 1
in_string = False
string_char = None
in_char = False
while pos < len(code) and depth > 0:
char = code[pos]
prev_char = code[pos - 1] if pos > 0 else ""
if char == "'" and not in_string and prev_char != "\\":
in_char = not in_char
elif char == '"' and not in_char and prev_char != "\\":
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "{":
depth += 1
elif char == "}":
depth -= 1
pos += 1
if depth != 0:
return None, -1
return code[open_brace_pos + 1 : pos - 1], pos
def _generate_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement code for an assertion.
The replacement captures target function return values and removes assertions.
Args:
assertion: The assertion to replace.
Returns:
Replacement code string.
"""
if assertion.is_exception_assertion:
return self._generate_exception_replacement(assertion)
if not assertion.target_calls:
# No target calls found, just comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertion: no target calls found"
# Generate capture statements for each target call
replacements = []
# For the first replacement, use the full leading whitespace
# For subsequent ones, strip leading newlines to avoid extra blank lines
base_indent = assertion.leading_whitespace.lstrip("\n\r")
for i, call in enumerate(assertion.target_calls):
self.invocation_counter += 1
var_name = f"_cf_result{self.invocation_counter}"
if i == 0:
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
else:
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
return "\n".join(replacements)
def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement for assertThrows/assertDoesNotThrow.
Transforms:
assertThrows(Exception.class, () -> calculator.divide(1, 0));
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
When assigned to a variable:
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
To:
IllegalArgumentException ex = null;
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
"""
self.invocation_counter += 1
counter = self.invocation_counter
ws = assertion.leading_whitespace
base_indent = ws.lstrip("\n\r")
# Extract code to run from lambda body or target calls
code_to_run = None
if assertion.lambda_body:
code_to_run = assertion.lambda_body
# Use a direct last-character check instead of .endswith for lower overhead
if code_to_run and code_to_run[-1] != ";":
code_to_run += ";"
# Handle variable assignment: Type var = assertThrows(...)
if assertion.assigned_var_name and assertion.assigned_var_type:
var_type = assertion.assigned_var_type
var_name = assertion.assigned_var_name
if assertion.assertion_method == "assertDoesNotThrow":
if ";" not in assertion.lambda_body.strip():
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
return f"{ws}{code_to_run}"
# For assertThrows with variable assignment, use exception_class if available
exception_type = assertion.exception_class or var_type
return (
f"{ws}{var_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
return (
f"{ws}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
call = assertion.target_calls[0]
return (
f"{ws}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# Fallback: comment out the assertion
return f"{ws}// Removed assertThrows: could not extract callable"
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
"""Transform Java test code by removing assertions and capturing function calls.
This is the main entry point for Java assertion transformation.
Args:
source: The Java test source code.
function_name: Name of the function being tested.
qualified_name: Optional fully qualified name of the function.
Returns:
Transformed source code with assertions replaced by capture statements.
"""
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name)
return transformer.transform(source)
def remove_assertions_from_test(source: str, target_function: FunctionToOptimize) -> str:
"""Remove assertions from test code for the given target function.
This is a convenience wrapper around transform_java_assertions that
takes a FunctionToOptimize object.
Args:
source: The Java test source code.
target_function: The function being optimized.
Returns:
Transformed source code.
"""
return transform_java_assertions(
source=source, function_name=target_function.function_name, qualified_name=target_function.qualified_name
)