fix: update tests for multi-round benchmark plugin

The benchmark plugin now runs multiple rounds with calibrated
iterations. Tests need SELECT DISTINCT for row counts and must
extract median_ns from BenchmarkStats before validation.
This commit is contained in:
Kevin Turcios 2026-04-02 07:24:55 -05:00
parent 7005fa0296
commit 74c29b20b1
3 changed files with 20 additions and 16 deletions

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
def validate_and_format_benchmark_table(
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
function_to_result = {}
# Process each function's benchmark data
@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
def process_benchmark_data(
replay_performance_gain: dict[BenchmarkKey, float],
fto_benchmark_timings: dict[BenchmarkKey, int],
total_benchmark_timings: dict[BenchmarkKey, int],
fto_benchmark_timings: dict[BenchmarkKey, float],
total_benchmark_timings: dict[BenchmarkKey, float],
) -> Optional[ProcessedBenchmarkInfo]:
"""Process benchmark data and generate detailed benchmark information.

View file

@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
cursor = conn.cursor()
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert (
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_unused_socket) == 1
assert len(test_results_unused_socket) >= 1
assert (
test_results_unused_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
test_results_unused_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
)
assert test_results_unused_socket.test_results[0].did_pass == True
assert test_results_unused_socket.test_results[0].did_pass is True
# Replace with optimized candidate
fto_unused_socket_path.write_text("""
@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(optimized_test_results_unused_socket) == 1
assert len(optimized_test_results_unused_socket) >= 1
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
assert match
@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert len(test_results_used_socket) >= 1
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert len(test_results_used_socket) >= 1
assert (
test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"

View file

@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
if conn is not None:
conn.close()
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir)
if replay_tests_dir.exists():
shutil.rmtree(replay_tests_dir)
# Skip the test in CI as the machine may not be multithreaded
@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None:
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
)
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results