mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
chore: merge omni-java into fix/java-exception-assignment-instrumentation
Resolved conflicts by merging the best of both branches: - Kept exception_class field from PR for better exception type detection - Adopted more general variable assignment detection from omni-java - Combined exception replacement logic to use exception_class with fallback - Added double catch (specific exception + generic Exception) for robustness - Merged test cases from both branches with updated expectations Changes: - Updated AssertionMatch to include all fields: assigned_var_type, assigned_var_name, exception_class - Lambda extraction now works for all exception assertions - Exception class extraction specifically for assertThrows - Variable assignment detection handles final modifier and fully qualified types - Exception replacement uses exception_class or falls back to assigned_var_type - All 80 tests passing Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
abf2c98994
4 changed files with 155 additions and 67 deletions
|
|
@ -83,10 +83,10 @@ class JavaLineProfiler:
|
|||
lines = source.splitlines(keepends=True)
|
||||
|
||||
# Process functions in reverse order to preserve line numbers
|
||||
for func in sorted(functions, key=lambda f: f.start_line, reverse=True):
|
||||
for func in sorted(functions, key=lambda f: f.starting_line, reverse=True):
|
||||
func_lines = self._instrument_function(func, lines, file_path, analyzer)
|
||||
start_idx = func.start_line - 1
|
||||
end_idx = func.end_line
|
||||
start_idx = func.starting_line - 1
|
||||
end_idx = func.ending_line
|
||||
lines = lines[:start_idx] + func_lines + lines[end_idx:]
|
||||
|
||||
instrumented_source = "".join(lines)
|
||||
|
|
@ -261,7 +261,7 @@ class {self.profiler_class} {{
|
|||
Instrumented function lines.
|
||||
|
||||
"""
|
||||
func_lines = lines[func.start_line - 1 : func.end_line]
|
||||
func_lines = lines[func.starting_line - 1 : func.ending_line]
|
||||
instrumented_lines = []
|
||||
|
||||
# Parse the function to find executable lines
|
||||
|
|
@ -271,7 +271,7 @@ class {self.profiler_class} {{
|
|||
tree = analyzer.parse(source.encode("utf8"))
|
||||
executable_lines = self._find_executable_lines(tree.root_node)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse function %s: %s", func.name, e)
|
||||
logger.warning("Failed to parse function %s: %s", func.function_name, e)
|
||||
return func_lines
|
||||
|
||||
# Add profiling to each executable line
|
||||
|
|
@ -279,7 +279,7 @@ class {self.profiler_class} {{
|
|||
|
||||
for local_idx, line in enumerate(func_lines):
|
||||
local_line_num = local_idx + 1 # 1-indexed within function
|
||||
global_line_num = func.start_line + local_idx # Global line number
|
||||
global_line_num = func.starting_line + local_idx # Global line number
|
||||
stripped = line.strip()
|
||||
|
||||
# Add enterFunction() call after the method's opening brace
|
||||
|
|
@ -409,7 +409,7 @@ class {self.profiler_class} {{
|
|||
|
||||
"""
|
||||
if not profile_file.exists():
|
||||
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
|
||||
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
|
||||
|
||||
try:
|
||||
with profile_file.open("r") as f:
|
||||
|
|
@ -435,15 +435,17 @@ class {self.profiler_class} {{
|
|||
"content": content,
|
||||
}
|
||||
|
||||
return {
|
||||
result = {
|
||||
"timings": timings,
|
||||
"unit": 1e-9, # nanoseconds
|
||||
"raw_data": data,
|
||||
}
|
||||
result["str_out"] = format_line_profile_results(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse line profile results: %s", e)
|
||||
return {"timings": {}, "unit": 1e-9, "raw_data": {}}
|
||||
return {"timings": {}, "unit": 1e-9, "raw_data": {}, "str_out": ""}
|
||||
|
||||
|
||||
def format_line_profile_results(results: dict, file_path: Path | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -166,9 +166,9 @@ class AssertionMatch:
|
|||
original_text: str = ""
|
||||
is_exception_assertion: bool = False
|
||||
lambda_body: str | None = None # For assertThrows lambda content
|
||||
variable_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
|
||||
variable_name: str | None = None # Name of assigned variable (e.g., "exception")
|
||||
exception_class: str | None = None # Exception class from assertThrows args
|
||||
assigned_var_type: str | None = None # Type of assigned variable (e.g., "IllegalArgumentException")
|
||||
assigned_var_name: str | None = None # Name of assigned variable (e.g., "exception")
|
||||
exception_class: str | None = None # Exception class from assertThrows args (e.g., "IllegalArgumentException")
|
||||
|
||||
|
||||
class JavaAssertTransformer:
|
||||
|
|
@ -306,8 +306,11 @@ class JavaAssertTransformer:
|
|||
# - assertEquals (static import)
|
||||
# - Assert.assertEquals (JUnit 4)
|
||||
# - Assertions.assertEquals (JUnit 5)
|
||||
# - org.junit.jupiter.api.Assertions.assertEquals (fully qualified)
|
||||
all_assertions = "|".join(JUNIT5_ALL_ASSERTIONS)
|
||||
pattern = re.compile(rf"(\s*)((?:Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE)
|
||||
pattern = re.compile(
|
||||
rf"(\s*)((?:(?:\w+\.)*Assert(?:ions)?\.)?({all_assertions}))\s*\(", re.MULTILINE
|
||||
)
|
||||
|
||||
for match in pattern.finditer(source):
|
||||
leading_ws = match.group(1)
|
||||
|
|
@ -332,32 +335,41 @@ class JavaAssertTransformer:
|
|||
target_calls = self._extract_target_calls(args_content, match.end())
|
||||
is_exception = assertion_method in JUNIT5_EXCEPTION_ASSERTIONS
|
||||
|
||||
# For assertThrows, extract the lambda body and exception class
|
||||
# For exception assertions, extract the lambda body
|
||||
lambda_body = None
|
||||
exception_class = None
|
||||
if is_exception and assertion_method == "assertThrows":
|
||||
if is_exception:
|
||||
lambda_body = self._extract_lambda_body(args_content)
|
||||
exception_class = self._extract_exception_class(args_content)
|
||||
# Extract exception class specifically for assertThrows
|
||||
if assertion_method == "assertThrows":
|
||||
exception_class = self._extract_exception_class(args_content)
|
||||
|
||||
# Check if assertion is assigned to a variable
|
||||
var_type, var_name = self._detect_variable_assignment(source, start_pos)
|
||||
# Detect variable assignment: Type var = assertXxx(...)
|
||||
# This applies to all assertions (assertThrows, assertTimeout, etc.)
|
||||
assigned_var_type = None
|
||||
assigned_var_name = None
|
||||
original_text = source[start_pos:end_pos]
|
||||
|
||||
# If variable assignment detected, adjust start_pos to include the entire line
|
||||
actual_start = start_pos
|
||||
actual_leading_ws = leading_ws
|
||||
if var_type:
|
||||
# Find the start of the line (beginning of variable declaration)
|
||||
line_start = source.rfind("\n", 0, start_pos)
|
||||
if line_start == -1:
|
||||
line_start = 0
|
||||
before = source[:start_pos]
|
||||
last_nl_idx = before.rfind("\n")
|
||||
if last_nl_idx >= 0:
|
||||
line_prefix = source[last_nl_idx + 1 : start_pos]
|
||||
else:
|
||||
line_prefix = source[:start_pos]
|
||||
|
||||
var_match = re.match(r"([ \t]*)(?:final\s+)?([\w.<>\[\]]+)\s+(\w+)\s*=\s*$", line_prefix)
|
||||
if var_match:
|
||||
if last_nl_idx >= 0:
|
||||
start_pos = last_nl_idx
|
||||
leading_ws = "\n" + var_match.group(1)
|
||||
else:
|
||||
line_start += 1
|
||||
actual_start = line_start
|
||||
# Extract the actual leading whitespace from the start of the line
|
||||
line_content = source[line_start:start_pos]
|
||||
actual_leading_ws = line_content[:len(line_content) - len(line_content.lstrip())]
|
||||
start_pos = 0
|
||||
leading_ws = var_match.group(1)
|
||||
|
||||
original_text = source[actual_start:end_pos]
|
||||
assigned_var_type = var_match.group(2)
|
||||
assigned_var_name = var_match.group(3)
|
||||
original_text = source[start_pos:end_pos] # Update with adjusted range
|
||||
|
||||
# Determine statement type based on detected framework
|
||||
detected = self._detected_framework or "junit5"
|
||||
|
|
@ -368,17 +380,17 @@ class JavaAssertTransformer:
|
|||
|
||||
assertions.append(
|
||||
AssertionMatch(
|
||||
start_pos=actual_start,
|
||||
start_pos=start_pos,
|
||||
end_pos=end_pos,
|
||||
statement_type=stmt_type,
|
||||
assertion_method=assertion_method,
|
||||
target_calls=target_calls,
|
||||
leading_whitespace=actual_leading_ws,
|
||||
leading_whitespace=leading_ws,
|
||||
original_text=original_text,
|
||||
is_exception_assertion=is_exception,
|
||||
lambda_body=lambda_body,
|
||||
variable_type=var_type,
|
||||
variable_name=var_name,
|
||||
assigned_var_type=assigned_var_type,
|
||||
assigned_var_name=assigned_var_name,
|
||||
exception_class=exception_class,
|
||||
)
|
||||
)
|
||||
|
|
@ -709,9 +721,9 @@ class JavaAssertTransformer:
|
|||
return brace_content.strip()
|
||||
else:
|
||||
# Expression lambda: () -> expr
|
||||
# Find the end (before the closing paren of assertThrows)
|
||||
# Find the end (before the closing paren of assertThrows, or comma at depth 0)
|
||||
depth = 0
|
||||
end = body_start
|
||||
end = len(content)
|
||||
for i, ch in enumerate(content[body_start:]):
|
||||
if ch == "(":
|
||||
depth += 1
|
||||
|
|
@ -720,6 +732,9 @@ class JavaAssertTransformer:
|
|||
end = body_start + i
|
||||
break
|
||||
depth -= 1
|
||||
elif ch == "," and depth == 0:
|
||||
end = body_start + i
|
||||
break
|
||||
return content[body_start:end].strip()
|
||||
|
||||
return None
|
||||
|
|
@ -851,14 +866,17 @@ class JavaAssertTransformer:
|
|||
To:
|
||||
try { calculator.divide(1, 0); } catch (Exception _cf_ignored1) {}
|
||||
|
||||
For variable assignments:
|
||||
When assigned to a variable:
|
||||
IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> code());
|
||||
To:
|
||||
IllegalArgumentException ex = null;
|
||||
try { code(); } catch (IllegalArgumentException e) { ex = e; } catch (Exception _cf_ignored1) {}
|
||||
try { code(); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
|
||||
"""
|
||||
self.invocation_counter += 1
|
||||
counter = self.invocation_counter
|
||||
ws = assertion.leading_whitespace
|
||||
base_indent = ws.lstrip("\n\r")
|
||||
|
||||
# Extract code to run from lambda body or target calls
|
||||
code_to_run = None
|
||||
|
|
@ -867,38 +885,39 @@ class JavaAssertTransformer:
|
|||
# Use a direct last-character check instead of .endswith for lower overhead
|
||||
if code_to_run and code_to_run[-1] != ";":
|
||||
code_to_run += ";"
|
||||
elif assertion.target_calls:
|
||||
call = assertion.target_calls[0]
|
||||
code_to_run = call.full_call + ";"
|
||||
|
||||
if not code_to_run:
|
||||
# Fallback: comment out the assertion
|
||||
return f"{assertion.leading_whitespace}// Removed assertThrows: could not extract callable"
|
||||
|
||||
# Check if assertion is assigned to a variable
|
||||
if assertion.variable_name and assertion.variable_type:
|
||||
# Generate proper exception capture with variable assignment
|
||||
exception_type = assertion.exception_class or assertion.variable_type
|
||||
var_name = assertion.variable_name
|
||||
|
||||
# Use a unique catch variable name to avoid conflicts
|
||||
catch_var = f"_cf_caught{self.invocation_counter}"
|
||||
|
||||
# Get base indentation from leading whitespace (without newlines)
|
||||
base_indent = assertion.leading_whitespace.lstrip("\n\r")
|
||||
# Handle variable assignment: Type var = assertThrows(...)
|
||||
if assertion.assigned_var_name and assertion.assigned_var_type:
|
||||
var_type = assertion.assigned_var_type
|
||||
var_name = assertion.assigned_var_name
|
||||
if assertion.assertion_method == "assertDoesNotThrow":
|
||||
if ";" not in assertion.lambda_body.strip():
|
||||
return f"{ws}{var_type} {var_name} = {assertion.lambda_body.strip()};"
|
||||
return f"{ws}{code_to_run}"
|
||||
# For assertThrows with variable assignment, use exception_class if available
|
||||
exception_type = assertion.exception_class or var_type
|
||||
return (
|
||||
f"{ws}{var_type} {var_name} = null;\n"
|
||||
f"{base_indent}try {{ {code_to_run} }} "
|
||||
f"catch ({exception_type} _cf_caught{counter}) {{ {var_name} = _cf_caught{counter}; }} "
|
||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||
)
|
||||
|
||||
return (
|
||||
f"{assertion.leading_whitespace}{assertion.variable_type} {var_name} = null;\n"
|
||||
f"{base_indent}try {{ {code_to_run} }} "
|
||||
f"catch ({exception_type} {catch_var}) {{ {var_name} = {catch_var}; }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
f"{ws}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||
)
|
||||
|
||||
# No variable assignment, use simple try-catch
|
||||
return (
|
||||
f"{assertion.leading_whitespace}try {{ {code_to_run} }} "
|
||||
f"catch (Exception _cf_ignored{self.invocation_counter}) {{}}"
|
||||
)
|
||||
# If no lambda body found, try to extract from target calls
|
||||
if assertion.target_calls:
|
||||
call = assertion.target_calls[0]
|
||||
return (
|
||||
f"{ws}try {{ {call.full_call}; }} "
|
||||
f"catch (Exception _cf_ignored{counter}) {{}}"
|
||||
)
|
||||
|
||||
# Fallback: comment out the assertion
|
||||
return f"{ws}// Removed assertThrows: could not extract callable"
|
||||
|
||||
|
||||
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str:
|
||||
|
|
|
|||
|
|
@ -322,7 +322,7 @@ class JavaSupport(LanguageSupport):
|
|||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Failed to instrument %s for line profiling: %s", func_info.name, e)
|
||||
logger.error("Failed to instrument %s for line profiling: %s", func_info.function_name, e)
|
||||
return False
|
||||
|
||||
def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
|
||||
|
|
|
|||
|
|
@ -1257,6 +1257,41 @@ void testSynchronizedMethodWithAssertJ() {
|
|||
assert result == expected
|
||||
|
||||
|
||||
class TestFullyQualifiedAssertions:
|
||||
"""Tests for fully qualified assertion calls like org.junit.jupiter.api.Assertions.assertXxx."""
|
||||
|
||||
def test_assert_timeout_fully_qualified_with_variable_assignment(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testLargeInput() {
|
||||
Long result = org.junit.jupiter.api.Assertions.assertTimeout(
|
||||
Duration.ofSeconds(1),
|
||||
() -> Fibonacci.fibonacci(100_000)
|
||||
);
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testLargeInput() {
|
||||
Object _cf_result1 = Fibonacci.fibonacci(100_000);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "fibonacci")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_equals_fully_qualified(self):
|
||||
source = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
org.junit.jupiter.api.Assertions.assertEquals(5, calc.add(2, 3));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testAdd() {
|
||||
Object _cf_result1 = calc.add(2, 3);
|
||||
}"""
|
||||
result = transform_java_assertions(source, "add")
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestAssertThrowsVariableAssignment:
|
||||
"""Tests for assertThrows with variable assignment (Issue: exception handling instrumentation bug)."""
|
||||
|
||||
|
|
@ -1358,3 +1393,35 @@ void testComplexException() {
|
|||
}"""
|
||||
result = transform_java_assertions(source, "execute")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_assigned_with_final_modifier(self):
|
||||
"""Test assertThrows with final modifier on variable."""
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
final IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
||||
def test_assert_throws_assigned_with_qualified_assertions(self):
|
||||
"""Test assertThrows with qualified assertion (Assertions.assertThrows)."""
|
||||
source = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = Assertions.assertThrows(IllegalArgumentException.class, () -> calc.divide(1, 0));
|
||||
}"""
|
||||
expected = """\
|
||||
@Test
|
||||
void testDivideByZero() {
|
||||
IllegalArgumentException ex = null;
|
||||
try { calc.divide(1, 0); } catch (IllegalArgumentException _cf_caught1) { ex = _cf_caught1; } catch (Exception _cf_ignored1) {}
|
||||
}"""
|
||||
result = transform_java_assertions(source, "divide")
|
||||
assert result == expected
|
||||
|
|
|
|||
Loading…
Reference in a new issue