mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: handle assertThrows variable assignment in Java instrumentation
When assertThrows was assigned to a variable to validate exception
properties, the transformation generated invalid Java syntax by
replacing the assertThrows call with try-catch while leaving the
variable assignment intact.
Example of invalid output:
IllegalArgumentException e = try { code(); } catch (Exception) {}
This fix detects variable assignments, extracts the exception type
from assertThrows arguments, and generates proper exception capture:
IllegalArgumentException e = null;
try { code(); } catch (IllegalArgumentException _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
df5b6a28f7
commit
e207b83a87
2 changed files with 247 additions and 15 deletions
|
|
@ -166,6 +166,9 @@ class AssertionMatch:
|
|||
original_text: str = ""
|
||||
is_exception_assertion: bool = False
|
||||
lambda_body: str | None = None # For assertThrows lambda content
|
||||
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
|
||||
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
|
||||
exception_class: str | None = None # Exception class from assertThrows args
|
||||
|
||||
|
||||
class JavaAssertTransformer:
|
||||
|
|
@ -326,12 +329,32 @@ class JavaAssertTransformer:
|
|||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
|
||||
|
||||
# For assertThrows, extract the lambda body
|
||||
# For assertThrows, extract the lambda body and exception class
|
||||
lambda_body = None
|
||||
exception_class = None
|
||||
if is_exception and assertion_method == "assertThrows":
|
||||
lambda_body = self._extract_lambda_body(args_content)
|
||||
exception_class = self._extract_exception_class(args_content)
|
||||
|
||||
original_text = source[start_pos:end_pos]
|
||||
# Check if assertion is assigned to a variable
|
||||
var_type, var_name = self._detect_variable_assignment(source, start_pos)
|
||||
|
||||
# If variable assignment detected, adjust start_pos to include the entire line
|
||||
actual_start = start_pos
|
||||
actual_leading_ws = leading_ws
|
||||
if var_type:
|
||||
# Find the start of the line (beginning of variable declaration)
|
||||
line_start = source.rfind("\n", 0, start_pos)
|
||||
if line_start == -1:
|
||||
line_start = 0
|
||||
else:
|
||||
line_start += 1
|
||||
actual_start = line_start
|
||||
# Extract the actual leading whitespace from the start of the line
|
||||
line_content = source[line_start:start_pos]
|
||||
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
|
||||
|
||||
original_text = source[actual_start:end_pos]
|
||||
|
||||
# Determine statement type based on detected framework
|
||||
detected = self._detected_framework or "junit5"
|
||||
|
|
@ -342,15 +365,18 @@ class JavaAssertTransformer:
|
|||
|
||||
assertions.append(
|
||||
AssertionMatch(
|
||||
start_pos=start_pos,
|
||||
start_pos=actual_start,
|
||||
end_pos=end_pos,
|
||||
statement_type=stmt_type,
|
||||
assertion_method=assertion_method,
|
||||
target_calls=target_calls,
|
||||
leading_whitespace=leading_ws,
|
||||
leading_whitespace=actual_leading_ws,
|
||||
original_text=original_text,
|
||||
is_exception_assertion=is_exception,
|
||||
lambda_body=lambda_body,
|
||||
variable_type=var_type,
|
||||
variable_name=var_name,
|
||||
exception_class=exception_class,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -580,6 +606,85 @@ class JavaAssertTransformer:
|
|||
|
||||
return target_calls
|
||||
|
||||
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
|
||||
"""Check if assertion is assigned to a variable.
|
||||
|
||||
Detects patterns like:
|
||||
IllegalArgumentException exception = assertThrows(...)
|
||||
Exception ex = assertThrows(...)
|
||||
|
||||
Args:
|
||||
source: The full source code.
|
||||
assertion_start: Start position of the assertion.
|
||||
|
||||
Returns:
|
||||
Tuple of (variable_type, variable_name) or (None, None).
|
||||
|
||||
"""
|
||||
# Look backwards from assertion_start to beginning of line
|
||||
line_start = source.rfind("\n", 0, assertion_start)
|
||||
if line_start == -1:
|
||||
line_start = 0
|
||||
else:
|
||||
line_start += 1
|
||||
|
||||
line_before_assert = source[line_start:assertion_start]
|
||||
|
||||
# Pattern: Type varName = assertXxx(...)
|
||||
# Handle generic types: Type<Generic> varName = ...
|
||||
pattern = r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$"
|
||||
match = re.search(pattern, line_before_assert)
|
||||
|
||||
if match:
|
||||
var_type = match.group(1).strip()
|
||||
var_name = match.group(2).strip()
|
||||
return var_type, var_name
|
||||
|
||||
return None, None
|
||||
|
||||
def _extract_exception_class(self, args_content: str) -> str | None:
|
||||
"""Extract exception class from assertThrows arguments.
|
||||
|
||||
Args:
|
||||
args_content: Content inside assertThrows parentheses.
|
||||
|
||||
Returns:
|
||||
Exception class name (e.g., "IllegalArgumentException") or None.
|
||||
|
||||
Example:
|
||||
assertThrows(IllegalArgumentException.class, ...) -> "IllegalArgumentException"
|
||||
|
||||
"""
|
||||
# First argument is the exception class reference (e.g., "IllegalArgumentException.class")
|
||||
# Split by comma, but respect nested parentheses and generics
|
||||
depth = 0
|
||||
current = []
|
||||
parts = []
|
||||
|
||||
for char in args_content:
|
||||
if char in "(<":
|
||||
depth += 1
|
||||
current.append(char)
|
||||
elif char in ")>":
|
||||
depth -= 1
|
||||
current.append(char)
|
||||
elif char == "," and depth == 0:
|
||||
parts.append("".join(current).strip())
|
||||
current = []
|
||||
else:
|
||||
current.append(char)
|
||||
|
||||
if current:
|
||||
parts.append("".join(current).strip())
|
||||
|
||||
if parts:
|
||||
exception_arg = parts[0].strip()
|
||||
# Remove .class suffix
|
||||
if exception_arg.endswith(".class"):
|
||||
return exception_arg[:-6].strip()
|
||||
|
||||
return None
|
||||
|
||||
def _extract_lambda_body(self, content: str) -> str | None:
|
||||
"""Extract the body of a lambda expression from assertThrows arguments.
|
||||
|
||||
|
|
@ -745,29 +850,53 @@ class JavaAssertTransformer:
|
|||
To:
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
|
||||
For variable assignments:
|
||||
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
|
||||
To:
|
||||
IllegalArgumentException ex = null;
|
||||
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
|
||||
|
||||
"""
|
||||
self.invocation_counter += 1
|
||||
|
||||
# Extract code to run from lambda body or target calls
|
||||
code_to_run = None
|
||||
if assertion.lambda_body:
|
||||
# Extract the actual code from the lambda
|
||||
code_to_run = assertion.lambda_body
|
||||
if not code_to_run.endswith(";"):
|
||||
code_to_run += ";"
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
|
||||
# If no lambda body found, try to extract from target calls
|
||||
if assertion.target_calls:
|
||||
elif assertion.target_calls:
|
||||
call = assertion.target_calls[0]
|
||||
code_to_run = call.full_call + ";"
|
||||
|
||||
if not code_to_run:
|
||||
# Fallback: comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
|
||||
|
||||
# Check if assertion is assigned to a variable
|
||||
if assertion.variable_name and assertion.variable_type:
|
||||
# Generate proper exception capture with variable assignment
|
||||
exception_type = assertion.exception_class or assertion.variable_type
|
||||
var_name = assertion.variable_name
|
||||
|
||||
# Use a unique catch variable name to avoid conflicts
|
||||
catch_var = f"_cf_caught{self.invocation_counter}"
|
||||
|
||||
# Get base indentation from leading whitespace (without newlines)
|
||||
base_indent = assertion.leading_whitespace.lstrip("\n\r")
|
||||
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
|
||||
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
|
||||
f"{base_indent}try {{ {code_to_run} }} "
|
||||
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
|
||||
# Fallback: comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
|
||||
# No variable assignment, use simple try-catch
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
|
||||
|
||||
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -1255,3 +1255,106 @@ void testSynchronizedMethodWithAssertJ() {
|
|||
}"""
|
||||
result = transform_java_assertions(source, "incrementAndGet")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAssertThrowsVariableAssignment:
|
||||
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
|
||||
|
||||
def test_assert_throws_with_variable_assignment_expression_lambda(self):
|
||||
"""Test assertThrows assigned to variable with expression lambda."""
|
||||
source = """\
|
||||
@Test
|
||||
void testNegativeInput() {
|
||||
IllegalArgumentException exception = assertThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> calculator.fibonacci(-1)
|
||||
);
|
||||
assertEquals("Negative input not allowed", exception.getMessage());
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testNegativeInput() {
|
||||
IllegalArgumentException exception = null;
|
||||
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertEquals("Negative input not allowed", exception.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_with_variable_assignment_block_lambda(self):
|
||||
"""Test assertThrows assigned to variable with block lambda."""
|
||||
source = """\
|
||||
@Test
|
||||
void testInvalidOperation() {
|
||||
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
|
||||
calculator.divide(10, 0);
|
||||
});
|
||||
assertEquals("Division by zero", ex.getMessage());
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testInvalidOperation() {
|
||||
ArithmeticException ex = null;
|
||||
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertEquals("Division by zero", ex.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_with_variable_assignment_generic_exception(self):
|
||||
"""Test assertThrows with generic Exception type."""
|
||||
source = """\
|
||||
@Test
|
||||
void testGenericException() {
|
||||
Exception e = assertThrows(Exception.class, () -> processor.process(null));
|
||||
assertNotNull(e.getMessage());
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testGenericException() {
|
||||
Exception e = null;
|
||||
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertNotNull(e.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "process")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_without_variable_assignment(self):
|
||||
"""Test assertThrows without variable assignment still works (no regression)."""
|
||||
source = """\
|
||||
@Test
|
||||
void testThrowsException() {
|
||||
assertThrows(IllegalArgumentException.class, () -> calculator.fibonacci(-1));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testThrowsException() {
|
||||
try { calculator.fibonacci(-1); } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_with_variable_and_multi_line_lambda(self):
|
||||
"""Test assertThrows with variable assignment and multi-line lambda."""
|
||||
source = """\
|
||||
@Test
|
||||
void testComplexException() {
|
||||
IllegalStateException exception = assertThrows(
|
||||
IllegalStateException.class,
|
||||
() -> {
|
||||
processor.initialize();
|
||||
processor.execute();
|
||||
}
|
||||
);
|
||||
assertTrue(exception.getMessage().contains("not initialized"));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testComplexException() {
|
||||
IllegalStateException exception = null;
|
||||
try { processor.initialize();
|
||||
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertTrue(exception.getMessage().contains("not initialized"));
|
||||
}"""
|
||||
result = transform_java_assertions(source, "execute")
|
||||
assert result == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue