fix: update tests for multi-round benchmark plugin

The benchmark plugin now runs multiple rounds with calibrated
iterations. Tests need SELECT DISTINCT for row counts and must
extract median_ns from BenchmarkStats before validation.
This commit is contained in:
Kevin Turcios 2026-04-02 07:24:55 -05:00
parent 7005fa0296
commit 74c29b20b1
3 changed files with 20 additions and 16 deletions

View file

@ -16,7 +16,7 @@ if TYPE_CHECKING:
def validate_and_format_benchmark_table( def validate_and_format_benchmark_table(
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int] function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: ) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
function_to_result = {} function_to_result = {}
# Process each function's benchmark data # Process each function's benchmark data
@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
def process_benchmark_data( def process_benchmark_data(
replay_performance_gain: dict[BenchmarkKey, float], replay_performance_gain: dict[BenchmarkKey, float],
fto_benchmark_timings: dict[BenchmarkKey, int], fto_benchmark_timings: dict[BenchmarkKey, float],
total_benchmark_timings: dict[BenchmarkKey, int], total_benchmark_timings: dict[BenchmarkKey, float],
) -> Optional[ProcessedBenchmarkInfo]: ) -> Optional[ProcessedBenchmarkInfo]:
"""Process benchmark data and generate detailed benchmark information. """Process benchmark data and generate detailed benchmark information.

View file

@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
) )
function_calls = cursor.fetchall() function_calls = cursor.fetchall()
# Assert the length of function calls # Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert ( assert (
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
pytest_max_loops=1, pytest_max_loops=1,
testing_time=1.0, testing_time=1.0,
) )
assert len(test_results_unused_socket) == 1 assert len(test_results_unused_socket) >= 1
assert ( assert (
test_results_unused_socket.test_results[0].id.test_module_path test_results_unused_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
test_results_unused_socket.test_results[0].id.test_function_name test_results_unused_socket.test_results[0].id.test_function_name
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch" == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
) )
assert test_results_unused_socket.test_results[0].did_pass == True assert test_results_unused_socket.test_results[0].did_pass is True
# Replace with optimized candidate # Replace with optimized candidate
fto_unused_socket_path.write_text(""" fto_unused_socket_path.write_text("""
@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1, pytest_max_loops=1,
testing_time=1.0, testing_time=1.0,
) )
assert len(optimized_test_results_unused_socket) == 1 assert len(optimized_test_results_unused_socket) >= 1
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
assert match assert match
@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
pytest_max_loops=1, pytest_max_loops=1,
testing_time=1.0, testing_time=1.0,
) )
assert len(test_results_used_socket) == 1 assert len(test_results_used_socket) >= 1
assert ( assert (
test_results_used_socket.test_results[0].id.test_module_path test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
pytest_max_loops=1, pytest_max_loops=1,
testing_time=1.0, testing_time=1.0,
) )
assert len(test_results_used_socket) == 1 assert len(test_results_used_socket) >= 1
assert ( assert (
test_results_used_socket.test_results[0].id.test_module_path test_results_used_socket.test_results[0].id.test_module_path
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"

View file

@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
# Get the count of records # Get the count of records
# Get all records # Get all records
cursor.execute( cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
) )
function_calls = cursor.fetchall() function_calls = cursor.fetchall()
@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
if conn is not None: if conn is not None:
conn.close() conn.close()
output_file.unlink(missing_ok=True) output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir) if replay_tests_dir.exists():
shutil.rmtree(replay_tests_dir)
# Skip the test in CI as the machine may not be multithreaded # Skip the test in CI as the machine may not be multithreaded
@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
# Get the count of records # Get the count of records
# Get all records # Get all records
cursor.execute( cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
) )
function_calls = cursor.fetchall() function_calls = cursor.fetchall()
# Assert the length of function calls # Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}" assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None:
# Get the count of records # Get the count of records
# Get all records # Get all records
cursor.execute( cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
) )
function_calls = cursor.fetchall() function_calls = cursor.fetchall()
# Assert the length of function calls # Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results