fix nested targeted function got removed

This commit is contained in:
HeshamHM28 2026-02-17 21:32:37 +02:00
parent 83f335ed04
commit fb6de47c1f
3 changed files with 113 additions and 11 deletions

View file

@ -551,20 +551,50 @@ class JavaAssertTransformer:
def _collect_target_invocations(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
base_offset: int, out: list[TargetCall],
seen_top_level: set[tuple[int, int]] | None = None,
) -> None:
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name."""
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name.
When a target call is nested inside another function call within an assertion argument,
the entire top-level expression is captured instead of just the target call, preserving
surrounding function calls.
"""
if seen_top_level is None:
seen_top_level = set()
prefix_len = len(self._TS_WRAPPER_PREFIX_BYTES)
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.analyzer.get_node_text(name_node, wrapper_bytes) == self.func_name:
top_node = self._find_top_level_arg_node(node, wrapper_bytes)
if top_node is not None:
range_key = (top_node.start_byte, top_node.end_byte)
if range_key not in seen_top_level:
seen_top_level.add(range_key)
start = top_node.start_byte - prefix_len
end = top_node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes):
full_call = self.analyzer.get_node_text(top_node, wrapper_bytes)
start_char = len(content_bytes[:start].decode("utf8"))
end_char = len(content_bytes[:end].decode("utf8"))
out.append(TargetCall(
receiver=None,
method_name=self.func_name,
arguments="",
full_call=full_call,
start_pos=base_offset + start_char,
end_pos=base_offset + end_char,
))
else:
start = node.start_byte - prefix_len
end = node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes):
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
return
for child in node.children:
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out)
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out, seen_top_level)
def _build_target_call(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
@ -593,6 +623,36 @@ class JavaAssertTransformer:
end_pos=base_offset + end_char,
)
def _find_top_level_arg_node(self, target_node, wrapper_bytes: bytes):
"""Find the top-level argument expression containing a nested target call.
Walks up the AST from target_node to the wrapper _d() call's argument_list.
Only considers the target as nested if it passes through the argument_list of
a regular (non-assertion) function call. Assertion methods (assertEquals, etc.)
and non-argument relationships (method chains like .size()) are not counted.
Returns the top-level expression node if the target is nested inside a regular
function call, or None if the target is direct.
"""
current = target_node
passed_through_regular_call = False
while current.parent is not None:
parent = current.parent
if parent.type == "argument_list" and parent.parent is not None:
grandparent = parent.parent
if grandparent.type == "method_invocation":
gp_name = grandparent.child_by_field_name("name")
if gp_name:
name = self.analyzer.get_node_text(gp_name, wrapper_bytes)
if name == "_d":
if passed_through_regular_call and current != target_node:
return current
return None
if not name.startswith("assert"):
passed_through_regular_call = True
current = current.parent
return None
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
"""Check if assertion is assigned to a variable.

View file

@ -407,11 +407,53 @@ void testDeep() {
expected = """\
@Test
void testDeep() {
Object _cf_result1 = calculator.fibonacci(5);
Object _cf_result1 = outer.process(inner.compute(calculator.fibonacci(5)));
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_target_nested_in_non_target_call(self):
source = """\
@Test
void testSubtract() {
assertEquals(0, add(2, subtract(2, 2)));
}"""
expected = """\
@Test
void testSubtract() {
Object _cf_result1 = add(2, subtract(2, 2));
}"""
result = transform_java_assertions(source, "subtract")
assert result == expected
def test_non_target_nested_in_target_call(self):
source = """\
@Test
void testAdd() {
assertEquals(0, subtract(2, add(2, 3)));
}"""
expected = """\
@Test
void testAdd() {
Object _cf_result1 = subtract(2, add(2, 3));
}"""
result = transform_java_assertions(source, "add")
assert result == expected
def test_multiple_targets_nested_in_same_outer_call(self):
source = """\
@Test
void testOuter() {
assertEquals(0, outer(subtract(1, 1), subtract(2, 2)));
}"""
expected = """\
@Test
void testOuter() {
Object _cf_result1 = outer(subtract(1, 1), subtract(2, 2));
}"""
result = transform_java_assertions(source, "subtract")
assert result == expected
class TestWhitespacePreservation:
"""Tests for whitespace and indentation preservation."""

View file

@ -478,8 +478,8 @@ public class FibonacciTest {
"""
result = transform_java_assertions(source, "fibonacci")
assert "assertTrue" not in result
assert "Object _cf_result1 = Fibonacci.fibonacci(5);" in result
assert "Object _cf_result2 = Fibonacci.fibonacci(6);" in result
# Both fibonacci calls are preserved inside the containing areConsecutiveFibonacci call
assert "Object _cf_result1 = Fibonacci.areConsecutiveFibonacci(Fibonacci.fibonacci(5), Fibonacci.fibonacci(6));" in result
def test_multiple_assertions_in_one_method(self):
source = """\
@ -626,8 +626,8 @@ public class FibonacciTest {
"""
result = transform_java_assertions(source, "fibonacci")
assert "assertEquals" not in result
# Should capture the inner fibonacci call
assert "Object _cf_result1 = Fibonacci.fibonacci(10);" in result
# Should capture the full top-level expression containing the target call
assert "Object _cf_result1 = Fibonacci.fibonacciIndex(Fibonacci.fibonacci(10));" in result
def test_chained_method_on_result(self):
"""Target function call with chained method (e.g., result.toString())."""