mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix asserts
This commit is contained in:
parent
ca4f01f7c5
commit
83f335ed04
5 changed files with 1457 additions and 62 deletions
|
|
@ -811,12 +811,12 @@ def instrument_generated_java_test(
|
|||
Instrumented test source code.
|
||||
|
||||
"""
|
||||
if not test_code or not test_code.strip():
|
||||
return test_code
|
||||
|
||||
from codeflash.languages.java.remove_asserts import transform_java_assertions
|
||||
|
||||
# For behavior mode, remove assertions and capture function return values
|
||||
# This converts the generated test into a regression test that captures outputs
|
||||
if mode == "behavior":
|
||||
test_code = transform_java_assertions(test_code, function_name, qualified_name)
|
||||
test_code = transform_java_assertions(test_code, function_name, qualified_name)
|
||||
|
||||
# Extract class name from the test code
|
||||
# Use pattern that starts at beginning of line to avoid matching words in comments
|
||||
|
|
@ -827,14 +827,8 @@ def instrument_generated_java_test(
|
|||
|
||||
original_class_name = class_match.group(1)
|
||||
|
||||
# For performance mode, add timing instrumentation
|
||||
# Use original class name (without suffix) in timing markers for consistency with Python
|
||||
if mode == "performance":
|
||||
# Rename class based on mode
|
||||
if mode == "behavior":
|
||||
new_class_name = f"{original_class_name}__perfinstrumented"
|
||||
else:
|
||||
new_class_name = f"{original_class_name}__perfonlyinstrumented"
|
||||
new_class_name = f"{original_class_name}__perfonlyinstrumented"
|
||||
|
||||
# Rename all references to the original class name in the source.
|
||||
# This includes the class declaration, return types, constructor calls, etc.
|
||||
|
|
@ -852,6 +846,8 @@ def instrument_generated_java_test(
|
|||
function_to_optimize=function_to_optimize,
|
||||
test_class_name=original_class_name,
|
||||
)
|
||||
else:
|
||||
modified_code = test_code
|
||||
|
||||
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
|
||||
return modified_code
|
||||
|
|
|
|||
|
|
@ -213,34 +213,24 @@ class JavaAssertTransformer:
|
|||
if not assertions:
|
||||
return source
|
||||
|
||||
# Filter to only assertions that contain target calls
|
||||
assertions_with_targets = [a for a in assertions if a.target_calls or a.is_exception_assertion]
|
||||
|
||||
if not assertions_with_targets:
|
||||
return source
|
||||
|
||||
# Sort by position (forward order) to assign counter numbers in source order
|
||||
assertions_with_targets.sort(key=lambda a: a.start_pos)
|
||||
assertions.sort(key=lambda a: a.start_pos)
|
||||
|
||||
# Filter out nested assertions (e.g., assertEquals inside assertAll)
|
||||
# An assertion is nested if it's completely contained within another assertion
|
||||
non_nested: list[AssertionMatch] = []
|
||||
for i, assertion in enumerate(assertions_with_targets):
|
||||
for i, assertion in enumerate(assertions):
|
||||
is_nested = False
|
||||
for j, other in enumerate(assertions_with_targets):
|
||||
for j, other in enumerate(assertions):
|
||||
if i != j:
|
||||
# Check if 'assertion' is nested inside 'other'
|
||||
if other.start_pos <= assertion.start_pos and assertion.end_pos <= other.end_pos:
|
||||
is_nested = True
|
||||
break
|
||||
if not is_nested:
|
||||
non_nested.append(assertion)
|
||||
|
||||
assertions_with_targets = non_nested
|
||||
|
||||
# Pre-compute all replacements with correct counter values
|
||||
replacements: list[tuple[int, int, str]] = []
|
||||
for assertion in assertions_with_targets:
|
||||
for assertion in non_nested:
|
||||
replacement = self._generate_replacement(assertion)
|
||||
replacements.append((assertion.start_pos, assertion.end_pos, replacement))
|
||||
|
||||
|
|
@ -822,8 +812,7 @@ class JavaAssertTransformer:
|
|||
return self._generate_exception_replacement(assertion)
|
||||
|
||||
if not assertion.target_calls:
|
||||
# No target calls found, just comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertion: no target calls found"
|
||||
return ""
|
||||
|
||||
# Generate capture statements for each target call
|
||||
replacements = []
|
||||
|
|
|
|||
|
|
@ -286,7 +286,6 @@ void testCalculator() {
|
|||
@Test
|
||||
void testCalculator() {
|
||||
Object _cf_result1 = calculator.add(2, 3);
|
||||
assertEquals(6, calculator.multiply(2, 3));
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
|
@ -550,8 +549,13 @@ void testWithVariables() {
|
|||
int actual = calculator.fibonacci(10);
|
||||
assertEquals(expected, actual);
|
||||
}"""
|
||||
# fibonacci is assigned to 'actual', not in the assertion - no transformation
|
||||
expected = source
|
||||
# Variable declarations are preserved, but assertEquals is removed (all assertions removed)
|
||||
expected = """\
|
||||
@Test
|
||||
void testWithVariables() {
|
||||
int expected = 55;
|
||||
int actual = calculator.fibonacci(10);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
|
@ -670,8 +674,11 @@ void testNoAssertions() {
|
|||
void testOther() {
|
||||
assertEquals(5, helper.compute(3));
|
||||
}"""
|
||||
# No transformation since target function is not in the assertion
|
||||
expected = source
|
||||
# All assertions are removed regardless of target function
|
||||
expected = """\
|
||||
@Test
|
||||
void testOther() {
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
|
|
@ -912,9 +919,13 @@ void testBasicCompoundInterest() {
|
|||
assertNotNull(result);
|
||||
assertTrue(result.contains("."));
|
||||
}"""
|
||||
# assertNotNull(result) and assertTrue(result.contains(".")) don't contain the target function
|
||||
# so they remain unchanged, and the variable assignment is also preserved
|
||||
expected = source
|
||||
# All assertions are removed; variable assignment is preserved
|
||||
expected = """\
|
||||
@Test
|
||||
@DisplayName("should calculate compound interest for basic case")
|
||||
void testBasicCompoundInterest() {
|
||||
String result = calculator.calculateCompoundInterest(1000.0, 0.05, 1, 12);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "calculateCompoundInterest")
|
||||
assert result == expected
|
||||
|
||||
|
|
@ -1018,13 +1029,12 @@ void testSynchronizedBlock() {
|
|||
assertTrue(cache.containsKey("key"));
|
||||
}
|
||||
}"""
|
||||
# All assertions are removed; target-containing ones get Object capture
|
||||
expected = """\
|
||||
@Test
|
||||
void testSynchronizedBlock() {
|
||||
synchronized (cache) {
|
||||
Object _cf_result1 = cache.size();
|
||||
assertNotNull(cache.get("key"));
|
||||
assertTrue(cache.containsKey("key"));
|
||||
}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "size")
|
||||
|
|
@ -1210,6 +1220,8 @@ void testCircularBufferOperations() {
|
|||
assertFalse(buffer.isEmpty());
|
||||
assertTrue(buffer.put(2));
|
||||
}"""
|
||||
# All assertions are removed; target-containing ones get Object capture,
|
||||
# non-target assertions (assertTrue(buffer.put(2))) are deleted entirely
|
||||
expected = """\
|
||||
@Test
|
||||
void testCircularBufferOperations() {
|
||||
|
|
@ -1217,25 +1229,9 @@ void testCircularBufferOperations() {
|
|||
Object _cf_result1 = buffer.isEmpty();
|
||||
buffer.put(1);
|
||||
Object _cf_result2 = buffer.isEmpty();
|
||||
Object _cf_result3 = buffer.put(2);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "isEmpty")
|
||||
# isEmpty is target for assertTrue/assertFalse; but put is NOT the target
|
||||
# so only isEmpty calls inside assertions are transformed
|
||||
# Actually: assertTrue(buffer.put(2)) also contains a non-target call
|
||||
# Let's verify what actually happens
|
||||
# put is not "isEmpty", so assertTrue(buffer.put(2)) has no target call -> untouched
|
||||
expected_corrected = """\
|
||||
@Test
|
||||
void testCircularBufferOperations() {
|
||||
CircularBuffer<Integer> buffer = new CircularBuffer<>(3);
|
||||
Object _cf_result1 = buffer.isEmpty();
|
||||
buffer.put(1);
|
||||
Object _cf_result2 = buffer.isEmpty();
|
||||
assertTrue(buffer.put(2));
|
||||
}"""
|
||||
result = transform_java_assertions(source, "isEmpty")
|
||||
assert result == expected_corrected
|
||||
assert result == expected
|
||||
|
||||
def test_concurrent_assertion_with_assertj(self):
|
||||
"""AssertJ assertion on a synchronized method call is correctly transformed."""
|
||||
|
|
@ -1310,12 +1306,12 @@ void testNegativeInput() {
|
|||
);
|
||||
assertEquals("Negative input not allowed", exception.getMessage());
|
||||
}"""
|
||||
# assertThrows becomes try/catch, and assertEquals after it is also removed
|
||||
expected = """\
|
||||
@Test
|
||||
void testNegativeInput() {
|
||||
IllegalArgumentException exception = null;
|
||||
try { calculator.fibonacci(-1); } catch (IllegalArgumentException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertEquals("Negative input not allowed", exception.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
|
@ -1330,12 +1326,12 @@ void testInvalidOperation() {
|
|||
});
|
||||
assertEquals("Division by zero", ex.getMessage());
|
||||
}"""
|
||||
# assertThrows becomes try/catch, and assertEquals after it is also removed
|
||||
expected = """\
|
||||
@Test
|
||||
void testInvalidOperation() {
|
||||
ArithmeticException ex = null;
|
||||
try { calculator.divide(10, 0); } catch (ArithmeticException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertEquals("Division by zero", ex.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
|
@ -1348,12 +1344,12 @@ void testGenericException() {
|
|||
Exception e = assertThrows(Exception.class, () -> processor.process(null));
|
||||
assertNotNull(e.getMessage());
|
||||
}"""
|
||||
# assertThrows becomes try/catch, and assertNotNull after it is also removed
|
||||
expected = """\
|
||||
@Test
|
||||
void testGenericException() {
|
||||
Exception e = null;
|
||||
try { processor.process(null); } catch (Exception _cf_caught1) { e = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertNotNull(e.getMessage());
|
||||
}"""
|
||||
result = transform_java_assertions(source, "process")
|
||||
assert result == expected
|
||||
|
|
@ -1387,13 +1383,13 @@ void testComplexException() {
|
|||
);
|
||||
assertTrue(exception.getMessage().contains("not initialized"));
|
||||
}"""
|
||||
# assertThrows becomes try/catch, and assertTrue after it is also removed
|
||||
expected = """\
|
||||
@Test
|
||||
void testComplexException() {
|
||||
IllegalStateException exception = null;
|
||||
try { processor.initialize();
|
||||
processor.execute(); } catch (IllegalStateException _cf_caught1) { exception = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
assertTrue(exception.getMessage().contains("not initialized"));
|
||||
}"""
|
||||
result = transform_java_assertions(source, "execute")
|
||||
assert result == expected
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ public class CalculatorTest__perfonlyinstrumented {
|
|||
long _cf_start1 = System.nanoTime();
|
||||
try {
|
||||
Calculator calc = new Calculator();
|
||||
assertEquals(4, calc.add(2, 2));
|
||||
Object _cf_result1 = calc.add(2, 2);
|
||||
} finally {
|
||||
long _cf_end1 = System.nanoTime();
|
||||
long _cf_dur1 = _cf_end1 - _cf_start1;
|
||||
|
|
@ -360,7 +360,6 @@ public class MathTest__perfonlyinstrumented {
|
|||
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
|
||||
long _cf_start1 = System.nanoTime();
|
||||
try {
|
||||
assertEquals(4, add(2, 2));
|
||||
} finally {
|
||||
long _cf_end1 = System.nanoTime();
|
||||
long _cf_dur1 = _cf_end1 - _cf_start1;
|
||||
|
|
@ -382,7 +381,6 @@ public class MathTest__perfonlyinstrumented {
|
|||
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
|
||||
long _cf_start2 = System.nanoTime();
|
||||
try {
|
||||
assertEquals(0, subtract(2, 2));
|
||||
} finally {
|
||||
long _cf_end2 = System.nanoTime();
|
||||
long _cf_dur2 = _cf_end2 - _cf_start2;
|
||||
|
|
@ -1256,7 +1254,7 @@ public class ImportTest__perfonlyinstrumented {
|
|||
long _cf_start1 = System.nanoTime();
|
||||
try {
|
||||
List<String> list = new ArrayList<>();
|
||||
assertEquals(0, list.size());
|
||||
Object _cf_result1 = list.size();
|
||||
} finally {
|
||||
long _cf_end1 = System.nanoTime();
|
||||
long _cf_dur1 = _cf_end1 - _cf_start1;
|
||||
|
|
|
|||
1416
tests/test_languages/test_java/test_remove_asserts.py
Normal file
1416
tests/test_languages/test_java/test_remove_asserts.py
Normal file
File diff suppressed because it is too large
Load diff
Loading…
Reference in a new issue