fix: handle multi-line function calls in behavior instrumentation

wrap_target_calls_with_treesitter() operated line-by-line but tree-sitter
byte offsets span multiple lines. Multi-line calls like:
  func(arg1,
       arg2);
only had the first line replaced, leaving orphaned continuation lines
that caused compilation errors (80% of Spring AI functions skipped).

Rewrote to operate on the full body text with pre-computed character
offsets, processing calls back-to-front — same approach as
_add_timing_instrumentation which never had this bug.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
HeshamHM28 2026-03-27 05:24:18 +00:00 committed by Ubuntu
parent 32575780bd
commit e1478caaba
2 changed files with 197 additions and 149 deletions

View file

@ -273,9 +273,9 @@ def wrap_target_calls_with_treesitter(
) -> tuple[list[str], int]:
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
Parses the method body with tree-sitter, walks the AST for method_invocation nodes
matching func_name, and generates capture/serialize lines. Uses the parent node type
to determine whether to keep or remove the original line after replacement.
Operates on the full body text with character offsets (not line-by-line) to correctly
handle calls that span multiple lines. Processes calls back-to-front so earlier offsets
remain valid after later replacements.
For behavior mode (precise_call_timing=True), each call is wrapped in its own
try-finally block with immediate SQLite write to prevent data loss from multiple calls.
@ -302,169 +302,129 @@ def wrap_target_calls_with_treesitter(
if not calls:
return list(body_lines), 0
# Build line byte-start offsets for mapping calls to body_lines indices
line_byte_starts = []
# Build line byte-start offsets for mapping byte offsets to line indices
line_byte_starts: list[int] = []
offset = 0
for line in body_lines:
line_byte_starts.append(offset)
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
# Group non-lambda and non-complex-expression calls by their line index
calls_by_line: dict[int, list[dict[str, Any]]] = {}
for call in calls:
if call["in_lambda"] or call.get("in_complex", False):
logger.debug("Skipping behavior instrumentation for call in lambda or complex expression")
continue
# Filter out lambda and complex-expression calls, sort by start_byte ascending for counter assignment
valid_calls = [c for c in calls if not c["in_lambda"] and not c.get("in_complex", False)]
if not valid_calls:
return list(body_lines), 0
valid_calls.sort(key=lambda c: c["start_byte"])
# Pre-compute character offsets and line info for each call (before any text modifications)
for i, call in enumerate(valid_calls, 1):
call["_counter"] = i
call["_call_start_char"] = len(body_bytes[: call["start_byte"]].decode("utf8"))
call["_call_end_char"] = len(body_bytes[: call["end_byte"]].decode("utf8"))
if call["parent_type"] == "expression_statement":
call["_es_start_char"] = len(body_bytes[: call["es_start_byte"]].decode("utf8"))
call["_es_end_char"] = len(body_bytes[: call["es_end_byte"]].decode("utf8"))
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
calls_by_line.setdefault(line_idx, []).append(call)
call["_line_idx"] = line_idx
call["_line_char_start"] = len(body_bytes[: line_byte_starts[line_idx]].decode("utf8"))
wrapped = []
call_counter = 0
# Process calls back-to-front so earlier character offsets stay valid
for call in reversed(valid_calls):
call_counter = call["_counter"]
line_idx = call["_line_idx"]
call_absolute_line = body_start_line + line_idx + 1
inv_id = f"L{call_absolute_line}_{call_counter}"
for line_idx, body_line in enumerate(body_lines):
if line_idx not in calls_by_line:
wrapped.append(body_line)
continue
orig_line = body_lines[line_idx]
line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip()))
line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True)
line_indent_str = " " * (len(body_line) - len(body_line.lstrip()))
line_byte_start = line_byte_starts[line_idx]
line_bytes = body_line.encode("utf8")
var_name = f"_cf_result{iter_id}_{call_counter}"
cast_type = _infer_array_cast_type(orig_line)
if not cast_type and target_return_type and target_return_type != "void":
cast_type = target_return_type
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
new_line = body_line
# Track cumulative char shift from earlier edits on this line
char_shift = 0
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
capture_stmt_assign = f"{var_name} = {call['full_call']};"
if precise_call_timing:
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
else:
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
for call in line_calls:
call_counter += 1
# Compute absolute line number (1-indexed) for the invocation ID
call_absolute_line = body_start_line + line_idx + 1
inv_id = f"L{call_absolute_line}_{call_counter}"
var_name = f"_cf_result{iter_id}_{call_counter}"
cast_type = _infer_array_cast_type(body_line)
if not cast_type and target_return_type and target_return_type != "void":
cast_type = target_return_type
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
# Use per-call unique variables (with call_counter suffix) for behavior mode
# For behavior mode, we declare the variable outside try block, so use assignment not declaration here
# For performance mode, use shared variables without call_counter suffix
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
capture_stmt_assign = f"{var_name} = {call['full_call']};"
if call["parent_type"] == "expression_statement":
es_start = call["_es_start_char"]
es_end = call["_es_end_char"]
if precise_call_timing:
# Behavior mode: per-call unique variables
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
else:
# Performance mode: shared variables without call_counter suffix
serialize_stmt = (
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
# No indent on first line — body_text[:es_start] already has leading whitespace.
# Subsequent lines get line_indent_str.
var_decls = [
f"Object {var_name} = null;",
f"long _cf_end{iter_id}_{call_counter} = -1;",
f"long _cf_start{iter_id}_{call_counter} = 0;",
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
]
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
try_block = [
"try {",
f" {start_stmt}",
f" {capture_stmt_assign}",
f" {end_stmt}",
f" {serialize_stmt}",
]
finally_block = _generate_sqlite_write_code(
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id
)
all_lines = [*var_decls, start_marker, *try_block, *finally_block]
replacement = (
all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:])
)
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
if call["parent_type"] == "expression_statement":
# Replace the expression_statement IN PLACE with capture+serialize.
# This keeps the code inside whatever scope it's in (e.g. try block),
# preventing calls from being moved outside try-catch blocks.
es_start_byte = call["es_start_byte"] - line_byte_start
es_end_byte = call["es_end_byte"] - line_byte_start
es_start_char = len(line_bytes[:es_start_byte].decode("utf8"))
es_end_char = len(line_bytes[:es_end_byte].decode("utf8"))
if precise_call_timing:
# For behavior mode: wrap each call in its own try-finally with SQLite write.
# This ensures data from all calls is captured independently.
# Declare per-call variables
var_decls = [
f"Object {var_name} = null;",
f"long _cf_end{iter_id}_{call_counter} = -1;",
f"long _cf_start{iter_id}_{call_counter} = 0;",
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
]
# Start marker
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
# Try block with capture (use assignment, not declaration, since variable is declared above)
try_block = [
"try {",
f" {start_stmt}",
f" {capture_stmt_assign}",
f" {end_stmt}",
f" {serialize_stmt}",
]
# Finally block with SQLite write
finally_block = _generate_sqlite_write_code(
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id
)
replacement_lines = [*var_decls, start_marker, *try_block, *finally_block]
# Don't add indent to first line (it's placed after existing indent), but add to subsequent lines
if replacement_lines:
replacement = (
replacement_lines[0]
+ "\n"
+ "\n".join(f"{line_indent_str}{line}" for line in replacement_lines[1:])
)
else:
replacement = ""
else:
replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
adj_start = es_start_char + char_shift
adj_end = es_end_char + char_shift
new_line = new_line[:adj_start] + replacement + new_line[adj_end:]
char_shift += len(replacement) - (es_end_char - es_start_char)
else:
# The call is embedded in a larger expression (assignment, assertion, etc.)
# Emit capture+serialize before the line, then replace the call with the variable.
if precise_call_timing:
# For behavior mode: wrap in try-finally with SQLite write
# Declare per-call variables
wrapped.append(f"{line_indent_str}Object {var_name} = null;")
wrapped.append(f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;")
wrapped.append(f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;")
wrapped.append(f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;")
# Start marker
wrapped.append(
f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
)
# Try block (use assignment, not declaration, since variable is declared above)
wrapped.append(f"{line_indent_str}try {{")
wrapped.append(f"{line_indent_str} {start_stmt}")
wrapped.append(f"{line_indent_str} {capture_stmt_assign}")
wrapped.append(f"{line_indent_str} {end_stmt}")
wrapped.append(f"{line_indent_str} {serialize_stmt}")
# Finally block with SQLite write
finally_lines = _generate_sqlite_write_code(
iter_id,
call_counter,
line_indent_str,
class_name,
func_name,
test_method_name,
invocation_id=inv_id,
)
wrapped.extend(finally_lines)
else:
capture_line = f"{line_indent_str}{capture_stmt_with_decl}"
wrapped.append(capture_line)
serialize_line = f"{line_indent_str}{serialize_stmt}"
wrapped.append(serialize_line)
replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
body_text = body_text[:es_start] + replacement + body_text[es_end:]
else:
# Embedded call: replace call with variable, then insert capture lines before the line
call_start = call["_call_start_char"]
call_end = call["_call_end_char"]
line_char_start = call["_line_char_start"]
call_start_byte = call["start_byte"] - line_byte_start
call_end_byte = call["end_byte"] - line_byte_start
call_start_char = len(line_bytes[:call_start_byte].decode("utf8"))
call_end_char = len(line_bytes[:call_end_byte].decode("utf8"))
adj_start = call_start_char + char_shift
adj_end = call_end_char + char_shift
new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:]
char_shift += len(var_with_cast) - (call_end_char - call_start_char)
if precise_call_timing:
prefix_lines = [
f"{line_indent_str}Object {var_name} = null;",
f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;",
f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;",
f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");',
f"{line_indent_str}try {{",
f"{line_indent_str} {start_stmt}",
f"{line_indent_str} {capture_stmt_assign}",
f"{line_indent_str} {end_stmt}",
f"{line_indent_str} {serialize_stmt}",
]
finally_lines = _generate_sqlite_write_code(
iter_id,
call_counter,
line_indent_str,
class_name,
func_name,
test_method_name,
invocation_id=inv_id,
)
prefix_lines.extend(finally_lines)
else:
prefix_lines = [f"{line_indent_str}{capture_stmt_with_decl}", f"{line_indent_str}{serialize_stmt}"]
# Keep the modified line only if it has meaningful content left
if new_line.strip():
wrapped.append(new_line)
# Step 1: Replace the call with the variable (at higher offset, safe to do first)
body_text = body_text[:call_start] + var_with_cast + body_text[call_end:]
# Step 2: Insert prefix lines before the line containing the call (at lower offset)
prefix_text = "\n".join(prefix_lines) + "\n"
body_text = body_text[:line_char_start] + prefix_text + body_text[line_char_start:]
return wrapped, call_counter
# Split back into lines, filtering out any lines that became empty from statement replacement
wrapped = [line for line in body_text.split("\n") if line.strip()]
return wrapped, len(valid_calls)
def _collect_calls(

View file

@ -1813,6 +1813,94 @@ public class InnerClassTest__perfonlyinstrumented {
assert result == expected
class TestMultiLineCallInstrumentation:
"""Tests that multi-line function calls are fully replaced during instrumentation.
When a call spans multiple lines the replacement must cover the entire byte range,
not just the first line. Otherwise continuation lines become orphaned.
"""
def test_behavior_mode_multiline_expression_statement(self, tmp_path: Path):
"""Multi-line expression statement call must not leave orphaned continuation lines."""
test_file = tmp_path / "SchemaTest.java"
source = """import org.junit.jupiter.api.Test;
public class SchemaTest {
@Test
public void testAugment() {
augmentToolInputSchema(baseSchema, propertyName,
description, required);
}
}
"""
test_file.write_text(source, encoding="utf-8")
func = FunctionToOptimize(
function_name="augmentToolInputSchema",
file_path=tmp_path / "Schema.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file
)
assert success is True
# The full call text appears inside the capture assignment within a try block
assert "_cf_result1_1 = augmentToolInputSchema(baseSchema, propertyName," in result
# The continuation line is inside the try block (after "try {"), not orphaned
lines = result.split("\n")
try_idx = next(i for i, l in enumerate(lines) if "try {" in l)
desc_indices = [i for i, l in enumerate(lines) if "description, required);" in l]
assert len(desc_indices) == 1, f"Expected exactly 1 continuation line, found {len(desc_indices)}"
assert desc_indices[0] > try_idx, "Continuation line must be inside the try block"
# Balanced braces
assert result.count("{") == result.count("}")
def test_behavior_mode_multiline_embedded_call(self, tmp_path: Path):
"""Multi-line call embedded in assertEquals must not leave orphaned continuation lines."""
test_file = tmp_path / "SchemaTest.java"
source = """import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class SchemaTest {
@Test
public void testAugment() {
assertEquals("expected", augmentToolInputSchema(baseSchema,
propertyName));
}
}
"""
test_file.write_text(source, encoding="utf-8")
func = FunctionToOptimize(
function_name="augmentToolInputSchema",
file_path=tmp_path / "Schema.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file
)
assert success is True
# The multi-line call is replaced with a variable in assertEquals
assert 'assertEquals("expected", _cf_result1_1);' in result
# No orphaned continuation line as a standalone statement
lines = result.split("\n")
assert not any(l.strip() == "propertyName));" for l in lines), "Orphaned continuation line found"
# Balanced braces
assert result.count("{") == result.count("}")
class TestMultiByteUtf8Instrumentation:
"""Tests that timing instrumentation handles multi-byte UTF-8 source correctly.