codeflash/codeflash/benchmarking/utils.py
Saurabh Misra 291ac2c294 fix test
ruff reformat and fix linting
2025-06-05 20:01:09 -07:00

132 lines
5.6 KiB
Python

from __future__ import annotations
import shutil
from typing import TYPE_CHECKING, Optional
from rich.console import Console
from rich.table import Table
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo
from codeflash.result.critic import performance_gain
if TYPE_CHECKING:
from codeflash.models.models import BenchmarkKey
def validate_and_format_benchmark_table(
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]], total_benchmark_timings: dict[BenchmarkKey, int]
) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
function_to_result = {}
# Process each function's benchmark data
for func_path, test_times in function_benchmark_timings.items():
# Sort by percentage (highest first)
sorted_tests = []
for benchmark_key, func_time in test_times.items():
total_time = total_benchmark_timings.get(benchmark_key, 0)
if func_time > total_time:
logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}")
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
# Do not try to project the optimization impact for this function.
sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0))
elif total_time > 0:
percentage = (func_time / total_time) * 100
# Convert nanoseconds to milliseconds
func_time_ms = func_time / 1_000_000
total_time_ms = total_time / 1_000_000
sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage))
sorted_tests.sort(key=lambda x: x[3], reverse=True)
function_to_result[func_path] = sorted_tests
return function_to_result
def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None:
try:
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
except Exception:
terminal_width = 120 # Fallback width
console = Console(width=terminal_width)
for func_path, sorted_tests in function_to_results.items():
console.print()
function_name = func_path.split(":")[-1]
# Create a table for this function
table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True)
benchmark_col_width = max(int(terminal_width * 0.4), 40)
# Add columns - split the benchmark test into two columns
table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold")
table.add_column("Test Function", style="magenta", overflow="fold")
table.add_column("Total Time (ms)", justify="right", style="green")
table.add_column("Function Time (ms)", justify="right", style="yellow")
table.add_column("Percentage (%)", justify="right", style="red")
for benchmark_key, total_time, func_time, percentage in sorted_tests:
# Split the benchmark test into module path and function name
module_path = benchmark_key.module_path
test_function = benchmark_key.function_name
if total_time == 0.0:
table.add_row(module_path, test_function, "N/A", "N/A", "N/A")
else:
table.add_row(module_path, test_function, f"{total_time:.3f}", f"{func_time:.3f}", f"{percentage:.2f}")
# Print the table
console.print(table)
def process_benchmark_data(
replay_performance_gain: dict[BenchmarkKey, float],
fto_benchmark_timings: dict[BenchmarkKey, int],
total_benchmark_timings: dict[BenchmarkKey, int],
) -> Optional[ProcessedBenchmarkInfo]:
"""Process benchmark data and generate detailed benchmark information.
Args:
----
replay_performance_gain: The performance gain from replay
fto_benchmark_timings: Function to optimize benchmark timings
total_benchmark_timings: Total benchmark timings
Returns:
-------
ProcessedBenchmarkInfo containing processed benchmark details
"""
if not replay_performance_gain or not fto_benchmark_timings or not total_benchmark_timings:
return None
benchmark_details = []
for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items():
total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0)
if total_benchmark_timing == 0:
continue # Skip benchmarks with zero timing
# Calculate expected new benchmark timing
expected_new_benchmark_timing = (
total_benchmark_timing
- og_benchmark_timing
+ (1 / (replay_performance_gain[benchmark_key] + 1)) * og_benchmark_timing
)
# Calculate speedup
benchmark_speedup_percent = (
performance_gain(
original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)
)
* 100
)
benchmark_details.append(
BenchmarkDetail(
benchmark_name=benchmark_key.module_path,
test_function=benchmark_key.function_name,
original_timing=humanize_runtime(int(total_benchmark_timing)),
expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)),
speedup_percent=benchmark_speedup_percent,
)
)
return ProcessedBenchmarkInfo(benchmark_details=benchmark_details)