codeflash/tests/test_languages/test_java/test_instrumentation.py
2026-02-16 08:32:55 +02:00

2435 lines
84 KiB
Python

"""Tests for Java code instrumentation.
Tests the instrumentation functions with exact string equality assertions
to ensure the generated code matches expected output exactly.
Also includes end-to-end execution tests that:
1. Instrument Java code
2. Execute with Maven
3. Parse JUnit XML and timing markers from stdout
4. Verify the parsed results are correct
"""
import os
import re
import shutil
import subprocess
from pathlib import Path
import pytest
# Set API key for tests that instantiate Optimizer
os.environ["CODEFLASH_API_KEY"] = "cf-test-key"
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
from codeflash.models.function_types import FunctionParent
from codeflash.languages.java.build_tools import find_maven_executable
from codeflash.languages.java.discovery import discover_functions_from_source
from codeflash.languages.java.instrumentation import (
_add_behavior_instrumentation,
_add_timing_instrumentation,
create_benchmark_test,
instrument_existing_test,
instrument_for_behavior,
instrument_for_benchmarking,
instrument_generated_java_test,
remove_instrumentation,
)
class TestInstrumentForBehavior:
"""Tests for instrument_for_behavior."""
def test_returns_source_unchanged(self):
"""Test that source is returned unchanged (Java uses JUnit pass/fail)."""
source = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
functions = discover_functions_from_source(source)
result = instrument_for_behavior(source, functions)
assert result == source
def test_no_functions_unchanged(self):
"""Test that source is unchanged when no functions provided."""
source = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
result = instrument_for_behavior(source, [])
assert result == source
class TestInstrumentForBenchmarking:
"""Tests for instrument_for_benchmarking."""
def test_returns_source_unchanged(self):
"""Test that source is returned unchanged (Java uses Maven Surefire timing)."""
source = """import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
"""
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_for_benchmarking(source, func)
assert result == source
class TestInstrumentExistingTest:
"""Tests for instrument_existing_test with exact string equality."""
def test_instrument_behavior_mode_simple(self, tmp_path: Path):
"""Test instrumenting a simple test in behavior mode."""
test_file = tmp_path / "CalculatorTest.java"
source = """import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="behavior",
test_path=test_file,
)
assert success is True
# Behavior mode now adds SQLite instrumentation
# Verify key elements are present
assert "import java.sql.Connection;" in result
assert "import java.sql.DriverManager;" in result
assert "import java.sql.PreparedStatement;" in result
# Note: java.sql.Statement is used fully qualified to avoid conflicts with other Statement classes
assert "java.sql.Statement" in result
assert "class CalculatorTest__perfinstrumented" in result
assert "CODEFLASH_OUTPUT_FILE" in result
assert "CREATE TABLE IF NOT EXISTS test_results" in result
assert "INSERT INTO test_results VALUES" in result
assert "_cf_loop1" in result
assert "_cf_iter1" in result
assert "System.nanoTime()" in result
assert "com.codeflash.Serializer.serialize((Object)" in result
def test_instrument_behavior_mode_assert_throws_expression_lambda(self, tmp_path: Path):
"""Test that assertThrows expression lambdas are not broken by behavior instrumentation.
When a target function call is inside an expression lambda (e.g., () -> Fibonacci.fibonacci(-1)),
the instrumentation must NOT wrap it in a variable assignment, as that would turn
the void-compatible lambda into a value-returning lambda and break compilation.
"""
test_file = tmp_path / "FibonacciTest.java"
source = """import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testNegativeInput_ThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> Fibonacci.fibonacci(-1));
}
@Test
void testZeroInput_ReturnsZero() {
assertEquals(0L, Fibonacci.fibonacci(0));
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="fibonacci",
file_path=tmp_path / "Fibonacci.java",
starting_line=1,
ending_line=10,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="behavior",
test_path=test_file,
)
assert success is True
# The assertThrows lambda line should remain unchanged (not wrapped in variable assignment)
assert "() -> Fibonacci.fibonacci(-1)" in result
# The non-lambda call should still be wrapped
assert "_cf_result" in result
def test_instrument_behavior_mode_assert_throws_block_lambda(self, tmp_path: Path):
"""Test that assertThrows block lambdas are not broken by behavior instrumentation.
When a target function call is inside a block lambda (e.g., () -> { func(); }),
the instrumentation must NOT wrap it in a variable assignment.
"""
test_file = tmp_path / "FibonacciTest.java"
source = """import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
void testNegativeInput_ThrowsIllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> {
Fibonacci.fibonacci(-1);
});
}
@Test
void testZeroInput_ReturnsZero() {
assertEquals(0L, Fibonacci.fibonacci(0));
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="fibonacci",
file_path=tmp_path / "Fibonacci.java",
starting_line=1,
ending_line=10,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="behavior",
test_path=test_file,
)
assert success is True
assert "Fibonacci.fibonacci(-1);" in result
assert "() -> {" in result
lines_with_cf_result = [l for l in result.split("\n") if "var _cf_result" in l and "Fibonacci.fibonacci(0)" in l]
assert len(lines_with_cf_result) > 0, "Non-lambda call to fibonacci(0) should be wrapped"
def test_instrument_performance_mode_simple(self, tmp_path: Path):
"""Test instrumenting a simple test in performance mode with inner loop."""
test_file = tmp_path / "CalculatorTest.java"
source = """import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
public class CalculatorTest__perfonlyinstrumented {
@Test
public void testAdd() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "CalculatorTest";
String _cf_cls1 = "CalculatorTest";
String _cf_fn1 = "add";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_instrument_performance_mode_multiple_tests(self, tmp_path: Path):
"""Test instrumenting multiple test methods in performance mode with inner loop."""
test_file = tmp_path / "MathTest.java"
source = """import org.junit.jupiter.api.Test;
public class MathTest {
@Test
public void testAdd() {
assertEquals(4, add(2, 2));
}
@Test
public void testSubtract() {
assertEquals(0, subtract(2, 2));
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="calculate",
file_path=tmp_path / "Math.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
public class MathTest__perfonlyinstrumented {
@Test
public void testAdd() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "MathTest";
String _cf_cls1 = "MathTest";
String _cf_fn1 = "calculate";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
assertEquals(4, add(2, 2));
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
@Test
public void testSubtract() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod2 = "MathTest";
String _cf_cls2 = "MathTest";
String _cf_fn2 = "calculate";
for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
assertEquals(0, subtract(2, 2));
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_instrument_preserves_annotations(self, tmp_path: Path):
"""Test that annotations other than @Test are preserved with inner loop."""
test_file = tmp_path / "ServiceTest.java"
source = """import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Disabled;
public class ServiceTest {
@Test
@DisplayName("Test service call")
public void testService() {
service.call();
}
@Disabled
@Test
public void testDisabled() {
service.other();
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="call",
file_path=tmp_path / "Service.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Disabled;
public class ServiceTest__perfonlyinstrumented {
@Test
@DisplayName("Test service call")
public void testService() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "ServiceTest";
String _cf_cls1 = "ServiceTest";
String _cf_fn1 = "call";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
service.call();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
@Disabled
@Test
public void testDisabled() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod2 = "ServiceTest";
String _cf_cls2 = "ServiceTest";
String _cf_fn2 = "call";
for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
service.other();
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_missing_file(self, tmp_path: Path):
"""Test handling missing test file."""
test_file = tmp_path / "NonExistent.java"
func = FunctionToOptimize(
function_name="add",
file_path=tmp_path / "Calculator.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
with pytest.raises(ValueError):
instrument_existing_test(
test_string="",
function_to_optimize=func,
mode="behavior",
)
class TestKryoSerializerUsage:
"""Tests for Kryo Serializer usage in behavior mode."""
def test_serializer_used_for_return_values(self):
"""Test that captured return values use com.codeflash.Serializer.serialize()."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_behavior_instrumentation(source, "MyTest", "foo")
assert "com.codeflash.Serializer.serialize((Object)" in result
# Should NOT use old _cfSerialize helper
assert "_cfSerialize" not in result
def test_byte_array_result_variable(self):
"""Test that the serialized result variable is byte[] not String."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_behavior_instrumentation(source, "MyTest", "foo")
assert "byte[] _cf_serializedResult" in result
assert "String _cf_serializedResult" not in result
def test_blob_column_in_schema(self):
"""Test that the SQLite schema uses BLOB for return_value column."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_behavior_instrumentation(source, "MyTest", "foo")
assert "return_value BLOB" in result
assert "return_value TEXT" not in result
def test_set_bytes_for_blob_write(self):
"""Test that setBytes is used to write BLOB data to SQLite."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_behavior_instrumentation(source, "MyTest", "foo")
assert "setBytes(8, _cf_serializedResult" in result
# Should NOT use setString for return value
assert "setString(8, _cf_serializedResult" not in result
def test_no_inline_helper_injected(self):
"""Test that no inline _cfSerialize helper method is injected."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_behavior_instrumentation(source, "MyTest", "foo")
assert "private static String _cfSerialize" not in result
def test_serializer_not_used_in_performance_mode(self):
"""Test that Serializer is NOT used in performance mode (only behavior)."""
source = """import org.junit.jupiter.api.Test;
public class MyTest {
@Test
public void testFoo() {
assertEquals(0, obj.foo());
}
}
"""
result = _add_timing_instrumentation(source, "MyTest", "foo")
assert "Serializer.serialize" not in result
assert "_cfSerialize" not in result
class TestAddTimingInstrumentation:
"""Tests for _add_timing_instrumentation helper function with inner loop."""
def test_single_test_method(self):
"""Test timing instrumentation for a single test method with inner loop."""
source = """public class SimpleTest {
@Test
public void testSomething() {
doSomething();
}
}
"""
result = _add_timing_instrumentation(source, "SimpleTest", "targetFunc")
expected = """public class SimpleTest {
@Test
public void testSomething() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "SimpleTest";
String _cf_cls1 = "SimpleTest";
String _cf_fn1 = "targetFunc";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
doSomething();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert result == expected
def test_multiple_test_methods(self):
"""Test timing instrumentation for multiple test methods with inner loop."""
source = """public class MultiTest {
@Test
public void testFirst() {
first();
}
@Test
public void testSecond() {
second();
}
}
"""
result = _add_timing_instrumentation(source, "MultiTest", "func")
expected = """public class MultiTest {
@Test
public void testFirst() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "MultiTest";
String _cf_cls1 = "MultiTest";
String _cf_fn1 = "func";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
first();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
@Test
public void testSecond() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod2 = "MultiTest";
String _cf_cls2 = "MultiTest";
String _cf_fn2 = "func";
for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
second();
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!");
}
}
}
}
"""
assert result == expected
def test_timing_markers_format(self):
"""Test that timing markers have the correct format with inner loop."""
source = """public class MarkerTest {
@Test
public void testMarkers() {
action();
}
}
"""
result = _add_timing_instrumentation(source, "TestClass", "targetMethod")
expected = """public class MarkerTest {
@Test
public void testMarkers() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "TestClass";
String _cf_cls1 = "TestClass";
String _cf_fn1 = "targetMethod";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
action();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert result == expected
class TestCreateBenchmarkTest:
"""Tests for create_benchmark_test."""
def test_create_benchmark(self):
"""Test creating a benchmark test."""
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = create_benchmark_test(
func,
test_setup_code="Calculator calc = new Calculator();",
invocation_code="calc.add(2, 2)",
iterations=1000,
)
expected = """
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.DisplayName;
/**
* Benchmark test for add.
* Generated by CodeFlash.
*/
public class TargetBenchmark {
@Test
@DisplayName("Benchmark add")
public void benchmarkAdd() {
Calculator calc = new Calculator();
// Warmup phase
for (int i = 0; i < 100; i++) {
calc.add(2, 2);
}
// Measurement phase
long startTime = System.nanoTime();
for (int i = 0; i < 1000; i++) {
calc.add(2, 2);
}
long endTime = System.nanoTime();
long totalNanos = endTime - startTime;
long avgNanos = totalNanos / 1000;
System.out.println("CODEFLASH_BENCHMARK:add:total_ns=" + totalNanos + ",avg_ns=" + avgNanos + ",iterations=1000");
}
}
"""
assert result == expected
def test_create_benchmark_different_iterations(self):
"""Test benchmark with different iteration count."""
func = FunctionToOptimize(
function_name="multiply",
file_path=Path("Math.java"),
starting_line=1,
ending_line=3,
parents=[],
is_method=True,
language="java",
)
result = create_benchmark_test(
func,
test_setup_code="",
invocation_code="multiply(5, 3)",
iterations=5000,
)
# Note: Empty test_setup_code still has 8-space indentation on its line
expected = (
"\n"
"import org.junit.jupiter.api.Test;\n"
"import org.junit.jupiter.api.DisplayName;\n"
"\n"
"/**\n"
" * Benchmark test for multiply.\n"
" * Generated by CodeFlash.\n"
" */\n"
"public class TargetBenchmark {\n"
"\n"
" @Test\n"
" @DisplayName(\"Benchmark multiply\")\n"
" public void benchmarkMultiply() {\n"
" \n" # Empty test_setup_code with 8-space indent
"\n"
" // Warmup phase\n"
" for (int i = 0; i < 500; i++) {\n"
" multiply(5, 3);\n"
" }\n"
"\n"
" // Measurement phase\n"
" long startTime = System.nanoTime();\n"
" for (int i = 0; i < 5000; i++) {\n"
" multiply(5, 3);\n"
" }\n"
" long endTime = System.nanoTime();\n"
"\n"
" long totalNanos = endTime - startTime;\n"
" long avgNanos = totalNanos / 5000;\n"
"\n"
" System.out.println(\"CODEFLASH_BENCHMARK:multiply:total_ns=\" + totalNanos + \",avg_ns=\" + avgNanos + \",iterations=5000\");\n"
" }\n"
"}\n"
)
assert result == expected
class TestRemoveInstrumentation:
"""Tests for remove_instrumentation."""
def test_returns_source_unchanged(self):
"""Test that source is returned unchanged (no-op for Java)."""
source = """import com.codeflash.CodeFlash;
import org.junit.jupiter.api.Test;
public class Test {}
"""
result = remove_instrumentation(source)
assert result == source
def test_preserves_regular_code(self):
"""Test that regular code is preserved."""
source = """public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
"""
result = remove_instrumentation(source)
assert result == source
class TestInstrumentGeneratedJavaTest:
"""Tests for instrument_generated_java_test."""
def test_instrument_generated_test_behavior_mode(self):
"""Test instrumenting generated test in behavior mode.
Behavior mode should:
1. Remove assertions containing the target function call
2. Capture the function return value instead
3. Rename the class with __perfinstrumented suffix
"""
test_code = """import org.junit.jupiter.api.Test;
public class CalculatorTest {
@Test
public void testAdd() {
assertEquals(4, new Calculator().add(2, 2));
}
}
"""
func = FunctionToOptimize(
function_name="add",
file_path=Path("Calculator.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code,
function_name="add",
qualified_name="Calculator.add",
mode="behavior",
function_to_optimize=func,
)
# Behavior mode now adds full instrumentation (SQLite, timing markers, etc.)
assert "CalculatorTest__perfinstrumented" in result
assert "_cf_result" in result
assert "com.codeflash.Serializer.serialize" in result
assert "CODEFLASH_OUTPUT_FILE" in result
assert "CREATE TABLE IF NOT EXISTS test_results" in result
def test_instrument_generated_test_performance_mode(self):
"""Test instrumenting generated test in performance mode with inner loop."""
test_code = """import org.junit.jupiter.api.Test;
public class GeneratedTest {
@Test
public void testMethod() {
target.method();
}
}
"""
func = FunctionToOptimize(
function_name="method",
file_path=Path("Target.java"),
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
result = instrument_generated_java_test(
test_code,
function_name="method",
qualified_name="Target.method",
mode="performance",
function_to_optimize=func,
)
expected = """import org.junit.jupiter.api.Test;
public class GeneratedTest__perfonlyinstrumented {
@Test
public void testMethod() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "GeneratedTest";
String _cf_cls1 = "GeneratedTest";
String _cf_fn1 = "method";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
target.method();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert result == expected
class TestTimingMarkerParsing:
"""Tests for parsing timing markers from stdout."""
def test_timing_markers_can_be_parsed(self):
"""Test that generated timing markers can be parsed with the standard regex."""
# Simulate stdout from instrumented test
stdout = """
!$######TestModule:TestClass:targetFunc:1:1######$!
Running test...
!######TestModule:TestClass:targetFunc:1:1:12345678######!
"""
# Use the same regex patterns from parse_test_output.py
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
start_matches = start_pattern.findall(stdout)
end_matches = end_pattern.findall(stdout)
assert len(start_matches) == 1
assert len(end_matches) == 1
# Verify parsed values
start = start_matches[0]
assert start[0] == "TestModule"
assert start[1] == "TestClass"
assert start[2] == "targetFunc"
assert start[3] == "1"
assert start[4] == "1"
end = end_matches[0]
assert end[0] == "TestModule"
assert end[1] == "TestClass"
assert end[2] == "targetFunc"
assert end[3] == "1"
assert end[4] == "1"
assert end[5] == "12345678" # Duration in nanoseconds
def test_multiple_timing_markers(self):
"""Test parsing multiple timing markers."""
stdout = """
!$######Module:Class:func:1:1######$!
test 1
!######Module:Class:func:1:1:100000######!
!$######Module:Class:func:2:1######$!
test 2
!######Module:Class:func:2:1:200000######!
!$######Module:Class:func:3:1######$!
test 3
!######Module:Class:func:3:1:150000######!
"""
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
end_matches = end_pattern.findall(stdout)
assert len(end_matches) == 3
# Verify durations
durations = [int(m[5]) for m in end_matches]
assert durations == [100000, 200000, 150000]
def test_inner_loop_timing_markers(self):
"""Test parsing timing markers from inner loop iterations.
With the inner loop, each test method produces N timing markers (one per iteration).
The iterationId (5th field) now represents the inner iteration number (0, 1, 2, ..., N-1).
"""
# Simulate stdout from 3 inner iterations (inner_iterations=3)
stdout = """
!$######Module:Class:func:1:0######$!
iteration 0
!######Module:Class:func:1:0:150000######!
!$######Module:Class:func:1:1######$!
iteration 1
!######Module:Class:func:1:1:50000######!
!$######Module:Class:func:1:2######$!
iteration 2
!######Module:Class:func:1:2:45000######!
"""
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
start_matches = start_pattern.findall(stdout)
end_matches = end_pattern.findall(stdout)
# Should have 3 start and 3 end markers (one per inner iteration)
assert len(start_matches) == 3
assert len(end_matches) == 3
# All markers should have the same loopIndex (1) but different iterationIds (0, 1, 2)
for i, (start, end) in enumerate(zip(start_matches, end_matches)):
assert start[3] == "1" # loopIndex
assert start[4] == str(i) # iterationId (0, 1, 2)
assert end[3] == "1" # loopIndex
assert end[4] == str(i) # iterationId (0, 1, 2)
# Verify durations - iteration 0 is slower (JIT warmup), iterations 1 and 2 are faster
durations = [int(m[5]) for m in end_matches]
assert durations == [150000, 50000, 45000]
# Min runtime logic would select 45000ns (the fastest iteration after JIT warmup)
min_runtime = min(durations)
assert min_runtime == 45000
class TestInstrumentedCodeValidity:
"""Tests to verify that instrumented code is syntactically valid Java with inner loop."""
def test_instrumented_code_has_balanced_braces(self, tmp_path: Path):
"""Test that instrumented code has balanced braces with inner loop."""
test_file = tmp_path / "BraceTest.java"
source = """import org.junit.jupiter.api.Test;
public class BraceTest {
@Test
public void testOne() {
if (true) {
doSomething();
}
}
@Test
public void testTwo() {
for (int i = 0; i < 10; i++) {
process(i);
}
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="process",
file_path=tmp_path / "Processor.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
public class BraceTest__perfonlyinstrumented {
@Test
public void testOne() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "BraceTest";
String _cf_cls1 = "BraceTest";
String _cf_fn1 = "process";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
if (true) {
doSomething();
}
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
@Test
public void testTwo() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod2 = "BraceTest";
String _cf_cls2 = "BraceTest";
String _cf_fn2 = "process";
for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
for (int i = 0; i < 10; i++) {
process(i);
}
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_instrumented_code_preserves_imports(self, tmp_path: Path):
"""Test that imports are preserved in instrumented code with inner loop."""
test_file = tmp_path / "ImportTest.java"
source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.List;
import java.util.ArrayList;
public class ImportTest {
@Test
public void testCollections() {
List<String> list = new ArrayList<>();
assertEquals(0, list.size());
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="size",
file_path=tmp_path / "Collection.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
import java.util.List;
import java.util.ArrayList;
public class ImportTest__perfonlyinstrumented {
@Test
public void testCollections() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "ImportTest";
String _cf_cls1 = "ImportTest";
String _cf_fn1 = "size";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
List<String> list = new ArrayList<>();
assertEquals(0, list.size());
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
class TestEdgeCases:
"""Edge cases for Java instrumentation with inner loop."""
def test_empty_test_method(self, tmp_path: Path):
"""Test instrumenting an empty test method with inner loop."""
test_file = tmp_path / "EmptyTest.java"
source = """import org.junit.jupiter.api.Test;
public class EmptyTest {
@Test
public void testEmpty() {
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="empty",
file_path=tmp_path / "Empty.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
public class EmptyTest__perfonlyinstrumented {
@Test
public void testEmpty() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "EmptyTest";
String _cf_cls1 = "EmptyTest";
String _cf_fn1 = "empty";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_test_with_nested_braces(self, tmp_path: Path):
"""Test instrumenting code with nested braces with inner loop."""
test_file = tmp_path / "NestedTest.java"
source = """import org.junit.jupiter.api.Test;
public class NestedTest {
@Test
public void testNested() {
if (condition) {
for (int i = 0; i < 10; i++) {
if (i > 5) {
process(i);
}
}
}
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="process",
file_path=tmp_path / "Processor.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
public class NestedTest__perfonlyinstrumented {
@Test
public void testNested() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "NestedTest";
String _cf_cls1 = "NestedTest";
String _cf_fn1 = "process";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
if (condition) {
for (int i = 0; i < 10; i++) {
if (i > 5) {
process(i);
}
}
}
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
}
"""
assert success is True
assert result == expected
def test_class_with_inner_class(self, tmp_path: Path):
"""Test instrumenting test class with inner class with inner loop."""
test_file = tmp_path / "InnerClassTest.java"
source = """import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Nested;
public class InnerClassTest {
@Test
public void testOuter() {
outerMethod();
}
@Nested
class InnerTests {
@Test
public void testInner() {
innerMethod();
}
}
}
"""
test_file.write_text(source)
func = FunctionToOptimize(
function_name="testMethod",
file_path=tmp_path / "Target.java",
starting_line=1,
ending_line=5,
parents=[],
is_method=True,
language="java",
)
success, result = instrument_existing_test(
test_string=source,
function_to_optimize=func,
mode="performance",
test_path=test_file,
)
expected = """import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Nested;
public class InnerClassTest__perfonlyinstrumented {
@Test
public void testOuter() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop1 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations1 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod1 = "InnerClassTest";
String _cf_cls1 = "InnerClassTest";
String _cf_fn1 = "testMethod";
for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++) {
System.out.println("!$######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + "######$!");
long _cf_start1 = System.nanoTime();
try {
outerMethod();
} finally {
long _cf_end1 = System.nanoTime();
long _cf_dur1 = _cf_end1 - _cf_start1;
System.out.println("!######" + _cf_mod1 + ":" + _cf_cls1 + ":" + _cf_fn1 + ":" + _cf_loop1 + ":" + _cf_i1 + ":" + _cf_dur1 + "######!");
}
}
}
@Nested
class InnerTests {
@Test
public void testInner() {
// Codeflash timing instrumentation with inner loop for JIT warmup
int _cf_loop2 = Integer.parseInt(System.getenv("CODEFLASH_LOOP_INDEX"));
int _cf_innerIterations2 = Integer.parseInt(System.getenv().getOrDefault("CODEFLASH_INNER_ITERATIONS", "100"));
String _cf_mod2 = "InnerClassTest";
String _cf_cls2 = "InnerClassTest";
String _cf_fn2 = "testMethod";
for (int _cf_i2 = 0; _cf_i2 < _cf_innerIterations2; _cf_i2++) {
System.out.println("!$######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + "######$!");
long _cf_start2 = System.nanoTime();
try {
innerMethod();
} finally {
long _cf_end2 = System.nanoTime();
long _cf_dur2 = _cf_end2 - _cf_start2;
System.out.println("!######" + _cf_mod2 + ":" + _cf_cls2 + ":" + _cf_fn2 + ":" + _cf_loop2 + ":" + _cf_i2 + ":" + _cf_dur2 + "######!");
}
}
}
}
}
"""
assert success is True
assert result == expected
# Skip all E2E tests if Maven is not available
requires_maven = pytest.mark.skipif(
find_maven_executable() is None,
reason="Maven not found - skipping execution tests",
)
@requires_maven
class TestRunAndParseTests:
"""End-to-end tests using the real run_and_parse_tests entry point."""
POM_CONTENT = """<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.example</groupId>
<artifactId>codeflash-test</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>5.9.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.platform</groupId>
<artifactId>junit-platform-console-standalone</artifactId>
<version>1.9.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.xerial</groupId>
<artifactId>sqlite-jdbc</artifactId>
<version>3.44.1.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.10.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.codeflash</groupId>
<artifactId>codeflash-runtime</artifactId>
<version>1.0.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.1.2</version>
<configuration>
<redirectTestOutputToFile>false</redirectTestOutputToFile>
</configuration>
</plugin>
</plugins>
</build>
</project>
"""
@pytest.fixture
def java_project(self, tmp_path: Path):
"""Create a temporary Maven project and set up Java language context."""
from codeflash.languages.base import Language
from codeflash.languages.current import set_current_language
# Force set the language to Java (reset the singleton first)
import codeflash.languages.current as current_module
current_module._current_language = None
set_current_language(Language.JAVA)
# Create Maven project structure
src_dir = tmp_path / "src" / "main" / "java" / "com" / "example"
test_dir = tmp_path / "src" / "test" / "java" / "com" / "example"
src_dir.mkdir(parents=True)
test_dir.mkdir(parents=True)
(tmp_path / "pom.xml").write_text(self.POM_CONTENT, encoding="utf-8")
yield tmp_path, src_dir, test_dir
# Reset language back to Python
current_module._current_language = None
set_current_language(Language.PYTHON)
def test_run_and_parse_behavior_mode(self, java_project):
"""Test run_and_parse_tests in BEHAVIOR mode."""
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
project_root, src_dir, test_dir = java_project
# Create source file
(src_dir / "Calculator.java").write_text("""package com.example;
public class Calculator {
public int add(int a, int b) {
return a + b;
}
}
""", encoding="utf-8")
# Create and instrument test
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class CalculatorTest {
@Test
public void testAdd() {
Calculator calc = new Calculator();
assertEquals(4, calc.add(2, 2));
}
}
"""
test_file = test_dir / "CalculatorTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "Calculator.java",
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
instrumented_file = test_dir / "CalculatorTest__perfinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
# Create Optimizer and FunctionOptimizer
fto = FunctionToOptimize(
function_name="add",
file_path=src_dir / "Calculator.java",
parents=[],
language="java",
)
opt = Optimizer(Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=test_dir,
test_project_root=project_root,
pytest_cmd="pytest",
experiment_id=None,
))
func_optimizer = opt.create_function_optimizer(fto)
assert func_optimizer is not None
func_optimizer.test_files = TestFiles(test_files=[
TestFile(
instrumented_behavior_file_path=instrumented_file,
test_type=TestType.EXISTING_UNIT_TEST,
original_file_path=test_file,
benchmarking_file_path=instrumented_file, # Use same file for behavior tests
)
])
# Run and parse tests
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_results, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Verify results
assert len(test_results.test_results) >= 1
result = test_results.test_results[0]
assert result.did_pass is True
assert result.runtime is not None
assert result.runtime > 0
def test_run_and_parse_performance_mode(self, java_project):
"""Test run_and_parse_tests in PERFORMANCE mode with inner loop timing.
This test verifies the complete performance benchmarking flow:
1. Instruments test with inner loop for JIT warmup
2. Runs with inner_iterations=2 (fast test)
3. Validates multiple timing markers are produced (one per inner iteration)
4. Validates parsed results contain timing data
"""
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
project_root, src_dir, test_dir = java_project
# Create source file
(src_dir / "MathUtils.java").write_text("""package com.example;
public class MathUtils {
public int multiply(int a, int b) {
return a * b;
}
}
""", encoding="utf-8")
# Create and instrument test
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class MathUtilsTest {
@Test
public void testMultiply() {
MathUtils math = new MathUtils();
assertEquals(6, math.multiply(2, 3));
}
}
"""
test_file = test_dir / "MathUtilsTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionToOptimize(
function_name="multiply",
file_path=src_dir / "MathUtils.java",
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success
# Verify instrumented code contains inner loop for JIT warmup
assert "CODEFLASH_INNER_ITERATIONS" in instrumented, "Performance mode should use inner loop"
assert "for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++)" in instrumented
instrumented_file = test_dir / "MathUtilsTest__perfonlyinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
# Create Optimizer and FunctionOptimizer
fto = FunctionToOptimize(
function_name="multiply",
file_path=src_dir / "MathUtils.java",
parents=[],
language="java",
)
opt = Optimizer(Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=test_dir,
test_project_root=project_root,
pytest_cmd="pytest",
experiment_id=None,
))
func_optimizer = opt.create_function_optimizer(fto)
assert func_optimizer is not None
func_optimizer.test_files = TestFiles(test_files=[
TestFile(
instrumented_behavior_file_path=test_file,
test_type=TestType.EXISTING_UNIT_TEST,
original_file_path=test_file,
benchmarking_file_path=instrumented_file,
)
])
# Run performance tests with inner_iterations=2 for fast test
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_INNER_ITERATIONS"] = "2" # Only 2 inner iterations for fast test
test_results, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.PERFORMANCE,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1, # Only 1 outer loop (Maven invocation)
testing_time=1.0,
)
# Should have 2 results (one per inner iteration)
assert len(test_results.test_results) >= 2, (
f"Expected at least 2 results from inner loop (inner_iterations=2), got {len(test_results.test_results)}"
)
# All results should pass with valid timing
runtimes = []
for result in test_results.test_results:
assert result.did_pass is True
assert result.runtime is not None
assert result.runtime > 0
runtimes.append(result.runtime)
# Verify we have multiple timing measurements
assert len(runtimes) >= 2, f"Expected at least 2 runtimes, got {len(runtimes)}"
# Log runtime info (min would be selected for benchmarking comparison)
min_runtime = min(runtimes)
max_runtime = max(runtimes)
print(f"Inner loop runtimes: min={min_runtime}ns, max={max_runtime}ns, count={len(runtimes)}")
def test_run_and_parse_multiple_test_methods(self, java_project):
"""Test run_and_parse_tests with multiple test methods."""
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
project_root, src_dir, test_dir = java_project
# Create source file
(src_dir / "StringUtils.java").write_text("""package com.example;
public class StringUtils {
public String reverse(String s) {
return new StringBuilder(s).reverse().toString();
}
}
""", encoding="utf-8")
# Create test with multiple methods
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class StringUtilsTest {
@Test
public void testReverseHello() {
assertEquals("olleh", new StringUtils().reverse("hello"));
}
@Test
public void testReverseEmpty() {
assertEquals("", new StringUtils().reverse(""));
}
@Test
public void testReverseSingle() {
assertEquals("a", new StringUtils().reverse("a"));
}
}
"""
test_file = test_dir / "StringUtilsTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionToOptimize(
function_name="reverse",
file_path=src_dir / "StringUtils.java",
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
instrumented_file = test_dir / "StringUtilsTest__perfinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
fto = FunctionToOptimize(
function_name="reverse",
file_path=src_dir / "StringUtils.java",
parents=[],
language="java",
)
opt = Optimizer(Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=test_dir,
test_project_root=project_root,
pytest_cmd="pytest",
experiment_id=None,
))
func_optimizer = opt.create_function_optimizer(fto)
func_optimizer.test_files = TestFiles(test_files=[
TestFile(
instrumented_behavior_file_path=instrumented_file,
test_type=TestType.EXISTING_UNIT_TEST,
original_file_path=test_file,
benchmarking_file_path=instrumented_file, # Use same file for behavior tests
)
])
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_results, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Should have results for test methods - at least 1 from JUnit XML parsing
# Note: With behavior mode instrumentation, all 3 tests should be parsed
assert len(test_results.test_results) >= 1, (
f"Expected at least 1 test result but got {len(test_results.test_results)}"
)
for result in test_results.test_results:
assert result.did_pass is True, f"Test {result.id.test_function_name} should have passed"
def test_run_and_parse_failing_test(self, java_project):
"""Test run_and_parse_tests correctly reports failing tests."""
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
project_root, src_dir, test_dir = java_project
# Create source file with a bug
(src_dir / "BrokenCalc.java").write_text("""package com.example;
public class BrokenCalc {
public int add(int a, int b) {
return a + b + 1; // Bug: adds extra 1
}
}
""", encoding="utf-8")
# Create test that will fail
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class BrokenCalcTest {
@Test
public void testAdd() {
BrokenCalc calc = new BrokenCalc();
assertEquals(4, calc.add(2, 2)); // Will fail: 5 != 4
}
}
"""
test_file = test_dir / "BrokenCalcTest.java"
test_file.write_text(test_source, encoding="utf-8")
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "BrokenCalc.java",
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
instrumented_file = test_dir / "BrokenCalcTest__perfinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
fto = FunctionToOptimize(
function_name="add",
file_path=src_dir / "BrokenCalc.java",
parents=[],
language="java",
)
opt = Optimizer(Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=test_dir,
test_project_root=project_root,
pytest_cmd="pytest",
experiment_id=None,
))
func_optimizer = opt.create_function_optimizer(fto)
func_optimizer.test_files = TestFiles(test_files=[
TestFile(
instrumented_behavior_file_path=instrumented_file,
test_type=TestType.EXISTING_UNIT_TEST,
original_file_path=test_file,
benchmarking_file_path=instrumented_file, # Use same file for behavior tests
)
])
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_results, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Should have result for the failing test
assert len(test_results.test_results) >= 1
result = test_results.test_results[0]
assert result.did_pass is False
def test_behavior_mode_writes_to_sqlite(self, java_project):
"""Test that behavior mode correctly writes results to SQLite file."""
import sqlite3
from argparse import Namespace
from codeflash.code_utils.code_utils import get_run_tmp_file
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import TestFile, TestFiles, TestingMode, TestType
from codeflash.optimization.optimizer import Optimizer
# Clean up any existing SQLite files from previous tests
sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
if sqlite_file.exists():
sqlite_file.unlink()
project_root, src_dir, test_dir = java_project
# Create source file
(src_dir / "Counter.java").write_text("""package com.example;
public class Counter {
private int value = 0;
public int increment() {
return ++value;
}
}
""", encoding="utf-8")
# Create test file - single test method for simplicity
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class CounterTest {
@Test
public void testIncrement() {
Counter counter = new Counter();
assertEquals(1, counter.increment());
}
}
"""
test_file = test_dir / "CounterTest.java"
test_file.write_text(test_source, encoding="utf-8")
# Instrument for BEHAVIOR mode (this should include SQLite writing)
func_info = FunctionToOptimize(
function_name="increment",
file_path=src_dir / "Counter.java",
starting_line=6,
ending_line=8,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="behavior", test_path=test_file
)
assert success
# Verify SQLite imports were added
assert "import java.sql.Connection;" in instrumented
assert "import java.sql.DriverManager;" in instrumented
assert "import java.sql.PreparedStatement;" in instrumented
# Verify SQLite writing code was added
assert "CODEFLASH_OUTPUT_FILE" in instrumented
assert "CREATE TABLE IF NOT EXISTS test_results" in instrumented
assert "INSERT INTO test_results VALUES" in instrumented
instrumented_file = test_dir / "CounterTest__perfinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
# Create Optimizer and FunctionOptimizer
fto = FunctionToOptimize(
function_name="increment",
file_path=src_dir / "Counter.java",
parents=[],
language="java",
)
opt = Optimizer(Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=test_dir,
test_project_root=project_root,
pytest_cmd="pytest",
experiment_id=None,
))
func_optimizer = opt.create_function_optimizer(fto)
assert func_optimizer is not None
func_optimizer.test_files = TestFiles(test_files=[
TestFile(
instrumented_behavior_file_path=instrumented_file,
test_type=TestType.EXISTING_UNIT_TEST,
original_file_path=test_file,
benchmarking_file_path=instrumented_file,
)
])
# Run tests
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_results, _ = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
)
# Verify tests passed - at least 1 result from JUnit XML parsing
assert len(test_results.test_results) >= 1, (
f"Expected at least 1 test result but got {len(test_results.test_results)}"
)
for result in test_results.test_results:
assert result.did_pass is True, f"Test {result.id.test_function_name} should have passed"
# Find the SQLite file that was created
# SQLite is created at get_run_tmp_file path
from codeflash.code_utils.code_utils import get_run_tmp_file
sqlite_file = get_run_tmp_file(Path("test_return_values_0.sqlite"))
if not sqlite_file.exists():
# Fall back to checking temp directory for any SQLite files
import tempfile
sqlite_files = list(Path(tempfile.gettempdir()).glob("**/test_return_values_*.sqlite"))
assert len(sqlite_files) >= 1, f"SQLite file should have been created at {sqlite_file} or in temp dir"
sqlite_file = max(sqlite_files, key=lambda p: p.stat().st_mtime)
# Verify SQLite contents
conn = sqlite3.connect(str(sqlite_file))
cursor = conn.cursor()
# Check that test_results table exists and has data
cursor.execute("SELECT COUNT(*) FROM test_results")
count = cursor.fetchone()[0]
assert count >= 1, f"Expected at least 1 result in SQLite, got {count}"
# Check the data structure
cursor.execute("SELECT * FROM test_results")
rows = cursor.fetchall()
for row in rows:
test_module_path, test_class_name, test_function_name, function_getting_tested, \
loop_index, iteration_id, runtime, return_value, verification_type = row
# Verify fields
assert test_module_path == "CounterTest"
assert test_class_name == "CounterTest"
assert function_getting_tested == "increment"
assert loop_index == 1
assert runtime > 0, f"Should have a positive runtime, got {runtime}"
assert verification_type == "function_call" # Updated from "output"
assert return_value is not None, "Return value should be serialized, not null"
assert isinstance(return_value, bytes), f"Expected bytes (Kryo binary), got: {type(return_value)}"
assert len(return_value) > 0, "Kryo-serialized return value should not be empty"
conn.close()
def test_performance_mode_inner_loop_timing_markers(self, java_project):
"""Test that performance mode produces multiple timing markers from inner loop.
This test verifies that:
1. Instrumented code runs inner_iterations=2 times
2. Two timing markers are produced (one per inner iteration)
3. Each marker has a unique iteration ID (0, 1)
4. Both markers have valid durations
"""
from codeflash.languages.java.test_runner import run_benchmarking_tests
project_root, src_dir, test_dir = java_project
# Create a simple function to optimize
(src_dir / "Fibonacci.java").write_text("""package com.example;
public class Fibonacci {
public int fib(int n) {
if (n <= 1) return n;
return fib(n - 1) + fib(n - 2);
}
}
""", encoding="utf-8")
# Create test file
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class FibonacciTest {
@Test
public void testFib() {
Fibonacci fib = new Fibonacci();
assertEquals(5, fib.fib(5));
}
}
"""
test_file = test_dir / "FibonacciTest.java"
test_file.write_text(test_source, encoding="utf-8")
# Instrument for performance mode (adds inner loop)
func_info = FunctionToOptimize(
function_name="fib",
file_path=src_dir / "Fibonacci.java",
starting_line=4,
ending_line=7,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success
# Verify instrumented code contains inner loop
assert "CODEFLASH_INNER_ITERATIONS" in instrumented
assert "for (int _cf_i1 = 0; _cf_i1 < _cf_innerIterations1; _cf_i1++)" in instrumented
instrumented_file = test_dir / "FibonacciTest__perfonlyinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
# Run benchmarking with inner_iterations=2 (fast)
test_env = os.environ.copy()
# Use TestFiles-like object
class MockTestFiles:
def __init__(self, files):
self.test_files = files
class MockTestFile:
def __init__(self, path):
self.benchmarking_file_path = path
self.instrumented_behavior_file_path = path
test_files = MockTestFiles([MockTestFile(instrumented_file)])
result_xml_path, result = run_benchmarking_tests(
test_paths=test_files,
test_env=test_env,
cwd=project_root,
timeout=120,
project_root=project_root,
min_loops=1,
max_loops=1, # Only 1 outer loop
target_duration_seconds=1.0,
inner_iterations=2, # Only 2 inner iterations for fast test
)
# Verify the test ran successfully
assert result.returncode == 0, f"Maven test failed: {result.stderr}"
# Parse timing markers from stdout
stdout = result.stdout
start_pattern = re.compile(r"!\$######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+)######\$!")
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
start_matches = start_pattern.findall(stdout)
end_matches = end_pattern.findall(stdout)
# Should have 2 timing markers (inner_iterations=2)
assert len(start_matches) == 2, f"Expected 2 start markers, got {len(start_matches)}: {start_matches}"
assert len(end_matches) == 2, f"Expected 2 end markers, got {len(end_matches)}: {end_matches}"
# Verify iteration IDs are 0 and 1
iteration_ids = [m[4] for m in start_matches]
assert "0" in iteration_ids, f"Expected iteration ID 0, got: {iteration_ids}"
assert "1" in iteration_ids, f"Expected iteration ID 1, got: {iteration_ids}"
# Verify all markers have the same loop index (1)
loop_indices = [m[3] for m in start_matches]
assert all(idx == "1" for idx in loop_indices), f"Expected all loop indices to be 1, got: {loop_indices}"
# Verify durations are positive
durations = [int(m[5]) for m in end_matches]
assert all(d > 0 for d in durations), f"Expected positive durations, got: {durations}"
def test_performance_mode_multiple_methods_inner_loop(self, java_project):
"""Test inner loop with multiple test methods.
Each test method should run inner_iterations times independently.
This produces 2 test methods x 2 inner iterations = 4 total timing markers.
"""
from codeflash.languages.java.test_runner import run_benchmarking_tests
project_root, src_dir, test_dir = java_project
# Create a simple math class
(src_dir / "MathOps.java").write_text("""package com.example;
public class MathOps {
public int add(int a, int b) {
return a + b;
}
}
""", encoding="utf-8")
# Create test with multiple test methods
test_source = """package com.example;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.*;
public class MathOpsTest {
@Test
public void testAddPositive() {
MathOps math = new MathOps();
assertEquals(5, math.add(2, 3));
}
@Test
public void testAddNegative() {
MathOps math = new MathOps();
assertEquals(-1, math.add(2, -3));
}
}
"""
test_file = test_dir / "MathOpsTest.java"
test_file.write_text(test_source, encoding="utf-8")
# Instrument for performance mode
func_info = FunctionToOptimize(
function_name="add",
file_path=src_dir / "MathOps.java",
starting_line=4,
ending_line=6,
parents=[],
is_method=True,
language="java",
)
success, instrumented = instrument_existing_test(
test_string=test_source, function_to_optimize=func_info, mode="performance", test_path=test_file
)
assert success
instrumented_file = test_dir / "MathOpsTest__perfonlyinstrumented.java"
instrumented_file.write_text(instrumented, encoding="utf-8")
# Run benchmarking with inner_iterations=2
test_env = os.environ.copy()
class MockTestFiles:
def __init__(self, files):
self.test_files = files
class MockTestFile:
def __init__(self, path):
self.benchmarking_file_path = path
self.instrumented_behavior_file_path = path
test_files = MockTestFiles([MockTestFile(instrumented_file)])
result_xml_path, result = run_benchmarking_tests(
test_paths=test_files,
test_env=test_env,
cwd=project_root,
timeout=120,
project_root=project_root,
min_loops=1,
max_loops=1,
target_duration_seconds=1.0,
inner_iterations=2,
)
assert result.returncode == 0, f"Maven test failed: {result.stderr}"
# Parse timing markers
stdout = result.stdout
end_pattern = re.compile(r"!######([^:]*):([^:]*):([^:]*):([^:]*):([^:]+):([^:]+)######!")
end_matches = end_pattern.findall(stdout)
# Should have 4 timing markers (2 test methods x 2 inner iterations)
assert len(end_matches) == 4, f"Expected 4 end markers, got {len(end_matches)}: {end_matches}"
# Count markers per iteration ID
iter_0_count = sum(1 for m in end_matches if m[4] == "0")
iter_1_count = sum(1 for m in end_matches if m[4] == "1")
assert iter_0_count == 2, f"Expected 2 markers for iteration 0, got {iter_0_count}"
assert iter_1_count == 2, f"Expected 2 markers for iteration 1, got {iter_1_count}"