Replace Regex with tree-sitter

This commit is contained in:
HeshamHM28 2026-02-16 08:32:55 +02:00
parent 17c8cf8a19
commit 4c976415ef
5 changed files with 308 additions and 325 deletions

View file

@ -61,82 +61,162 @@ def _is_test_annotation(stripped_line: str) -> bool:
return bool(_TEST_ANNOTATION_RE.match(stripped_line))
def _find_balanced_end(text: str, start: int) -> int:
"""Find the position after the closing paren that balances the opening paren at start.
def _is_inside_lambda(node) -> bool:
"""Check if a tree-sitter node is inside a lambda_expression."""
current = node.parent
while current is not None:
if current.type == "lambda_expression":
return True
if current.type == "method_declaration":
return False
current = current.parent
return False
Args:
text: The source text.
start: Index of the opening parenthesis '('.
Returns:
Index one past the matching closing ')', or -1 if not found.
_TS_BODY_PREFIX = "class _D { void _m() {\n"
_TS_BODY_SUFFIX = "\n}}"
_TS_BODY_PREFIX_BYTES = _TS_BODY_PREFIX.encode("utf8")
def wrap_target_calls_with_treesitter(body_lines: list[str], func_name: str, iter_id: int) -> tuple[list[str], int]:
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
Parses the method body with tree-sitter, walks the AST for method_invocation nodes
matching func_name, and generates capture/serialize lines. Uses the parent node type
to determine whether to keep or remove the original line after replacement.
Returns (wrapped_body_lines, call_counter).
"""
if start >= len(text) or text[start] != "(":
return -1
depth = 1
pos = start + 1
in_string = False
string_char = None
in_char = False
while pos < len(text) and depth > 0:
ch = text[pos]
prev = text[pos - 1] if pos > 0 else ""
if ch == "'" and not in_string and prev != "\\":
in_char = not in_char
elif ch == '"' and not in_char and prev != "\\":
if not in_string:
in_string = True
string_char = ch
elif ch == string_char:
in_string = False
string_char = None
elif not in_string and not in_char:
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
pos += 1
return pos if depth == 0 else -1
from codeflash.languages.java.parser import get_java_analyzer
analyzer = get_java_analyzer()
body_text = "\n".join(body_lines)
body_bytes = body_text.encode("utf8")
prefix_len = len(_TS_BODY_PREFIX_BYTES)
def _find_method_calls_balanced(line: str, func_name: str):
"""Find method calls to func_name with properly balanced parentheses.
wrapper_bytes = _TS_BODY_PREFIX_BYTES + body_bytes + _TS_BODY_SUFFIX.encode("utf8")
tree = analyzer.parse(wrapper_bytes)
Handles nested parentheses in arguments correctly, unlike a pure regex approach.
Returns a list of (start, end, full_call) tuples where start/end are positions
in the line and full_call is the matched text (receiver.funcName(args)).
# Collect all matching calls with their metadata
calls = []
_collect_calls(tree.root_node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, calls)
Args:
line: A single line of Java source code.
func_name: The method name to look for.
if not calls:
return list(body_lines), 0
Returns:
List of (start_pos, end_pos, full_call_text) tuples.
# Build line byte-start offsets for mapping calls to body_lines indices
line_byte_starts = []
offset = 0
for line in body_lines:
line_byte_starts.append(offset)
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
"""
# First find all occurrences of .funcName( in the line using regex
# to locate the method name, then use balanced paren finding for args
prefix_pattern = re.compile(
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\("
)
results = []
search_start = 0
while search_start < len(line):
m = prefix_pattern.search(line, search_start)
if not m:
break
# m.end() - 1 is the position of the opening paren
open_paren_pos = m.end() - 1
close_pos = _find_balanced_end(line, open_paren_pos)
if close_pos == -1:
# Unbalanced parens - skip this match
search_start = m.end()
# Group non-lambda calls by their line index
calls_by_line: dict[int, list] = {}
for call in calls:
if call["in_lambda"]:
continue
full_call = line[m.start():close_pos]
results.append((m.start(), close_pos, full_call))
search_start = close_pos
return results
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
calls_by_line.setdefault(line_idx, []).append(call)
wrapped = []
call_counter = 0
for line_idx, body_line in enumerate(body_lines):
if line_idx not in calls_by_line:
wrapped.append(body_line)
continue
line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True)
line_indent_str = " " * (len(body_line) - len(body_line.lstrip()))
line_byte_start = line_byte_starts[line_idx]
line_bytes = body_line.encode("utf8")
new_line = body_line
# Track cumulative char shift from earlier edits on this line
char_shift = 0
for call in line_calls:
call_counter += 1
var_name = f"_cf_result{iter_id}_{call_counter}"
cast_type = _infer_array_cast_type(body_line)
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
capture_stmt = f"var {var_name} = {call['full_call']};"
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
if call["parent_type"] == "expression_statement":
# Replace the expression_statement IN PLACE with capture+serialize.
# This keeps the code inside whatever scope it's in (e.g. try block),
# preventing calls from being moved outside try-catch blocks.
es_start_byte = call["es_start_byte"] - line_byte_start
es_end_byte = call["es_end_byte"] - line_byte_start
es_start_char = len(line_bytes[:es_start_byte].decode("utf8"))
es_end_char = len(line_bytes[:es_end_byte].decode("utf8"))
replacement = f"{capture_stmt} {serialize_stmt}"
adj_start = es_start_char + char_shift
adj_end = es_end_char + char_shift
new_line = new_line[:adj_start] + replacement + new_line[adj_end:]
char_shift += len(replacement) - (es_end_char - es_start_char)
else:
# The call is embedded in a larger expression (assignment, assertion, etc.)
# Emit capture+serialize before the line, then replace the call with the variable.
capture_line = f"{line_indent_str}{capture_stmt}"
serialize_line = f"{line_indent_str}{serialize_stmt}"
wrapped.append(capture_line)
wrapped.append(serialize_line)
call_start_byte = call["start_byte"] - line_byte_start
call_end_byte = call["end_byte"] - line_byte_start
call_start_char = len(line_bytes[:call_start_byte].decode("utf8"))
call_end_char = len(line_bytes[:call_end_byte].decode("utf8"))
adj_start = call_start_char + char_shift
adj_end = call_end_char + char_shift
new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:]
char_shift += len(var_with_cast) - (call_end_char - call_start_char)
# Keep the modified line only if it has meaningful content left
if new_line.strip():
wrapped.append(new_line)
return wrapped, call_counter
def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, out):
"""Recursively collect method_invocation nodes matching func_name."""
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func_name:
start = node.start_byte - prefix_len
end = node.end_byte - prefix_len
if start >= 0 and end <= len(body_bytes):
parent = node.parent
parent_type = parent.type if parent else ""
es_start = es_end = 0
if parent_type == "expression_statement":
es_start = parent.start_byte - prefix_len
es_end = parent.end_byte - prefix_len
out.append(
{
"start_byte": start,
"end_byte": end,
"full_call": analyzer.get_node_text(node, wrapper_bytes),
"parent_type": parent_type,
"in_lambda": _is_inside_lambda(node),
"es_start_byte": es_start,
"es_end_byte": es_end,
}
)
for child in node.children:
_collect_calls(child, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, out)
def _byte_to_line_index(byte_offset: int, line_byte_starts: list[int]) -> int:
"""Map a byte offset in body_text to a body_lines index."""
for i in range(len(line_byte_starts) - 1, -1, -1):
if byte_offset >= line_byte_starts[i]:
return i
return 0
def _infer_array_cast_type(line: str) -> str | None:
@ -279,9 +359,7 @@ def instrument_existing_test(
# This includes the class declaration, return types, constructor calls,
# variable declarations, etc. We use word-boundary matching to avoid
# replacing substrings of other identifiers.
modified_source = re.sub(
rf"\b{re.escape(original_class_name)}\b", new_class_name, source
)
modified_source = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, source)
# Add timing instrumentation to test methods
# Use original class name (without suffix) in timing markers for consistency with Python
@ -429,95 +507,11 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
i += 1
break
# Wrap function calls to capture return values
# Look for patterns like: obj.funcName(args) or new Class().funcName(args)
call_counter = 0
wrapped_body_lines = []
# Track lambda block nesting depth to avoid wrapping calls inside lambda bodies.
# assertThrows/assertDoesNotThrow expect an Executable (void functional interface),
# and wrapping the call in a variable assignment would turn the void-compatible
# lambda into a value-returning lambda, causing a compilation error.
# Also, variables declared outside lambdas cannot be reassigned inside them
# (Java requires effectively final variables in lambda captures).
# Handles both no-arg lambdas: () -> { func(); }
# and parameterized lambdas: (a, b, c) -> { func(); }
lambda_brace_depth = 0
for body_line in body_lines:
# Detect block lambda openings: (...) -> { or () -> {
# Matches both () -> { and (a, b, c) -> {
is_lambda_open = bool(re.search(r"->\s*\{", body_line))
# Update lambda brace depth tracking for block lambdas
if is_lambda_open or lambda_brace_depth > 0:
open_braces = body_line.count("{")
close_braces = body_line.count("}")
if is_lambda_open and lambda_brace_depth == 0:
# Starting a new lambda block - only count braces from this lambda
lambda_brace_depth = open_braces - close_braces
else:
lambda_brace_depth += open_braces - close_braces
# Ensure depth doesn't go below 0
lambda_brace_depth = max(0, lambda_brace_depth)
inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line))
# Check if this line contains a call to the target function
if func_name in body_line and "(" in body_line:
# Skip wrapping if the function call is inside a lambda expression
if inside_lambda:
wrapped_body_lines.append(body_line)
continue
line_indent = len(body_line) - len(body_line.lstrip())
line_indent_str = " " * line_indent
# Find all matches using balanced parenthesis matching
# This correctly handles nested parens like:
# obj.func(a, Rows.toRowID(frame.getIndex(), row))
matches = _find_method_calls_balanced(body_line, func_name)
if matches:
# Process matches in reverse order to maintain correct positions
new_line = body_line
for start_pos, end_pos, full_call in reversed(matches):
call_counter += 1
var_name = f"_cf_result{iter_id}_{call_counter}"
# Check if we need to cast the result for assertions with primitive arrays
# This handles assertArrayEquals(int[], int[]) etc.
cast_type = _infer_array_cast_type(body_line)
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
# Replace this occurrence with the variable (with cast if needed)
new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:]
# Use 'var' instead of 'Object' to preserve the exact return type.
# This avoids boxing mismatches (e.g., assertEquals(int, Object) where
# Object is boxed Long but expected is boxed Integer). Requires Java 10+.
capture_line = f"{line_indent_str}var {var_name} = {full_call};"
wrapped_body_lines.append(capture_line)
# Immediately serialize the captured result while the variable
# is still in scope. This is necessary because the variable may
# be declared inside a nested block (while/for/if/try) and would
# be out of scope at the end of the method body.
serialize_line = (
f"{line_indent_str}_cf_serializedResult{iter_id} = "
f"com.codeflash.Serializer.serialize((Object) {var_name});"
)
wrapped_body_lines.append(serialize_line)
# Check if the line is now just a variable reference (invalid statement)
# This happens when the original line was just a void method call
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
stripped_new = new_line.strip().rstrip(";").strip()
if stripped_new and stripped_new not in (var_name, var_with_cast):
wrapped_body_lines.append(new_line)
else:
wrapped_body_lines.append(body_line)
else:
wrapped_body_lines.append(body_line)
# Wrap function calls to capture return values using tree-sitter AST analysis.
# This correctly handles lambdas, try-catch blocks, assignments, and nested calls.
wrapped_body_lines, _call_counter = wrap_target_calls_with_treesitter(
body_lines=body_lines, func_name=func_name, iter_id=iter_id
)
# Add behavior instrumentation code
behavior_start_code = [
@ -833,12 +827,9 @@ def instrument_generated_java_test(
original_class_name = class_match.group(1)
# For performance mode, add timing instrumentation
# Use original class name (without suffix) in timing markers for consistency with Python
if mode == "performance":
# Rename class based on mode
if mode == "behavior":
new_class_name = f"{original_class_name}__perfinstrumented"
@ -847,9 +838,7 @@ def instrument_generated_java_test(
# Rename all references to the original class name in the source.
# This includes the class declaration, return types, constructor calls, etc.
modified_code = re.sub(
rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code
)
modified_code = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code)
modified_code = _add_timing_instrumentation(
modified_code,
@ -857,7 +846,12 @@ def instrument_generated_java_test(
function_name,
)
elif mode == "behavior":
_ , modified_code = instrument_existing_test(test_string=test_code, mode=mode, function_to_optimize=function_to_optimize, test_class_name=original_class_name)
_, modified_code = instrument_existing_test(
test_string=test_code,
mode=mode,
function_to_optimize=function_to_optimize,
test_class_name=original_class_name,
)
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
return modified_code
@ -890,5 +884,3 @@ def _add_import(source: str, import_statement: str) -> str:
lines.insert(insert_idx, import_statement + "\n")
return "".join(lines)

View file

@ -110,7 +110,11 @@ class JavaLineProfiler:
lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
)
return "".join(lines_with_profiler)
result = "".join(lines_with_profiler)
if not analyzer.validate_syntax(result):
logger.warning("Line profiler instrumentation produced invalid Java, returning original source")
return source
return result
def _generate_profiler_class(self) -> str:
"""Generate Java code for profiler class."""

View file

@ -539,87 +539,69 @@ class JavaAssertTransformer:
return pos
# Wrapper template to make assertion argument fragments parseable by tree-sitter.
# e.g. content "55, obj.fibonacci(10)" becomes "class _D { void _m() { _d(55, obj.fibonacci(10)); } }"
_TS_WRAPPER_PREFIX = "class _D { void _m() { _d("
_TS_WRAPPER_SUFFIX = "); } }"
_TS_WRAPPER_PREFIX_BYTES = _TS_WRAPPER_PREFIX.encode("utf8")
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
"""Extract calls to the target function from assertion arguments."""
target_calls: list[TargetCall] = []
"""Find all calls to the target function within assertion argument text using tree-sitter."""
if not content or not content.strip():
return []
# Pattern to match method calls with various receiver styles:
# - obj.method(args)
# - ClassName.staticMethod(args)
# - new ClassName().method(args)
# - new ClassName(args).method(args)
# - method(args) (no receiver)
#
# Strategy: Find the function name, then look backwards for the receiver
pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE)
content_bytes = content.encode("utf8")
wrapper_bytes = self._TS_WRAPPER_PREFIX_BYTES + content_bytes + self._TS_WRAPPER_SUFFIX.encode("utf8")
tree = self.analyzer.parse(wrapper_bytes)
for match in pattern.finditer(content):
method_name = match.group(1)
method_start = match.start()
results: list[TargetCall] = []
self._collect_target_invocations(tree.root_node, wrapper_bytes, content_bytes, base_offset, results)
return results
# Find the arguments
paren_pos = match.end() - 1
args_content, end_pos = self._find_balanced_parens(content, paren_pos)
if args_content is None:
continue
def _collect_target_invocations(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
base_offset: int, out: list[TargetCall],
) -> None:
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name."""
prefix_len = len(self._TS_WRAPPER_PREFIX_BYTES)
# Look backwards from the method name to find the receiver
receiver_start = method_start
if node.type == "method_invocation":
name_node = node.child_by_field_name("name")
if name_node and self.analyzer.get_node_text(name_node, wrapper_bytes) == self.func_name:
start = node.start_byte - prefix_len
end = node.end_byte - prefix_len
if 0 <= start and end <= len(content_bytes):
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
# Check if there's a dot before the method name (indicating a receiver)
before_method = content[:method_start]
stripped_before = before_method.rstrip()
if stripped_before.endswith("."):
dot_pos = len(stripped_before) - 1
before_dot = content[:dot_pos]
for child in node.children:
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out)
# Check for new ClassName() or new ClassName(args)
stripped_before_dot = before_dot.rstrip()
if stripped_before_dot.endswith(")"):
# Find matching opening paren for constructor args
close_paren_pos = len(stripped_before_dot) - 1
paren_depth = 1
i = close_paren_pos - 1
while i >= 0 and paren_depth > 0:
if stripped_before_dot[i] == ")":
paren_depth += 1
elif stripped_before_dot[i] == "(":
paren_depth -= 1
i -= 1
if paren_depth == 0:
open_paren_pos = i + 1
# Look for "new ClassName" before the opening paren
before_paren = stripped_before_dot[:open_paren_pos].rstrip()
new_match = re.search(r"new\s+[a-zA-Z_]\w*\s*$", before_paren)
if new_match:
receiver_start = new_match.start()
else:
# Could be chained call like something().method()
# For now, just use the part from open paren
receiver_start = open_paren_pos
else:
# Simple identifier: obj.method() or Class.method() or pkg.Class.method()
ident_match = re.search(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$", stripped_before_dot)
if ident_match:
receiver_start = ident_match.start()
def _build_target_call(
self, node, wrapper_bytes: bytes, content_bytes: bytes,
start_byte: int, end_byte: int, base_offset: int,
) -> TargetCall:
"""Build a TargetCall from a tree-sitter method_invocation node."""
get_text = self.analyzer.get_node_text
full_call = content[receiver_start:end_pos]
receiver = (
content[receiver_start:method_start].rstrip(".").strip() if receiver_start < method_start else None
)
object_node = node.child_by_field_name("object")
args_node = node.child_by_field_name("arguments")
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
# argument_list node includes parens, strip them
if args_text.startswith("(") and args_text.endswith(")"):
args_text = args_text[1:-1]
target_calls.append(
TargetCall(
receiver=receiver,
method_name=method_name,
arguments=args_content,
full_call=full_call,
start_pos=base_offset + receiver_start,
end_pos=base_offset + end_pos,
)
)
# Byte offsets -> char offsets for correct Python string indexing
start_char = len(content_bytes[:start_byte].decode("utf8"))
end_char = len(content_bytes[:end_byte].decode("utf8"))
return target_calls
return TargetCall(
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
method_name=self.func_name,
arguments=args_text,
full_call=get_text(node, wrapper_bytes),
start_pos=base_offset + start_char,
end_pos=base_offset + end_char,
)
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
"""Check if assertion is assigned to a variable.

View file

@ -6,6 +6,9 @@ regression test code that captures function return values.
All tests assert for full string equality, no substring matching.
"""
from pathlib import Path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
@ -839,26 +842,26 @@ public class FibonacciTest {
assertEquals(55, calc.fibonacci(10));
}
}"""
expected = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest__perfinstrumented {
@Test
void testFibonacci() {
Calculator calc = new Calculator();
Object _cf_result1 = calc.fibonacci(10);
}
}"""
func = FunctionToOptimize(
function_name="fibonacci",
file_path=Path("Calculator.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code=test_code,
function_name="fibonacci",
qualified_name="com.example.Calculator.fibonacci",
mode="behavior",
function_to_optimize=func,
)
assert result == expected
# Behavior mode now adds full instrumentation
assert "FibonacciTest__perfinstrumented" in result
assert "_cf_result" in result
assert "com.codeflash.Serializer.serialize" in result
def test_behavior_mode_with_assertj(self):
from codeflash.languages.java.instrumentation import instrument_generated_java_test
@ -875,25 +878,26 @@ public class StringUtilsTest {
assertThat(StringUtils.reverse("hello")).isEqualTo("olleh");
}
}"""
expected = """\
package com.example;
import org.junit.jupiter.api.Test;
import static org.assertj.core.api.Assertions.assertThat;
public class StringUtilsTest__perfinstrumented {
@Test
void testReverse() {
Object _cf_result1 = StringUtils.reverse("hello");
}
}"""
func = FunctionToOptimize(
function_name="reverse",
file_path=Path("StringUtils.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code=test_code,
function_name="reverse",
qualified_name="com.example.StringUtils.reverse",
mode="behavior",
function_to_optimize=func,
)
assert result == expected
# Behavior mode now adds full instrumentation
assert "StringUtilsTest__perfinstrumented" in result
assert "_cf_result" in result
assert "com.codeflash.Serializer.serialize" in result
class TestComplexRealWorldExamples:

View file

@ -125,11 +125,10 @@ public class CalculatorTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
test_path=test_file,
)
assert success is True
@ -186,11 +185,10 @@ public class FibonacciTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
test_path=test_file,
)
assert success is True
@ -236,11 +234,10 @@ public class FibonacciTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
test_path=test_file,
)
assert success is True
@ -275,11 +272,10 @@ public class CalculatorTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -342,11 +338,10 @@ public class MathTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -434,11 +429,10 @@ public class ServiceTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -510,15 +504,12 @@ public class ServiceTest__perfonlyinstrumented {
language="java",
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
function_to_optimize=func,
tests_project_root=tmp_path,
mode="behavior",
)
assert success is False
with pytest.raises(ValueError):
instrument_existing_test(
test_string="",
function_to_optimize=func,
mode="behavior",
)
class TestKryoSerializerUsage:
@ -925,24 +916,29 @@ public class CalculatorTest {
}
}
"""
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code,
function_name="add",
qualified_name="Calculator.add",
mode="behavior",
function_to_optimize=func,
)
# Behavior mode transforms assertions to capture return values
expected = """import org.junit.jupiter.api.Test;
public class CalculatorTest__perfinstrumented {
@Test
public void testAdd() {
Object _cf_result1 = new Calculator().add(2, 2);
}
}
"""
assert result == expected
# Behavior mode now adds full instrumentation (SQLite, timing markers, etc.)
assert "CalculatorTest__perfinstrumented" in result
assert "_cf_result" in result
assert "com.codeflash.Serializer.serialize" in result
assert "CODEFLASH_OUTPUT_FILE" in result
assert "CREATE TABLE IF NOT EXISTS test_results" in result
def test_instrument_generated_test_performance_mode(self):
"""Test instrumenting generated test in performance mode with inner loop."""
@ -955,11 +951,21 @@ public class GeneratedTest {
}
}
"""
func = FunctionToOptimize(
function_name="method",
file_path=Path("Target.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code,
function_name="method",
qualified_name="Target.method",
mode="performance",
function_to_optimize=func,
)
expected = """import org.junit.jupiter.api.Test;
@ -1130,11 +1136,10 @@ public class BraceTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -1223,11 +1228,10 @@ public class ImportTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """package com.example;
@ -1293,11 +1297,10 @@ public class EmptyTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -1359,11 +1362,10 @@ public class NestedTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -1435,11 +1437,10 @@ public class InnerClassTest {
)
success, result = instrument_existing_test(
test_file,
call_positions=[],
test_string=source,
function_to_optimize=func,
tests_project_root=tmp_path,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
@ -1643,7 +1644,7 @@ public class CalculatorTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="behavior"
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
@ -1755,7 +1756,7 @@ public class MathUtilsTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="performance"
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success
@ -1888,7 +1889,7 @@ public class StringUtilsTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="behavior"
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
@ -1990,7 +1991,7 @@ public class BrokenCalcTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="behavior"
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
@ -2100,7 +2101,7 @@ public class CounterTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="behavior"
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
@ -2262,7 +2263,7 @@ public class FibonacciTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="performance"
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success
@ -2383,7 +2384,7 @@ public class MathOpsTest {
)
success, instrumented = instrument_existing_test(
test_file, [], func_info, test_dir, mode="performance"
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success