fix: infer Java return types in assertion transformer instead of using Object

The assertion transformer always declared `Object _cf_resultN = call()` when
replacing assertions, losing the actual return type. This caused compilation
failures when the result was used in a context expecting a primitive type
(e.g., int, boolean).

Now infers the return type from assertion context:
- assertEquals(int_literal, call()) -> int
- assertTrue/assertFalse(call()) -> boolean
- assertEquals("string", call()) -> String
- Falls back to Object when type can't be determined

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Mohamed Ashraf 2026-02-25 20:19:13 +00:00
parent 39bea14c99
commit 9e5880f032
2 changed files with 181 additions and 55 deletions

View file

@ -894,6 +894,129 @@ class JavaAssertTransformer:
return code[open_brace_pos + 1 : pos - 1], pos
def _infer_return_type(self, assertion: AssertionMatch) -> str:
"""Infer the Java return type from the assertion context.
For assertEquals(expected, actual) patterns, the expected literal determines the type.
For assertTrue/assertFalse, the result is boolean.
Falls back to Object when the type cannot be determined.
"""
method = assertion.assertion_method
# assertTrue/assertFalse always deal with boolean values
if method in ("assertTrue", "assertFalse"):
return "boolean"
# assertNull/assertNotNull — keep Object (reference type)
if method in ("assertNull", "assertNotNull"):
return "Object"
# For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal
if method in JUNIT5_VALUE_ASSERTIONS:
return self._infer_type_from_assertion_args(assertion.original_text, method)
# For fluent assertions (assertThat), type inference is harder — keep Object
return "Object"
# Regex patterns for Java literal type inference
_LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
_INT_LITERAL_RE = re.compile(r"^-?\d+$")
_DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
_FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
_CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
def _infer_type_from_assertion_args(self, original_text: str, method: str) -> str:
"""Infer the return type from assertEquals/assertNotEquals expected value."""
# Extract the args portion from the assertion text
# Pattern: assertXxx( args... )
paren_idx = original_text.find("(")
if paren_idx < 0:
return "Object"
args_str = original_text[paren_idx + 1 :]
# Remove trailing ");", whitespace
args_str = args_str.rstrip()
if args_str.endswith(");"):
args_str = args_str[:-2]
elif args_str.endswith(")"):
args_str = args_str[:-1]
# Split top-level args (respecting parens, strings, generics)
args = self._split_top_level_args(args_str)
if not args:
return "Object"
# assertEquals has (expected, actual) or (expected, actual, message/delta)
# Some overloads have (message, expected, actual) in JUnit 4 but JUnit 5 uses (expected, actual[, message])
# Try the first argument as the expected value
expected = args[0].strip()
return self._type_from_literal(expected)
def _type_from_literal(self, value: str) -> str:
"""Determine the Java type of a literal value."""
if value in ("true", "false"):
return "boolean"
if value == "null":
return "Object"
if self._FLOAT_LITERAL_RE.match(value):
return "float"
if self._DOUBLE_LITERAL_RE.match(value):
return "double"
if self._LONG_LITERAL_RE.match(value):
return "long"
if self._INT_LITERAL_RE.match(value):
return "int"
if self._CHAR_LITERAL_RE.match(value):
return "char"
if value.startswith('"'):
return "String"
# Cast expression like (byte)0, (short)1
cast_match = re.match(r"^\((\w+)\)", value)
if cast_match:
return cast_match.group(1)
return "Object"
def _split_top_level_args(self, args_str: str) -> list[str]:
"""Split assertion arguments at top-level commas, respecting parens/strings/generics."""
args: list[str] = []
depth = 0
current: list[str] = []
i = 0
in_string = False
string_char = ""
while i < len(args_str):
ch = args_str[i]
if in_string:
current.append(ch)
if ch == "\\" and i + 1 < len(args_str):
i += 1
current.append(args_str[i])
elif ch == string_char:
in_string = False
elif ch in ('"', "'"):
in_string = True
string_char = ch
current.append(ch)
elif ch in ("(", "<", "[", "{"):
depth += 1
current.append(ch)
elif ch in (")", ">", "]", "}"):
depth -= 1
current.append(ch)
elif ch == "," and depth == 0:
args.append("".join(current))
current = []
else:
current.append(ch)
i += 1
if current:
args.append("".join(current))
return args
def _generate_replacement(self, assertion: AssertionMatch) -> str:
"""Generate replacement code for an assertion.
@ -912,6 +1035,9 @@ class JavaAssertTransformer:
if not assertion.target_calls:
return ""
# Infer the return type from assertion context to avoid Object→primitive cast errors
return_type = self._infer_return_type(assertion)
# Generate capture statements for each target call
replacements = []
# For the first replacement, use the full leading whitespace
@ -921,9 +1047,9 @@ class JavaAssertTransformer:
self.invocation_counter += 1
var_name = f"_cf_result{self.invocation_counter}"
if i == 0:
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
replacements.append(f"{assertion.leading_whitespace}{return_type} {var_name} = {call.full_call};")
else:
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};")
return "\n".join(replacements)

View file

@ -41,7 +41,7 @@ import static org.junit.Assert.*;
public class BitSetTest {
@Test
public void testGet_IndexZero_ReturnsFalse() {
Object _cf_result1 = instance.get(0);
boolean _cf_result1 = instance.get(0);
}
}
"""
@ -67,7 +67,7 @@ import static org.junit.Assert.*;
public class BitSetTest {
@Test
public void testGet_SetBit_DetectedTrue() {
Object _cf_result1 = bs.get(67);
boolean _cf_result1 = bs.get(67);
}
}
"""
@ -93,7 +93,7 @@ import static org.junit.Assert.*;
public class FibonacciTest {
@Test
public void testFibonacci() {
Object _cf_result1 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(10);
}
}
"""
@ -121,7 +121,7 @@ public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
Object _cf_result1 = calc.add(2, 2);
int _cf_result1 = calc.add(2, 2);
}
}
"""
@ -199,7 +199,7 @@ import static org.junit.Assert.*;
public class CalculatorTest {
@Test
public void testSubtract() {
Object _cf_result1 = calc.subtract(5, 3);
int _cf_result1 = calc.subtract(5, 3);
}
}
"""
@ -251,7 +251,7 @@ import org.junit.Assert;
public class CalculatorTest {
@Test
public void testAdd() {
Object _cf_result1 = calc.add(2, 2);
int _cf_result1 = calc.add(2, 2);
}
}
"""
@ -298,9 +298,9 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testFibonacci() {
Object _cf_result1 = Fibonacci.fibonacci(0);
Object _cf_result2 = Fibonacci.fibonacci(1);
Object _cf_result3 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result2 = Fibonacci.fibonacci(1);
int _cf_result3 = Fibonacci.fibonacci(10);
}
}
"""
@ -326,7 +326,7 @@ import org.junit.jupiter.api.Assertions;
public class FibonacciTest {
@Test
void testFibonacci() {
Object _cf_result1 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(10);
}
}
"""
@ -485,7 +485,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testIsFibonacci() {
Object _cf_result1 = Fibonacci.isFibonacci(5);
boolean _cf_result1 = Fibonacci.isFibonacci(5);
}
}
"""
@ -511,7 +511,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testIsNotFibonacci() {
Object _cf_result1 = Fibonacci.isFibonacci(4);
boolean _cf_result1 = Fibonacci.isFibonacci(4);
}
}
"""
@ -709,7 +709,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testConsecutive() {
Object _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6));
boolean _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6));
}
}
"""
@ -739,11 +739,11 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testMultiple() {
Object _cf_result1 = Fibonacci.fibonacci(0);
Object _cf_result2 = Fibonacci.fibonacci(1);
Object _cf_result3 = Fibonacci.fibonacci(2);
Object _cf_result4 = Fibonacci.fibonacci(3);
Object _cf_result5 = Fibonacci.fibonacci(5);
int _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result2 = Fibonacci.fibonacci(1);
int _cf_result3 = Fibonacci.fibonacci(2);
int _cf_result4 = Fibonacci.fibonacci(3);
int _cf_result5 = Fibonacci.fibonacci(5);
}
}
"""
@ -774,7 +774,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class SetupTest {
@Test
void testSetup() {
Object _cf_result1 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(10);
}
}
"""
@ -829,7 +829,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testFibonacci() {
Object _cf_result1 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(10);
}
}
"""
@ -855,7 +855,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class ParserTest {
@Test
void testParse() {
Object _cf_result1 = parser.parse("input(1)");
String _cf_result1 = parser.parse("input(1)");
}
}
"""
@ -911,7 +911,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testIndex() {
Object _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10));
int _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10));
}
}
"""
@ -937,7 +937,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testUpTo() {
Object _cf_result1 = Fibonacci.fibonacciUpTo(20);
int _cf_result1 = Fibonacci.fibonacciUpTo(20);
}
}
"""
@ -1053,24 +1053,24 @@ public class BitSetTest {
@Test
public void testGet_IndexZero_ReturnsFalse() {
Object _cf_result1 = instance.get(0);
boolean _cf_result1 = instance.get(0);
}
@Test
public void testGet_SpecificIndexWithinRange_ReturnsFalse() {
Object _cf_result2 = instance.get(100);
boolean _cf_result2 = instance.get(100);
}
@Test
public void testGet_LastIndexOfInitialRange_ReturnsFalse() {
int lastIndex = 16 * BitSet.BITS_PER_WORD - 1;
Object _cf_result3 = instance.get(lastIndex);
boolean _cf_result3 = instance.get(lastIndex);
}
@Test
public void testGet_IndexBeyondAllocated_ReturnsFalse() {
int beyond = 16 * BitSet.BITS_PER_WORD;
Object _cf_result4 = instance.get(beyond);
boolean _cf_result4 = instance.get(beyond);
}
@Test(expected = ArrayIndexOutOfBoundsException.class)
@ -1086,22 +1086,22 @@ public class BitSetTest {
long[] words = new long[2];
words[1] = 1L << 3;
wordsField.set(bs, words);
Object _cf_result5 = bs.get(64 + 3);
boolean _cf_result5 = bs.get(64 + 3);
}
@Test
public void testGet_LargeIndexDoesNotThrow_ReturnsFalse() {
Object _cf_result6 = instance.get(Integer.MAX_VALUE);
boolean _cf_result6 = instance.get(Integer.MAX_VALUE);
}
@Test
public void testGet_BitBoundaryWordEdge63_ReturnsFalse() {
Object _cf_result7 = instance.get(63);
boolean _cf_result7 = instance.get(63);
}
@Test
public void testGet_BitBoundaryWordEdge64_ReturnsFalse() {
Object _cf_result8 = instance.get(64);
boolean _cf_result8 = instance.get(64);
}
@Test
@ -1109,7 +1109,7 @@ public class BitSetTest {
int nBits = 1_000_000;
BitSet big = new BitSet(nBits);
int last = nBits - 1;
Object _cf_result9 = big.get(last);
boolean _cf_result9 = big.get(last);
}
}
"""
@ -1240,15 +1240,15 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibTest {
@Test
void testA() {
Object _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result1 = Fibonacci.fibonacci(0);
}
@Test
void testB() {
Object _cf_result2 = Fibonacci.fibonacci(10);
int _cf_result2 = Fibonacci.fibonacci(10);
}
@Test
void testC() {
Object _cf_result3 = Fibonacci.fibonacci(1);
int _cf_result3 = Fibonacci.fibonacci(1);
}
}
"""
@ -1329,7 +1329,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibTest {
@Test
void test() {
Object _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result1 = Fibonacci.fibonacci(0);
}
}
"""
@ -1362,11 +1362,11 @@ import static org.junit.jupiter.api.Assertions.*;
public class CalcTest {
@Test
void test() {
Object _cf_result1 = engine.compute(1);
Object _cf_result2 = engine.compute(2);
Object _cf_result3 = engine.compute(3);
Object _cf_result4 = engine.compute(4);
Object _cf_result5 = engine.compute(5);
int _cf_result1 = engine.compute(1);
int _cf_result2 = engine.compute(2);
int _cf_result3 = engine.compute(3);
int _cf_result4 = engine.compute(4);
int _cf_result5 = engine.compute(5);
}
}
"""
@ -1400,8 +1400,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibTest {
@Test
void test() {
Object _cf_result1 = Fibonacci.fibonacci(0);
Object _cf_result2 = Fibonacci.fibonacci(1);
int _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result2 = Fibonacci.fibonacci(1);
}
}
"""
@ -1461,7 +1461,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibTest {
@Test
void test() {
Object _cf_result1 = Fibonacci.fibonacci(10);
int _cf_result1 = Fibonacci.fibonacci(10);
}
}
"""
@ -1489,8 +1489,8 @@ import static org.junit.jupiter.api.Assertions.*;
public class CalcTest {
@Test
void test() {
Object _cf_result1 = calc.add(1, 2);
Object _cf_result2 = calc.add(3, 4);
int _cf_result1 = calc.add(1, 2);
int _cf_result2 = calc.add(3, 4);
}
}
"""
@ -1546,12 +1546,12 @@ import static org.junit.jupiter.api.Assertions.*;
public class FibTest {
@Test
void test1() {
Object _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result1 = Fibonacci.fibonacci(0);
}
@Test
void test2() {
Object _cf_result2 = Fibonacci.fibonacci(10);
int _cf_result2 = Fibonacci.fibonacci(10);
}
}
"""
@ -1784,9 +1784,9 @@ public class FibonacciTest {
@Test
void testFibonacci() {
Object _cf_result1 = Fibonacci.fibonacci(0);
Object _cf_result2 = Fibonacci.fibonacci(1);
Object _cf_result3 = Fibonacci.fibonacci(5);
int _cf_result1 = Fibonacci.fibonacci(0);
int _cf_result2 = Fibonacci.fibonacci(1);
int _cf_result3 = Fibonacci.fibonacci(5);
}
@Test
@ -1846,7 +1846,7 @@ public class CalcTest {
void testAdd() {
Calculator calc = new Calculator();
int result = calc.setup();
Object _cf_result1 = calc.add(2, 3);
int _cf_result1 = calc.add(2, 3);
}
}
"""
@ -1902,7 +1902,7 @@ import static org.junit.jupiter.api.Assertions.*;
public class MixedTest {
@Test
void test() {
Object _cf_result1 = obj.target(1);
int _cf_result1 = obj.target(1);
}
}
"""