mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
- test_trace_multithreaded_benchmark: SELECT DISTINCT collapses all 10 threaded sorter calls to 1 row (identical metadata), change 10 → 1 - test_trace_benchmark_decorator: accept zero timing when func_time > total_time triggers the overflow guard in validate_and_format
360 lines
18 KiB
Python
360 lines
18 KiB
Python
import shutil
|
|
import sqlite3
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
|
from codeflash.benchmarking.replay_test import generate_replay_test
|
|
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
|
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
|
|
|
|
|
def test_trace_benchmarks() -> None:
|
|
# Test the trace_benchmarks function
|
|
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
|
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
|
|
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
|
tests_root = project_root / "tests"
|
|
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
|
conn: sqlite3.Connection | None = None
|
|
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
|
assert output_file.exists()
|
|
try:
|
|
# check contents of trace file
|
|
# connect to database
|
|
conn = sqlite3.connect(output_file.as_posix())
|
|
cursor = conn.cursor()
|
|
|
|
# Get the count of records
|
|
# Get all records
|
|
cursor.execute(
|
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
|
)
|
|
function_calls = cursor.fetchall()
|
|
|
|
# Assert the length of function calls
|
|
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
|
|
|
|
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
|
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
|
|
# Expected function calls
|
|
expected_calls = [
|
|
(
|
|
"sorter",
|
|
"Sorter",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_class_sort",
|
|
"tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example",
|
|
17,
|
|
),
|
|
(
|
|
"sort_class",
|
|
"Sorter",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_class_sort2",
|
|
"tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example",
|
|
20,
|
|
),
|
|
(
|
|
"sort_static",
|
|
"Sorter",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_class_sort3",
|
|
"tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example",
|
|
23,
|
|
),
|
|
(
|
|
"__init__",
|
|
"Sorter",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_class_sort4",
|
|
"tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example",
|
|
26,
|
|
),
|
|
(
|
|
"sorter",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_sort",
|
|
"tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example",
|
|
7,
|
|
),
|
|
(
|
|
"compute_and_sort",
|
|
"",
|
|
"code_to_optimize.process_and_bubble_sort_codeflash_trace",
|
|
f"{process_and_bubble_sort_path}",
|
|
"test_compute_and_sort",
|
|
"tests.pytest.benchmarks_test.test_process_and_sort_example",
|
|
4,
|
|
),
|
|
(
|
|
"sorter",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_no_func",
|
|
"tests.pytest.benchmarks_test.test_process_and_sort_example",
|
|
8,
|
|
),
|
|
(
|
|
"recursive_bubble_sort",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_recursive_sort",
|
|
"tests.pytest.benchmarks_test.test_recursive_example",
|
|
5,
|
|
),
|
|
]
|
|
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
|
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
|
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
|
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
|
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
|
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
|
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
|
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
|
conn.close()
|
|
conn = None
|
|
generate_replay_test(output_file, replay_tests_dir)
|
|
test_class_sort_path = replay_tests_dir / Path(
|
|
"test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py"
|
|
)
|
|
assert test_class_sort_path.exists()
|
|
test_class_sort_code = f"""
|
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
|
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
|
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
|
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
|
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
|
|
|
functions = ['sort_class', 'sort_static', 'sorter']
|
|
trace_file_path = r"{output_file.as_posix()}"
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_sort():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_sort", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter_test_class_sort():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort", function_name="sorter", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
function_name = "sorter"
|
|
if not args:
|
|
raise ValueError("No arguments provided for the method.")
|
|
if function_name == "__init__":
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
|
|
else:
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sorter(*args, **kwargs)
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class_test_class_sort2():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort2", function_name="sort_class", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
if not args:
|
|
raise ValueError("No arguments provided for the method.")
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_class(*args[1:], **kwargs)
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static_test_class_sort3():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort3", function_name="sort_static", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_static(*args, **kwargs)
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init___test_class_sort4():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort4", function_name="__init__", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
function_name = "__init__"
|
|
if not args:
|
|
raise ValueError("No arguments provided for the method.")
|
|
if function_name == "__init__":
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
|
|
else:
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args, **kwargs)
|
|
|
|
"""
|
|
assert test_class_sort_path.read_text("utf-8").strip() == test_class_sort_code.strip()
|
|
|
|
test_sort_path = replay_tests_dir / Path(
|
|
"test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py"
|
|
)
|
|
assert test_sort_path.exists()
|
|
test_sort_code = f"""
|
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
|
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
|
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
|
|
compute_and_sort as \\
|
|
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
|
|
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
|
|
|
functions = ['compute_and_sort', 'sorter']
|
|
trace_file_path = r"{output_file.as_posix()}"
|
|
|
|
def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort_test_compute_and_sort():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
ret = code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(*args, **kwargs)
|
|
|
|
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter_test_no_func():
|
|
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_no_func", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
|
|
args = pickle.loads(args_pkl)
|
|
kwargs = pickle.loads(kwargs_pkl)
|
|
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
|
|
|
|
"""
|
|
assert test_sort_path.read_text("utf-8").strip() == test_sort_code.strip()
|
|
finally:
|
|
if conn is not None:
|
|
conn.close()
|
|
output_file.unlink(missing_ok=True)
|
|
if replay_tests_dir.exists():
|
|
shutil.rmtree(replay_tests_dir)
|
|
|
|
|
|
# Skip the test in CI as the machine may not be multithreaded
|
|
@pytest.mark.ci_skip
|
|
def test_trace_multithreaded_benchmark() -> None:
|
|
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
|
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_multithread"
|
|
tests_root = project_root / "tests"
|
|
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
|
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
|
assert output_file.exists()
|
|
conn: sqlite3.Connection | None = None
|
|
try:
|
|
# check contents of trace file
|
|
# connect to database
|
|
conn = sqlite3.connect(output_file.as_posix())
|
|
cursor = conn.cursor()
|
|
|
|
# Get the count of records
|
|
# Get all records
|
|
cursor.execute(
|
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
|
)
|
|
function_calls = cursor.fetchall()
|
|
|
|
# Assert the length of function calls
|
|
assert len(function_calls) == 1, f"Expected 1 function call, but got {len(function_calls)}"
|
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
|
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
|
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
|
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
|
|
|
test_name, total_time, function_time, percent = function_to_results[
|
|
"code_to_optimize.bubble_sort_codeflash_trace.sorter"
|
|
][0]
|
|
assert total_time >= 0.0
|
|
assert function_time >= 0.0
|
|
assert percent >= 0.0
|
|
|
|
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
|
# Expected function calls
|
|
expected_calls = [
|
|
(
|
|
"sorter",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_benchmark_sort",
|
|
"tests.pytest.benchmarks_multithread.test_multithread_sort",
|
|
4,
|
|
)
|
|
]
|
|
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
|
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
|
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
|
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
|
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
|
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
|
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
|
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
|
finally:
|
|
if conn is not None:
|
|
conn.close()
|
|
output_file.unlink(missing_ok=True)
|
|
|
|
|
|
def test_trace_benchmark_decorator() -> None:
|
|
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
|
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
|
|
tests_root = project_root / "tests"
|
|
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
|
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
|
assert output_file.exists()
|
|
conn: sqlite3.Connection | None = None
|
|
try:
|
|
# check contents of trace file
|
|
# connect to database
|
|
conn = sqlite3.connect(output_file.as_posix())
|
|
cursor = conn.cursor()
|
|
|
|
# Get the count of records
|
|
# Get all records
|
|
cursor.execute(
|
|
"SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name"
|
|
)
|
|
function_calls = cursor.fetchall()
|
|
|
|
# Assert the length of function calls
|
|
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
|
total_benchmark_stats = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
|
total_benchmark_timings = {k: v.median_ns for k, v in total_benchmark_stats.items()}
|
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
|
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
|
|
|
test_name, total_time, function_time, percent = function_to_results[
|
|
"code_to_optimize.bubble_sort_codeflash_trace.sorter"
|
|
][0]
|
|
assert total_time >= 0.0
|
|
assert function_time >= 0.0
|
|
assert percent >= 0.0
|
|
|
|
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
|
# Expected function calls
|
|
expected_calls = [
|
|
(
|
|
"sorter",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_benchmark_sort",
|
|
"tests.pytest.benchmarks_test_decorator.test_benchmark_decorator",
|
|
5,
|
|
),
|
|
(
|
|
"sorter",
|
|
"",
|
|
"code_to_optimize.bubble_sort_codeflash_trace",
|
|
f"{bubble_sort_path}",
|
|
"test_pytest_mark",
|
|
"tests.pytest.benchmarks_test_decorator.test_benchmark_decorator",
|
|
11,
|
|
),
|
|
]
|
|
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
|
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
|
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
|
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
|
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
|
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
|
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
|
finally:
|
|
if conn is not None:
|
|
conn.close()
|
|
output_file.unlink(missing_ok=True)
|