fix asserts

This commit is contained in:
HeshamHM28 2026-02-11 01:59:04 +02:00
parent df5b6a28f7
commit 4740725af7
2 changed files with 188 additions and 11 deletions

View file

@ -166,6 +166,8 @@ class AssertionMatch:
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
assigned_var_type: str | None = None # For Type var = assertThrows(...)
assigned_var_name: str | None = None
class JavaAssertTransformer:
@ -300,8 +302,11 @@ class JavaAssertTransformer:
# - assertEquals (static import)
# - Assert.assertEquals (JUnit 4)
# - Assertions.assertEquals (JUnit 5)
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
pattern = re.compile(
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
)
for match in pattern.finditer(source):
leading_ws = match.group(1)
@ -326,13 +331,38 @@ class JavaAssertTransformer:
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For assertThrows, extract the lambda body
# For exception assertions, extract the lambda body
lambda_body = None
if is_exception and assertion_method == "assertThrows":
if is_exception:
lambda_body = self._extract_lambda_body(args_content)
original_text = source[start_pos:end_pos]
# Detect variable assignment: Type var = assertXxx(...)
# This applies to all assertions (assertThrows, assertTimeout, etc.)
assigned_var_type = None
assigned_var_name = None
before = source[:start_pos]
last_nl_idx = before.rfind("\n")
if last_nl_idx >= 0:
line_prefix = source[last_nl_idx + 1 : start_pos]
else:
line_prefix = source[:start_pos]
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
if var_match:
if last_nl_idx >= 0:
start_pos = last_nl_idx
leading_ws = "\n" + var_match.group(1)
else:
start_pos = 0
leading_ws = var_match.group(1)
assigned_var_type = var_match.group(2)
assigned_var_name = var_match.group(3)
original_text = source[start_pos:end_pos]
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
if "jupiter" in detected or detected == "junit5":
@ -351,6 +381,8 @@ class JavaAssertTransformer:
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
assigned_var_type=assigned_var_type,
assigned_var_name=assigned_var_name,
)
)
@ -603,9 +635,9 @@ class JavaAssertTransformer:
return brace_content.strip()
else:
# Expression lambda: () -> expr
# Find the end (before the closing paren of assertThrows)
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
depth = 0
end = body_start
end = len(content)
for i, ch in enumerate(content[body_start:]):
if ch == "(":
depth += 1
@ -614,6 +646,9 @@ class JavaAssertTransformer:
end = body_start + i
break
depth -= 1
elif ch == "," and depth == 0:
end = body_start + i
break
return content[body_start:end].strip()
return None
@ -745,29 +780,52 @@ class JavaAssertTransformer:
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
When assigned to a variable:
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
To:
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
"""
self.invocation_counter += 1
counter = self.invocation_counter
ws = assertion.leading_whitespace
base_indent = ws.lstrip("\n\r")
if assertion.lambda_body:
# Extract the actual code from the lambda
code_to_run = assertion.lambda_body
if not code_to_run.endswith(";"):
code_to_run += ";"
# Handle variable assignment: Type var = assertThrows(...)
if assertion.assigned_var_name and assertion.assigned_var_type:
var_type = assertion.assigned_var_type
var_name = assertion.assigned_var_name
if assertion.assertion_method == "assertDoesNotThrow":
if ";" not in assertion.lambda_body.strip():
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
return f"{ws}{code_to_run}"
return (
f"{ws}{var_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({var_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }}"
)
return (
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
f"{ws}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
call = assertion.target_calls[0]
return (
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
f"{ws}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
return f"{ws}// Removed assertThrows: could not extract callable"
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:

View file

@ -1255,3 +1255,122 @@ void testSynchronizedMethodWithAssertJ() {
}"""
result = transform_java_assertions(source, "incrementAndGet")
assert result == expected
class TestFullyQualifiedAssertions:
"""Tests for fully qualified assertion calls like org.junit.jupiter.api.Assertions.assertXxx."""
def test_assert_timeout_fully_qualified_with_variable_assignment(self):
source = """\
@Test
void testLargeInput() {
Long result = org.junit.jupiter.api.Assertions.assertTimeout(
Duration.ofSeconds(1),
() -> Fibonacci.fibonacci(100_000)
);
}"""
expected = """\
@Test
void testLargeInput() {
Object _cf_result1 = Fibonacci.fibonacci(100_000);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_equals_fully_qualified(self):
source = """\
@Test
void testAdd() {
org.junit.jupiter.api.Assertions.assertEquals(5, calc.add(2, 3));
}"""
expected = """\
@Test
void testAdd() {
Object _cf_result1 = calc.add(2, 3);
}"""
result = transform_java_assertions(source, "add")
assert result == expected
class TestAssertThrowsVariableAssignment:
"""Tests for assertThrows assigned to a variable: Type var = assertThrows(...)."""
def test_assert_throws_assigned_to_variable(self):
source = """\
@Test
void testDivideByZero() {
Calculator calc = new Calculator();
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
assertEquals("Cannot divide by zero", ex.getMessage());
}"""
expected = """\
@Test
void testDivideByZero() {
Calculator calc = new Calculator();
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
assertEquals("Cannot divide by zero", ex.getMessage());
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_assigned_to_variable_block_lambda(self):
source = """\
@Test
void testDivideByZero() {
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
calculator.divide(1, 0);
});
}"""
expected = """\
@Test
void testDivideByZero() {
ArithmeticException ex = null;
try { calculator.divide(1, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; }
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_assigned_with_final_modifier(self):
source = """\
@Test
void testDivideByZero() {
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_not_assigned_unchanged(self):
source = """\
@Test
void testDivideByZero() {
assertThrows(IllegalArgumentException.class, () -> calculator.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_assigned_with_qualified_assertions(self):
source = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = Assertions.assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
}"""
result = transform_java_assertions(source, "divide")
assert result == expected