Merge pull request #1980 from codeflash-ai/cf-java-void-optimization

feat: support void method optimization in Java pipeline
This commit is contained in:
Hesham Mohamed 2026-04-09 16:24:29 +02:00 committed by GitHub
commit 602cb68239
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 966 additions and 34 deletions

View file

@ -0,0 +1,105 @@
name: E2E - Java Void Optimization (No Git)
on:
pull_request:
paths:
- 'codeflash/languages/java/**'
- 'codeflash/languages/base.py'
- 'codeflash/languages/registry.py'
- 'codeflash/optimization/**'
- 'codeflash/verification/**'
- 'code_to_optimize/java/**'
- 'codeflash-java-runtime/**'
- 'tests/scripts/end_to_end_test_java_void_optimization.py'
- '.github/workflows/e2e-java-void-optimization.yaml'
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref_name }}
cancel-in-progress: true
jobs:
java-void-optimization-no-git:
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 70
CODEFLASH_END_TO_END: 1
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.ref }}
repository: ${{ github.event.pull_request.head.repo.full_name }}
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Validate PR
env:
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
PR_STATE: ${{ github.event.pull_request.state }}
BASE_SHA: ${{ github.event.pull_request.base.sha }}
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
run: |
if git diff --name-only "$BASE_SHA" "$HEAD_SHA" | grep -q "^.github/workflows/"; then
echo "⚠️ Workflow changes detected."
echo "PR Author: $PR_AUTHOR"
if [[ "$PR_AUTHOR" == "misrasaurabh1" || "$PR_AUTHOR" == "KRRT7" ]]; then
echo "✅ Authorized user ($PR_AUTHOR). Proceeding."
elif [[ "$PR_STATE" == "open" ]]; then
echo "✅ PR is open. Proceeding."
else
echo "⛔ Unauthorized user ($PR_AUTHOR) attempting to modify workflows. Exiting."
exit 1
fi
else
echo "✅ No workflow file changes detected. Proceeding."
fi
- name: Set up JDK 11
uses: actions/setup-java@v4
with:
java-version: '11'
distribution: 'temurin'
cache: maven
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v6
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: uv sync
- name: Build codeflash-runtime JAR
run: |
cd codeflash-java-runtime
mvn clean package -q -DskipTests
mvn install -q -DskipTests
- name: Verify Java installation
run: |
java -version
mvn --version
- name: Remove .git
run: |
if [ -d ".git" ]; then
sudo rm -rf .git
echo ".git directory removed."
else
echo ".git directory does not exist."
exit 1
fi
- name: Run Codeflash to optimize void function
run: |
uv run python tests/scripts/end_to_end_test_java_void_optimization.py

View file

@ -0,0 +1,21 @@
package com.example;
public class InPlaceSorter {
public static void bubbleSortInPlace(int[] arr) {
if (arr == null || arr.length <= 1) {
return;
}
int n = arr.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n - 1; j++) {
if (arr[j] > arr[j + 1]) {
int temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
}
}
}

View file

@ -0,0 +1,21 @@
package com.example;
public class InstanceSorter {
public void bubbleSortInPlace(int[] arr) {
if (arr == null || arr.length <= 1) {
return;
}
int n = arr.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n - 1; j++) {
if (arr[j] > arr[j + 1]) {
int temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
}
}
}

View file

@ -0,0 +1,62 @@
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class InPlaceSorterTest {
@Test
void testBubbleSortInPlace() {
int[] arr = {5, 3, 1, 4, 2};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceAlreadySorted() {
int[] arr = {1, 2, 3, 4, 5};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceReversed() {
int[] arr = {5, 4, 3, 2, 1};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceWithDuplicates() {
int[] arr = {3, 2, 4, 1, 3, 2};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 2, 3, 3, 4}, arr);
}
@Test
void testBubbleSortInPlaceWithNegatives() {
int[] arr = {3, -2, 7, 0, -5};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{-5, -2, 0, 3, 7}, arr);
}
@Test
void testBubbleSortInPlaceSingleElement() {
int[] arr = {42};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{42}, arr);
}
@Test
void testBubbleSortInPlaceEmpty() {
int[] arr = {};
InPlaceSorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{}, arr);
}
@Test
void testBubbleSortInPlaceNull() {
InPlaceSorter.bubbleSortInPlace(null);
}
}

View file

@ -0,0 +1,69 @@
package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
class InstanceSorterTest {
@Test
void testBubbleSortInPlace() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {5, 3, 1, 4, 2};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceAlreadySorted() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {1, 2, 3, 4, 5};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceReversed() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {5, 4, 3, 2, 1};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
}
@Test
void testBubbleSortInPlaceWithDuplicates() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {3, 2, 4, 1, 3, 2};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{1, 2, 2, 3, 3, 4}, arr);
}
@Test
void testBubbleSortInPlaceWithNegatives() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {3, -2, 7, 0, -5};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{-5, -2, 0, 3, 7}, arr);
}
@Test
void testBubbleSortInPlaceSingleElement() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {42};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{42}, arr);
}
@Test
void testBubbleSortInPlaceEmpty() {
InstanceSorter sorter = new InstanceSorter();
int[] arr = {};
sorter.bubbleSortInPlace(arr);
assertArrayEquals(new int[]{}, arr);
}
@Test
void testBubbleSortInPlaceNull() {
InstanceSorter sorter = new InstanceSorter();
sorter.bubbleSortInPlace(null);
}
}

View file

@ -209,6 +209,87 @@ class ComparatorCorrectnessTest {
assertFalse(Comparator.isDeserializationError(42)); assertFalse(Comparator.isDeserializationError(42));
} }
// ============================================================
// VOID METHOD STATE COMPARISON proves we actually compare
// post-call state for void methods, not just skip them
// ============================================================
@Test
@DisplayName("void state: both sides sorted identically → equivalent")
void testVoidState_identicalMutation_equivalent() throws Exception {
createTestDb(originalDb);
createTestDb(candidateDb);
// Simulate: bubbleSortInPlace(arr) both original and candidate sort correctly
// Post-call state: Object[]{sorted_array}
int[] sortedArr = {1, 2, 3, 4, 5};
byte[] origState = Serializer.serialize(new Object[]{sortedArr});
byte[] candState = Serializer.serialize(new Object[]{new int[]{1, 2, 3, 4, 5}});
insertRow(originalDb, "L1_1", 1, origState);
insertRow(candidateDb, "L1_1", 1, candState);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
assertTrue((Boolean) result.get("equivalent"),
"Both sides produce same sorted array — should be equivalent");
assertEquals(1, ((Number) result.get("actualComparisons")).intValue());
}
@Test
@DisplayName("void state: candidate mutates array differently → NOT equivalent")
void testVoidState_differentMutation_rejected() throws Exception {
createTestDb(originalDb);
createTestDb(candidateDb);
// Simulate: original sorts [3,1,2] [1,2,3]
// Bad optimization doesn't sort correctly [3,1,2] unchanged
byte[] origState = Serializer.serialize(new Object[]{new int[]{1, 2, 3}});
byte[] candState = Serializer.serialize(new Object[]{new int[]{3, 1, 2}});
insertRow(originalDb, "L1_1", 1, origState);
insertRow(candidateDb, "L1_1", 1, candState);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
assertFalse((Boolean) result.get("equivalent"),
"Candidate produced wrong array — must be rejected");
assertEquals(1, ((Number) result.get("actualComparisons")).intValue());
}
@Test
@DisplayName("void state: receiver + args both compared — wrong receiver state rejected")
void testVoidState_receiverAndArgs_wrongReceiverRejected() throws Exception {
createTestDb(originalDb);
createTestDb(candidateDb);
// Simulate: instance method sorter.sort(data)
// Post-call state is Object[]{receiver_fields_map, mutated_data}
// Original: receiver has size=3, data is [1,2,3]
// Candidate: receiver has size=0 (wrong), data is [1,2,3]
Map<String, Object> origReceiver = new HashMap<>();
origReceiver.put("size", 3);
origReceiver.put("sorted", true);
Map<String, Object> candReceiver = new HashMap<>();
candReceiver.put("size", 0);
candReceiver.put("sorted", true);
byte[] origState = Serializer.serialize(new Object[]{origReceiver, new int[]{1, 2, 3}});
byte[] candState = Serializer.serialize(new Object[]{candReceiver, new int[]{1, 2, 3}});
insertRow(originalDb, "L1_1", 1, origState);
insertRow(candidateDb, "L1_1", 1, candState);
String json = Comparator.compareDatabases(originalDb.toString(), candidateDb.toString());
Map<String, Object> result = parseJson(json);
assertFalse((Boolean) result.get("equivalent"),
"Receiver state differs (size 3 vs 0) — must be rejected even though args match");
assertEquals(1, ((Number) result.get("actualComparisons")).intValue());
}
// --- Helpers --- // --- Helpers ---
private void createTestDb(Path dbPath) throws Exception { private void createTestDb(Path dbPath) throws Exception {

View file

@ -195,7 +195,8 @@ def _find_all_functions_via_language_support(file_path: Path) -> dict[Path, list
try: try:
lang_support = get_language_support(file_path) lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True) require_return = lang_support.language != Language.JAVA
criteria = FunctionFilterCriteria(require_return=require_return)
functions[file_path] = lang_support.discover_functions(file_path, criteria) functions[file_path] = lang_support.discover_functions(file_path, criteria)
except Exception as e: except Exception as e:
logger.debug(f"Failed to discover functions in {file_path}: {e}") logger.debug(f"Failed to discover functions in {file_path}: {e}")
@ -454,7 +455,8 @@ def find_all_functions_in_file(file_path: Path) -> dict[Path, list[FunctionToOpt
from codeflash.languages.base import FunctionFilterCriteria from codeflash.languages.base import FunctionFilterCriteria
lang_support = get_language_support(file_path) lang_support = get_language_support(file_path)
criteria = FunctionFilterCriteria(require_return=True) require_return = lang_support.language != Language.JAVA
criteria = FunctionFilterCriteria(require_return=require_return)
source = file_path.read_text(encoding="utf-8") source = file_path.read_text(encoding="utf-8")
return {file_path: lang_support.discover_functions(source, file_path, criteria)} return {file_path: lang_support.discover_functions(source, file_path, criteria)}
except Exception as e: except Exception as e:

View file

@ -203,6 +203,7 @@ def _generate_sqlite_write_code(
func_name: str, func_name: str,
test_method_name: str, test_method_name: str,
invocation_id: str = "", invocation_id: str = "",
verification_type: str = "function_call",
) -> list[str]: ) -> list[str]:
"""Generate SQLite write code for a single function call. """Generate SQLite write code for a single function call.
@ -249,7 +250,7 @@ def _generate_sqlite_write_code(
f'{inner_indent} _cf_pstmt{id_pair}.setString(6, "{inv_id_str}");', f'{inner_indent} _cf_pstmt{id_pair}.setString(6, "{inv_id_str}");',
f"{inner_indent} _cf_pstmt{id_pair}.setLong(7, _cf_dur{id_pair});", f"{inner_indent} _cf_pstmt{id_pair}.setLong(7, _cf_dur{id_pair});",
f"{inner_indent} _cf_pstmt{id_pair}.setBytes(8, _cf_serializedResult{id_pair});", f"{inner_indent} _cf_pstmt{id_pair}.setBytes(8, _cf_serializedResult{id_pair});",
f'{inner_indent} _cf_pstmt{id_pair}.setString(9, "function_call");', f'{inner_indent} _cf_pstmt{id_pair}.setString(9, "{verification_type}");',
f"{inner_indent} _cf_pstmt{id_pair}.executeUpdate();", f"{inner_indent} _cf_pstmt{id_pair}.executeUpdate();",
f"{inner_indent} }}", f"{inner_indent} }}",
f"{inner_indent} }}", f"{inner_indent} }}",
@ -337,22 +338,53 @@ def wrap_target_calls_with_treesitter(
orig_line = body_lines[line_idx] orig_line = body_lines[line_idx]
line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip())) line_indent_str = " " * (len(orig_line) - len(orig_line.lstrip()))
is_void = target_return_type == "void"
var_name = f"_cf_result{iter_id}_{call_counter}" var_name = f"_cf_result{iter_id}_{call_counter}"
receiver = call.get("receiver", "this")
arg_texts: list[str] = call.get("arg_texts", [])
cast_type = _infer_array_cast_type(orig_line) cast_type = _infer_array_cast_type(orig_line)
if not cast_type and target_return_type and target_return_type != "void": if not cast_type and target_return_type and not is_void:
cast_type = target_return_type cast_type = target_return_type
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
capture_stmt_with_decl = f"var {var_name} = {call['full_call']};" if is_void:
capture_stmt_assign = f"{var_name} = {call['full_call']};" bare_call_stmt = f"{call['full_call']};"
if precise_call_timing: # For void methods, serialize the post-call state to capture side effects.
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});" # We always serialize the arguments (which are mutated in place).
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();" # For instance methods, we also include the receiver to capture object state changes.
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();" # For static methods, the receiver is a class name (not a value), so args only.
is_static_call = receiver != "this" and receiver[:1].isupper()
parts: list[str] = []
if not is_static_call:
parts.append(receiver)
parts.extend(arg_texts)
if parts:
serialize_target = f"new Object[]{{{', '.join(parts)}}}"
else:
serialize_target = "new Object[]{}"
if precise_call_timing:
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize({serialize_target});"
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
else:
serialize_stmt = (
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize({serialize_target});"
)
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
else: else:
serialize_stmt = f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});" capture_stmt_with_decl = f"var {var_name} = {call['full_call']};"
start_stmt = f"_cf_start{iter_id} = System.nanoTime();" capture_stmt_assign = f"{var_name} = {call['full_call']};"
end_stmt = f"_cf_end{iter_id} = System.nanoTime();" if precise_call_timing:
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {var_name});"
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
else:
serialize_stmt = (
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {var_name});"
)
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
if call["parent_type"] == "expression_statement": if call["parent_type"] == "expression_statement":
es_start = call["_es_start_char"] es_start = call["_es_start_char"]
@ -360,31 +392,61 @@ def wrap_target_calls_with_treesitter(
if precise_call_timing: if precise_call_timing:
# No indent on first line — body_text[:es_start] already has leading whitespace. # No indent on first line — body_text[:es_start] already has leading whitespace.
# Subsequent lines get line_indent_str. # Subsequent lines get line_indent_str.
var_decls = [ if is_void:
f"Object {var_name} = null;", var_decls = [
f"long _cf_end{iter_id}_{call_counter} = -1;", f"long _cf_end{iter_id}_{call_counter} = -1;",
f"long _cf_start{iter_id}_{call_counter} = 0;", f"long _cf_start{iter_id}_{call_counter} = 0;",
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;", f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
] ]
else:
var_decls = [
f"Object {var_name} = null;",
f"long _cf_end{iter_id}_{call_counter} = -1;",
f"long _cf_start{iter_id}_{call_counter} = 0;",
f"byte[] _cf_serializedResult{iter_id}_{call_counter} = null;",
]
start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");' start_marker = f'System.out.println("!$######" + _cf_mod{iter_id} + ":" + _cf_cls{iter_id} + "." + _cf_test{iter_id} + ":" + _cf_fn{iter_id} + ":" + _cf_loop{iter_id} + ":{inv_id}" + "######$!");'
try_block = [ if is_void:
"try {", try_block = [
f" {start_stmt}", "try {",
f" {capture_stmt_assign}", f" {start_stmt}",
f" {end_stmt}", f" {bare_call_stmt}",
f" {serialize_stmt}", f" {end_stmt}",
] f" {serialize_stmt}",
]
else:
try_block = [
"try {",
f" {start_stmt}",
f" {capture_stmt_assign}",
f" {end_stmt}",
f" {serialize_stmt}",
]
finally_block = _generate_sqlite_write_code( finally_block = _generate_sqlite_write_code(
iter_id, call_counter, "", class_name, func_name, test_method_name, invocation_id=inv_id iter_id,
call_counter,
"",
class_name,
func_name,
test_method_name,
invocation_id=inv_id,
verification_type="void_state" if is_void else "function_call",
) )
all_lines = [*var_decls, start_marker, *try_block, *finally_block] all_lines = [*var_decls, start_marker, *try_block, *finally_block]
replacement = ( replacement = (
all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:]) all_lines[0] + "\n" + "\n".join(f"{line_indent_str}{repl_line}" for repl_line in all_lines[1:])
) )
elif is_void:
replacement = f"{bare_call_stmt} {serialize_stmt}"
else: else:
replacement = f"{capture_stmt_with_decl} {serialize_stmt}" replacement = f"{capture_stmt_with_decl} {serialize_stmt}"
body_text = body_text[:es_start] + replacement + body_text[es_end:] body_text = body_text[:es_start] + replacement + body_text[es_end:]
else: else:
if is_void:
# Void calls cannot be embedded in expressions in valid Java — skip instrumentation
logger.warning("Skipping instrumentation of embedded void call: %s", call["full_call"])
continue
# Embedded call: replace call with variable, then insert capture lines before the line # Embedded call: replace call with variable, then insert capture lines before the line
call_start = call["_call_start_char"] call_start = call["_call_start_char"]
call_end = call["_call_end_char"] call_end = call["_call_end_char"]
@ -451,6 +513,15 @@ def _collect_calls(
if parent_type == "expression_statement": if parent_type == "expression_statement":
es_start = parent.start_byte - prefix_len es_start = parent.start_byte - prefix_len
es_end = parent.end_byte - prefix_len es_end = parent.end_byte - prefix_len
object_node = node.child_by_field_name("object")
receiver = analyzer.get_node_text(object_node, wrapper_bytes) if object_node else "this"
# Extract argument texts for void method serialization
args_node = node.child_by_field_name("arguments")
arg_texts: list[str] = []
if args_node:
for child in args_node.children:
if child.type not in ("(", ")", ","):
arg_texts.append(analyzer.get_node_text(child, wrapper_bytes))
out.append( out.append(
{ {
"start_byte": start, "start_byte": start,
@ -461,6 +532,8 @@ def _collect_calls(
"in_complex": _is_inside_complex_expression(node), "in_complex": _is_inside_complex_expression(node),
"es_start_byte": es_start, "es_start_byte": es_start,
"es_end_byte": es_end, "es_end_byte": es_end,
"receiver": receiver,
"arg_texts": arg_texts,
} }
) )
for child in node.children: for child in node.children:

View file

@ -189,6 +189,7 @@ class JavaAssertTransformer:
qualified_name: str | None = None, qualified_name: str | None = None,
analyzer: JavaAnalyzer | None = None, analyzer: JavaAnalyzer | None = None,
mode: str = "capture", mode: str = "capture",
target_return_type: str = "",
) -> None: ) -> None:
self.analyzer = analyzer or get_java_analyzer() self.analyzer = analyzer or get_java_analyzer()
self.func_name = function_name self.func_name = function_name
@ -196,6 +197,7 @@ class JavaAssertTransformer:
self.invocation_counter = 0 self.invocation_counter = 0
self._detected_framework: str | None = None self._detected_framework: str | None = None
self.mode = mode # "capture" (default, instrumentation) or "strip" (clean display) self.mode = mode # "capture" (default, instrumentation) or "strip" (clean display)
self.target_return_type = target_return_type
# Precompile the assignment-detection regex to avoid recompiling on each call. # Precompile the assignment-detection regex to avoid recompiling on each call.
self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$") self._assign_re = re.compile(r"(\w+(?:<[^>]+>)?)\s+(\w+)\s*=\s*$")
@ -1062,7 +1064,7 @@ class JavaAssertTransformer:
if not assertion.target_calls: if not assertion.target_calls:
return "" return ""
if self.mode == "strip": if self.mode == "strip" or self.target_return_type == "void":
return self._generate_strip_replacement(assertion) return self._generate_strip_replacement(assertion)
# Infer the return type from assertion context to avoid Object→primitive cast errors # Infer the return type from assertion context to avoid Object→primitive cast errors
@ -1244,7 +1246,9 @@ class JavaAssertTransformer:
return "".join(cur).rstrip() return "".join(cur).rstrip()
def transform_java_assertions(source: str, function_name: str, qualified_name: str | None = None) -> str: def transform_java_assertions(
source: str, function_name: str, qualified_name: str | None = None, target_return_type: str = ""
) -> str:
"""Transform Java test code by removing assertions and capturing function calls. """Transform Java test code by removing assertions and capturing function calls.
This is the main entry point for Java assertion transformation. This is the main entry point for Java assertion transformation.
@ -1253,12 +1257,15 @@ def transform_java_assertions(source: str, function_name: str, qualified_name: s
source: The Java test source code. source: The Java test source code.
function_name: Name of the function being tested. function_name: Name of the function being tested.
qualified_name: Optional fully qualified name of the function. qualified_name: Optional fully qualified name of the function.
target_return_type: Return type of the target function (e.g., "void", "int").
Returns: Returns:
Transformed source code with assertions replaced by capture statements. Transformed source code with assertions replaced by capture statements.
""" """
transformer = JavaAssertTransformer(function_name=function_name, qualified_name=qualified_name) transformer = JavaAssertTransformer(
function_name=function_name, qualified_name=qualified_name, target_return_type=target_return_type
)
return transformer.transform(source) return transformer.transform(source)

View file

@ -740,6 +740,7 @@ class VerificationType(str, Enum):
) )
INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init INIT_STATE_FTO = "init_state_fto" # Correctness verification for fto class instance attributes after init
INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init INIT_STATE_HELPER = "init_state_helper" # Correctness verification for helper class instance attributes after init
VOID_STATE = "void_state" # Correctness verification for void methods (no return value)
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)

View file

@ -0,0 +1,18 @@
import os
import pathlib
from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_with_retries
def run_test(expected_improvement_pct: int) -> bool:
config = TestConfig(
file_path="src/main/java/com/example/InPlaceSorter.java",
function_name="bubbleSortInPlace",
min_improvement_x=0.70,
)
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "java").resolve()
return run_codeflash_command(cwd, config, expected_improvement_pct)
if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 70))))

View file

@ -11,7 +11,18 @@
<maven.compiler.source>11</maven.compiler.source> <maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target> <maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties> <checkstyle.skip>true</checkstyle.skip>
<disable.checks>true</disable.checks>
<spotbugs.skip>true</spotbugs.skip>
<pmd.skip>true</pmd.skip>
<rat.skip>true</rat.skip>
<enforcer.skip>true</enforcer.skip>
<japicmp.skip>true</japicmp.skip>
<checkstyle.failOnViolation>false</checkstyle.failOnViolation>
<checkstyle.failsOnError>false</checkstyle.failsOnError>
<maven-checkstyle-plugin.failsOnError>false</maven-checkstyle-plugin.failsOnError>
<maven-checkstyle-plugin.failOnViolation>false</maven-checkstyle-plugin.failOnViolation>
</properties>
<dependencies> <dependencies>
<dependency> <dependency>
@ -62,6 +73,26 @@
</execution> </execution>
</executions> </executions>
</plugin> </plugin>
</plugins> <!-- codeflash-validation-skip -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-checkstyle-plugin</artifactId>
<configuration>
<skip>true</skip>
<failOnViolation>false</failOnViolation>
<failsOnError>false</failsOnError>
</configuration>
</plugin>
<plugin>
<groupId>com.github.spotbugs</groupId>
<artifactId>spotbugs-maven-plugin</artifactId>
<configuration><skip>true</skip></configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-pmd-plugin</artifactId>
<configuration><skip>true</skip></configuration>
</plugin>
</plugins>
</build> </build>
</project> </project>

View file

@ -22,7 +22,6 @@ os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language from codeflash.languages.current import set_current_language
from codeflash.languages.java.maven_strategy import MavenStrategy
from codeflash.languages.java.discovery import discover_functions_from_source from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import ( from codeflash.languages.java.instrumentation import (
_add_behavior_instrumentation, _add_behavior_instrumentation,
@ -34,6 +33,7 @@ from codeflash.languages.java.instrumentation import (
instrument_generated_java_test, instrument_generated_java_test,
remove_instrumentation, remove_instrumentation,
) )
from codeflash.languages.java.maven_strategy import MavenStrategy
class TestInstrumentForBehavior: class TestInstrumentForBehavior:
@ -2177,7 +2177,7 @@ public class AccentTest {
# Skip all E2E tests if Maven is not available # Skip all E2E tests if Maven is not available
requires_maven = pytest.mark.skipif( requires_maven = pytest.mark.skipif(
MavenStrategy().find_executable(Path(".")) is None, reason="Maven not found - skipping execution tests" MavenStrategy().find_executable(Path()) is None, reason="Maven not found - skipping execution tests"
) )
@ -3485,3 +3485,444 @@ public class SpinWaitTest__perfonlyinstrumented {
assert math.isclose(duration, 100_000_000, rel_tol=0.15), ( assert math.isclose(duration, 100_000_000, rel_tol=0.15), (
f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)" f"Long spin measured {duration}ns, expected ~100_000_000ns (15% tolerance)"
) )
class TestVoidMethodInstrumentation:
"""Tests for void method instrumentation — behavior mode captures receiver state."""
def test_behavior_mode_void_method_serializes_receiver(self, tmp_path: Path):
"""Void method instrumentation should serialize the receiver, not a return value."""
source_file = (tmp_path / "Sorter.java").resolve()
source_file.write_text(
"public class Sorter {\n"
" public void sort(int[] data) {\n"
" java.util.Arrays.sort(data);\n"
" }\n"
"}\n",
encoding="utf-8",
)
test_file = (tmp_path / "SorterTest.java").resolve()
test_source = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class SorterTest {\n"
" @Test\n"
" public void testSort() {\n"
" Sorter sorter = new Sorter();\n"
" int[] data = {3, 1, 2};\n"
" sorter.sort(data);\n"
" }\n"
"}\n"
)
test_file.write_text(test_source, encoding="utf-8")
func = FunctionToOptimize(
function_name="sort",
file_path=source_file,
starting_line=2,
ending_line=4,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=test_source, function_to_optimize=func, mode="behavior", test_path=test_file
)
assert success is True
assert result == (
"import org.junit.jupiter.api.Test;\n"
"import java.sql.Connection;\n"
"import java.sql.DriverManager;\n"
"import java.sql.PreparedStatement;\n"
"\n"
'@SuppressWarnings("CheckReturnValue")\n'
"public class SorterTest__perfinstrumented {\n"
" @Test\n"
" public void testSort() {\n"
" // Codeflash behavior instrumentation\n"
' int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n'
" int _cf_iter1 = 1;\n"
' String _cf_mod1 = "SorterTest__perfinstrumented";\n'
' String _cf_cls1 = "SorterTest__perfinstrumented";\n'
' String _cf_fn1 = "sort";\n'
' String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE");\n'
' String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION");\n'
' if (_cf_testIteration1 == null) _cf_testIteration1 = "0";\n'
' String _cf_test1 = "testSort";\n'
" Sorter sorter = new Sorter();\n"
" int[] data = {3, 1, 2};\n"
" long _cf_end1_1 = -1;\n"
" long _cf_start1_1 = 0;\n"
" byte[] _cf_serializedResult1_1 = null;\n"
' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L12_1" + "######$!");\n'
" try {\n"
" _cf_start1_1 = System.nanoTime();\n"
" sorter.sort(data);\n"
" _cf_end1_1 = System.nanoTime();\n"
" _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{sorter, data});\n"
" } finally {\n"
" long _cf_end1_1_finally = System.nanoTime();\n"
" long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1;\n"
' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L12_1" + "######!");\n'
" // Write to SQLite if output file is set\n"
" if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) {\n"
" try {\n"
' Class.forName("org.sqlite.JDBC");\n'
' try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) {\n'
" try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) {\n"
' _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" +\n'
' "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +\n'
' "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +\n'
' "runtime INTEGER, return_value BLOB, verification_type TEXT)");\n'
" }\n"
' String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";\n'
" try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) {\n"
" _cf_pstmt1_1.setString(1, _cf_mod1);\n"
" _cf_pstmt1_1.setString(2, _cf_cls1);\n"
" _cf_pstmt1_1.setString(3, _cf_test1);\n"
" _cf_pstmt1_1.setString(4, _cf_fn1);\n"
" _cf_pstmt1_1.setInt(5, _cf_loop1);\n"
' _cf_pstmt1_1.setString(6, "L12_1");\n'
" _cf_pstmt1_1.setLong(7, _cf_dur1_1);\n"
" _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1);\n"
' _cf_pstmt1_1.setString(9, "void_state");\n'
" _cf_pstmt1_1.executeUpdate();\n"
" }\n"
" }\n"
" } catch (Exception _cf_e1_1) {\n"
' System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage());\n'
" }\n"
" }\n"
" }\n"
" }\n"
"}\n"
)
def test_behavior_mode_void_method_implicit_this_receiver(self, tmp_path: Path):
"""Void method with no explicit receiver uses 'this' for serialization."""
source_file = (tmp_path / "Container.java").resolve()
source_file.write_text(
"public class Container {\n"
" public void clear() {\n"
" // clears internal state\n"
" }\n"
"}\n",
encoding="utf-8",
)
test_file = (tmp_path / "ContainerTest.java").resolve()
test_source = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class ContainerTest {\n"
" @Test\n"
" public void testClear() {\n"
" clear();\n"
" }\n"
"}\n"
)
test_file.write_text(test_source, encoding="utf-8")
func = FunctionToOptimize(
function_name="clear",
file_path=source_file,
starting_line=2,
ending_line=4,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=test_source, function_to_optimize=func, mode="behavior", test_path=test_file
)
assert success is True
assert result == (
"import org.junit.jupiter.api.Test;\n"
"import java.sql.Connection;\n"
"import java.sql.DriverManager;\n"
"import java.sql.PreparedStatement;\n"
"\n"
'@SuppressWarnings("CheckReturnValue")\n'
"public class ContainerTest__perfinstrumented {\n"
" @Test\n"
" public void testClear() {\n"
" // Codeflash behavior instrumentation\n"
' int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n'
" int _cf_iter1 = 1;\n"
' String _cf_mod1 = "ContainerTest__perfinstrumented";\n'
' String _cf_cls1 = "ContainerTest__perfinstrumented";\n'
' String _cf_fn1 = "clear";\n'
' String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE");\n'
' String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION");\n'
' if (_cf_testIteration1 == null) _cf_testIteration1 = "0";\n'
' String _cf_test1 = "testClear";\n'
" long _cf_end1_1 = -1;\n"
" long _cf_start1_1 = 0;\n"
" byte[] _cf_serializedResult1_1 = null;\n"
' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L10_1" + "######$!");\n'
" try {\n"
" _cf_start1_1 = System.nanoTime();\n"
" clear();\n"
" _cf_end1_1 = System.nanoTime();\n"
" _cf_serializedResult1_1 = com.codeflash.Serializer.serialize(new Object[]{this});\n"
" } finally {\n"
" long _cf_end1_1_finally = System.nanoTime();\n"
" long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1;\n"
' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L10_1" + "######!");\n'
" // Write to SQLite if output file is set\n"
" if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) {\n"
" try {\n"
' Class.forName("org.sqlite.JDBC");\n'
' try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) {\n'
" try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) {\n"
' _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" +\n'
' "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +\n'
' "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +\n'
' "runtime INTEGER, return_value BLOB, verification_type TEXT)");\n'
" }\n"
' String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";\n'
" try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) {\n"
" _cf_pstmt1_1.setString(1, _cf_mod1);\n"
" _cf_pstmt1_1.setString(2, _cf_cls1);\n"
" _cf_pstmt1_1.setString(3, _cf_test1);\n"
" _cf_pstmt1_1.setString(4, _cf_fn1);\n"
" _cf_pstmt1_1.setInt(5, _cf_loop1);\n"
' _cf_pstmt1_1.setString(6, "L10_1");\n'
" _cf_pstmt1_1.setLong(7, _cf_dur1_1);\n"
" _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1);\n"
' _cf_pstmt1_1.setString(9, "void_state");\n'
" _cf_pstmt1_1.executeUpdate();\n"
" }\n"
" }\n"
" } catch (Exception _cf_e1_1) {\n"
' System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage());\n'
" }\n"
" }\n"
" }\n"
" }\n"
"}\n"
)
def test_behavior_mode_non_void_still_captures_result(self, tmp_path: Path):
"""Non-void methods should still capture the return value (not the receiver)."""
source_file = (tmp_path / "Calculator.java").resolve()
source_file.write_text(
"public class Calculator {\n"
" public int add(int a, int b) {\n"
" return a + b;\n"
" }\n"
"}\n",
encoding="utf-8",
)
test_file = (tmp_path / "CalculatorTest.java").resolve()
test_source = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class CalculatorTest {\n"
" @Test\n"
" public void testAdd() {\n"
" Calculator calc = new Calculator();\n"
" assertEquals(4, calc.add(2, 2));\n"
" }\n"
"}\n"
)
test_file.write_text(test_source, encoding="utf-8")
func = FunctionToOptimize(
function_name="add",
file_path=source_file,
starting_line=2,
ending_line=4,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=test_source, function_to_optimize=func, mode="behavior", test_path=test_file
)
assert success is True
assert result == (
"import org.junit.jupiter.api.Test;\n"
"import java.sql.Connection;\n"
"import java.sql.DriverManager;\n"
"import java.sql.PreparedStatement;\n"
"\n"
'@SuppressWarnings("CheckReturnValue")\n'
"public class CalculatorTest__perfinstrumented {\n"
" @Test\n"
" public void testAdd() {\n"
" // Codeflash behavior instrumentation\n"
' int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n'
" int _cf_iter1 = 1;\n"
' String _cf_mod1 = "CalculatorTest__perfinstrumented";\n'
' String _cf_cls1 = "CalculatorTest__perfinstrumented";\n'
' String _cf_fn1 = "add";\n'
' String _cf_outputFile1 = System.getenv("CODEFLASH_OUTPUT_FILE");\n'
' String _cf_testIteration1 = System.getenv("CODEFLASH_TEST_ITERATION");\n'
' if (_cf_testIteration1 == null) _cf_testIteration1 = "0";\n'
' String _cf_test1 = "testAdd";\n'
" Calculator calc = new Calculator();\n"
" Object _cf_result1_1 = null;\n"
" long _cf_end1_1 = -1;\n"
" long _cf_start1_1 = 0;\n"
" byte[] _cf_serializedResult1_1 = null;\n"
' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":L11_1" + "######$!");\n'
" try {\n"
" _cf_start1_1 = System.nanoTime();\n"
" _cf_result1_1 = calc.add(2, 2);\n"
" _cf_end1_1 = System.nanoTime();\n"
" _cf_serializedResult1_1 = com.codeflash.Serializer.serialize((Object) _cf_result1_1);\n"
" } finally {\n"
" long _cf_end1_1_finally = System.nanoTime();\n"
" long _cf_dur1_1 = (_cf_end1_1 != -1 ? _cf_end1_1 : _cf_end1_1_finally) - _cf_start1_1;\n"
' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + "L11_1" + "######!");\n'
" // Write to SQLite if output file is set\n"
" if (_cf_outputFile1 != null && !_cf_outputFile1.isEmpty()) {\n"
" try {\n"
' Class.forName("org.sqlite.JDBC");\n'
' try (Connection _cf_conn1_1 = DriverManager.getConnection("jdbc:sqlite:" + _cf_outputFile1)) {\n'
" try (java.sql.Statement _cf_stmt1_1 = _cf_conn1_1.createStatement()) {\n"
' _cf_stmt1_1.execute("CREATE TABLE IF NOT EXISTS test_results (" +\n'
' "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " +\n'
' "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " +\n'
' "runtime INTEGER, return_value BLOB, verification_type TEXT)");\n'
" }\n"
' String _cf_sql1_1 = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";\n'
" try (PreparedStatement _cf_pstmt1_1 = _cf_conn1_1.prepareStatement(_cf_sql1_1)) {\n"
" _cf_pstmt1_1.setString(1, _cf_mod1);\n"
" _cf_pstmt1_1.setString(2, _cf_cls1);\n"
" _cf_pstmt1_1.setString(3, _cf_test1);\n"
" _cf_pstmt1_1.setString(4, _cf_fn1);\n"
" _cf_pstmt1_1.setInt(5, _cf_loop1);\n"
' _cf_pstmt1_1.setString(6, "L11_1");\n'
" _cf_pstmt1_1.setLong(7, _cf_dur1_1);\n"
" _cf_pstmt1_1.setBytes(8, _cf_serializedResult1_1);\n"
' _cf_pstmt1_1.setString(9, "function_call");\n'
" _cf_pstmt1_1.executeUpdate();\n"
" }\n"
" }\n"
" } catch (Exception _cf_e1_1) {\n"
' System.err.println("CodeflashHelper: SQLite error: " + _cf_e1_1.getMessage());\n'
" }\n"
" }\n"
" }\n"
" assertEquals(4, (int)_cf_result1_1);\n"
" }\n"
"}\n"
)
def test_void_discovery_with_require_return_false(self):
"""Void methods should be discovered when require_return=False."""
from codeflash.languages.base import FunctionFilterCriteria
from codeflash.languages.java.discovery import discover_functions_from_source
source = (
"public class Example {\n"
" public void doSomething() {\n"
' System.out.println("hello");\n'
" }\n"
"\n"
" public int getValue() {\n"
" return 42;\n"
" }\n"
"}\n"
)
criteria_no_return = FunctionFilterCriteria(require_return=False)
functions = discover_functions_from_source(source, filter_criteria=criteria_no_return)
method_names = {f.function_name for f in functions}
assert "doSomething" in method_names
assert "getValue" in method_names
criteria_require_return = FunctionFilterCriteria(require_return=True)
functions = discover_functions_from_source(source, filter_criteria=criteria_require_return)
method_names = {f.function_name for f in functions}
assert "doSomething" not in method_names
assert "getValue" in method_names
def test_performance_mode_void_method_generates_valid_code(self, tmp_path: Path):
"""Void methods in performance mode should generate valid timing code."""
source_file = (tmp_path / "Sorter.java").resolve()
source_file.write_text(
"public class Sorter {\n"
" public void sort(int[] data) {\n"
" java.util.Arrays.sort(data);\n"
" }\n"
"}\n",
encoding="utf-8",
)
test_file = (tmp_path / "SorterTest.java").resolve()
test_source = (
"import org.junit.jupiter.api.Test;\n"
"\n"
"public class SorterTest {\n"
" @Test\n"
" public void testSort() {\n"
" Sorter sorter = new Sorter();\n"
" int[] data = {3, 1, 2};\n"
" sorter.sort(data);\n"
" }\n"
"}\n"
)
test_file.write_text(test_source, encoding="utf-8")
func = FunctionToOptimize(
function_name="sort",
file_path=source_file,
starting_line=2,
ending_line=4,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=test_source, function_to_optimize=func, mode="performance", test_path=test_file
)
assert success is True
assert result == (
"import org.junit.jupiter.api.Test;\n"
"\n"
'@SuppressWarnings("CheckReturnValue")\n'
"public class SorterTest__perfonlyinstrumented {\n"
" @Test\n"
" public void testSort() {\n"
" // Codeflash timing instrumentation with inner loop for JIT warmup\n"
' int _cf_outerLoop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));\n'
' int _cf_maxInnerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10"));\n'
' int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "10"));\n'
' String _cf_mod1 = "SorterTest__perfonlyinstrumented";\n'
' String _cf_cls1 = "SorterTest__perfonlyinstrumented";\n'
' String _cf_test1 = "testSort";\n'
' String _cf_fn1 = "sort";\n'
" \n"
" Sorter sorter = new Sorter();\n"
" int[] data = {3, 1, 2};\n"
" for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {\n"
" int _cf_loopId1 = _cf_outerLoop1 * _cf_maxInnerIterations1 + _cf_i1;\n"
' System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L9_1" + "######$!");\n'
" long _cf_end1 = -1;\n"
" long _cf_start1 = 0;\n"
" try {\n"
" _cf_start1 = System.nanoTime();\n"
" sorter.sort(data);\n"
" _cf_end1 = System.nanoTime();\n"
" } finally {\n"
" long _cf_end1_finally = System.nanoTime();\n"
" long _cf_dur1 = (_cf_end1 != -1 ? _cf_end1 : _cf_end1_finally) - _cf_start1;\n"
' System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + "." + _cf_test1 + ":" + _cf_fn1 + ":" + _cf_loopId1 + ":" + "L9_1" + ":" + _cf_dur1 + "######!");\n'
" }\n"
" }\n"
" }\n"
"}\n"
)