fix asserts

This commit is contained in:
HeshamHM28 2026-02-17 03:29:58 +02:00
parent ca4f01f7c5
commit 83f335ed04
5 changed files with 1457 additions and 62 deletions

View file

@ -811,12 +811,12 @@ def instrument_generated_java_test(
Instrumented test source code.
"""
if not test_code or not test_code.strip():
return test_code
from codeflash.languages.java.remove_asserts import transform_java_assertions
# For behavior mode, remove assertions and capture function return values
# This converts the generated test into a regression test that captures outputs
if mode == "behavior":
test_code = transform_java_assertions(test_code, function_name, qualified_name)
test_code = transform_java_assertions(test_code, function_name, qualified_name)
# Extract class name from the test code
# Use pattern that starts at beginning of line to avoid matching words in comments
@ -827,14 +827,8 @@ def instrument_generated_java_test(
original_class_name = class_match.group(1)
# For performance mode, add timing instrumentation
# Use original class name (without suffix) in timing markers for consistency with Python
if mode == "performance":
# Rename class based on mode
if mode == "behavior":
new_class_name = f"{original_class_name}__perfinstrumented"
else:
new_class_name = f"{original_class_name}__perfonlyinstrumented"
new_class_name = f"{original_class_name}__perfonlyinstrumented"
# Rename all references to the original class name in the source.
# This includes the class declaration, return types, constructor calls, etc.
@ -852,6 +846,8 @@ def instrument_generated_java_test(
function_to_optimize=function_to_optimize,
test_class_name=original_class_name,
)
else:
modified_code = test_code
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
return modified_code

View file

@ -213,34 +213,24 @@ class JavaAssertTransformer:
if not assertions:
return source
# Filter to only assertions that contain target calls
assertions_with_targets = [a for a in assertions if a.target_calls or a.is_exception_assertion]
if not assertions_with_targets:
return source
# Sort by position (forward order) to assign counter numbers in source order
assertions_with_targets.sort(key=lambda a: a.start_pos)
assertions.sort(key=lambda a: a.start_pos)
# Filter out nested assertions (e.g., assertEquals inside assertAll)
# An assertion is nested if it's completely contained within another assertion
non_nested: list[AssertionMatch] = []
for i, assertion in enumerate(assertions_with_targets):
for i, assertion in enumerate(assertions):
is_nested = False
for j, other in enumerate(assertions_with_targets):
for j, other in enumerate(assertions):
if i != j:
# Check if 'assertion' is nested inside 'other'
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
is_nested = True
break
if not is_nested:
non_nested.append(assertion)
assertions_with_targets = non_nested
# Pre-compute all replacements with correct counter values
replacements: list[tuple[int, int, str]] = []
for assertion in assertions_with_targets:
for assertion in non_nested:
replacement = self._generate_replacement(assertion)
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
@ -822,8 +812,7 @@ class JavaAssertTransformer:
return self._generate_exception_replacement(assertion)
if not assertion.target_calls:
# No target calls found, just comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertion: no target calls found"
return ""
# Generate capture statements for each target call
replacements = []

View file

@ -286,7 +286,6 @@ void testCalculator() {
@Test
void testCalculator() {
Object _cf_result1 = calculator.add(2, 3);
assertEquals(6, calculator.multiply(2, 3));
}"""
result = transform_java_assertions(source, "add")
assert result == expected
@ -550,8 +549,13 @@ void testWithVariables() {
int actual = calculator.fibonacci(10);
assertEquals(expected, actual);
}"""
# fibonacci is assigned to 'actual', not in the assertion - no transformation
expected = source
# Variable declarations are preserved, but assertEquals is removed (all assertions removed)
expected = """\
@Test
void testWithVariables() {
int expected = 55;
int actual = calculator.fibonacci(10);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
@ -670,8 +674,11 @@ void testNoAssertions() {
void testOther() {
assertEquals(5, helper.compute(3));
}"""
# No transformation since target function is not in the assertion
expected = source
# All assertions are removed regardless of target function
expected = """\
@Test
void testOther() {
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
@ -912,9 +919,13 @@ void testBasicCompoundInterest() {
assertNotNull(result);
assertTrue(result.contains("."));
}"""
# assertNotNull(result) and assertTrue(result.contains(".")) don't contain the target function
# so they remain unchanged, and the variable assignment is also preserved
expected = source
# All assertions are removed; variable assignment is preserved
expected = """\
@Test
@DisplayName("should calculate compound interest for basic case")
void testBasicCompoundInterest() {
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
}"""
result = transform_java_assertions(source, "calculateCompoundInterest")
assert result == expected
@ -1018,13 +1029,12 @@ void testSynchronizedBlock() {
assertTrue(cache.containsKey("key"));
}
}"""
# All assertions are removed; target-containing ones get Object capture
expected = """\
@Test
void testSynchronizedBlock() {
synchronized (cache) {
Object _cf_result1 = cache.size();
assertNotNull(cache.get("key"));
assertTrue(cache.containsKey("key"));
}
}"""
result = transform_java_assertions(source, "size")
@ -1210,6 +1220,8 @@ void testCircularBufferOperations() {
assertFalse(buffer.isEmpty());
assertTrue(buffer.put(2));
}"""
# All assertions are removed; target-containing ones get Object capture,
# non-target assertions (assertTrue(buffer.put(2))) are deleted entirely
expected = """\
@Test
void testCircularBufferOperations() {
@ -1217,25 +1229,9 @@ void testCircularBufferOperations() {
Object _cf_result1 = buffer.isEmpty();
buffer.put(1);
Object _cf_result2 = buffer.isEmpty();
Object _cf_result3 = buffer.put(2);
}"""
result = transform_java_assertions(source, "isEmpty")
# isEmpty is target for assertTrue/assertFalse; but put is NOT the target
# so only isEmpty calls inside assertions are transformed
# Actually: assertTrue(buffer.put(2)) also contains a non-target call
# Let's verify what actually happens
# put is not "isEmpty", so assertTrue(buffer.put(2)) has no target call -> untouched
expected_corrected = """\
@Test
void testCircularBufferOperations() {
CircularBuffer<Integer> buffer = new CircularBuffer<>(3);
Object _cf_result1 = buffer.isEmpty();
buffer.put(1);
Object _cf_result2 = buffer.isEmpty();
assertTrue(buffer.put(2));
}"""
result = transform_java_assertions(source, "isEmpty")
assert result == expected_corrected
assert result == expected
def test_concurrent_assertion_with_assertj(self):
"""AssertJ assertion on a synchronized method call is correctly transformed."""
@ -1310,12 +1306,12 @@ void testNegativeInput() {
);
assertEquals("Negative input not allowed", exception.getMessage());
}"""
# assertThrows becomes try/catch, and assertEquals after it is also removed
expected = """\
@Test
void testNegativeInput() {
IllegalArgumentException exception = null;
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertEquals("Negative input not allowed", exception.getMessage());
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
@ -1330,12 +1326,12 @@ void testInvalidOperation() {
});
assertEquals("Division by zero", ex.getMessage());
}"""
# assertThrows becomes try/catch, and assertEquals after it is also removed
expected = """\
@Test
void testInvalidOperation() {
ArithmeticException ex = null;
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertEquals("Division by zero", ex.getMessage());
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
@ -1348,12 +1344,12 @@ void testGenericException() {
Exception e = assertThrows(Exception.class, () -> processor.process(null));
assertNotNull(e.getMessage());
}"""
# assertThrows becomes try/catch, and assertNotNull after it is also removed
expected = """\
@Test
void testGenericException() {
Exception e = null;
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertNotNull(e.getMessage());
}"""
result = transform_java_assertions(source, "process")
assert result == expected
@ -1387,13 +1383,13 @@ void testComplexException() {
);
assertTrue(exception.getMessage().contains("not initialized"));
}"""
# assertThrows becomes try/catch, and assertTrue after it is also removed
expected = """\
@Test
void testComplexException() {
IllegalStateException exception = null;
try { processor.initialize();
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
assertTrue(exception.getMessage().contains("not initialized"));
}"""
result = transform_java_assertions(source, "execute")
assert result == expected

View file

@ -295,7 +295,7 @@ public class CalculatorTest__perfonlyinstrumented {
long _cf_start1 = System.nanoTime();
try {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
Object _cf_result1 = calc.add(2, 2);
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
@ -360,7 +360,6 @@ public class MathTest__perfonlyinstrumented {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
assertEquals(4, add(2, 2));
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
@ -382,7 +381,6 @@ public class MathTest__perfonlyinstrumented {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
assertEquals(0, subtract(2, 2));
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
@ -1256,7 +1254,7 @@ public class ImportTest__perfonlyinstrumented {
long _cf_start1 = System.nanoTime();
try {
List<String> list = new ArrayList<>();
assertEquals(0, list.size());
Object _cf_result1 = list.size();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;

File diff suppressed because it is too large Load diff