codeflash/codeflash/languages/java/remove_asserts.py
HeshamHM28 41814cd24b feat: support void method optimization in Java pipeline
Discover void methods, instrument them by serializing the receiver instead
of a return value, and treat all-null comparisons as equivalent.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 11:55:10 +02:00

1309 lines
48 KiB
Python

"""Java assertion removal transformer for converting tests to regression tests.
This module removes assertion statements from Java test code while preserving
function calls, enabling behavioral verification by comparing outputs between
original and optimized code.
Supported frameworks:
- JUnit 5 (Jupiter): assertEquals, assertTrue, assertThrows, etc.
- JUnit 4: org.junit.Assert.*
- AssertJ: assertThat(...).isEqualTo(...)
- TestNG: org.testng.Assert.*
- Hamcrest: assertThat(actual, is(expected))
- Truth: assertThat(actual).isEqualTo(expected)
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from codeflash.languages.java.parser import get_java_analyzer
if TYPE_CHECKING:
from tree_sitter import Node
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.parser import JavaAnalyzer
_ASSIGN_RE = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
logger = logging.getLogger(__name__)
# JUnit 5 assertion methods that take (expected, actual, ...) or (actual, ...)
JUNIT5_VALUE_ASSERTIONS = frozenset(
{
"assertEquals",
"assertNotEquals",
"assertSame",
"assertNotSame",
"assertArrayEquals",
"assertIterableEquals",
"assertLinesMatch",
}
)
# JUnit 5 assertions that take a single boolean/object argument
JUNIT5_CONDITION_ASSERTIONS = frozenset({"assertTrue", "assertFalse", "assertNull", "assertNotNull"})
# JUnit 5 assertions that handle exceptions (need special treatment)
JUNIT5_EXCEPTION_ASSERTIONS = frozenset({"assertThrows", "assertDoesNotThrow"})
# JUnit 5 timeout assertions
JUNIT5_TIMEOUT_ASSERTIONS = frozenset({"assertTimeout", "assertTimeoutPreemptively"})
# JUnit 5 grouping assertion
JUNIT5_GROUP_ASSERTIONS = frozenset({"assertAll"})
# All JUnit 5 assertions
JUNIT5_ALL_ASSERTIONS = (
JUNIT5_VALUE_ASSERTIONS
| JUNIT5_CONDITION_ASSERTIONS
| JUNIT5_EXCEPTION_ASSERTIONS
| JUNIT5_TIMEOUT_ASSERTIONS
| JUNIT5_GROUP_ASSERTIONS
)
# AssertJ terminal assertions (methods that end the chain)
ASSERTJ_TERMINAL_METHODS = frozenset(
{
"isEqualTo",
"isNotEqualTo",
"isSameAs",
"isNotSameAs",
"isNull",
"isNotNull",
"isTrue",
"isFalse",
"isEmpty",
"isNotEmpty",
"isBlank",
"isNotBlank",
"contains",
"containsOnly",
"containsExactly",
"containsExactlyInAnyOrder",
"doesNotContain",
"startsWith",
"endsWith",
"matches",
"hasSize",
"hasSizeBetween",
"hasSizeGreaterThan",
"hasSizeLessThan",
"isGreaterThan",
"isGreaterThanOrEqualTo",
"isLessThan",
"isLessThanOrEqualTo",
"isBetween",
"isCloseTo",
"isPositive",
"isNegative",
"isZero",
"isNotZero",
"isInstanceOf",
"isNotInstanceOf",
"isIn",
"isNotIn",
"containsKey",
"containsKeys",
"containsValue",
"containsValues",
"containsEntry",
"hasFieldOrPropertyWithValue",
"extracting",
"satisfies",
"doesNotThrow",
}
)
# Hamcrest matcher methods
HAMCREST_MATCHERS = frozenset(
{
"is",
"equalTo",
"not",
"nullValue",
"notNullValue",
"hasItem",
"hasItems",
"hasSize",
"containsString",
"startsWith",
"endsWith",
"greaterThan",
"lessThan",
"closeTo",
"instanceOf",
"anything",
"allOf",
"anyOf",
}
)
@dataclass
class TargetCall:
"""Represents a method call that should be captured."""
receiver: str | None # 'calc', 'algorithms' (None for static)
method_name: str
arguments: str
full_call: str # 'calc.fibonacci(10)'
start_pos: int
end_pos: int
@dataclass
class AssertionMatch:
"""Represents a matched assertion statement."""
start_pos: int
end_pos: int
statement_type: str # 'junit5', 'assertj', 'junit4', 'testng', 'hamcrest'
assertion_method: str
target_calls: list[TargetCall] = field(default_factory=list)
leading_whitespace: str = ""
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
class JavaAssertTransformer:
"""Transforms Java test code by removing assertions and preserving function calls.
This class uses tree-sitter for AST-based analysis and regex for text manipulation.
It handles various Java testing frameworks including JUnit 5, JUnit 4, AssertJ,
TestNG, Hamcrest, and Truth.
"""
def __init__(
self,
function_name: str,
qualified_name: str | None = None,
analyzer: JavaAnalyzer | None = None,
mode: str = "capture",
target_return_type: str = "",
) -> None:
self.analyzer = analyzer or get_java_analyzer()
self.func_name = function_name
self.qualified_name = qualified_name or function_name
self.invocation_counter = 0
self._detected_framework: str | None = None
self.mode = mode # "capture" (default, instrumentation) or "strip" (clean display)
self.target_return_type = target_return_type
# Precompile the assignment-detection regex to avoid recompiling on each call.
self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
# Precompile regex to find next special character (quotes, parens, braces).
self._special_re = re.compile(r"[\"'{}()]")
# Precompile literal/cast regexes to avoid recompilation on each literal check.
self._LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
self._INT_LITERAL_RE = re.compile(r"^-?\d+$")
self._DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
self._FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
self._CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
self._cast_re = re.compile(r"^\((\w+)\)")
def transform(self, source: str) -> str:
"""Remove assertions from source code, preserving target function calls.
Args:
source: Java source code containing test assertions.
Returns:
Transformed source with assertions replaced by captured function calls.
"""
if not source or not source.strip():
return source
# Detect framework from imports
self._detected_framework = self._detect_framework(source)
# Find all assertion statements
assertions = self._find_assertions(source)
if not assertions:
return source
# Sort by position (forward order) to assign counter numbers in source order
assertions.sort(key=lambda a: a.start_pos)
# Filter out nested assertions (e.g., assertEquals inside assertAll)
non_nested: list[AssertionMatch] = []
max_end = -1
for assertion in assertions:
# If any previous assertion ends at or after this one's end, this is nested.
if max_end >= assertion.end_pos:
continue
non_nested.append(assertion)
max_end = max(max_end, assertion.end_pos)
# Pre-compute all replacements with correct counter values
# Pre-compute all replacements with correct counter values
replacements: list[tuple[int, int, str]] = []
for assertion in non_nested:
replacement = self._generate_replacement(assertion)
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
# Apply replacements in ascending order by assembling parts to avoid repeated slicing.
if not replacements:
return source
parts: list[str] = []
prev = 0
for start_pos, end_pos, replacement in replacements:
parts.append(source[prev:start_pos])
parts.append(replacement)
prev = end_pos
parts.append(source[prev:])
return "".join(parts)
def _detect_framework(self, source: str) -> str:
"""Detect which testing framework is being used from imports.
Checks more specific frameworks first (AssertJ, Hamcrest) before
falling back to generic JUnit.
"""
imports = self.analyzer.find_imports(source)
# First pass: check for specific assertion libraries
for imp in imports:
path = imp.import_path.lower()
if "org.assertj" in path:
return "assertj"
if "org.hamcrest" in path:
return "hamcrest"
if "com.google.common.truth" in path:
return "truth"
if "org.testng" in path:
return "testng"
# Second pass: check for JUnit versions
for imp in imports:
path = imp.import_path.lower()
if "org.junit.jupiter" in path or "junit.jupiter" in path:
return "junit5"
if "org.junit" in path:
return "junit4"
# Default to JUnit 5 if no specific imports found
return "junit5"
def _find_assertions(self, source: str) -> list[AssertionMatch]:
"""Find all assertion statements in the source code."""
assertions: list[AssertionMatch] = []
# Find JUnit-style assertions
assertions.extend(self._find_junit_assertions(source))
# Find AssertJ/Truth-style fluent assertions
assertions.extend(self._find_fluent_assertions(source))
# Find Hamcrest assertions
assertions.extend(self._find_hamcrest_assertions(source))
return assertions
def _find_junit_assertions(self, source: str) -> list[AssertionMatch]:
"""Find JUnit 4/5 and TestNG style assertions."""
assertions: list[AssertionMatch] = []
# Pattern for JUnit assertions: (Assert.|Assertions.)?assertXxx(...)
# This handles both static imports and qualified calls:
# - assertEquals (static import)
# - Assert.assertEquals (JUnit 4)
# - Assertions.assertEquals (JUnit 5)
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile(rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
full_method = match.group(2)
assertion_method = match.group(3)
# Find the complete assertion statement (balanced parens)
start_pos = match.start()
paren_start = match.end() - 1 # Position of opening paren
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon after closing paren
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from the arguments
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For exception assertions, extract the lambda body
lambda_body = None
exception_class = None
if is_exception:
lambda_body = self._extract_lambda_body(args_content)
# Extract exception class specifically for assertThrows
if assertion_method == "assertThrows":
exception_class = self._extract_exception_class(args_content)
# Check if assertion is assigned to a variable
# Detect variable assignment: Type var = assertXxx(...)
# This applies to all assertions (assertThrows, assertTimeout, etc.)
assigned_var_type = None
assigned_var_name = None
original_text = source[start_pos:end_pos]
before = source[:start_pos]
last_nl_idx = before.rfind("\n")
if last_nl_idx >= 0:
line_prefix = source[last_nl_idx + 1 : start_pos]
else:
line_prefix = source[:start_pos]
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
if var_match:
if last_nl_idx >= 0:
start_pos = last_nl_idx
leading_ws = "\n" + var_match.group(1)
else:
start_pos = 0
leading_ws = var_match.group(1)
assigned_var_type = var_match.group(2)
assigned_var_name = var_match.group(3)
original_text = source[start_pos:end_pos] # Update with adjusted range
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
if "jupiter" in detected or detected == "junit5":
stmt_type = "junit5"
else:
stmt_type = detected
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method=assertion_method,
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
assigned_var_type=assigned_var_type,
assigned_var_name=assigned_var_name,
exception_class=exception_class,
)
)
return assertions
def _find_fluent_assertions(self, source: str) -> list[AssertionMatch]:
"""Find AssertJ and Truth style fluent assertions (assertThat chains)."""
assertions: list[AssertionMatch] = []
# Pattern for fluent assertions: assertThat(...).<chain>
# Handles both org.assertj and com.google.common.truth
pattern = re.compile(r"(\s*)((?:Assertions?\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
# Find assertThat(...) content
args_content, after_paren = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Find the assertion chain (e.g., .isEqualTo(5).hasSize(3))
chain_end = self._find_fluent_chain_end(source, after_paren)
if chain_end == after_paren:
# No chain found, skip
continue
# Check for semicolon
end_pos = chain_end
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# Extract target calls from assertThat argument
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "assertj"
stmt_type = "assertj" if "assertj" in detected else "truth"
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_hamcrest_assertions(self, source: str) -> list[AssertionMatch]:
"""Find Hamcrest style assertions: assertThat(actual, matcher)."""
assertions: list[AssertionMatch] = []
if self._detected_framework != "hamcrest":
return assertions
# Pattern for Hamcrest: assertThat(actual, is(...)) or assertThat(reason, actual, matcher)
pattern = re.compile(r"(\s*)((?:MatcherAssert\.)?assertThat)\s*\(", re.MULTILINE)
for match in pattern.finditer(source):
leading_ws = match.group(1)
start_pos = match.start()
paren_start = match.end() - 1
args_content, end_pos = self._find_balanced_parens(source, paren_start)
if args_content is None:
continue
# Check for semicolon
while end_pos < len(source) and source[end_pos] in " \t\n\r":
end_pos += 1
if end_pos < len(source) and source[end_pos] == ";":
end_pos += 1
# For Hamcrest, the first arg (or second if reason given) is the actual value
target_calls = self._extract_target_calls(args_content, match.end())
original_text = source[start_pos:end_pos]
assertions.append(
AssertionMatch(
start_pos=start_pos,
end_pos=end_pos,
statement_type="hamcrest",
assertion_method="assertThat",
target_calls=target_calls,
leading_whitespace=leading_ws,
original_text=original_text,
)
)
return assertions
def _find_fluent_chain_end(self, source: str, start_pos: int) -> int:
"""Find the end of a fluent assertion chain."""
pos = start_pos
while pos < len(source):
# Skip whitespace
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
if pos >= len(source) or source[pos] != ".":
break
pos += 1 # Skip dot
# Skip whitespace after dot
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Read method name
method_start = pos
while pos < len(source) and (source[pos].isalnum() or source[pos] == "_"):
pos += 1
if pos == method_start:
break
method_name = source[method_start:pos]
# Skip whitespace before potential parens
while pos < len(source) and source[pos] in " \t\n\r":
pos += 1
# Check for parentheses
if pos < len(source) and source[pos] == "(":
_, new_pos = self._find_balanced_parens(source, pos)
if new_pos == -1:
break
pos = new_pos
# Check if this is a terminal assertion method
if method_name in ASSERTJ_TERMINAL_METHODS:
# Continue looking for chained assertions
continue
return pos
# Wrapper template to make assertion argument fragments parseable by tree-sitter.
# e.g. content "55, obj.fibonacci(10)" becomes "class _D { void _m() { _d(55, obj.fibonacci(10)); } }"
_TS_WRAPPER_PREFIX = "class _D { void _m() { _d("
_TS_WRAPPER_SUFFIX = "); } }"
_TS_WRAPPER_PREFIX_BYTES = _TS_WRAPPER_PREFIX.encode("utf8")
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
"""Find all calls to the target function within assertion argument text using tree-sitter."""
if not content or not content.strip():
return []
content_bytes = content.encode("utf8")
wrapper_bytes = self._TS_WRAPPER_PREFIX_BYTES + content_bytes + self._TS_WRAPPER_SUFFIX.encode("utf8")
tree = self.analyzer.parse(wrapper_bytes)
results: list[TargetCall] = []
self._collect_target_invocations(tree.root_node, wrapper_bytes, content_bytes, base_offset, results)
return results
def _collect_target_invocations(
self,
node: Node,
wrapper_bytes: bytes,
content_bytes: bytes,
base_offset: int,
out: list[TargetCall],
seen_top_level: set[tuple[int, int]] | None = None,
) -> None:
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name.
When a target call is nested inside another function call within an assertion argument,
the entire top-level expression is captured instead of just the target call, preserving
surrounding function calls.
"""
if seen_top_level is None:
seen_top_level = set()
prefix_len = len(self._TS_WRAPPER_PREFIX_BYTES)
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.analyzer.get_node_text(name_node, wrapper_bytes) == self.func_name:
top_node = self._find_top_level_arg_node(node, wrapper_bytes)
if top_node is not None:
range_key = (top_node.start_byte, top_node.end_byte)
if range_key not in seen_top_level:
seen_top_level.add(range_key)
start = top_node.start_byte - prefix_len
end = top_node.end_byte - prefix_len
if start >= 0 and end <= len(content_bytes):
full_call = self.analyzer.get_node_text(top_node, wrapper_bytes)
start_char = len(content_bytes[:start].decode("utf8"))
end_char = len(content_bytes[:end].decode("utf8"))
out.append(
TargetCall(
receiver=None,
method_name=self.func_name,
arguments="",
full_call=full_call,
start_pos=base_offset + start_char,
end_pos=base_offset + end_char,
)
)
else:
start = node.start_byte - prefix_len
end = node.end_byte - prefix_len
if start >= 0 and end <= len(content_bytes):
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
return
for child in node.children:
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
def _build_target_call(
self, node: Node, wrapper_bytes: bytes, content_bytes: bytes, start_byte: int, end_byte: int, base_offset: int
) -> TargetCall:
"""Build a TargetCall from a tree-sitter method_invocation node."""
object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
if args_node:
args_text = wrapper_bytes[args_node.start_byte : args_node.end_byte].decode("utf8")
else:
args_text = ""
# argument_list node includes parens, strip them
if args_text and args_text[0] == "(" and args_text[-1] == ")":
args_text = args_text[1:-1]
# Byte offsets -> char offsets for correct Python string indexing using analyzer mapping
start_char = self.analyzer.byte_to_char_index(start_byte, content_bytes)
end_char = self.analyzer.byte_to_char_index(end_byte, content_bytes)
# Extract receiver and full call text from the wrapper bytes directly (fast for small wrappers)
receiver_text = (
wrapper_bytes[object_node.start_byte : object_node.end_byte].decode("utf8") if object_node else None
)
full_call_text = wrapper_bytes[node.start_byte : node.end_byte].decode("utf8")
return TargetCall(
receiver=receiver_text,
method_name=self.func_name,
arguments=args_text,
full_call=full_call_text,
start_pos=base_offset + start_char,
end_pos=base_offset + end_char,
)
def _find_top_level_arg_node(self, target_node: Node, wrapper_bytes: bytes) -> Node | None:
"""Find the top-level argument expression containing a nested target call.
Walks up the AST from target_node to the wrapper _d() call's argument_list.
Only considers the target as nested if it passes through the argument_list of
a regular (non-assertion) function call. Assertion methods (assertEquals, etc.)
and non-argument relationships (method chains like .size()) are not counted.
Returns the top-level expression node if the target is nested inside a regular
function call, or None if the target is direct.
"""
current = target_node
passed_through_regular_call = False
while current.parent is not None:
parent = current.parent
if parent.type == "argument_list" and parent.parent is not None:
grandparent = parent.parent
if grandparent.type == "method_invocation":
gp_name = grandparent.child_by_field_name("name")
if gp_name:
name = self.analyzer.get_node_text(gp_name, wrapper_bytes)
if name == "_d":
if passed_through_regular_call and current != target_node:
return current
return None
if not name.startswith("assert"):
passed_through_regular_call = True
current = current.parent
return None
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
"""Check if assertion is assigned to a variable.
Detects patterns like:
IllegalArgumentException exception = assertThrows(...)
Exception ex = assertThrows(...)
Args:
source: The full source code.
assertion_start: Start position of the assertion.
Returns:
Tuple of (variable_type, variable_name) or (None, None).
"""
# Look backwards from assertion_start to beginning of line
line_start = source.rfind("\n", 0, assertion_start)
if line_start == -1:
line_start = 0
else:
line_start += 1
# Pattern: Type varName = assertXxx(...)
# Handle generic types: Type<Generic> varName = ...
match = self._assign_re.search(source, line_start, assertion_start)
if match:
var_type = match.group(1).strip()
var_name = match.group(2).strip()
return var_type, var_name
return None, None
def _extract_exception_class(self, args_content: str) -> str | None:
"""Extract exception class from assertThrows arguments.
Args:
args_content: Content inside assertThrows parentheses.
Returns:
Exception class name (e.g., "IllegalArgumentException") or None.
Example:
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
"""
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
# Split by comma, but respect nested parentheses and generics
depth = 0
current = []
parts = []
for char in args_content:
if char in "(<":
depth += 1
current.append(char)
elif char in ")>":
depth -= 1
current.append(char)
elif char == "," and depth == 0:
parts.append("".join(current).strip())
current = []
else:
current.append(char)
if current:
parts.append("".join(current).strip())
if parts:
exception_arg = parts[0].strip()
# Remove .class suffix
if exception_arg.endswith(".class"):
return exception_arg[:-6].strip()
return None
def _extract_lambda_body(self, content: str) -> str | None:
"""Extract the body of a lambda expression from assertThrows arguments.
For assertThrows(Exception.class, () -> code()), we want to extract 'code()'.
For assertThrows(Exception.class, () -> { code(); }), we want 'code();'.
"""
# Look for lambda: () -> expr or () -> { block }
lambda_match = re.search(r"\(\s*\)\s*->\s*", content)
if not lambda_match:
return None
body_start = lambda_match.end()
remaining = content[body_start:].strip()
if remaining.startswith("{"):
# Block lambda: () -> { code }
_, block_end = self._find_balanced_braces(content, body_start + content[body_start:].index("{"))
if block_end != -1:
# Extract content inside braces
brace_content = content[body_start + content[body_start:].index("{") + 1 : block_end - 1]
return brace_content.strip()
else:
# Expression lambda: () -> expr
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
depth = 0
end = len(content)
for i, ch in enumerate(content[body_start:]):
if ch == "(":
depth += 1
elif ch == ")":
if depth == 0:
end = body_start + i
break
depth -= 1
elif ch == "," and depth == 0:
end = body_start + i
break
return content[body_start:end].strip()
return None
def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | None, int]:
"""Find content within balanced parentheses.
Args:
code: The source code.
open_paren_pos: Position of the opening parenthesis.
Returns:
Tuple of (content inside parens, position after closing paren) or (None, -1).
"""
if open_paren_pos >= len(code) or code[open_paren_pos] != "(":
return None, -1
end = len(code)
depth = 1
pos = open_paren_pos + 1
in_string = False
string_char = None
in_char = False
while depth > 0:
m = self._special_re.search(code, pos)
if m is None:
return None, -1
i = m.start()
char = m.group()
escaped = i > 0 and code[i - 1] == "\\"
# Handle character literals
if char == "'" and not in_string and not escaped:
in_char = not in_char
# Handle string literals (double quotes)
elif char == '"' and not in_char and not escaped:
if not in_string:
in_string = True
string_char = char
elif char == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if char == "(":
depth += 1
elif char == ")":
depth -= 1
pos = i + 1
return code[open_paren_pos + 1 : pos - 1], pos
def _find_balanced_braces(self, code: str, open_brace_pos: int) -> tuple[str | None, int]:
"""Find content within balanced braces."""
if open_brace_pos >= len(code) or code[open_brace_pos] != "{":
return None, -1
depth = 1
pos = open_brace_pos + 1
code_len = len(code)
special_re = self._special_re
while pos < code_len and depth > 0:
m = special_re.search(code, pos)
if m is None:
return None, -1
idx = m.start()
char = m.group()
prev_char = code[idx - 1] if idx > 0 else ""
if char == "'" and prev_char != "\\":
j = code.find("'", idx + 1)
while j != -1 and j > 0 and code[j - 1] == "\\":
j = code.find("'", j + 1)
if j == -1:
return None, -1
pos = j + 1
continue
if char == '"' and prev_char != "\\":
j = code.find('"', idx + 1)
while j != -1 and j > 0 and code[j - 1] == "\\":
j = code.find('"', j + 1)
if j == -1:
return None, -1
pos = j + 1
continue
if char == "{":
depth += 1
elif char == "}":
depth -= 1
pos = idx + 1
if depth != 0:
return None, -1
return code[open_brace_pos + 1 : pos - 1], pos
def _infer_return_type(self, assertion: AssertionMatch) -> str:
"""Infer the Java return type from the assertion context.
For assertEquals(expected, actual) patterns, the expected literal determines the type.
For assertTrue/assertFalse, the result is boolean.
Falls back to Object when the type cannot be determined.
"""
method = assertion.assertion_method
# assertTrue/assertFalse always deal with boolean values
if method in {"assertTrue", "assertFalse"}:
return "boolean"
# assertNull/assertNotNull — keep Object (reference type)
if method in {"assertNull", "assertNotNull"}:
return "Object"
# For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal
if method in JUNIT5_VALUE_ASSERTIONS:
return self._infer_type_from_assertion_args(assertion.original_text, method)
# For fluent assertions (assertThat), type inference is harder — keep Object
return "Object"
# Regex patterns for Java literal type inference
_LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
_INT_LITERAL_RE = re.compile(r"^-?\d+$")
_DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
_FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
_CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
def _infer_type_from_assertion_args(self, original_text: str, method: str) -> str:
"""Infer the return type from assertEquals/assertNotEquals expected value."""
# Extract the args portion from the assertion text
# Pattern: assertXxx( args... )
paren_idx = original_text.find("(")
if paren_idx < 0:
return "Object"
args_str = original_text[paren_idx + 1 :]
# Remove trailing ");", whitespace
args_str = args_str.rstrip()
if args_str.endswith(");"):
args_str = args_str[:-2]
elif args_str.endswith(")"):
args_str = args_str[:-1]
# Fast-path: only extract the first top-level argument instead of splitting all arguments.
first_arg = self._extract_first_arg(args_str)
if not first_arg:
return "Object"
expected = first_arg.strip()
# JUnit 4 has assertEquals(String message, expected, actual) where the first arg is a message.
# If the first arg is a string literal, check if there are 3+ args — if so, the real expected
# value is the second argument, not the message string.
if expected.startswith('"') and method in ("assertEquals", "assertNotEquals"):
all_args = self._split_top_level_args(args_str)
if len(all_args) >= 3:
expected = all_args[1].strip()
return self._type_from_literal(expected)
def _type_from_literal(self, value: str) -> str:
"""Determine the Java type of a literal value."""
if value in ("true", "false"):
return "boolean"
if value == "null":
return "Object"
if self._FLOAT_LITERAL_RE.match(value):
return "float"
if self._DOUBLE_LITERAL_RE.match(value):
return "double"
if self._LONG_LITERAL_RE.match(value):
return "long"
if self._INT_LITERAL_RE.match(value):
return "int"
if self._CHAR_LITERAL_RE.match(value):
return "char"
if value.startswith('"'):
return "String"
# Cast expression like (byte)0, (short)1
cast_match = self._cast_re.match(value)
if cast_match:
return cast_match.group(1)
return "Object"
def _split_top_level_args(self, args_str: str) -> list[str]:
"""Split assertion arguments at top-level commas, respecting parens/strings/generics."""
# Fast-path: if there are no special delimiters that require parsing,
# we can use a simple split which is much faster for common simple cases.
if not self._special_re.search(args_str):
# Preserve original behavior of returning a list with the single unstripped string
# when there are no commas, otherwise split on commas.
if "," in args_str:
return args_str.split(",")
return [args_str]
args: list[str] = []
depth = 0
current: list[str] = []
i = 0
in_string = False
string_char = ""
while i < len(args_str):
ch = args_str[i]
if in_string:
current.append(ch)
if ch == "\\" and i + 1 < len(args_str):
i += 1
current.append(args_str[i])
elif ch == string_char:
in_string = False
elif ch in ('"', "'"):
in_string = True
string_char = ch
current.append(ch)
elif ch in ("(", "<", "[", "{"):
depth += 1
current.append(ch)
elif ch in (")", ">", "]", "}"):
depth -= 1
current.append(ch)
elif ch == "," and depth == 0:
args.append("".join(current))
current = []
else:
current.append(ch)
i += 1
if current:
args.append("".join(current))
return args
def _generate_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement code for an assertion.
The replacement captures target function return values and removes assertions.
Args:
assertion: The assertion to replace.
Returns:
Replacement code string.
"""
if assertion.is_exception_assertion:
return self._generate_exception_replacement(assertion)
if not assertion.target_calls:
return ""
if self.mode == "strip" or self.target_return_type == "void":
return self._generate_strip_replacement(assertion)
# Infer the return type from assertion context to avoid Object→primitive cast errors
return_type = self._infer_return_type(assertion)
# Generate capture statements for each target call
replacements: list[str] = []
# For the first replacement, use the full leading whitespace
# For subsequent ones, strip leading newlines to avoid extra blank lines
leading_ws = assertion.leading_whitespace
base_indent = leading_ws.lstrip("\n\r")
# Use a local counter to minimize attribute write overhead in the loop.
inv = self.invocation_counter
calls = assertion.target_calls
# Handle first call explicitly to avoid a per-iteration branch
if calls:
inv += 1
var_name = "_cf_result" + str(inv)
replacements.append(f"{leading_ws}{return_type} {var_name} = {calls[0].full_call};")
# Handle remaining calls
for call in calls[1:]:
inv += 1
var_name = "_cf_result" + str(inv)
replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};")
# Write back the counter
self.invocation_counter = inv
return "\n".join(replacements)
def _generate_strip_replacement(self, assertion: AssertionMatch) -> str:
"""Generate clean replacement for strip mode: bare function calls, no capture variables."""
replacements: list[str] = []
leading_ws = assertion.leading_whitespace
base_indent = leading_ws.lstrip("\n\r")
calls = assertion.target_calls
if calls:
replacements.append(f"{leading_ws}{calls[0].full_call};")
for call in calls[1:]:
replacements.append(f"{base_indent}{call.full_call};")
return "\n".join(replacements)
def _generate_exception_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement for assertThrows/assertDoesNotThrow.
Transforms:
assertThrows(Exception.class, () -> calculator.divide(1, 0));
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
When assigned to a variable:
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
To:
IllegalArgumentException ex = null;
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
In strip mode, exception assertions emit just the lambda body as a bare call
(or try/catch without capture variables).
"""
ws = assertion.leading_whitespace
if self.mode == "strip":
return self._generate_strip_exception_replacement(assertion)
# Increment invocation counter once for this exception handling
inv = self.invocation_counter + 1
self.invocation_counter = inv
counter = inv
base_indent = ws.lstrip("\n\r")
# Extract code to run from lambda body or target calls
code_to_run = None
if assertion.lambda_body:
code_to_run = assertion.lambda_body
# Use a direct last-character check instead of .endswith for lower overhead
if code_to_run and code_to_run[-1] != ";":
code_to_run += ";"
# Handle variable assignment: Type var = assertThrows(...)
if assertion.assigned_var_name and assertion.assigned_var_type:
var_type = assertion.assigned_var_type
var_name = assertion.assigned_var_name
if assertion.assertion_method == "assertDoesNotThrow":
if ";" not in assertion.lambda_body.strip():
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
return f"{ws}{code_to_run}"
# For assertThrows with variable assignment, use exception_class if available
exception_type = assertion.exception_class or var_type
return (
f"{ws}{var_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
return f"{ws}try {{ {code_to_run} }} catch (Exception _cf_ignored{counter}) {{}}"
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
call = assertion.target_calls[0]
return f"{ws}try {{ {call.full_call}; }} catch (Exception _cf_ignored{counter}) {{}}"
# Fallback: comment out the assertion
return f"{ws}// Removed assertThrows: could not extract callable"
def _generate_strip_exception_replacement(self, assertion: AssertionMatch) -> str:
"""Generate clean replacement for exception assertions in strip mode."""
ws = assertion.leading_whitespace
# Extract code to run from lambda body or target calls
if assertion.lambda_body:
code_to_run = assertion.lambda_body.strip()
if code_to_run and code_to_run[-1] != ";":
code_to_run += ";"
exception_type = assertion.exception_class or "Exception"
return f"{ws}try {{ {code_to_run} }} catch ({exception_type} ignored) {{}}"
if assertion.target_calls:
call = assertion.target_calls[0]
return f"{ws}try {{ {call.full_call}; }} catch (Exception ignored) {{}}"
return ""
def _extract_first_arg(self, args_str: str) -> str | None:
"""Extract the first top-level argument from args_str.
This is a lightweight alternative to splitting all top-level arguments;
it stops at the first top-level comma, respects nested delimiters and strings,
and avoids constructing the full argument list for better performance.
"""
n = len(args_str)
i = 0
# skip leading whitespace
while i < n and args_str[i].isspace():
i += 1
if i >= n:
return None
depth = 0
in_string = False
string_char = ""
cur: list[str] = []
while i < n:
ch = args_str[i]
if in_string:
cur.append(ch)
if ch == "\\" and i + 1 < n:
i += 1
cur.append(args_str[i])
elif ch == string_char:
in_string = False
elif ch in ('"', "'"):
in_string = True
string_char = ch
cur.append(ch)
elif ch in ("(", "<", "[", "{"):
depth += 1
cur.append(ch)
elif ch in (")", ">", "]", "}"):
depth -= 1
cur.append(ch)
elif ch == "," and depth == 0:
break
else:
cur.append(ch)
i += 1
# Trim trailing whitespace from the extracted argument
if not cur:
return None
return "".join(cur).rstrip()
def transform_java_assertions(
source: str, function_name: str, qualified_name: str | None = None, target_return_type: str = ""
) -> str:
"""Transform Java test code by removing assertions and capturing function calls.
This is the main entry point for Java assertion transformation.
Args:
source: The Java test source code.
function_name: Name of the function being tested.
qualified_name: Optional fully qualified name of the function.
target_return_type: Return type of the target function (e.g., "void", "int").
Returns:
Transformed source code with assertions replaced by capture statements.
"""
transformer = JavaAssertTransformer(
function_name=function_name, qualified_name=qualified_name, target_return_type=target_return_type
)
return transformer.transform(source)
def strip_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
"""Strip assertions from Java test code for clean display in PRs.
Unlike transform_java_assertions (capture mode), this produces clean output:
- Assertions with target function calls become bare function calls (no capture variables)
- Assertions without target function calls are removed entirely
- Exception assertions become simple try/catch without numbered variables
Args:
source: The Java test source code.
function_name: Name of the function being tested.
qualified_name: Optional fully qualified name of the function.
Returns:
Clean source code suitable for display in PRs.
"""
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name, mode="strip")
return transformer.transform(source)
def remove_assertions_from_test(source: str, target_function: FunctionToOptimize) -> str:
"""Remove assertions from test code for the given target function.
This is a convenience wrapper around transform_java_assertions that
takes a FunctionToOptimize object.
Args:
source: The Java test source code.
target_function: The function being optimized.
Returns:
Transformed source code.
"""
return transform_java_assertions(
source=source, function_name=target_function.function_name, qualified_name=target_function.qualified_name
)