mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Replace Regex with tree-sitter
This commit is contained in:
parent
17c8cf8a19
commit
4c976415ef
5 changed files with 308 additions and 325 deletions
|
|
@ -61,82 +61,162 @@ def _is_test_annotation(stripped_line: str) -> bool:
|
|||
return bool(_TEST_ANNOTATION_RE.match(stripped_line))
|
||||
|
||||
|
||||
def _find_balanced_end(text: str, start: int) -> int:
|
||||
"""Find the position after the closing paren that balances the opening paren at start.
|
||||
def _is_inside_lambda(node) -> bool:
|
||||
"""Check if a tree-sitter node is inside a lambda_expression."""
|
||||
current = node.parent
|
||||
while current is not None:
|
||||
if current.type == "lambda_expression":
|
||||
return True
|
||||
if current.type == "method_declaration":
|
||||
return False
|
||||
current = current.parent
|
||||
return False
|
||||
|
||||
Args:
|
||||
text: The source text.
|
||||
start: Index of the opening parenthesis '('.
|
||||
|
||||
Returns:
|
||||
Index one past the matching closing ')', or -1 if not found.
|
||||
_TS_BODY_PREFIX = "class _D { void _m() {\n"
|
||||
_TS_BODY_SUFFIX = "\n}}"
|
||||
_TS_BODY_PREFIX_BYTES = _TS_BODY_PREFIX.encode("utf8")
|
||||
|
||||
|
||||
def wrap_target_calls_with_treesitter(body_lines: list[str], func_name: str, iter_id: int) -> tuple[list[str], int]:
|
||||
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
|
||||
|
||||
Parses the method body with tree-sitter, walks the AST for method_invocation nodes
|
||||
matching func_name, and generates capture/serialize lines. Uses the parent node type
|
||||
to determine whether to keep or remove the original line after replacement.
|
||||
|
||||
Returns (wrapped_body_lines, call_counter).
|
||||
"""
|
||||
if start >= len(text) or text[start] != "(":
|
||||
return -1
|
||||
depth = 1
|
||||
pos = start + 1
|
||||
in_string = False
|
||||
string_char = None
|
||||
in_char = False
|
||||
while pos < len(text) and depth > 0:
|
||||
ch = text[pos]
|
||||
prev = text[pos - 1] if pos > 0 else ""
|
||||
if ch == "'" and not in_string and prev != "\\":
|
||||
in_char = not in_char
|
||||
elif ch == '"' and not in_char and prev != "\\":
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = ch
|
||||
elif ch == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
elif not in_string and not in_char:
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
elif ch == ")":
|
||||
depth -= 1
|
||||
pos += 1
|
||||
return pos if depth == 0 else -1
|
||||
from codeflash.languages.java.parser import get_java_analyzer
|
||||
|
||||
analyzer = get_java_analyzer()
|
||||
body_text = "\n".join(body_lines)
|
||||
body_bytes = body_text.encode("utf8")
|
||||
prefix_len = len(_TS_BODY_PREFIX_BYTES)
|
||||
|
||||
def _find_method_calls_balanced(line: str, func_name: str):
|
||||
"""Find method calls to func_name with properly balanced parentheses.
|
||||
wrapper_bytes = _TS_BODY_PREFIX_BYTES + body_bytes + _TS_BODY_SUFFIX.encode("utf8")
|
||||
tree = analyzer.parse(wrapper_bytes)
|
||||
|
||||
Handles nested parentheses in arguments correctly, unlike a pure regex approach.
|
||||
Returns a list of (start, end, full_call) tuples where start/end are positions
|
||||
in the line and full_call is the matched text (receiver.funcName(args)).
|
||||
# Collect all matching calls with their metadata
|
||||
calls = []
|
||||
_collect_calls(tree.root_node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, calls)
|
||||
|
||||
Args:
|
||||
line: A single line of Java source code.
|
||||
func_name: The method name to look for.
|
||||
if not calls:
|
||||
return list(body_lines), 0
|
||||
|
||||
Returns:
|
||||
List of (start_pos, end_pos, full_call_text) tuples.
|
||||
# Build line byte-start offsets for mapping calls to body_lines indices
|
||||
line_byte_starts = []
|
||||
offset = 0
|
||||
for line in body_lines:
|
||||
line_byte_starts.append(offset)
|
||||
offset += len(line.encode("utf8")) + 1 # +1 for \n from join
|
||||
|
||||
"""
|
||||
# First find all occurrences of .funcName( in the line using regex
|
||||
# to locate the method name, then use balanced paren finding for args
|
||||
prefix_pattern = re.compile(
|
||||
rf"((?:new\s+\w+\s*\([^)]*\)|[a-zA-Z_]\w*))\s*\.\s*{re.escape(func_name)}\s*\("
|
||||
)
|
||||
results = []
|
||||
search_start = 0
|
||||
while search_start < len(line):
|
||||
m = prefix_pattern.search(line, search_start)
|
||||
if not m:
|
||||
break
|
||||
# m.end() - 1 is the position of the opening paren
|
||||
open_paren_pos = m.end() - 1
|
||||
close_pos = _find_balanced_end(line, open_paren_pos)
|
||||
if close_pos == -1:
|
||||
# Unbalanced parens - skip this match
|
||||
search_start = m.end()
|
||||
# Group non-lambda calls by their line index
|
||||
calls_by_line: dict[int, list] = {}
|
||||
for call in calls:
|
||||
if call["in_lambda"]:
|
||||
continue
|
||||
full_call = line[m.start():close_pos]
|
||||
results.append((m.start(), close_pos, full_call))
|
||||
search_start = close_pos
|
||||
return results
|
||||
line_idx = _byte_to_line_index(call["start_byte"], line_byte_starts)
|
||||
calls_by_line.setdefault(line_idx, []).append(call)
|
||||
|
||||
wrapped = []
|
||||
call_counter = 0
|
||||
|
||||
for line_idx, body_line in enumerate(body_lines):
|
||||
if line_idx not in calls_by_line:
|
||||
wrapped.append(body_line)
|
||||
continue
|
||||
|
||||
line_calls = sorted(calls_by_line[line_idx], key=lambda c: c["start_byte"], reverse=True)
|
||||
line_indent_str = " " * (len(body_line) - len(body_line.lstrip()))
|
||||
line_byte_start = line_byte_starts[line_idx]
|
||||
line_bytes = body_line.encode("utf8")
|
||||
|
||||
new_line = body_line
|
||||
# Track cumulative char shift from earlier edits on this line
|
||||
char_shift = 0
|
||||
|
||||
for call in line_calls:
|
||||
call_counter += 1
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
cast_type = _infer_array_cast_type(body_line)
|
||||
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
|
||||
|
||||
capture_stmt = f"var {var_name} = {call['full_call']};"
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
|
||||
if call["parent_type"] == "expression_statement":
|
||||
# Replace the expression_statement IN PLACE with capture+serialize.
|
||||
# This keeps the code inside whatever scope it's in (e.g. try block),
|
||||
# preventing calls from being moved outside try-catch blocks.
|
||||
es_start_byte = call["es_start_byte"] - line_byte_start
|
||||
es_end_byte = call["es_end_byte"] - line_byte_start
|
||||
es_start_char = len(line_bytes[:es_start_byte].decode("utf8"))
|
||||
es_end_char = len(line_bytes[:es_end_byte].decode("utf8"))
|
||||
replacement = f"{capture_stmt} {serialize_stmt}"
|
||||
adj_start = es_start_char + char_shift
|
||||
adj_end = es_end_char + char_shift
|
||||
new_line = new_line[:adj_start] + replacement + new_line[adj_end:]
|
||||
char_shift += len(replacement) - (es_end_char - es_start_char)
|
||||
else:
|
||||
# The call is embedded in a larger expression (assignment, assertion, etc.)
|
||||
# Emit capture+serialize before the line, then replace the call with the variable.
|
||||
capture_line = f"{line_indent_str}{capture_stmt}"
|
||||
serialize_line = f"{line_indent_str}{serialize_stmt}"
|
||||
wrapped.append(capture_line)
|
||||
wrapped.append(serialize_line)
|
||||
|
||||
call_start_byte = call["start_byte"] - line_byte_start
|
||||
call_end_byte = call["end_byte"] - line_byte_start
|
||||
call_start_char = len(line_bytes[:call_start_byte].decode("utf8"))
|
||||
call_end_char = len(line_bytes[:call_end_byte].decode("utf8"))
|
||||
adj_start = call_start_char + char_shift
|
||||
adj_end = call_end_char + char_shift
|
||||
new_line = new_line[:adj_start] + var_with_cast + new_line[adj_end:]
|
||||
char_shift += len(var_with_cast) - (call_end_char - call_start_char)
|
||||
|
||||
# Keep the modified line only if it has meaningful content left
|
||||
if new_line.strip():
|
||||
wrapped.append(new_line)
|
||||
|
||||
return wrapped, call_counter
|
||||
|
||||
|
||||
def _collect_calls(node, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, out):
|
||||
"""Recursively collect method_invocation nodes matching func_name."""
|
||||
if node.type == "method_invocation":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node and analyzer.get_node_text(name_node, wrapper_bytes) == func_name:
|
||||
start = node.start_byte - prefix_len
|
||||
end = node.end_byte - prefix_len
|
||||
if start >= 0 and end <= len(body_bytes):
|
||||
parent = node.parent
|
||||
parent_type = parent.type if parent else ""
|
||||
es_start = es_end = 0
|
||||
if parent_type == "expression_statement":
|
||||
es_start = parent.start_byte - prefix_len
|
||||
es_end = parent.end_byte - prefix_len
|
||||
out.append(
|
||||
{
|
||||
"start_byte": start,
|
||||
"end_byte": end,
|
||||
"full_call": analyzer.get_node_text(node, wrapper_bytes),
|
||||
"parent_type": parent_type,
|
||||
"in_lambda": _is_inside_lambda(node),
|
||||
"es_start_byte": es_start,
|
||||
"es_end_byte": es_end,
|
||||
}
|
||||
)
|
||||
for child in node.children:
|
||||
_collect_calls(child, wrapper_bytes, body_bytes, prefix_len, func_name, analyzer, out)
|
||||
|
||||
|
||||
def _byte_to_line_index(byte_offset: int, line_byte_starts: list[int]) -> int:
|
||||
"""Map a byte offset in body_text to a body_lines index."""
|
||||
for i in range(len(line_byte_starts) - 1, -1, -1):
|
||||
if byte_offset >= line_byte_starts[i]:
|
||||
return i
|
||||
return 0
|
||||
|
||||
|
||||
def _infer_array_cast_type(line: str) -> str | None:
|
||||
|
|
@ -279,9 +359,7 @@ def instrument_existing_test(
|
|||
# This includes the class declaration, return types, constructor calls,
|
||||
# variable declarations, etc. We use word-boundary matching to avoid
|
||||
# replacing substrings of other identifiers.
|
||||
modified_source = re.sub(
|
||||
rf"\b{re.escape(original_class_name)}\b", new_class_name, source
|
||||
)
|
||||
modified_source = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, source)
|
||||
|
||||
# Add timing instrumentation to test methods
|
||||
# Use original class name (without suffix) in timing markers for consistency with Python
|
||||
|
|
@ -429,95 +507,11 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
|
|||
i += 1
|
||||
break
|
||||
|
||||
# Wrap function calls to capture return values
|
||||
# Look for patterns like: obj.funcName(args) or new Class().funcName(args)
|
||||
call_counter = 0
|
||||
wrapped_body_lines = []
|
||||
|
||||
# Track lambda block nesting depth to avoid wrapping calls inside lambda bodies.
|
||||
# assertThrows/assertDoesNotThrow expect an Executable (void functional interface),
|
||||
# and wrapping the call in a variable assignment would turn the void-compatible
|
||||
# lambda into a value-returning lambda, causing a compilation error.
|
||||
# Also, variables declared outside lambdas cannot be reassigned inside them
|
||||
# (Java requires effectively final variables in lambda captures).
|
||||
# Handles both no-arg lambdas: () -> { func(); }
|
||||
# and parameterized lambdas: (a, b, c) -> { func(); }
|
||||
lambda_brace_depth = 0
|
||||
|
||||
for body_line in body_lines:
|
||||
# Detect block lambda openings: (...) -> { or () -> {
|
||||
# Matches both () -> { and (a, b, c) -> {
|
||||
is_lambda_open = bool(re.search(r"->\s*\{", body_line))
|
||||
|
||||
# Update lambda brace depth tracking for block lambdas
|
||||
if is_lambda_open or lambda_brace_depth > 0:
|
||||
open_braces = body_line.count("{")
|
||||
close_braces = body_line.count("}")
|
||||
if is_lambda_open and lambda_brace_depth == 0:
|
||||
# Starting a new lambda block - only count braces from this lambda
|
||||
lambda_brace_depth = open_braces - close_braces
|
||||
else:
|
||||
lambda_brace_depth += open_braces - close_braces
|
||||
# Ensure depth doesn't go below 0
|
||||
lambda_brace_depth = max(0, lambda_brace_depth)
|
||||
|
||||
inside_lambda = lambda_brace_depth > 0 or bool(re.search(r"->\s+\S", body_line))
|
||||
|
||||
# Check if this line contains a call to the target function
|
||||
if func_name in body_line and "(" in body_line:
|
||||
# Skip wrapping if the function call is inside a lambda expression
|
||||
if inside_lambda:
|
||||
wrapped_body_lines.append(body_line)
|
||||
continue
|
||||
|
||||
line_indent = len(body_line) - len(body_line.lstrip())
|
||||
line_indent_str = " " * line_indent
|
||||
|
||||
# Find all matches using balanced parenthesis matching
|
||||
# This correctly handles nested parens like:
|
||||
# obj.func(a, Rows.toRowID(frame.getIndex(), row))
|
||||
matches = _find_method_calls_balanced(body_line, func_name)
|
||||
if matches:
|
||||
# Process matches in reverse order to maintain correct positions
|
||||
new_line = body_line
|
||||
for start_pos, end_pos, full_call in reversed(matches):
|
||||
call_counter += 1
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
|
||||
# Check if we need to cast the result for assertions with primitive arrays
|
||||
# This handles assertArrayEquals(int[], int[]) etc.
|
||||
cast_type = _infer_array_cast_type(body_line)
|
||||
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
|
||||
|
||||
# Replace this occurrence with the variable (with cast if needed)
|
||||
new_line = new_line[:start_pos] + var_with_cast + new_line[end_pos:]
|
||||
|
||||
# Use 'var' instead of 'Object' to preserve the exact return type.
|
||||
# This avoids boxing mismatches (e.g., assertEquals(int, Object) where
|
||||
# Object is boxed Long but expected is boxed Integer). Requires Java 10+.
|
||||
capture_line = f"{line_indent_str}var {var_name} = {full_call};"
|
||||
wrapped_body_lines.append(capture_line)
|
||||
|
||||
# Immediately serialize the captured result while the variable
|
||||
# is still in scope. This is necessary because the variable may
|
||||
# be declared inside a nested block (while/for/if/try) and would
|
||||
# be out of scope at the end of the method body.
|
||||
serialize_line = (
|
||||
f"{line_indent_str}_cf_serializedResult{iter_id} = "
|
||||
f"com.codeflash.Serializer.serialize((Object) {var_name});"
|
||||
)
|
||||
wrapped_body_lines.append(serialize_line)
|
||||
|
||||
# Check if the line is now just a variable reference (invalid statement)
|
||||
# This happens when the original line was just a void method call
|
||||
# e.g., "BubbleSort.bubbleSort(original);" becomes "_cf_result1_1;"
|
||||
stripped_new = new_line.strip().rstrip(";").strip()
|
||||
if stripped_new and stripped_new not in (var_name, var_with_cast):
|
||||
wrapped_body_lines.append(new_line)
|
||||
else:
|
||||
wrapped_body_lines.append(body_line)
|
||||
else:
|
||||
wrapped_body_lines.append(body_line)
|
||||
# Wrap function calls to capture return values using tree-sitter AST analysis.
|
||||
# This correctly handles lambdas, try-catch blocks, assignments, and nested calls.
|
||||
wrapped_body_lines, _call_counter = wrap_target_calls_with_treesitter(
|
||||
body_lines=body_lines, func_name=func_name, iter_id=iter_id
|
||||
)
|
||||
|
||||
# Add behavior instrumentation code
|
||||
behavior_start_code = [
|
||||
|
|
@ -833,12 +827,9 @@ def instrument_generated_java_test(
|
|||
|
||||
original_class_name = class_match.group(1)
|
||||
|
||||
|
||||
# For performance mode, add timing instrumentation
|
||||
# Use original class name (without suffix) in timing markers for consistency with Python
|
||||
if mode == "performance":
|
||||
|
||||
|
||||
# Rename class based on mode
|
||||
if mode == "behavior":
|
||||
new_class_name = f"{original_class_name}__perfinstrumented"
|
||||
|
|
@ -847,9 +838,7 @@ def instrument_generated_java_test(
|
|||
|
||||
# Rename all references to the original class name in the source.
|
||||
# This includes the class declaration, return types, constructor calls, etc.
|
||||
modified_code = re.sub(
|
||||
rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code
|
||||
)
|
||||
modified_code = re.sub(rf"\b{re.escape(original_class_name)}\b", new_class_name, test_code)
|
||||
|
||||
modified_code = _add_timing_instrumentation(
|
||||
modified_code,
|
||||
|
|
@ -857,7 +846,12 @@ def instrument_generated_java_test(
|
|||
function_name,
|
||||
)
|
||||
elif mode == "behavior":
|
||||
_ , modified_code = instrument_existing_test(test_string=test_code, mode=mode, function_to_optimize=function_to_optimize, test_class_name=original_class_name)
|
||||
_, modified_code = instrument_existing_test(
|
||||
test_string=test_code,
|
||||
mode=mode,
|
||||
function_to_optimize=function_to_optimize,
|
||||
test_class_name=original_class_name,
|
||||
)
|
||||
|
||||
logger.debug("Instrumented generated Java test for %s (mode=%s)", function_name, mode)
|
||||
return modified_code
|
||||
|
|
@ -890,5 +884,3 @@ def _add_import(source: str, import_statement: str) -> str:
|
|||
|
||||
lines.insert(insert_idx, import_statement + "\n")
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -110,7 +110,11 @@ class JavaLineProfiler:
|
|||
lines[:import_end_idx] + [profiler_class_code + "\n"] + lines[import_end_idx:]
|
||||
)
|
||||
|
||||
return "".join(lines_with_profiler)
|
||||
result = "".join(lines_with_profiler)
|
||||
if not analyzer.validate_syntax(result):
|
||||
logger.warning("Line profiler instrumentation produced invalid Java, returning original source")
|
||||
return source
|
||||
return result
|
||||
|
||||
def _generate_profiler_class(self) -> str:
|
||||
"""Generate Java code for profiler class."""
|
||||
|
|
|
|||
|
|
@ -539,87 +539,69 @@ class JavaAssertTransformer:
|
|||
|
||||
return pos
|
||||
|
||||
# Wrapper template to make assertion argument fragments parseable by tree-sitter.
|
||||
# e.g. content "55, obj.fibonacci(10)" becomes "class _D { void _m() { _d(55, obj.fibonacci(10)); } }"
|
||||
_TS_WRAPPER_PREFIX = "class _D { void _m() { _d("
|
||||
_TS_WRAPPER_SUFFIX = "); } }"
|
||||
_TS_WRAPPER_PREFIX_BYTES = _TS_WRAPPER_PREFIX.encode("utf8")
|
||||
|
||||
def _extract_target_calls(self, content: str, base_offset: int) -> list[TargetCall]:
|
||||
"""Extract calls to the target function from assertion arguments."""
|
||||
target_calls: list[TargetCall] = []
|
||||
"""Find all calls to the target function within assertion argument text using tree-sitter."""
|
||||
if not content or not content.strip():
|
||||
return []
|
||||
|
||||
# Pattern to match method calls with various receiver styles:
|
||||
# - obj.method(args)
|
||||
# - ClassName.staticMethod(args)
|
||||
# - new ClassName().method(args)
|
||||
# - new ClassName(args).method(args)
|
||||
# - method(args) (no receiver)
|
||||
#
|
||||
# Strategy: Find the function name, then look backwards for the receiver
|
||||
pattern = re.compile(rf"({re.escape(self.func_name)})\s*\(", re.MULTILINE)
|
||||
content_bytes = content.encode("utf8")
|
||||
wrapper_bytes = self._TS_WRAPPER_PREFIX_BYTES + content_bytes + self._TS_WRAPPER_SUFFIX.encode("utf8")
|
||||
tree = self.analyzer.parse(wrapper_bytes)
|
||||
|
||||
for match in pattern.finditer(content):
|
||||
method_name = match.group(1)
|
||||
method_start = match.start()
|
||||
results: list[TargetCall] = []
|
||||
self._collect_target_invocations(tree.root_node, wrapper_bytes, content_bytes, base_offset, results)
|
||||
return results
|
||||
|
||||
# Find the arguments
|
||||
paren_pos = match.end() - 1
|
||||
args_content, end_pos = self._find_balanced_parens(content, paren_pos)
|
||||
if args_content is None:
|
||||
continue
|
||||
def _collect_target_invocations(
|
||||
self, node, wrapper_bytes: bytes, content_bytes: bytes,
|
||||
base_offset: int, out: list[TargetCall],
|
||||
) -> None:
|
||||
"""Recursively walk the AST and collect method_invocation nodes that match self.func_name."""
|
||||
prefix_len = len(self._TS_WRAPPER_PREFIX_BYTES)
|
||||
|
||||
# Look backwards from the method name to find the receiver
|
||||
receiver_start = method_start
|
||||
if node.type == "method_invocation":
|
||||
name_node = node.child_by_field_name("name")
|
||||
if name_node and self.analyzer.get_node_text(name_node, wrapper_bytes) == self.func_name:
|
||||
start = node.start_byte - prefix_len
|
||||
end = node.end_byte - prefix_len
|
||||
if 0 <= start and end <= len(content_bytes):
|
||||
out.append(self._build_target_call(node, wrapper_bytes, content_bytes, start, end, base_offset))
|
||||
|
||||
# Check if there's a dot before the method name (indicating a receiver)
|
||||
before_method = content[:method_start]
|
||||
stripped_before = before_method.rstrip()
|
||||
if stripped_before.endswith("."):
|
||||
dot_pos = len(stripped_before) - 1
|
||||
before_dot = content[:dot_pos]
|
||||
for child in node.children:
|
||||
self._collect_target_invocations(child, wrapper_bytes, content_bytes, base_offset, out)
|
||||
|
||||
# Check for new ClassName() or new ClassName(args)
|
||||
stripped_before_dot = before_dot.rstrip()
|
||||
if stripped_before_dot.endswith(")"):
|
||||
# Find matching opening paren for constructor args
|
||||
close_paren_pos = len(stripped_before_dot) - 1
|
||||
paren_depth = 1
|
||||
i = close_paren_pos - 1
|
||||
while i >= 0 and paren_depth > 0:
|
||||
if stripped_before_dot[i] == ")":
|
||||
paren_depth += 1
|
||||
elif stripped_before_dot[i] == "(":
|
||||
paren_depth -= 1
|
||||
i -= 1
|
||||
if paren_depth == 0:
|
||||
open_paren_pos = i + 1
|
||||
# Look for "new ClassName" before the opening paren
|
||||
before_paren = stripped_before_dot[:open_paren_pos].rstrip()
|
||||
new_match = re.search(r"new\s+[a-zA-Z_]\w*\s*$", before_paren)
|
||||
if new_match:
|
||||
receiver_start = new_match.start()
|
||||
else:
|
||||
# Could be chained call like something().method()
|
||||
# For now, just use the part from open paren
|
||||
receiver_start = open_paren_pos
|
||||
else:
|
||||
# Simple identifier: obj.method() or Class.method() or pkg.Class.method()
|
||||
ident_match = re.search(r"[a-zA-Z_]\w*(?:\.[a-zA-Z_]\w*)*\s*$", stripped_before_dot)
|
||||
if ident_match:
|
||||
receiver_start = ident_match.start()
|
||||
def _build_target_call(
|
||||
self, node, wrapper_bytes: bytes, content_bytes: bytes,
|
||||
start_byte: int, end_byte: int, base_offset: int,
|
||||
) -> TargetCall:
|
||||
"""Build a TargetCall from a tree-sitter method_invocation node."""
|
||||
get_text = self.analyzer.get_node_text
|
||||
|
||||
full_call = content[receiver_start:end_pos]
|
||||
receiver = (
|
||||
content[receiver_start:method_start].rstrip(".").strip() if receiver_start < method_start else None
|
||||
)
|
||||
object_node = node.child_by_field_name("object")
|
||||
args_node = node.child_by_field_name("arguments")
|
||||
args_text = get_text(args_node, wrapper_bytes) if args_node else ""
|
||||
# argument_list node includes parens, strip them
|
||||
if args_text.startswith("(") and args_text.endswith(")"):
|
||||
args_text = args_text[1:-1]
|
||||
|
||||
target_calls.append(
|
||||
TargetCall(
|
||||
receiver=receiver,
|
||||
method_name=method_name,
|
||||
arguments=args_content,
|
||||
full_call=full_call,
|
||||
start_pos=base_offset + receiver_start,
|
||||
end_pos=base_offset + end_pos,
|
||||
)
|
||||
)
|
||||
# Byte offsets -> char offsets for correct Python string indexing
|
||||
start_char = len(content_bytes[:start_byte].decode("utf8"))
|
||||
end_char = len(content_bytes[:end_byte].decode("utf8"))
|
||||
|
||||
return target_calls
|
||||
return TargetCall(
|
||||
receiver=get_text(object_node, wrapper_bytes) if object_node else None,
|
||||
method_name=self.func_name,
|
||||
arguments=args_text,
|
||||
full_call=get_text(node, wrapper_bytes),
|
||||
start_pos=base_offset + start_char,
|
||||
end_pos=base_offset + end_char,
|
||||
)
|
||||
|
||||
def _detect_variable_assignment(self, source: str, assertion_start: int) -> tuple[str | None, str | None]:
|
||||
"""Check if assertion is assigned to a variable.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,9 @@ regression test code that captures function return values.
|
|||
All tests assert for full string equality, no substring matching.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.languages.java.remove_asserts import JavaAssertTransformer, transform_java_assertions
|
||||
|
||||
|
||||
|
|
@ -839,26 +842,26 @@ public class FibonacciTest {
|
|||
assertEquals(55, calc.fibonacci(10));
|
||||
}
|
||||
}"""
|
||||
expected = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
public class FibonacciTest__perfinstrumented {
|
||||
@Test
|
||||
void testFibonacci() {
|
||||
Calculator calc = new Calculator();
|
||||
Object _cf_result1 = calc.fibonacci(10);
|
||||
}
|
||||
}"""
|
||||
func = FunctionToOptimize(
|
||||
function_name="fibonacci",
|
||||
file_path=Path("Calculator.java"),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
result = instrument_generated_java_test(
|
||||
test_code=test_code,
|
||||
function_name="fibonacci",
|
||||
qualified_name="com.example.Calculator.fibonacci",
|
||||
mode="behavior",
|
||||
function_to_optimize=func,
|
||||
)
|
||||
assert result == expected
|
||||
# Behavior mode now adds full instrumentation
|
||||
assert "FibonacciTest__perfinstrumented" in result
|
||||
assert "_cf_result" in result
|
||||
assert "com.codeflash.Serializer.serialize" in result
|
||||
|
||||
def test_behavior_mode_with_assertj(self):
|
||||
from codeflash.languages.java.instrumentation import instrument_generated_java_test
|
||||
|
|
@ -875,25 +878,26 @@ public class StringUtilsTest {
|
|||
assertThat(StringUtils.reverse("hello")).isEqualTo("olleh");
|
||||
}
|
||||
}"""
|
||||
expected = """\
|
||||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
public class StringUtilsTest__perfinstrumented {
|
||||
@Test
|
||||
void testReverse() {
|
||||
Object _cf_result1 = StringUtils.reverse("hello");
|
||||
}
|
||||
}"""
|
||||
func = FunctionToOptimize(
|
||||
function_name="reverse",
|
||||
file_path=Path("StringUtils.java"),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
result = instrument_generated_java_test(
|
||||
test_code=test_code,
|
||||
function_name="reverse",
|
||||
qualified_name="com.example.StringUtils.reverse",
|
||||
mode="behavior",
|
||||
function_to_optimize=func,
|
||||
)
|
||||
assert result == expected
|
||||
# Behavior mode now adds full instrumentation
|
||||
assert "StringUtilsTest__perfinstrumented" in result
|
||||
assert "_cf_result" in result
|
||||
assert "com.codeflash.Serializer.serialize" in result
|
||||
|
||||
|
||||
class TestComplexRealWorldExamples:
|
||||
|
|
|
|||
|
|
@ -125,11 +125,10 @@ public class CalculatorTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
|
|
@ -186,11 +185,10 @@ public class FibonacciTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
|
|
@ -236,11 +234,10 @@ public class FibonacciTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
assert success is True
|
||||
|
|
@ -275,11 +272,10 @@ public class CalculatorTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -342,11 +338,10 @@ public class MathTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -434,11 +429,10 @@ public class ServiceTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -510,15 +504,12 @@ public class ServiceTest__perfonlyinstrumented {
|
|||
language="java",
|
||||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
assert success is False
|
||||
with pytest.raises(ValueError):
|
||||
instrument_existing_test(
|
||||
test_string="",
|
||||
function_to_optimize=func,
|
||||
mode="behavior",
|
||||
)
|
||||
|
||||
|
||||
class TestKryoSerializerUsage:
|
||||
|
|
@ -925,24 +916,29 @@ public class CalculatorTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionToOptimize(
|
||||
function_name="add",
|
||||
file_path=Path("Calculator.java"),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
result = instrument_generated_java_test(
|
||||
test_code,
|
||||
function_name="add",
|
||||
qualified_name="Calculator.add",
|
||||
mode="behavior",
|
||||
function_to_optimize=func,
|
||||
)
|
||||
|
||||
# Behavior mode transforms assertions to capture return values
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
||||
public class CalculatorTest__perfinstrumented {
|
||||
@Test
|
||||
public void testAdd() {
|
||||
Object _cf_result1 = new Calculator().add(2, 2);
|
||||
}
|
||||
}
|
||||
"""
|
||||
assert result == expected
|
||||
# Behavior mode now adds full instrumentation (SQLite, timing markers, etc.)
|
||||
assert "CalculatorTest__perfinstrumented" in result
|
||||
assert "_cf_result" in result
|
||||
assert "com.codeflash.Serializer.serialize" in result
|
||||
assert "CODEFLASH_OUTPUT_FILE" in result
|
||||
assert "CREATE TABLE IF NOT EXISTS test_results" in result
|
||||
|
||||
def test_instrument_generated_test_performance_mode(self):
|
||||
"""Test instrumenting generated test in performance mode with inner loop."""
|
||||
|
|
@ -955,11 +951,21 @@ public class GeneratedTest {
|
|||
}
|
||||
}
|
||||
"""
|
||||
func = FunctionToOptimize(
|
||||
function_name="method",
|
||||
file_path=Path("Target.java"),
|
||||
starting_line=1,
|
||||
ending_line=5,
|
||||
parents=[],
|
||||
is_method=True,
|
||||
language="java",
|
||||
)
|
||||
result = instrument_generated_java_test(
|
||||
test_code,
|
||||
function_name="method",
|
||||
qualified_name="Target.method",
|
||||
mode="performance",
|
||||
function_to_optimize=func,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -1130,11 +1136,10 @@ public class BraceTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -1223,11 +1228,10 @@ public class ImportTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """package com.example;
|
||||
|
|
@ -1293,11 +1297,10 @@ public class EmptyTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -1359,11 +1362,10 @@ public class NestedTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -1435,11 +1437,10 @@ public class InnerClassTest {
|
|||
)
|
||||
|
||||
success, result = instrument_existing_test(
|
||||
test_file,
|
||||
call_positions=[],
|
||||
test_string=source,
|
||||
function_to_optimize=func,
|
||||
tests_project_root=tmp_path,
|
||||
mode="performance",
|
||||
test_path=test_file,
|
||||
)
|
||||
|
||||
expected = """import org.junit.jupiter.api.Test;
|
||||
|
|
@ -1643,7 +1644,7 @@ public class CalculatorTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="behavior"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -1755,7 +1756,7 @@ public class MathUtilsTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="performance"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -1888,7 +1889,7 @@ public class StringUtilsTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="behavior"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -1990,7 +1991,7 @@ public class BrokenCalcTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="behavior"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -2100,7 +2101,7 @@ public class CounterTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="behavior"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -2262,7 +2263,7 @@ public class FibonacciTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="performance"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
@ -2383,7 +2384,7 @@ public class MathOpsTest {
|
|||
)
|
||||
|
||||
success, instrumented = instrument_existing_test(
|
||||
test_file, [], func_info, test_dir, mode="performance"
|
||||
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
|
||||
)
|
||||
assert success
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue