fix: handle multi-line function calls in behavior instrumentation
wrap_target_calls_with_treesitter() operated line-by-line but tree-sitter
byte offsets span multiple lines. Multi-line calls like:
func(arg1,
arg2);
only had the first line replaced, leaving orphaned continuation lines
that caused compilation errors (80% of Spring AI functions skipped).
Rewrote to operate on the full body text with pre-computed character
offsets, processing calls back-to-front — same approach as
_add_timing_instrumentation which never had this bug.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
32575780bd
commit
e1478caaba
2 changed files with 197 additions and 149 deletions
|
|
@ -273,9 +273,9 @@ def wrap_target_calls_with_treesitter(
|
|||
) -> tuple[list[str], int]:
|
||||
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
|
||||
|
||||
Parses the method body with tree-sitter, walks the AST for method_invocation nodes
|
||||
matching func_name, and generates capture/serialize lines. Uses the parent node type
|
||||
to determine whether to keep or remove the original line after replacement.
|
||||
Operates on the full body text with character offsets (not line-by-line) to correctly
|
||||
handle calls that span multiple lines. Processes calls back-to-front so earlier offsets
|
||||
remain valid after later replacements.
|
||||
|
||||
For behavior mode (precise_call_timing=True), each call is wrapped in its own
|
||||
try-finally block with immediate SQLite write to prevent data loss from multiple calls.
|
||||
|
|
@ -302,169 +302,129 @@ def wrap_target_calls_with_treesitter(
|
|||
if not calls:
|
||||
return list(body_lines), 0
|
||||
|
||||
# Build line byte-start offsets for mapping calls to body_lines indices
|
||||
line_byte_starts = []
|
||||
# Build line byte-start offsets for mapping byte offsets to line indices
|
||||
line_byte_starts: list[int] = []
|
||||
offset = 0
|
||||
for line in body_lines:
|
||||
line_byte_starts.append(offset)
|
||||
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
|
||||
|
||||
# Group non-lambda and non-complex-expression calls by their line index
|
||||
calls_by_line: dict[int, list[dict[str, Any]]] = {}
|
||||
for call in calls:
|
||||
if call["in_lambda"] or call.get("in_complex", False):
|
||||
logger.debug("Skipping behavior instrumentation for call in lambda or complex expression")
|
||||
continue
|
||||
# Filter out lambda and complex-expression calls, sort by start_byte ascending for counter assignment
|
||||
valid_calls = [c for c in calls if not c["in_lambda"] and not c.get("in_complex", False)]
|
||||
if not valid_calls:
|
||||
return list(body_lines), 0
|
||||
valid_calls.sort(key=lambda c: c["start_byte"])
|
||||
|
||||
# Pre-compute character offsets and line info for each call (before any text modifications)
|
||||
for i, call in enumerate(valid_calls, 1):
|
||||
call["_counter"] = i
|
||||
call["_call_start_char"] = len(body_bytes[: call["start_byte"]].decode("utf8"))
|
||||
call["_call_end_char"] = len(body_bytes[: call["end_byte"]].decode("utf8"))
|
||||
if call["parent_type"] == "expression_statement":
|
||||
call["_es_start_char"] = len(body_bytes[: call["es_start_byte"]].decode("utf8"))
|
||||
call["_es_end_char"] = len(body_bytes[: call["es_end_byte"]].decode("utf8"))
|
||||
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
|
||||
calls_by_line.setdefault(line_idx, []).append(call)
|
||||
call["_line_idx"] = line_idx
|
||||
call["_line_char_start"] = len(body_bytes[: line_byte_starts[line_idx]].decode("utf8"))
|
||||
|
||||
wrapped = []
|
||||
call_counter = 0
|
||||
# Process calls back-to-front so earlier character offsets stay valid
|
||||
for call in reversed(valid_calls):
|
||||
call_counter = call["_counter"]
|
||||
line_idx = call["_line_idx"]
|
||||
call_absolute_line = body_start_line + line_idx + 1
|
||||
inv_id = f"L{call_absolute_line}_{call_counter}"
|
||||
|
||||
for line_idx, body_line in enumerate(body_lines):
|
||||
if line_idx not in calls_by_line:
|
||||
wrapped.append(body_line)
|
||||
continue
|
||||
orig_line = body_lines[line_idx]
|
||||
line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip()))
|
||||
|
||||
line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True)
|
||||
line_indent_str = " " * (len(body_line) - len(body_line.lstrip()))
|
||||
line_byte_start = line_byte_starts[line_idx]
|
||||
line_bytes = body_line.encode("utf8")
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
cast_type = _infer_array_cast_type(orig_line)
|
||||
if not cast_type and target_return_type and target_return_type != "void":
|
||||
cast_type = target_return_type
|
||||
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
|
||||
|
||||
new_line = body_line
|
||||
# Track cumulative char shift from earlier edits on this line
|
||||
char_shift = 0
|
||||
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
|
||||
capture_stmt_assign = f"{var_name} = {call['full_call']};"
|
||||
if precise_call_timing:
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
|
||||
else:
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
|
||||
|
||||
for call in line_calls:
|
||||
call_counter += 1
|
||||
# Compute absolute line number (1-indexed) for the invocation ID
|
||||
call_absolute_line = body_start_line + line_idx + 1
|
||||
inv_id = f"L{call_absolute_line}_{call_counter}"
|
||||
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
cast_type = _infer_array_cast_type(body_line)
|
||||
if not cast_type and target_return_type and target_return_type != "void":
|
||||
cast_type = target_return_type
|
||||
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
|
||||
|
||||
# Use per-call unique variables (with call_counter suffix) for behavior mode
|
||||
# For behavior mode, we declare the variable outside try block, so use assignment not declaration here
|
||||
# For performance mode, use shared variables without call_counter suffix
|
||||
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
|
||||
capture_stmt_assign = f"{var_name} = {call['full_call']};"
|
||||
if call["parent_type"] == "expression_statement":
|
||||
es_start = call["_es_start_char"]
|
||||
es_end = call["_es_end_char"]
|
||||
if precise_call_timing:
|
||||
# Behavior mode: per-call unique variables
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
|
||||
else:
|
||||
# Performance mode: shared variables without call_counter suffix
|
||||
serialize_stmt = (
|
||||
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
# No indent on first line — body_text[:es_start] already has leading whitespace.
|
||||
# Subsequent lines get line_indent_str.
|
||||
var_decls = [
|
||||
f"Object {var_name} = null;",
|
||||
f"long _cf_end{iter_id}_{call_counter} = -1;",
|
||||
f"long _cf_start{iter_id}_{call_counter} = 0;",
|
||||
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
|
||||
]
|
||||
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
|
||||
try_block = [
|
||||
"try {",
|
||||
f" {start_stmt}",
|
||||
f" {capture_stmt_assign}",
|
||||
f" {end_stmt}",
|
||||
f" {serialize_stmt}",
|
||||
]
|
||||
finally_block = _generate_sqlite_write_code(
|
||||
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id
|
||||
)
|
||||
all_lines = [*var_decls, start_marker, *try_block, *finally_block]
|
||||
replacement = (
|
||||
all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:])
|
||||
)
|
||||
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
|
||||
|
||||
if call["parent_type"] == "expression_statement":
|
||||
# Replace the expression_statement IN PLACE with capture+serialize.
|
||||
# This keeps the code inside whatever scope it's in (e.g. try block),
|
||||
# preventing calls from being moved outside try-catch blocks.
|
||||
es_start_byte = call["es_start_byte"] - line_byte_start
|
||||
es_end_byte = call["es_end_byte"] - line_byte_start
|
||||
es_start_char = len(line_bytes[:es_start_byte].decode("utf8"))
|
||||
es_end_char = len(line_bytes[:es_end_byte].decode("utf8"))
|
||||
if precise_call_timing:
|
||||
# For behavior mode: wrap each call in its own try-finally with SQLite write.
|
||||
# This ensures data from all calls is captured independently.
|
||||
# Declare per-call variables
|
||||
var_decls = [
|
||||
f"Object {var_name} = null;",
|
||||
f"long _cf_end{iter_id}_{call_counter} = -1;",
|
||||
f"long _cf_start{iter_id}_{call_counter} = 0;",
|
||||
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
|
||||
]
|
||||
# Start marker
|
||||
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
|
||||
# Try block with capture (use assignment, not declaration, since variable is declared above)
|
||||
try_block = [
|
||||
"try {",
|
||||
f" {start_stmt}",
|
||||
f" {capture_stmt_assign}",
|
||||
f" {end_stmt}",
|
||||
f" {serialize_stmt}",
|
||||
]
|
||||
# Finally block with SQLite write
|
||||
finally_block = _generate_sqlite_write_code(
|
||||
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id
|
||||
)
|
||||
|
||||
replacement_lines = [*var_decls, start_marker, *try_block, *finally_block]
|
||||
# Don't add indent to first line (it's placed after existing indent), but add to subsequent lines
|
||||
if replacement_lines:
|
||||
replacement = (
|
||||
replacement_lines[0]
|
||||
+ "\n"
|
||||
+ "\n".join(f"{line_indent_str}{line}" for line in replacement_lines[1:])
|
||||
)
|
||||
else:
|
||||
replacement = ""
|
||||
else:
|
||||
replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
|
||||
adj_start = es_start_char + char_shift
|
||||
adj_end = es_end_char + char_shift
|
||||
new_line = new_line[:adj_start] + replacement + new_line[adj_end:]
|
||||
char_shift += len(replacement) - (es_end_char - es_start_char)
|
||||
else:
|
||||
# The call is embedded in a larger expression (assignment, assertion, etc.)
|
||||
# Emit capture+serialize before the line, then replace the call with the variable.
|
||||
if precise_call_timing:
|
||||
# For behavior mode: wrap in try-finally with SQLite write
|
||||
# Declare per-call variables
|
||||
wrapped.append(f"{line_indent_str}Object {var_name} = null;")
|
||||
wrapped.append(f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;")
|
||||
wrapped.append(f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;")
|
||||
wrapped.append(f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;")
|
||||
# Start marker
|
||||
wrapped.append(
|
||||
f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
|
||||
)
|
||||
# Try block (use assignment, not declaration, since variable is declared above)
|
||||
wrapped.append(f"{line_indent_str}try {{")
|
||||
wrapped.append(f"{line_indent_str} {start_stmt}")
|
||||
wrapped.append(f"{line_indent_str} {capture_stmt_assign}")
|
||||
wrapped.append(f"{line_indent_str} {end_stmt}")
|
||||
wrapped.append(f"{line_indent_str} {serialize_stmt}")
|
||||
# Finally block with SQLite write
|
||||
finally_lines = _generate_sqlite_write_code(
|
||||
iter_id,
|
||||
call_counter,
|
||||
line_indent_str,
|
||||
class_name,
|
||||
func_name,
|
||||
test_method_name,
|
||||
invocation_id=inv_id,
|
||||
)
|
||||
wrapped.extend(finally_lines)
|
||||
else:
|
||||
capture_line = f"{line_indent_str}{capture_stmt_with_decl}"
|
||||
wrapped.append(capture_line)
|
||||
serialize_line = f"{line_indent_str}{serialize_stmt}"
|
||||
wrapped.append(serialize_line)
|
||||
replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
|
||||
body_text = body_text[:es_start] + replacement + body_text[es_end:]
|
||||
else:
|
||||
# Embedded call: replace call with variable, then insert capture lines before the line
|
||||
call_start = call["_call_start_char"]
|
||||
call_end = call["_call_end_char"]
|
||||
line_char_start = call["_line_char_start"]
|
||||
|
||||
call_start_byte = call["start_byte"] - line_byte_start
|
||||
call_end_byte = call["end_byte"] - line_byte_start
|
||||
call_start_char = len(line_bytes[:call_start_byte].decode("utf8"))
|
||||
call_end_char = len(line_bytes[:call_end_byte].decode("utf8"))
|
||||
adj_start = call_start_char + char_shift
|
||||
adj_end = call_end_char + char_shift
|
||||
new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:]
|
||||
char_shift += len(var_with_cast) - (call_end_char - call_start_char)
|
||||
if precise_call_timing:
|
||||
prefix_lines = [
|
||||
f"{line_indent_str}Object {var_name} = null;",
|
||||
f"{line_indent_str}long _cf_end{iter_id}_{call_counter} = -1;",
|
||||
f"{line_indent_str}long _cf_start{iter_id}_{call_counter} = 0;",
|
||||
f"{line_indent_str}byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
|
||||
f'{line_indent_str}System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");',
|
||||
f"{line_indent_str}try {{",
|
||||
f"{line_indent_str} {start_stmt}",
|
||||
f"{line_indent_str} {capture_stmt_assign}",
|
||||
f"{line_indent_str} {end_stmt}",
|
||||
f"{line_indent_str} {serialize_stmt}",
|
||||
]
|
||||
finally_lines = _generate_sqlite_write_code(
|
||||
iter_id,
|
||||
call_counter,
|
||||
line_indent_str,
|
||||
class_name,
|
||||
func_name,
|
||||
test_method_name,
|
||||
invocation_id=inv_id,
|
||||
)
|
||||
prefix_lines.extend(finally_lines)
|
||||
else:
|
||||
prefix_lines = [f"{line_indent_str}{capture_stmt_with_decl}", f"{line_indent_str}{serialize_stmt}"]
|
||||
|
||||
# Keep the modified line only if it has meaningful content left
|
||||
if new_line.strip():
|
||||
wrapped.append(new_line)
|
||||
# Step 1: Replace the call with the variable (at higher offset, safe to do first)
|
||||
body_text = body_text[:call_start] + var_with_cast + body_text[call_end:]
|
||||
# Step 2: Insert prefix lines before the line containing the call (at lower offset)
|
||||
prefix_text = "\n".join(prefix_lines) + "\n"
|
||||
body_text = body_text[:line_char_start] + prefix_text + body_text[line_char_start:]
|
||||
|
||||
return wrapped, call_counter
|
||||
# Split back into lines, filtering out any lines that became empty from statement replacement
|
||||
wrapped = [line for line in body_text.split("\n") if line.strip()]
|
||||
return wrapped, len(valid_calls)
|
||||
|
||||
|
||||
def _collect_calls(
|
||||
|
|
|
|||
|
|
@ -1813,6 +1813,94 @@ public class InnerClassTest__perfonlyinstrumented {
|
|||
assert result == expected
|
||||
|
||||
|
||||
class TestMultiLineCallInstrumentation:
|
||||
"""Tests that multi-line function calls are fully replaced during instrumentation.
|
||||
|
||||
When a call spans multiple lines the replacement must cover the entire byte range,
|
||||
not just the first line. Otherwise continuation lines become orphaned.
|
||||
"""
|
||||
|
||||
def test_behavior_mode_multiline_expression_statement(self, tmp_path: Path):
|
||||
"""Multi-line expression statement call must not leave orphaned continuation lines."""
|
||||
test_file = tmp_path / "SchemaTest.java"
|
||||
source = """import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SchemaTest {
|
||||
@Test
|
||||
public void testAugment() {
|
||||
augmentToolInputSchema(baseSchema, propertyName,
|
||||
description, required);
|
||||
}
|
||||
}
|
||||
"""
|
||||
test_file.write_text(source, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="augmentToolInputSchema",
|
||||
file_path=tmp_path / "Schema.java",
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file
|
||||
)
|
||||
|
||||
assert success is True
|
||||
# The full call text appears inside the capture assignment within a try block
|
||||
assert "_cf_result1_1 = augmentToolInputSchema(baseSchema, propertyName," in result
|
||||
# The continuation line is inside the try block (after "try {"), not orphaned
|
||||
lines = result.split("\n")
|
||||
try_idx = next(i for i, l in enumerate(lines) if "try {" in l)
|
||||
desc_indices = [i for i, l in enumerate(lines) if "description, required);" in l]
|
||||
assert len(desc_indices) == 1, f"Expected exactly 1 continuation line, found {len(desc_indices)}"
|
||||
assert desc_indices[0] > try_idx, "Continuation line must be inside the try block"
|
||||
# Balanced braces
|
||||
assert result.count("{") == result.count("}")
|
||||
|
||||
def test_behavior_mode_multiline_embedded_call(self, tmp_path: Path):
|
||||
"""Multi-line call embedded in assertEquals must not leave orphaned continuation lines."""
|
||||
test_file = tmp_path / "SchemaTest.java"
|
||||
source = """import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class SchemaTest {
|
||||
@Test
|
||||
public void testAugment() {
|
||||
assertEquals("expected", augmentToolInputSchema(baseSchema,
|
||||
propertyName));
|
||||
}
|
||||
}
|
||||
"""
|
||||
test_file.write_text(source, encoding="utf-8")
|
||||
|
||||
func = FunctionToOptimize(
|
||||
function_name="augmentToolInputSchema",
|
||||
file_path=tmp_path / "Schema.java",
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_string=source, function_to_optimize=func, mode="behavior", test_path=test_file
|
||||
)
|
||||
|
||||
assert success is True
|
||||
# The multi-line call is replaced with a variable in assertEquals
|
||||
assert 'assertEquals("expected", _cf_result1_1);' in result
|
||||
# No orphaned continuation line as a standalone statement
|
||||
lines = result.split("\n")
|
||||
assert not any(l.strip() == "propertyName));" for l in lines), "Orphaned continuation line found"
|
||||
# Balanced braces
|
||||
assert result.count("{") == result.count("}")
|
||||
|
||||
|
||||
class TestMultiByteUtf8Instrumentation:
|
||||
"""Tests that timing instrumentation handles multi-byte UTF-8 source correctly.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue