diff --git a/codeflash/benchmarking/utils.py b/codeflash/benchmarking/utils.py index db89c4c33..6a84dc5f3 100644 --- a/codeflash/benchmarking/utils.py +++ b/codeflash/benchmarking/utils.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: def validate_and_format_benchmark_table( - function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int] + function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float] ) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]: function_to_result = {} # Process each function's benchmark data @@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey def process_benchmark_data( replay_performance_gain: dict[BenchmarkKey, float], - fto_benchmark_timings: dict[BenchmarkKey, int], - total_benchmark_timings: dict[BenchmarkKey, int], + fto_benchmark_timings: dict[BenchmarkKey, float], + total_benchmark_timings: dict[BenchmarkKey, float], ) -> Optional[ProcessedBenchmarkInfo]: """Process benchmark data and generate detailed benchmark information. diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index 804ff137b..ccf89312a 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None: cursor = conn.cursor() cursor.execute( - "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" + "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" ) function_calls = cursor.fetchall() # Assert the length of function calls assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) - total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()} function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) assert ( "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" @@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None: pytest_max_loops=1, testing_time=1.0, ) - assert len(test_results_unused_socket) == 1 + assert len(test_results_unused_socket) >= 1 assert ( test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" @@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None: test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch" ) - assert test_results_unused_socket.test_results[0].did_pass == True + assert test_results_unused_socket.test_results[0].did_pass is True # Replace with optimized candidate fto_unused_socket_path.write_text(""" @@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container): pytest_max_loops=1, testing_time=1.0, ) - assert len(optimized_test_results_unused_socket) == 1 + assert len(optimized_test_results_unused_socket) >= 1 match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) assert match @@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container): pytest_max_loops=1, testing_time=1.0, ) - assert len(test_results_used_socket) == 1 + assert len(test_results_used_socket) >= 1 assert ( test_results_used_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" @@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container): pytest_max_loops=1, testing_time=1.0, ) - assert len(test_results_used_socket) == 1 + assert len(test_results_used_socket) >= 1 assert ( test_results_used_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py index 4e0f7be47..c8e18de9c 100644 --- a/tests/test_trace_benchmarks.py +++ b/tests/test_trace_benchmarks.py @@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None: # Get the count of records # Get all records cursor.execute( - "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" + "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" ) function_calls = cursor.fetchall() @@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func(): if conn is not None: conn.close() output_file.unlink(missing_ok=True) - shutil.rmtree(replay_tests_dir) + if replay_tests_dir.exists(): + shutil.rmtree(replay_tests_dir) # Skip the test in CI as the machine may not be multithreaded @@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None: # Get the count of records # Get all records cursor.execute( - "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" + "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" ) function_calls = cursor.fetchall() # Assert the length of function calls assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}" function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) - total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()} function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results @@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None: # Get the count of records # Get all records cursor.execute( - "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" + "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name" ) function_calls = cursor.fetchall() # Assert the length of function calls assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) - total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()} function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results