2025-03-14 22:03:36 +00:00
import sqlite3
2025-03-25 16:51:12 +00:00
from codeflash . benchmarking . benchmark_database_utils import BenchmarkDatabaseUtils
2025-02-28 00:25:07 +00:00
from codeflash . benchmarking . trace_benchmarks import trace_benchmarks_pytest
2025-03-14 22:03:36 +00:00
from codeflash . benchmarking . replay_test import generate_replay_test
2025-02-28 00:25:07 +00:00
from pathlib import Path
2025-03-24 23:45:13 +00:00
from codeflash . benchmarking . utils import print_benchmark_table , validate_and_format_benchmark_table
2025-03-14 01:14:38 +00:00
import shutil
2025-02-28 00:25:07 +00:00
2025-03-14 22:03:36 +00:00
2025-02-28 00:25:07 +00:00
def test_trace_benchmarks ( ) :
# Test the trace_benchmarks function
project_root = Path ( __file__ ) . parent . parent / " code_to_optimize "
2025-03-24 23:45:13 +00:00
benchmarks_root = project_root / " tests " / " pytest " / " benchmarks_test "
2025-03-14 01:14:38 +00:00
tests_root = project_root / " tests " / " test_trace_benchmarks "
tests_root . mkdir ( parents = False , exist_ok = False )
output_file = ( tests_root / Path ( " test_trace_benchmarks.trace " ) ) . resolve ( )
trace_benchmarks_pytest ( benchmarks_root , tests_root , project_root , output_file )
2025-03-12 18:46:29 +00:00
assert output_file . exists ( )
2025-03-14 22:03:36 +00:00
try :
# check contents of trace file
# connect to database
conn = sqlite3 . connect ( output_file . as_posix ( ) )
cursor = conn . cursor ( )
# Get the count of records
# Get all records
cursor . execute (
" SELECT function_name, class_name, module_name, file_name, benchmark_function_name, benchmark_file_name, benchmark_line_number FROM function_calls ORDER BY benchmark_file_name, benchmark_function_name, function_name " )
function_calls = cursor . fetchall ( )
# Assert the length of function calls
assert len ( function_calls ) == 7 , f " Expected 6 function calls, but got { len ( function_calls ) } "
2025-03-20 22:49:26 +00:00
bubble_sort_path = ( project_root / " bubble_sort_codeflash_trace.py " ) . as_posix ( )
process_and_bubble_sort_path = ( project_root / " process_and_bubble_sort_codeflash_trace.py " ) . as_posix ( )
2025-03-14 22:03:36 +00:00
# Expected function calls
expected_calls = [
( " __init__ " , " Sorter " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_class_sort " , " test_benchmark_bubble_sort.py " , 20 ) ,
( " sort_class " , " Sorter " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_class_sort " , " test_benchmark_bubble_sort.py " , 18 ) ,
( " sort_static " , " Sorter " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_class_sort " , " test_benchmark_bubble_sort.py " , 19 ) ,
( " sorter " , " Sorter " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_class_sort " , " test_benchmark_bubble_sort.py " , 17 ) ,
( " sorter " , " " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_sort " , " test_benchmark_bubble_sort.py " , 7 ) ,
( " compute_and_sort " , " " , " code_to_optimize.process_and_bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { process_and_bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_compute_and_sort " , " test_process_and_sort.py " , 4 ) ,
( " sorter " , " " , " code_to_optimize.bubble_sort_codeflash_trace " ,
2025-03-20 22:49:26 +00:00
f " { bubble_sort_path } " ,
2025-03-14 22:03:36 +00:00
" test_no_func " , " test_process_and_sort.py " , 8 ) ,
]
for idx , ( actual , expected ) in enumerate ( zip ( function_calls , expected_calls ) ) :
assert actual [ 0 ] == expected [ 0 ] , f " Mismatch at index { idx } for function_name "
assert actual [ 1 ] == expected [ 1 ] , f " Mismatch at index { idx } for class_name "
assert actual [ 2 ] == expected [ 2 ] , f " Mismatch at index { idx } for module_name "
assert Path ( actual [ 3 ] ) . name == Path ( expected [ 3 ] ) . name , f " Mismatch at index { idx } for file_name "
assert actual [ 4 ] == expected [ 4 ] , f " Mismatch at index { idx } for benchmark_function_name "
assert actual [ 5 ] == expected [ 5 ] , f " Mismatch at index { idx } for benchmark_file_name "
assert actual [ 6 ] == expected [ 6 ] , f " Mismatch at index { idx } for benchmark_line_number "
# Close connection
conn . close ( )
generate_replay_test ( output_file , tests_root )
test_class_sort_path = tests_root / Path ( " test_benchmark_bubble_sort_py_test_class_sort__replay_test_0.py " )
assert test_class_sort_path . exists ( )
test_class_sort_code = f """
import dill as pickle
from code_to_optimize . bubble_sort_codeflash_trace import \\
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
from codeflash . benchmarking . replay_test import get_next_arg_and_return
functions = [ ' sorter ' , ' sort_class ' , ' sort_static ' ]
trace_file_path = r " { output_file.as_posix()} "
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter ( ) :
2025-03-20 22:49:26 +00:00
for args_pkl , kwargs_pkl in get_next_arg_and_return ( trace_file = trace_file_path , function_name = " sorter " , file_name = r " {bubble_sort_path} " , class_name = " Sorter " , num_to_get = 100 ) :
2025-03-14 22:03:36 +00:00
args = pickle . loads ( args_pkl )
kwargs = pickle . loads ( kwargs_pkl )
function_name = " sorter "
if not args :
raise ValueError ( " No arguments provided for the method. " )
if function_name == " __init__ " :
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter ( * args [ 1 : ] , * * kwargs )
else :
instance = args [ 0 ] # self
ret = instance . sorter ( * args [ 1 : ] , * * kwargs )
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class ( ) :
2025-03-20 22:49:26 +00:00
for args_pkl , kwargs_pkl in get_next_arg_and_return ( trace_file = trace_file_path , function_name = " sort_class " , file_name = r " {bubble_sort_path} " , class_name = " Sorter " , num_to_get = 100 ) :
2025-03-14 22:03:36 +00:00
args = pickle . loads ( args_pkl )
kwargs = pickle . loads ( kwargs_pkl )
if not args :
raise ValueError ( " No arguments provided for the method. " )
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter . sort_class ( * args [ 1 : ] , * * kwargs )
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static ( ) :
2025-03-20 22:49:26 +00:00
for args_pkl , kwargs_pkl in get_next_arg_and_return ( trace_file = trace_file_path , function_name = " sort_static " , file_name = r " {bubble_sort_path} " , class_name = " Sorter " , num_to_get = 100 ) :
2025-03-14 22:03:36 +00:00
args = pickle . loads ( args_pkl )
kwargs = pickle . loads ( kwargs_pkl )
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter . sort_static ( * args , * * kwargs )
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__ ( ) :
2025-03-20 22:49:26 +00:00
for args_pkl , kwargs_pkl in get_next_arg_and_return ( trace_file = trace_file_path , function_name = " __init__ " , file_name = r " {bubble_sort_path} " , class_name = " Sorter " , num_to_get = 100 ) :
2025-03-14 22:03:36 +00:00
args = pickle . loads ( args_pkl )
kwargs = pickle . loads ( kwargs_pkl )
function_name = " __init__ "
if not args :
raise ValueError ( " No arguments provided for the method. " )
if function_name == " __init__ " :
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter ( * args [ 1 : ] , * * kwargs )
else :
instance = args [ 0 ] # self
ret = instance ( * args [ 1 : ] , * * kwargs )
"""
assert test_class_sort_path . read_text ( " utf-8 " ) . strip ( ) == test_class_sort_code . strip ( )
test_sort_path = tests_root / Path ( " test_benchmark_bubble_sort_py_test_sort__replay_test_0.py " )
assert test_sort_path . exists ( )
test_sort_code = f """
import dill as pickle
from code_to_optimize . bubble_sort_codeflash_trace import \\
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
from codeflash . benchmarking . replay_test import get_next_arg_and_return
functions = [ ' sorter ' ]
trace_file_path = r " {output_file} "
2025-03-14 01:14:38 +00:00
2025-03-14 22:03:36 +00:00
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter ( ) :
2025-03-20 22:49:26 +00:00
for args_pkl , kwargs_pkl in get_next_arg_and_return ( trace_file = trace_file_path , function_name = " sorter " , file_name = r " {bubble_sort_path} " , num_to_get = 100 ) :
2025-03-14 22:03:36 +00:00
args = pickle . loads ( args_pkl )
kwargs = pickle . loads ( kwargs_pkl )
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter ( * args , * * kwargs )
2025-03-14 01:14:38 +00:00
2025-03-14 22:03:36 +00:00
"""
assert test_sort_path . read_text ( " utf-8 " ) . strip ( ) == test_sort_code . strip ( )
2025-03-24 23:45:13 +00:00
finally :
# cleanup
shutil . rmtree ( tests_root )
pass
def test_trace_multithreaded_benchmark ( ) - > None :
project_root = Path ( __file__ ) . parent . parent / " code_to_optimize "
benchmarks_root = project_root / " tests " / " pytest " / " benchmarks_multithread "
tests_root = project_root / " tests " / " test_trace_benchmarks "
tests_root . mkdir ( parents = False , exist_ok = False )
output_file = ( tests_root / Path ( " test_trace_benchmarks.trace " ) ) . resolve ( )
trace_benchmarks_pytest ( benchmarks_root , tests_root , project_root , output_file )
assert output_file . exists ( )
try :
# check contents of trace file
# connect to database
conn = sqlite3 . connect ( output_file . as_posix ( ) )
cursor = conn . cursor ( )
# Get the count of records
# Get all records
cursor . execute (
" SELECT function_name, class_name, module_name, file_name, benchmark_function_name, benchmark_file_name, benchmark_line_number FROM function_calls ORDER BY benchmark_file_name, benchmark_function_name, function_name " )
function_calls = cursor . fetchall ( )
# Assert the length of function calls
assert len ( function_calls ) == 10 , f " Expected 10 function calls, but got { len ( function_calls ) } "
2025-03-25 16:51:12 +00:00
function_benchmark_timings = BenchmarkDatabaseUtils . get_function_benchmark_timings ( output_file )
total_benchmark_timings = BenchmarkDatabaseUtils . get_benchmark_timings ( output_file )
2025-03-24 23:45:13 +00:00
function_to_results = validate_and_format_benchmark_table ( function_benchmark_timings , total_benchmark_timings )
assert " code_to_optimize.bubble_sort_codeflash_trace.sorter " in function_to_results
test_name , total_time , function_time , percent = function_to_results [ " code_to_optimize.bubble_sort_codeflash_trace.sorter " ] [ 0 ]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_path = ( project_root / " bubble_sort_codeflash_trace.py " ) . as_posix ( )
# Expected function calls
expected_calls = [
( " sorter " , " " , " code_to_optimize.bubble_sort_codeflash_trace " ,
f " { bubble_sort_path } " ,
" test_benchmark_sort " , " test_multithread_sort.py " , 4 ) ,
]
for idx , ( actual , expected ) in enumerate ( zip ( function_calls , expected_calls ) ) :
assert actual [ 0 ] == expected [ 0 ] , f " Mismatch at index { idx } for function_name "
assert actual [ 1 ] == expected [ 1 ] , f " Mismatch at index { idx } for class_name "
assert actual [ 2 ] == expected [ 2 ] , f " Mismatch at index { idx } for module_name "
assert Path ( actual [ 3 ] ) . name == Path ( expected [ 3 ] ) . name , f " Mismatch at index { idx } for file_name "
assert actual [ 4 ] == expected [ 4 ] , f " Mismatch at index { idx } for benchmark_function_name "
assert actual [ 5 ] == expected [ 5 ] , f " Mismatch at index { idx } for benchmark_file_name "
assert actual [ 6 ] == expected [ 6 ] , f " Mismatch at index { idx } for benchmark_line_number "
# Close connection
conn . close ( )
2025-03-14 22:03:36 +00:00
finally :
# cleanup
2025-03-20 22:49:26 +00:00
shutil . rmtree ( tests_root )
pass