mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: infer Java return types in assertion transformer instead of using Object
The assertion transformer always declared `Object _cf_resultN = call()` when
replacing assertions, losing the actual return type. This caused compilation
failures when the result was used in a context expecting a primitive type
(e.g., int, boolean).
Now infers the return type from assertion context:
- assertEquals(int_literal, call()) -> int
- assertTrue/assertFalse(call()) -> boolean
- assertEquals("string", call()) -> String
- Falls back to Object when type can't be determined
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
39bea14c99
commit
9e5880f032
2 changed files with 181 additions and 55 deletions
|
|
@ -894,6 +894,129 @@ class JavaAssertTransformer:
|
|||
|
||||
return code[open_brace_pos + 1 : pos - 1], pos
|
||||
|
||||
def _infer_return_type(self, assertion: AssertionMatch) -> str:
|
||||
"""Infer the Java return type from the assertion context.
|
||||
|
||||
For assertEquals(expected, actual) patterns, the expected literal determines the type.
|
||||
For assertTrue/assertFalse, the result is boolean.
|
||||
Falls back to Object when the type cannot be determined.
|
||||
"""
|
||||
method = assertion.assertion_method
|
||||
|
||||
# assertTrue/assertFalse always deal with boolean values
|
||||
if method in ("assertTrue", "assertFalse"):
|
||||
return "boolean"
|
||||
|
||||
# assertNull/assertNotNull — keep Object (reference type)
|
||||
if method in ("assertNull", "assertNotNull"):
|
||||
return "Object"
|
||||
|
||||
# For assertEquals/assertNotEquals/assertSame, try to infer from the expected literal
|
||||
if method in JUNIT5_VALUE_ASSERTIONS:
|
||||
return self._infer_type_from_assertion_args(assertion.original_text, method)
|
||||
|
||||
# For fluent assertions (assertThat), type inference is harder — keep Object
|
||||
return "Object"
|
||||
|
||||
# Regex patterns for Java literal type inference
|
||||
_LONG_LITERAL_RE = re.compile(r"^-?\d+[lL]$")
|
||||
_INT_LITERAL_RE = re.compile(r"^-?\d+$")
|
||||
_DOUBLE_LITERAL_RE = re.compile(r"^-?\d+\.\d*[dD]?$|^-?\d+[dD]$")
|
||||
_FLOAT_LITERAL_RE = re.compile(r"^-?\d+\.?\d*[fF]$")
|
||||
_CHAR_LITERAL_RE = re.compile(r"^'.'$|^'\\.'$")
|
||||
|
||||
def _infer_type_from_assertion_args(self, original_text: str, method: str) -> str:
|
||||
"""Infer the return type from assertEquals/assertNotEquals expected value."""
|
||||
# Extract the args portion from the assertion text
|
||||
# Pattern: assertXxx( args... )
|
||||
paren_idx = original_text.find("(")
|
||||
if paren_idx < 0:
|
||||
return "Object"
|
||||
|
||||
args_str = original_text[paren_idx + 1 :]
|
||||
# Remove trailing ");", whitespace
|
||||
args_str = args_str.rstrip()
|
||||
if args_str.endswith(");"):
|
||||
args_str = args_str[:-2]
|
||||
elif args_str.endswith(")"):
|
||||
args_str = args_str[:-1]
|
||||
|
||||
# Split top-level args (respecting parens, strings, generics)
|
||||
args = self._split_top_level_args(args_str)
|
||||
if not args:
|
||||
return "Object"
|
||||
|
||||
# assertEquals has (expected, actual) or (expected, actual, message/delta)
|
||||
# Some overloads have (message, expected, actual) in JUnit 4 but JUnit 5 uses (expected, actual[, message])
|
||||
# Try the first argument as the expected value
|
||||
expected = args[0].strip()
|
||||
|
||||
return self._type_from_literal(expected)
|
||||
|
||||
def _type_from_literal(self, value: str) -> str:
|
||||
"""Determine the Java type of a literal value."""
|
||||
if value in ("true", "false"):
|
||||
return "boolean"
|
||||
if value == "null":
|
||||
return "Object"
|
||||
if self._FLOAT_LITERAL_RE.match(value):
|
||||
return "float"
|
||||
if self._DOUBLE_LITERAL_RE.match(value):
|
||||
return "double"
|
||||
if self._LONG_LITERAL_RE.match(value):
|
||||
return "long"
|
||||
if self._INT_LITERAL_RE.match(value):
|
||||
return "int"
|
||||
if self._CHAR_LITERAL_RE.match(value):
|
||||
return "char"
|
||||
if value.startswith('"'):
|
||||
return "String"
|
||||
# Cast expression like (byte)0, (short)1
|
||||
cast_match = re.match(r"^\((\w+)\)", value)
|
||||
if cast_match:
|
||||
return cast_match.group(1)
|
||||
return "Object"
|
||||
|
||||
def _split_top_level_args(self, args_str: str) -> list[str]:
|
||||
"""Split assertion arguments at top-level commas, respecting parens/strings/generics."""
|
||||
args: list[str] = []
|
||||
depth = 0
|
||||
current: list[str] = []
|
||||
i = 0
|
||||
in_string = False
|
||||
string_char = ""
|
||||
|
||||
while i < len(args_str):
|
||||
ch = args_str[i]
|
||||
|
||||
if in_string:
|
||||
current.append(ch)
|
||||
if ch == "\\" and i + 1 < len(args_str):
|
||||
i += 1
|
||||
current.append(args_str[i])
|
||||
elif ch == string_char:
|
||||
in_string = False
|
||||
elif ch in ('"', "'"):
|
||||
in_string = True
|
||||
string_char = ch
|
||||
current.append(ch)
|
||||
elif ch in ("(", "<", "[", "{"):
|
||||
depth += 1
|
||||
current.append(ch)
|
||||
elif ch in (")", ">", "]", "}"):
|
||||
depth -= 1
|
||||
current.append(ch)
|
||||
elif ch == "," and depth == 0:
|
||||
args.append("".join(current))
|
||||
current = []
|
||||
else:
|
||||
current.append(ch)
|
||||
i += 1
|
||||
|
||||
if current:
|
||||
args.append("".join(current))
|
||||
return args
|
||||
|
||||
def _generate_replacement(self, assertion: AssertionMatch) -> str:
|
||||
"""Generate replacement code for an assertion.
|
||||
|
||||
|
|
@ -912,6 +1035,9 @@ class JavaAssertTransformer:
|
|||
if not assertion.target_calls:
|
||||
return ""
|
||||
|
||||
# Infer the return type from assertion context to avoid Object→primitive cast errors
|
||||
return_type = self._infer_return_type(assertion)
|
||||
|
||||
# Generate capture statements for each target call
|
||||
replacements = []
|
||||
# For the first replacement, use the full leading whitespace
|
||||
|
|
@ -921,9 +1047,9 @@ class JavaAssertTransformer:
|
|||
self.invocation_counter += 1
|
||||
var_name = f"_cf_result{self.invocation_counter}"
|
||||
if i == 0:
|
||||
replacements.append(f"{assertion.leading_whitespace}Object {var_name} = {call.full_call};")
|
||||
replacements.append(f"{assertion.leading_whitespace}{return_type} {var_name} = {call.full_call};")
|
||||
else:
|
||||
replacements.append(f"{base_indent}Object {var_name} = {call.full_call};")
|
||||
replacements.append(f"{base_indent}{return_type} {var_name} = {call.full_call};")
|
||||
|
||||
return "\n".join(replacements)
|
||||
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ import static org.junit.Assert.*;
|
|||
public class BitSetTest {
|
||||
@Test
|
||||
public void testGet_IndexZero_ReturnsFalse() {
|
||||
Object _cf_result1 = instance.get(0);
|
||||
boolean _cf_result1 = instance.get(0);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -67,7 +67,7 @@ import static org.junit.Assert.*;
|
|||
public class BitSetTest {
|
||||
@Test
|
||||
public void testGet_SetBit_DetectedTrue() {
|
||||
Object _cf_result1 = bs.get(67);
|
||||
boolean _cf_result1 = bs.get(67);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -93,7 +93,7 @@ import static org.junit.Assert.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
public void testFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -121,7 +121,7 @@ public class CalculatorTest {
|
|||
@Test
|
||||
public void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
Object _cf_result1 = calc.add(2, 2);
|
||||
int _cf_result1 = calc.add(2, 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -199,7 +199,7 @@ import static org.junit.Assert.*;
|
|||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testSubtract() {
|
||||
Object _cf_result1 = calc.subtract(5, 3);
|
||||
int _cf_result1 = calc.subtract(5, 3);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -251,7 +251,7 @@ import org.junit.Assert;
|
|||
public class CalculatorTest {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Object _cf_result1 = calc.add(2, 2);
|
||||
int _cf_result1 = calc.add(2, 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -298,9 +298,9 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
Object _cf_result2 = Fibonacci.fibonacci(1);
|
||||
Object _cf_result3 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result2 = Fibonacci.fibonacci(1);
|
||||
int _cf_result3 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -326,7 +326,7 @@ import org.junit.jupiter.api.Assertions;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -485,7 +485,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testIsFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.isFibonacci(5);
|
||||
boolean _cf_result1 = Fibonacci.isFibonacci(5);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -511,7 +511,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testIsNotFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.isFibonacci(4);
|
||||
boolean _cf_result1 = Fibonacci.isFibonacci(4);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -709,7 +709,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testConsecutive() {
|
||||
Object _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6));
|
||||
boolean _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6));
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -739,11 +739,11 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testMultiple() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
Object _cf_result2 = Fibonacci.fibonacci(1);
|
||||
Object _cf_result3 = Fibonacci.fibonacci(2);
|
||||
Object _cf_result4 = Fibonacci.fibonacci(3);
|
||||
Object _cf_result5 = Fibonacci.fibonacci(5);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result2 = Fibonacci.fibonacci(1);
|
||||
int _cf_result3 = Fibonacci.fibonacci(2);
|
||||
int _cf_result4 = Fibonacci.fibonacci(3);
|
||||
int _cf_result5 = Fibonacci.fibonacci(5);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -774,7 +774,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class SetupTest {
|
||||
@Test
|
||||
void testSetup() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -829,7 +829,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -855,7 +855,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class ParserTest {
|
||||
@Test
|
||||
void testParse() {
|
||||
Object _cf_result1 = parser.parse("input(1)");
|
||||
String _cf_result1 = parser.parse("input(1)");
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -911,7 +911,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testIndex() {
|
||||
Object _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10));
|
||||
int _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10));
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -937,7 +937,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibonacciTest {
|
||||
@Test
|
||||
void testUpTo() {
|
||||
Object _cf_result1 = Fibonacci.fibonacciUpTo(20);
|
||||
int _cf_result1 = Fibonacci.fibonacciUpTo(20);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1053,24 +1053,24 @@ public class BitSetTest {
|
|||
|
||||
@Test
|
||||
public void testGet_IndexZero_ReturnsFalse() {
|
||||
Object _cf_result1 = instance.get(0);
|
||||
boolean _cf_result1 = instance.get(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_SpecificIndexWithinRange_ReturnsFalse() {
|
||||
Object _cf_result2 = instance.get(100);
|
||||
boolean _cf_result2 = instance.get(100);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_LastIndexOfInitialRange_ReturnsFalse() {
|
||||
int lastIndex = 16 * BitSet.BITS_PER_WORD - 1;
|
||||
Object _cf_result3 = instance.get(lastIndex);
|
||||
boolean _cf_result3 = instance.get(lastIndex);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_IndexBeyondAllocated_ReturnsFalse() {
|
||||
int beyond = 16 * BitSet.BITS_PER_WORD;
|
||||
Object _cf_result4 = instance.get(beyond);
|
||||
boolean _cf_result4 = instance.get(beyond);
|
||||
}
|
||||
|
||||
@Test(expected = ArrayIndexOutOfBoundsException.class)
|
||||
|
|
@ -1086,22 +1086,22 @@ public class BitSetTest {
|
|||
long[] words = new long[2];
|
||||
words[1] = 1L << 3;
|
||||
wordsField.set(bs, words);
|
||||
Object _cf_result5 = bs.get(64 + 3);
|
||||
boolean _cf_result5 = bs.get(64 + 3);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_LargeIndexDoesNotThrow_ReturnsFalse() {
|
||||
Object _cf_result6 = instance.get(Integer.MAX_VALUE);
|
||||
boolean _cf_result6 = instance.get(Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_BitBoundaryWordEdge63_ReturnsFalse() {
|
||||
Object _cf_result7 = instance.get(63);
|
||||
boolean _cf_result7 = instance.get(63);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGet_BitBoundaryWordEdge64_ReturnsFalse() {
|
||||
Object _cf_result8 = instance.get(64);
|
||||
boolean _cf_result8 = instance.get(64);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -1109,7 +1109,7 @@ public class BitSetTest {
|
|||
int nBits = 1_000_000;
|
||||
BitSet big = new BitSet(nBits);
|
||||
int last = nBits - 1;
|
||||
Object _cf_result9 = big.get(last);
|
||||
boolean _cf_result9 = big.get(last);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1240,15 +1240,15 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibTest {
|
||||
@Test
|
||||
void testA() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
}
|
||||
@Test
|
||||
void testB() {
|
||||
Object _cf_result2 = Fibonacci.fibonacci(10);
|
||||
int _cf_result2 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
@Test
|
||||
void testC() {
|
||||
Object _cf_result3 = Fibonacci.fibonacci(1);
|
||||
int _cf_result3 = Fibonacci.fibonacci(1);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1329,7 +1329,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1362,11 +1362,11 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class CalcTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = engine.compute(1);
|
||||
Object _cf_result2 = engine.compute(2);
|
||||
Object _cf_result3 = engine.compute(3);
|
||||
Object _cf_result4 = engine.compute(4);
|
||||
Object _cf_result5 = engine.compute(5);
|
||||
int _cf_result1 = engine.compute(1);
|
||||
int _cf_result2 = engine.compute(2);
|
||||
int _cf_result3 = engine.compute(3);
|
||||
int _cf_result4 = engine.compute(4);
|
||||
int _cf_result5 = engine.compute(5);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1400,8 +1400,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
Object _cf_result2 = Fibonacci.fibonacci(1);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result2 = Fibonacci.fibonacci(1);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1461,7 +1461,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(10);
|
||||
int _cf_result1 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1489,8 +1489,8 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class CalcTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = calc.add(1, 2);
|
||||
Object _cf_result2 = calc.add(3, 4);
|
||||
int _cf_result1 = calc.add(1, 2);
|
||||
int _cf_result2 = calc.add(3, 4);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1546,12 +1546,12 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class FibTest {
|
||||
@Test
|
||||
void test1() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
}
|
||||
|
||||
@Test
|
||||
void test2() {
|
||||
Object _cf_result2 = Fibonacci.fibonacci(10);
|
||||
int _cf_result2 = Fibonacci.fibonacci(10);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1784,9 +1784,9 @@ public class FibonacciTest {
|
|||
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(0);
|
||||
Object _cf_result2 = Fibonacci.fibonacci(1);
|
||||
Object _cf_result3 = Fibonacci.fibonacci(5);
|
||||
int _cf_result1 = Fibonacci.fibonacci(0);
|
||||
int _cf_result2 = Fibonacci.fibonacci(1);
|
||||
int _cf_result3 = Fibonacci.fibonacci(5);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -1846,7 +1846,7 @@ public class CalcTest {
|
|||
void testAdd() {
|
||||
Calculator calc = new Calculator();
|
||||
int result = calc.setup();
|
||||
Object _cf_result1 = calc.add(2, 3);
|
||||
int _cf_result1 = calc.add(2, 3);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -1902,7 +1902,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
public class MixedTest {
|
||||
@Test
|
||||
void test() {
|
||||
Object _cf_result1 = obj.target(1);
|
||||
int _cf_result1 = obj.target(1);
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in a new issue