mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix asserts
This commit is contained in:
parent
df5b6a28f7
commit
4740725af7
2 changed files with 188 additions and 11 deletions
|
|
@ -166,6 +166,8 @@ class AssertionMatch:
|
|||
original_text: str = ""
|
||||
is_exception_assertion: bool = False
|
||||
lambda_body: str | None = None # For assertThrows lambda content
|
||||
assigned_var_type: str | None = None # For Type var = assertThrows(...)
|
||||
assigned_var_name: str | None = None
|
||||
|
||||
|
||||
class JavaAssertTransformer:
|
||||
|
|
@ -300,8 +302,11 @@ class JavaAssertTransformer:
|
|||
# - assertEquals (static import)
|
||||
# - Assert.assertEquals (JUnit 4)
|
||||
# - Assertions.assertEquals (JUnit 5)
|
||||
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
|
||||
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
|
||||
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
|
||||
pattern = re.compile(
|
||||
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
|
||||
)
|
||||
|
||||
for match in pattern.finditer(source):
|
||||
leading_ws = match.group(1)
|
||||
|
|
@ -326,13 +331,38 @@ class JavaAssertTransformer:
|
|||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
|
||||
|
||||
# For assertThrows, extract the lambda body
|
||||
# For exception assertions, extract the lambda body
|
||||
lambda_body = None
|
||||
if is_exception and assertion_method == "assertThrows":
|
||||
if is_exception:
|
||||
lambda_body = self._extract_lambda_body(args_content)
|
||||
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
# Detect variable assignment: Type var = assertXxx(...)
|
||||
# This applies to all assertions (assertThrows, assertTimeout, etc.)
|
||||
assigned_var_type = None
|
||||
assigned_var_name = None
|
||||
|
||||
before = source[:start_pos]
|
||||
last_nl_idx = before.rfind("\n")
|
||||
if last_nl_idx >= 0:
|
||||
line_prefix = source[last_nl_idx + 1 : start_pos]
|
||||
else:
|
||||
line_prefix = source[:start_pos]
|
||||
|
||||
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
|
||||
if var_match:
|
||||
if last_nl_idx >= 0:
|
||||
start_pos = last_nl_idx
|
||||
leading_ws = "\n" + var_match.group(1)
|
||||
else:
|
||||
start_pos = 0
|
||||
leading_ws = var_match.group(1)
|
||||
|
||||
assigned_var_type = var_match.group(2)
|
||||
assigned_var_name = var_match.group(3)
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
# Determine statement type based on detected framework
|
||||
detected = self._detected_framework or "junit5"
|
||||
if "jupiter" in detected or detected == "junit5":
|
||||
|
|
@ -351,6 +381,8 @@ class JavaAssertTransformer:
|
|||
original_text=original_text,
|
||||
is_exception_assertion=is_exception,
|
||||
lambda_body=lambda_body,
|
||||
assigned_var_type=assigned_var_type,
|
||||
assigned_var_name=assigned_var_name,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -603,9 +635,9 @@ class JavaAssertTransformer:
|
|||
return brace_content.strip()
|
||||
else:
|
||||
# Expression lambda: () -> expr
|
||||
# Find the end (before the closing paren of assertThrows)
|
||||
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
|
||||
depth = 0
|
||||
end = body_start
|
||||
end = len(content)
|
||||
for i, ch in enumerate(content[body_start:]):
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
|
|
@ -614,6 +646,9 @@ class JavaAssertTransformer:
|
|||
end = body_start + i
|
||||
break
|
||||
depth -= 1
|
||||
elif ch == "," and depth == 0:
|
||||
end = body_start + i
|
||||
break
|
||||
return content[body_start:end].strip()
|
||||
|
||||
return None
|
||||
|
|
@ -745,29 +780,52 @@ class JavaAssertTransformer:
|
|||
To:
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
|
||||
When assigned to a variable:
|
||||
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
To:
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
|
||||
|
||||
"""
|
||||
self.invocation_counter += 1
|
||||
counter = self.invocation_counter
|
||||
ws = assertion.leading_whitespace
|
||||
base_indent = ws.lstrip("\n\r")
|
||||
|
||||
if assertion.lambda_body:
|
||||
# Extract the actual code from the lambda
|
||||
code_to_run = assertion.lambda_body
|
||||
if not code_to_run.endswith(";"):
|
||||
code_to_run += ";"
|
||||
|
||||
# Handle variable assignment: Type var = assertThrows(...)
|
||||
if assertion.assigned_var_name and assertion.assigned_var_type:
|
||||
var_type = assertion.assigned_var_type
|
||||
var_name = assertion.assigned_var_name
|
||||
if assertion.assertion_method == "assertDoesNotThrow":
|
||||
if ";" not in assertion.lambda_body.strip():
|
||||
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
|
||||
return f"{ws}{code_to_run}"
|
||||
return (
|
||||
f"{ws}{var_type} {var_name} = null;\n"
|
||||
f"{base_indent}try {{ {code_to_run} }} "
|
||||
f"catch ({var_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
f"{ws}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||
)
|
||||
|
||||
# If no lambda body found, try to extract from target calls
|
||||
if assertion.target_calls:
|
||||
call = assertion.target_calls[0]
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {call.full_call}; }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
f"{ws}try {{ {call.full_call}; }} "
|
||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||
)
|
||||
|
||||
# Fallback: comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
|
||||
return f"{ws}// Removed assertThrows: could not extract callable"
|
||||
|
||||
|
||||
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -1255,3 +1255,122 @@ void testSynchronizedMethodWithAssertJ() {
|
|||
}"""
|
||||
result = transform_java_assertions(source, "incrementAndGet")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestFullyQualifiedAssertions:
|
||||
"""Tests for fully qualified assertion calls like org.junit.jupiter.api.Assertions.assertXxx."""
|
||||
|
||||
def test_assert_timeout_fully_qualified_with_variable_assignment(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testLargeInput() {
|
||||
Long result = org.junit.jupiter.api.Assertions.assertTimeout(
|
||||
Duration.ofSeconds(1),
|
||||
() -> Fibonacci.fibonacci(100_000)
|
||||
);
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testLargeInput() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(100_000);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_equals_fully_qualified(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
org.junit.jupiter.api.Assertions.assertEquals(5, calc.add(2, 3));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
Object _cf_result1 = calc.add(2, 3);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAssertThrowsVariableAssignment:
|
||||
"""Tests for assertThrows assigned to a variable: Type var = assertThrows(...)."""
|
||||
|
||||
def test_assert_throws_assigned_to_variable(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
Calculator calc = new Calculator();
|
||||
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
assertEquals("Cannot divide by zero", ex.getMessage());
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
Calculator calc = new Calculator();
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
|
||||
assertEquals("Cannot divide by zero", ex.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_assigned_to_variable_block_lambda(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
ArithmeticException ex = assertThrows(ArithmeticException.class, () -> {
|
||||
calculator.divide(1, 0);
|
||||
});
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
ArithmeticException ex = null;
|
||||
try { calculator.divide(1, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; }
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_assigned_with_final_modifier(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_not_assigned_unchanged(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
assertThrows(IllegalArgumentException.class, () -> calculator.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_assigned_with_qualified_assertions(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = Assertions.assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; }
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue