mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
fix: update tests for multi-round benchmark plugin
The benchmark plugin now runs multiple rounds with calibrated iterations. Tests need SELECT DISTINCT for row counts and must extract median_ns from BenchmarkStats before validation.
This commit is contained in:
parent
7005fa0296
commit
74c29b20b1
3 changed files with 20 additions and 16 deletions
|
|
@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
def validate_and_format_benchmark_table(
|
def validate_and_format_benchmark_table(
|
||||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
|
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]], total_benchmark_timings: dict[BenchmarkKey, float]
|
||||||
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
|
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
|
||||||
function_to_result = {}
|
function_to_result = {}
|
||||||
# Process each function's benchmark data
|
# Process each function's benchmark data
|
||||||
|
|
@ -77,8 +77,8 @@ def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey
|
||||||
|
|
||||||
def process_benchmark_data(
|
def process_benchmark_data(
|
||||||
replay_performance_gain: dict[BenchmarkKey, float],
|
replay_performance_gain: dict[BenchmarkKey, float],
|
||||||
fto_benchmark_timings: dict[BenchmarkKey, int],
|
fto_benchmark_timings: dict[BenchmarkKey, float],
|
||||||
total_benchmark_timings: dict[BenchmarkKey, int],
|
total_benchmark_timings: dict[BenchmarkKey, float],
|
||||||
) -> Optional[ProcessedBenchmarkInfo]:
|
) -> Optional[ProcessedBenchmarkInfo]:
|
||||||
"""Process benchmark data and generate detailed benchmark information.
|
"""Process benchmark data and generate detailed benchmark information.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -253,14 +253,15 @@ def test_run_and_parse_picklepatch() -> None:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||||
)
|
)
|
||||||
function_calls = cursor.fetchall()
|
function_calls = cursor.fetchall()
|
||||||
|
|
||||||
# Assert the length of function calls
|
# Assert the length of function calls
|
||||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||||
|
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||||
assert (
|
assert (
|
||||||
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
|
"code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"
|
||||||
|
|
@ -401,7 +402,7 @@ def test_run_and_parse_picklepatch() -> None:
|
||||||
pytest_max_loops=1,
|
pytest_max_loops=1,
|
||||||
testing_time=1.0,
|
testing_time=1.0,
|
||||||
)
|
)
|
||||||
assert len(test_results_unused_socket) == 1
|
assert len(test_results_unused_socket) >= 1
|
||||||
assert (
|
assert (
|
||||||
test_results_unused_socket.test_results[0].id.test_module_path
|
test_results_unused_socket.test_results[0].id.test_module_path
|
||||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
|
|
@ -410,7 +411,7 @@ def test_run_and_parse_picklepatch() -> None:
|
||||||
test_results_unused_socket.test_results[0].id.test_function_name
|
test_results_unused_socket.test_results[0].id.test_function_name
|
||||||
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
|
== "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket_test_socket_picklepatch"
|
||||||
)
|
)
|
||||||
assert test_results_unused_socket.test_results[0].did_pass == True
|
assert test_results_unused_socket.test_results[0].did_pass is True
|
||||||
|
|
||||||
# Replace with optimized candidate
|
# Replace with optimized candidate
|
||||||
fto_unused_socket_path.write_text("""
|
fto_unused_socket_path.write_text("""
|
||||||
|
|
@ -432,7 +433,7 @@ def bubble_sort_with_unused_socket(data_container):
|
||||||
pytest_max_loops=1,
|
pytest_max_loops=1,
|
||||||
testing_time=1.0,
|
testing_time=1.0,
|
||||||
)
|
)
|
||||||
assert len(optimized_test_results_unused_socket) == 1
|
assert len(optimized_test_results_unused_socket) >= 1
|
||||||
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
match, _ = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
||||||
assert match
|
assert match
|
||||||
|
|
||||||
|
|
@ -487,7 +488,7 @@ def bubble_sort_with_unused_socket(data_container):
|
||||||
pytest_max_loops=1,
|
pytest_max_loops=1,
|
||||||
testing_time=1.0,
|
testing_time=1.0,
|
||||||
)
|
)
|
||||||
assert len(test_results_used_socket) == 1
|
assert len(test_results_used_socket) >= 1
|
||||||
assert (
|
assert (
|
||||||
test_results_used_socket.test_results[0].id.test_module_path
|
test_results_used_socket.test_results[0].id.test_module_path
|
||||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
|
|
@ -522,7 +523,7 @@ def bubble_sort_with_used_socket(data_container):
|
||||||
pytest_max_loops=1,
|
pytest_max_loops=1,
|
||||||
testing_time=1.0,
|
testing_time=1.0,
|
||||||
)
|
)
|
||||||
assert len(test_results_used_socket) == 1
|
assert len(test_results_used_socket) >= 1
|
||||||
assert (
|
assert (
|
||||||
test_results_used_socket.test_results[0].id.test_module_path
|
test_results_used_socket.test_results[0].id.test_module_path
|
||||||
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
== "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ def test_trace_benchmarks() -> None:
|
||||||
# Get the count of records
|
# Get the count of records
|
||||||
# Get all records
|
# Get all records
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||||
)
|
)
|
||||||
function_calls = cursor.fetchall()
|
function_calls = cursor.fetchall()
|
||||||
|
|
||||||
|
|
@ -220,7 +220,8 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
|
||||||
if conn is not None:
|
if conn is not None:
|
||||||
conn.close()
|
conn.close()
|
||||||
output_file.unlink(missing_ok=True)
|
output_file.unlink(missing_ok=True)
|
||||||
shutil.rmtree(replay_tests_dir)
|
if replay_tests_dir.exists():
|
||||||
|
shutil.rmtree(replay_tests_dir)
|
||||||
|
|
||||||
|
|
||||||
# Skip the test in CI as the machine may not be multithreaded
|
# Skip the test in CI as the machine may not be multithreaded
|
||||||
|
|
@ -242,14 +243,15 @@ def test_trace_multithreaded_benchmark() -> None:
|
||||||
# Get the count of records
|
# Get the count of records
|
||||||
# Get all records
|
# Get all records
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||||
)
|
)
|
||||||
function_calls = cursor.fetchall()
|
function_calls = cursor.fetchall()
|
||||||
|
|
||||||
# Assert the length of function calls
|
# Assert the length of function calls
|
||||||
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
|
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
|
||||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||||
|
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||||
|
|
||||||
|
|
@ -304,14 +306,15 @@ def test_trace_benchmark_decorator() -> None:
|
||||||
# Get the count of records
|
# Get the count of records
|
||||||
# Get all records
|
# Get all records
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
||||||
)
|
)
|
||||||
function_calls = cursor.fetchall()
|
function_calls = cursor.fetchall()
|
||||||
|
|
||||||
# Assert the length of function calls
|
# Assert the length of function calls
|
||||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||||
|
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
||||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue