mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: update tests for multi-round benchmark plugin
The benchmark plugin now runs multiple rounds with calibrated iterations. Tests need SELECT DISTINCT for row counts and must extract median_ns from BenchmarkStats before validation.
This commit is contained in:
parent
7005fa0296
commit
74c29b20b1
3 changed files with 20 additions and 16 deletions
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
def validate_and_format_benchmark_table(
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
|
||||
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
|
||||
function_to_result = {}
|
||||
# Process each function's benchmark data
|
||||
|
|
@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
|
|||
|
||||
def process_benchmark_data(
|
||||
replay_performance_gain: dict[BenchmarkKey, float],
|
||||
fto_benchmark_timings: dict[BenchmarkKey, int],
|
||||
total_benchmark_timings: dict[BenchmarkKey, int],
|
||||
fto_benchmark_timings: dict[BenchmarkKey, float],
|
||||
total_benchmark_timings: dict[BenchmarkKey, float],
|
||||
) -> Optional[ProcessedBenchmarkInfo]:
|
||||
"""Process benchmark data and generate detailed benchmark information.
|
||||
|
||||
|
|
|
|||
|
|
@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
)
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert (
|
||||
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
|
||||
|
|
@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_unused_socket) == 1
|
||||
assert len(test_results_unused_socket) >= 1
|
||||
assert (
|
||||
test_results_unused_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
|
|
@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
|
|||
test_results_unused_socket.test_results[0].id.test_function_name
|
||||
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
|
||||
)
|
||||
assert test_results_unused_socket.test_results[0].did_pass == True
|
||||
assert test_results_unused_socket.test_results[0].did_pass is True
|
||||
|
||||
# Replace with optimized candidate
|
||||
fto_unused_socket_path.write_text("""
|
||||
|
|
@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(optimized_test_results_unused_socket) == 1
|
||||
assert len(optimized_test_results_unused_socket) >= 1
|
||||
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
||||
assert match
|
||||
|
||||
|
|
@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
|
|||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert len(test_results_used_socket) >= 1
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
|
|
@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
|
|||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert len(test_results_used_socket) >= 1
|
||||
assert (
|
||||
test_results_used_socket.test_results[0].id.test_module_path
|
||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
|
|||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
)
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
|
|
@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
|
|||
if conn is not None:
|
||||
conn.close()
|
||||
output_file.unlink(missing_ok=True)
|
||||
shutil.rmtree(replay_tests_dir)
|
||||
if replay_tests_dir.exists():
|
||||
shutil.rmtree(replay_tests_dir)
|
||||
|
||||
|
||||
# Skip the test in CI as the machine may not be multithreaded
|
||||
|
|
@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
|
|||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
)
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
|
|
@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None:
|
|||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||
)
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue