feat: implement bubble sort optimization and corresponding tests in Java
This commit is contained in:
parent
41814cd24b
commit
f42b58bb98
7 changed files with 316 additions and 2 deletions
105
.github/workflows/e2e-java-void-optimization.yaml
vendored
Normal file
105
.github/workflows/e2e-java-void-optimization.yaml
vendored
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
name: E2E - Java Void Optimization (No Git)
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'codeflash/languages/java/**'
|
||||
- 'codeflash/languages/base.py'
|
||||
- 'codeflash/languages/registry.py'
|
||||
- 'codeflash/optimization/**'
|
||||
- 'codeflash/verification/**'
|
||||
- 'code_to_optimize/java/**'
|
||||
- 'codeflash-java-runtime/**'
|
||||
- 'tests/scripts/end_to_end_test_java_void_optimization.py'
|
||||
- '.github/workflows/e2e-java-void-optimization.yaml'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
java-void-optimization-no-git:
|
||||
environment: ${{ (github.event_name == 'workflow_dispatch' || (contains(toJSON(github.event.pull_request.files.*.filename), '.github/workflows/') && github.event.pull_request.user.login != 'misrasaurabh1' && github.event.pull_request.user.login != 'KRRT7')) && 'external-trusted-contributors' || '' }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 70
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Validate PR
|
||||
env:
|
||||
PR_AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
PR_STATE: ${{ github.event.pull_request.state }}
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha }}
|
||||
HEAD_SHA: ${{ github.event.pull_request.head.sha }}
|
||||
run: |
|
||||
if git diff --name-only "$BASE_SHA" "$HEAD_SHA" | grep -q "^.github/workflows/"; then
|
||||
echo "⚠️ Workflow changes detected."
|
||||
echo "PR Author: $PR_AUTHOR"
|
||||
if [[ "$PR_AUTHOR" == "misrasaurabh1" || "$PR_AUTHOR" == "KRRT7" ]]; then
|
||||
echo "✅ Authorized user ($PR_AUTHOR). Proceeding."
|
||||
elif [[ "$PR_STATE" == "open" ]]; then
|
||||
echo "✅ PR is open. Proceeding."
|
||||
else
|
||||
echo "⛔ Unauthorized user ($PR_AUTHOR) attempting to modify workflows. Exiting."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "✅ No workflow file changes detected. Proceeding."
|
||||
fi
|
||||
|
||||
- name: Set up JDK 11
|
||||
uses: actions/setup-java@v4
|
||||
with:
|
||||
java-version: '11'
|
||||
distribution: 'temurin'
|
||||
cache: maven
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: uv sync
|
||||
|
||||
- name: Build codeflash-runtime JAR
|
||||
run: |
|
||||
cd codeflash-java-runtime
|
||||
mvn clean package -q -DskipTests
|
||||
mvn install -q -DskipTests
|
||||
|
||||
- name: Verify Java installation
|
||||
run: |
|
||||
java -version
|
||||
mvn --version
|
||||
|
||||
- name: Remove .git
|
||||
run: |
|
||||
if [ -d ".git" ]; then
|
||||
sudo rm -rf .git
|
||||
echo ".git directory removed."
|
||||
else
|
||||
echo ".git directory does not exist."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: Run Codeflash to optimize void function
|
||||
run: |
|
||||
uv run python tests/scripts/end_to_end_test_java_void_optimization.py
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
package com.example;
|
||||
|
||||
public class InPlaceSorter {
|
||||
|
||||
public static void bubbleSortInPlace(int[] arr) {
|
||||
if (arr == null || arr.length <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int n = arr.length;
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n - 1; j++) {
|
||||
if (arr[j] > arr[j + 1]) {
|
||||
int temp = arr[j];
|
||||
arr[j] = arr[j + 1];
|
||||
arr[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
package com.example;
|
||||
|
||||
public class InstanceSorter {
|
||||
|
||||
public void bubbleSortInPlace(int[] arr) {
|
||||
if (arr == null || arr.length <= 1) {
|
||||
return;
|
||||
}
|
||||
|
||||
int n = arr.length;
|
||||
for (int i = 0; i < n; i++) {
|
||||
for (int j = 0; j < n - 1; j++) {
|
||||
if (arr[j] > arr[j + 1]) {
|
||||
int temp = arr[j];
|
||||
arr[j] = arr[j + 1];
|
||||
arr[j + 1] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class InPlaceSorterTest {
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlace() {
|
||||
int[] arr = {5, 3, 1, 4, 2};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceAlreadySorted() {
|
||||
int[] arr = {1, 2, 3, 4, 5};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceReversed() {
|
||||
int[] arr = {5, 4, 3, 2, 1};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceWithDuplicates() {
|
||||
int[] arr = {3, 2, 4, 1, 3, 2};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 2, 3, 3, 4}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceWithNegatives() {
|
||||
int[] arr = {3, -2, 7, 0, -5};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{-5, -2, 0, 3, 7}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceSingleElement() {
|
||||
int[] arr = {42};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{42}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceEmpty() {
|
||||
int[] arr = {};
|
||||
InPlaceSorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceNull() {
|
||||
InPlaceSorter.bubbleSortInPlace(null);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,69 @@
|
|||
package com.example;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
class InstanceSorterTest {
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlace() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {5, 3, 1, 4, 2};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceAlreadySorted() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {1, 2, 3, 4, 5};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceReversed() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {5, 4, 3, 2, 1};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 3, 4, 5}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceWithDuplicates() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {3, 2, 4, 1, 3, 2};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{1, 2, 2, 3, 3, 4}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceWithNegatives() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {3, -2, 7, 0, -5};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{-5, -2, 0, 3, 7}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceSingleElement() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {42};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{42}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceEmpty() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
int[] arr = {};
|
||||
sorter.bubbleSortInPlace(arr);
|
||||
assertArrayEquals(new int[]{}, arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testBubbleSortInPlaceNull() {
|
||||
InstanceSorter sorter = new InstanceSorter();
|
||||
sorter.bubbleSortInPlace(null);
|
||||
}
|
||||
}
|
||||
|
|
@ -340,6 +340,7 @@ def wrap_target_calls_with_treesitter(
|
|||
is_void = target_return_type == "void"
|
||||
var_name = f"_cf_result{iter_id}_{call_counter}"
|
||||
receiver = call.get("receiver", "this")
|
||||
arg_texts: list[str] = call.get("arg_texts", [])
|
||||
cast_type = _infer_array_cast_type(orig_line)
|
||||
if not cast_type and target_return_type and not is_void:
|
||||
cast_type = target_return_type
|
||||
|
|
@ -347,13 +348,22 @@ def wrap_target_calls_with_treesitter(
|
|||
|
||||
if is_void:
|
||||
bare_call_stmt = f"{call['full_call']};"
|
||||
# For void methods, serialize the post-call state to capture side effects.
|
||||
# For instance methods (receiver is a variable), serialize the receiver.
|
||||
# For static methods (receiver is a class name), serialize the arguments
|
||||
# since the class name itself is not a value and can't be cast to Object.
|
||||
is_static_call = receiver != "this" and receiver[:1].isupper()
|
||||
if is_static_call and arg_texts:
|
||||
serialize_target = f"new Object[]{{{', '.join(arg_texts)}}}"
|
||||
else:
|
||||
serialize_target = f"(Object) {receiver}"
|
||||
if precise_call_timing:
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize((Object) {receiver});"
|
||||
serialize_stmt = f"_cf_serializedResult{iter_id}_{call_counter} = com.codeflash.Serializer.serialize({serialize_target});"
|
||||
start_stmt = f"_cf_start{iter_id}_{call_counter} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id}_{call_counter} = System.nanoTime();"
|
||||
else:
|
||||
serialize_stmt = (
|
||||
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize((Object) {receiver});"
|
||||
f"_cf_serializedResult{iter_id} = com.codeflash.Serializer.serialize({serialize_target});"
|
||||
)
|
||||
start_stmt = f"_cf_start{iter_id} = System.nanoTime();"
|
||||
end_stmt = f"_cf_end{iter_id} = System.nanoTime();"
|
||||
|
|
@ -493,6 +503,13 @@ def _collect_calls(
|
|||
es_end = parent.end_byte - prefix_len
|
||||
object_node = node.child_by_field_name("object")
|
||||
receiver = analyzer.get_node_text(object_node, wrapper_bytes) if object_node else "this"
|
||||
# Extract argument texts for void method serialization
|
||||
args_node = node.child_by_field_name("arguments")
|
||||
arg_texts: list[str] = []
|
||||
if args_node:
|
||||
for child in args_node.children:
|
||||
if child.type not in ("(", ")", ","):
|
||||
arg_texts.append(analyzer.get_node_text(child, wrapper_bytes))
|
||||
out.append(
|
||||
{
|
||||
"start_byte": start,
|
||||
|
|
@ -504,6 +521,7 @@ def _collect_calls(
|
|||
"es_start_byte": es_start,
|
||||
"es_end_byte": es_end,
|
||||
"receiver": receiver,
|
||||
"arg_texts": arg_texts,
|
||||
}
|
||||
)
|
||||
for child in node.children:
|
||||
|
|
|
|||
18
tests/scripts/end_to_end_test_java_void_optimization.py
Normal file
18
tests/scripts/end_to_end_test_java_void_optimization.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from end_to_end_test_utilities import TestConfig, run_codeflash_command, run_with_retries
|
||||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
config = TestConfig(
|
||||
file_path="src/main/java/com/example/InPlaceSorter.java",
|
||||
function_name="bubbleSortInPlace",
|
||||
min_improvement_x=0.70,
|
||||
)
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "java").resolve()
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 70))))
|
||||
Loading…
Reference in a new issue