fix: handle assertThrows variable assignment in Java instrumentation

When assertThrows was assigned to a variable to validate exception
properties, the transformation generated invalid Java syntax by
replacing the assertThrows call with try-catch while leaving the
variable assignment intact.

Example of invalid output:
  IllegalArgumentException e = try { code(); } catch (Exception) {}

This fix detects variable assignments, extracts the exception type
from assertThrows arguments, and generates proper exception capture:
  IllegalArgumentException e = null;
  try { code(); } catch (IllegalArgumentException _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Mohamed Ashraf 2026-02-10 21:22:21 +00:00
parent df5b6a28f7
commit e207b83a87
2 changed files with 247 additions and 15 deletions

View file

@ -166,6 +166,9 @@ class AssertionMatch:
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
exception_class: str | None = None # Exception class from assertThrows args
class JavaAssertTransformer:
@ -326,12 +329,32 @@ class JavaAssertTransformer:
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For assertThrows, extract the lambda body
# For assertThrows, extract the lambda body and exception class
lambda_body = None
exception_class = None
if is_exception and assertion_method == "assertThrows":
lambda_body = self._extract_lambda_body(args_content)
exception_class = self._extract_exception_class(args_content)
original_text = source[start_pos:end_pos]
# Check if assertion is assigned to a variable
var_type, var_name = self._detect_variable_assignment(source, start_pos)
# If variable assignment detected, adjust start_pos to include the entire line
actual_start = start_pos
actual_leading_ws = leading_ws
if var_type:
# Find the start of the line (beginning of variable declaration)
line_start = source.rfind("\n", 0, start_pos)
if line_start == -1:
line_start = 0
else:
line_start += 1
actual_start = line_start
# Extract the actual leading whitespace from the start of the line
line_content = source[line_start:start_pos]
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
original_text = source[actual_start:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
@ -342,15 +365,18 @@ class JavaAssertTransformer:
assertions.append(
AssertionMatch(
start_pos=start_pos,
start_pos=actual_start,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method=assertion_method,
target_calls=target_calls,
leading_whitespace=leading_ws,
leading_whitespace=actual_leading_ws,
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
variable_type=var_type,
variable_name=var_name,
exception_class=exception_class,
)
)
@ -580,6 +606,85 @@ class JavaAssertTransformer:
return target_calls
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
"""Check if assertion is assigned to a variable.
Detects patterns like:
IllegalArgumentException exception = assertThrows(...)
Exception ex = assertThrows(...)
Args:
source: The full source code.
assertion_start: Start position of the assertion.
Returns:
Tuple of (variable_type, variable_name) or (None, None).
"""
# Look backwards from assertion_start to beginning of line
line_start = source.rfind("\n", 0, assertion_start)
if line_start == -1:
line_start = 0
else:
line_start += 1
line_before_assert = source[line_start:assertion_start]
# Pattern: Type varName = assertXxx(...)
# Handle generic types: Type<Generic> varName = ...
pattern = r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$"
match = re.search(pattern, line_before_assert)
if match:
var_type = match.group(1).strip()
var_name = match.group(2).strip()
return var_type, var_name
return None, None
def _extract_exception_class(self, args_content: str) -> str | None:
"""Extract exception class from assertThrows arguments.
Args:
args_content: Content inside assertThrows parentheses.
Returns:
Exception class name (e.g., "IllegalArgumentException") or None.
Example:
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
"""
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
# Split by comma, but respect nested parentheses and generics
depth = 0
current = []
parts = []
for char in args_content:
if char in "(<":
depth += 1
current.append(char)
elif char in ")>":
depth -= 1
current.append(char)
elif char == "," and depth == 0:
parts.append("".join(current).strip())
current = []
else:
current.append(char)
if current:
parts.append("".join(current).strip())
if parts:
exception_arg = parts[0].strip()
# Remove .class suffix
if exception_arg.endswith(".class"):
return exception_arg[:-6].strip()
return None
def _extract_lambda_body(self, content: str) -> str | None:
"""Extract the body of a lambda expression from assertThrows arguments.
@ -745,29 +850,53 @@ class JavaAssertTransformer:
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
For variable assignments:
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
To:
IllegalArgumentException ex = null;
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
"""
self.invocation_counter += 1
# Extract code to run from lambda body or target calls
code_to_run = None
if assertion.lambda_body:
# Extract the actual code from the lambda
code_to_run = assertion.lambda_body
if not code_to_run.endswith(";"):
code_to_run += ";"
return (
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
elif assertion.target_calls:
call = assertion.target_calls[0]
code_to_run = call.full_call + ";"
if not code_to_run:
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
# Check if assertion is assigned to a variable
if assertion.variable_name and assertion.variable_type:
# Generate proper exception capture with variable assignment
exception_type = assertion.exception_class or assertion.variable_type
var_name = assertion.variable_name
# Use a unique catch variable name to avoid conflicts
catch_var = f"_cf_caught{self.invocation_counter}"
# Get base indentation from leading whitespace (without newlines)
base_indent = assertion.leading_whitespace.lstrip("\n\r")
return (
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
# No variable assignment, use simple try-catch
return (
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:

View file

@ -1255,3 +1255,106 @@ void testSynchronizedMethodWithAssertJ() {
}"""
result = transform_java_assertions(source, "incrementAndGet")
assert result == expected
class TestAssertThrowsVariableAssignment:
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
def test_assert_throws_with_variable_assignment_expression_lambda(self):
"""Test assertThrows assigned to variable with expression lambda."""
source = """\
@Test
void testNegativeInput() {
IllegalArgumentException exception = assertThrows(
IllegalArgumentException.class,
() -> calculator.fibonacci(-1)
);
assertEquals("Negative input not allowed", exception.getMessage());
}"""
expected = """\
@Test
void testNegativeInput() {
IllegalArgumentException exception = null;
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertEquals("Negative input not allowed", exception.getMessage());
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_throws_with_variable_assignment_block_lambda(self):
"""Test assertThrows assigned to variable with block lambda."""
source = """\
@Test
void testInvalidOperation() {
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
calculator.divide(10, 0);
});
assertEquals("Division by zero", ex.getMessage());
}"""
expected = """\
@Test
void testInvalidOperation() {
ArithmeticException ex = null;
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertEquals("Division by zero", ex.getMessage());
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_with_variable_assignment_generic_exception(self):
"""Test assertThrows with generic Exception type."""
source = """\
@Test
void testGenericException() {
Exception e = assertThrows(Exception.class, () -> processor.process(null));
assertNotNull(e.getMessage());
}"""
expected = """\
@Test
void testGenericException() {
Exception e = null;
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertNotNull(e.getMessage());
}"""
result = transform_java_assertions(source, "process")
assert result == expected
def test_assert_throws_without_variable_assignment(self):
"""Test assertThrows without variable assignment still works (no regression)."""
source = """\
@Test
void testThrowsException() {
assertThrows(IllegalArgumentException.class, () -> calculator.fibonacci(-1));
}"""
expected = """\
@Test
void testThrowsException() {
try { calculator.fibonacci(-1); } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_throws_with_variable_and_multi_line_lambda(self):
"""Test assertThrows with variable assignment and multi-line lambda."""
source = """\
@Test
void testComplexException() {
IllegalStateException exception = assertThrows(
IllegalStateException.class,
() -> {
processor.initialize();
processor.execute();
}
);
assertTrue(exception.getMessage().contains("not initialized"));
}"""
expected = """\
@Test
void testComplexException() {
IllegalStateException exception = null;
try { processor.initialize();
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertTrue(exception.getMessage().contains("not initialized"));
}"""
result = transform_java_assertions(source, "execute")
assert result == expected