chore: merge omni-java into fix/java-exception-assignment-instrumentation

Resolved conflicts by merging the best of both branches:
- Kept exception_class field from PR for better exception type detection
- Adopted more general variable assignment detection from omni-java
- Combined exception replacement logic to use exception_class with fallback
- Added double catch (specific exception + generic Exception) for robustness
- Merged test cases from both branches with updated expectations

Changes:
- Updated AssertionMatch to include all fields: assigned_var_type, assigned_var_name, exception_class
- Lambda extraction now works for all exception assertions
- Exception class extraction specifically for assertThrows
- Variable assignment detection handles final modifier and fully qualified types
- Exception replacement uses exception_class or falls back to assigned_var_type
- All 80 tests passing

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Mohamed Ashraf 2026-02-11 12:52:15 +00:00
commit abf2c98994
4 changed files with 155 additions and 67 deletions

View file

@ -83,10 +83,10 @@ class JavaLineProfiler:
lines = source.splitlines(keepends=True)
# Process functions in reverse order to preserve line numbers
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
func_lines = self._instrument_function(func, lines, file_path, analyzer)
start_idx = func.start_line - 1
end_idx = func.end_line
start_idx = func.starting_line - 1
end_idx = func.ending_line
lines = lines[:start_idx] + func_lines + lines[end_idx:]
instrumented_source = "".join(lines)
@ -261,7 +261,7 @@ class {self.profiler_class} {{
Instrumented function lines.
"""
func_lines = lines[func.start_line - 1 : func.end_line]
func_lines = lines[func.starting_line - 1 : func.ending_line]
instrumented_lines = []
# Parse the function to find executable lines
@ -271,7 +271,7 @@ class {self.profiler_class} {{
tree = analyzer.parse(source.encode("utf8"))
executable_lines = self._find_executable_lines(tree.root_node)
except Exception as e:
logger.warning("Failed to parse function %s: %s", func.name, e)
logger.warning("Failed to parse function %s: %s", func.function_name, e)
return func_lines
# Add profiling to each executable line
@ -279,7 +279,7 @@ class {self.profiler_class} {{
for local_idx, line in enumerate(func_lines):
local_line_num = local_idx + 1 # 1-indexed within function
global_line_num = func.start_line + local_idx # Global line number
global_line_num = func.starting_line + local_idx # Global line number
stripped = line.strip()
# Add enterFunction() call after the method's opening brace
@ -409,7 +409,7 @@ class {self.profiler_class} {{
"""
if not profile_file.exists():
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
try:
with profile_file.open("r") as f:
@ -435,15 +435,17 @@ class {self.profiler_class} {{
"content": content,
}
return {
result = {
"timings": timings,
"unit": 1e-9, # nanoseconds
"raw_data": data,
}
result["str_out"] = format_line_profile_results(result)
return result
except Exception as e:
logger.error("Failed to parse line profile results: %s", e)
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
def format_line_profile_results(results: dict, file_path: Path | None = None) -> str:

View file

@ -166,9 +166,9 @@ class AssertionMatch:
original_text: str = ""
is_exception_assertion: bool = False
lambda_body: str | None = None # For assertThrows lambda content
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
exception_class: str | None = None # Exception class from assertThrows args
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
class JavaAssertTransformer:
@ -306,8 +306,11 @@ class JavaAssertTransformer:
# - assertEquals (static import)
# - Assert.assertEquals (JUnit 4)
# - Assertions.assertEquals (JUnit 5)
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
pattern = re.compile(
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
)
for match in pattern.finditer(source):
leading_ws = match.group(1)
@ -332,32 +335,41 @@ class JavaAssertTransformer:
target_calls = self._extract_target_calls(args_content, match.end())
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
# For assertThrows, extract the lambda body and exception class
# For exception assertions, extract the lambda body
lambda_body = None
exception_class = None
if is_exception and assertion_method == "assertThrows":
if is_exception:
lambda_body = self._extract_lambda_body(args_content)
exception_class = self._extract_exception_class(args_content)
# Extract exception class specifically for assertThrows
if assertion_method == "assertThrows":
exception_class = self._extract_exception_class(args_content)
# Check if assertion is assigned to a variable
var_type, var_name = self._detect_variable_assignment(source, start_pos)
# Detect variable assignment: Type var = assertXxx(...)
# This applies to all assertions (assertThrows, assertTimeout, etc.)
assigned_var_type = None
assigned_var_name = None
original_text = source[start_pos:end_pos]
# If variable assignment detected, adjust start_pos to include the entire line
actual_start = start_pos
actual_leading_ws = leading_ws
if var_type:
# Find the start of the line (beginning of variable declaration)
line_start = source.rfind("\n", 0, start_pos)
if line_start == -1:
line_start = 0
before = source[:start_pos]
last_nl_idx = before.rfind("\n")
if last_nl_idx >= 0:
line_prefix = source[last_nl_idx + 1 : start_pos]
else:
line_prefix = source[:start_pos]
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
if var_match:
if last_nl_idx >= 0:
start_pos = last_nl_idx
leading_ws = "\n" + var_match.group(1)
else:
line_start += 1
actual_start = line_start
# Extract the actual leading whitespace from the start of the line
line_content = source[line_start:start_pos]
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
start_pos = 0
leading_ws = var_match.group(1)
original_text = source[actual_start:end_pos]
assigned_var_type = var_match.group(2)
assigned_var_name = var_match.group(3)
original_text = source[start_pos:end_pos] # Update with adjusted range
# Determine statement type based on detected framework
detected = self._detected_framework or "junit5"
@ -368,17 +380,17 @@ class JavaAssertTransformer:
assertions.append(
AssertionMatch(
start_pos=actual_start,
start_pos=start_pos,
end_pos=end_pos,
statement_type=stmt_type,
assertion_method=assertion_method,
target_calls=target_calls,
leading_whitespace=actual_leading_ws,
leading_whitespace=leading_ws,
original_text=original_text,
is_exception_assertion=is_exception,
lambda_body=lambda_body,
variable_type=var_type,
variable_name=var_name,
assigned_var_type=assigned_var_type,
assigned_var_name=assigned_var_name,
exception_class=exception_class,
)
)
@ -709,9 +721,9 @@ class JavaAssertTransformer:
return brace_content.strip()
else:
# Expression lambda: () -> expr
# Find the end (before the closing paren of assertThrows)
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
depth = 0
end = body_start
end = len(content)
for i, ch in enumerate(content[body_start:]):
if ch == "(":
depth += 1
@ -720,6 +732,9 @@ class JavaAssertTransformer:
end = body_start + i
break
depth -= 1
elif ch == "," and depth == 0:
end = body_start + i
break
return content[body_start:end].strip()
return None
@ -851,14 +866,17 @@ class JavaAssertTransformer:
To:
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
For variable assignments:
When assigned to a variable:
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
To:
IllegalArgumentException ex = null;
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
"""
self.invocation_counter += 1
counter = self.invocation_counter
ws = assertion.leading_whitespace
base_indent = ws.lstrip("\n\r")
# Extract code to run from lambda body or target calls
code_to_run = None
@ -867,38 +885,39 @@ class JavaAssertTransformer:
# Use a direct last-character check instead of .endswith for lower overhead
if code_to_run and code_to_run[-1] != ";":
code_to_run += ";"
elif assertion.target_calls:
call = assertion.target_calls[0]
code_to_run = call.full_call + ";"
if not code_to_run:
# Fallback: comment out the assertion
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
# Check if assertion is assigned to a variable
if assertion.variable_name and assertion.variable_type:
# Generate proper exception capture with variable assignment
exception_type = assertion.exception_class or assertion.variable_type
var_name = assertion.variable_name
# Use a unique catch variable name to avoid conflicts
catch_var = f"_cf_caught{self.invocation_counter}"
# Get base indentation from leading whitespace (without newlines)
base_indent = assertion.leading_whitespace.lstrip("\n\r")
# Handle variable assignment: Type var = assertThrows(...)
if assertion.assigned_var_name and assertion.assigned_var_type:
var_type = assertion.assigned_var_type
var_name = assertion.assigned_var_name
if assertion.assertion_method == "assertDoesNotThrow":
if ";" not in assertion.lambda_body.strip():
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
return f"{ws}{code_to_run}"
# For assertThrows with variable assignment, use exception_class if available
exception_type = assertion.exception_class or var_type
return (
f"{ws}{var_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
return (
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
f"{base_indent}try {{ {code_to_run} }} "
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
f"{ws}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# No variable assignment, use simple try-catch
return (
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
)
# If no lambda body found, try to extract from target calls
if assertion.target_calls:
call = assertion.target_calls[0]
return (
f"{ws}try {{ {call.full_call}; }} "
f"catch (Exception _cf_ignored{counter}) {{}}"
)
# Fallback: comment out the assertion
return f"{ws}// Removed assertThrows: could not extract callable"
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:

View file

@ -322,7 +322,7 @@ class JavaSupport(LanguageSupport):
return True
except Exception as e:
logger.error("Failed to instrument %s for line profiling: %s", func_info.name, e)
logger.error("Failed to instrument %s for line profiling: %s", func_info.function_name, e)
return False
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:

View file

@ -1257,6 +1257,41 @@ void testSynchronizedMethodWithAssertJ() {
assert result == expected
class TestFullyQualifiedAssertions:
"""Tests for fully qualified assertion calls like org.junit.jupiter.api.Assertions.assertXxx."""
def test_assert_timeout_fully_qualified_with_variable_assignment(self):
source = """\
@Test
void testLargeInput() {
Long result = org.junit.jupiter.api.Assertions.assertTimeout(
Duration.ofSeconds(1),
() -> Fibonacci.fibonacci(100_000)
);
}"""
expected = """\
@Test
void testLargeInput() {
Object _cf_result1 = Fibonacci.fibonacci(100_000);
}"""
result = transform_java_assertions(source, "fibonacci")
assert result == expected
def test_assert_equals_fully_qualified(self):
source = """\
@Test
void testAdd() {
org.junit.jupiter.api.Assertions.assertEquals(5, calc.add(2, 3));
}"""
expected = """\
@Test
void testAdd() {
Object _cf_result1 = calc.add(2, 3);
}"""
result = transform_java_assertions(source, "add")
assert result == expected
class TestAssertThrowsVariableAssignment:
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
@ -1358,3 +1393,35 @@ void testComplexException() {
}"""
result = transform_java_assertions(source, "execute")
assert result == expected
def test_assert_throws_assigned_with_final_modifier(self):
"""Test assertThrows with final modifier on variable."""
source = """\
@Test
void testDivideByZero() {
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected
def test_assert_throws_assigned_with_qualified_assertions(self):
"""Test assertThrows with qualified assertion (Assertions.assertThrows)."""
source = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = Assertions.assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
}"""
expected = """\
@Test
void testDivideByZero() {
IllegalArgumentException ex = null;
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
}"""
result = transform_java_assertions(source, "divide")
assert result == expected