improvements
This commit is contained in:
parent
1f3fd4d2db
commit
5faccd821c
5 changed files with 150 additions and 78 deletions
|
|
@ -1,4 +1,10 @@
|
|||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
@codeflash_trace("bubble_sort.trace")
|
||||
def sorter(arr):
|
||||
# Utilizing Python's built-in Timsort algorithm for better performance
|
||||
arr.sort()
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
|
|
|||
122
codeflash/benchmarking/codeflash_trace.py
Normal file
122
codeflash/benchmarking/codeflash_trace.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class CodeflashTrace:
|
||||
"""A class that provides both a decorator for tracing function calls
|
||||
and a context manager for managing the tracing data lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.function_calls_data = []
|
||||
|
||||
def __enter__(self) -> None:
|
||||
# Initialize for context manager use
|
||||
self.function_calls_data = []
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
# Cleanup is optional here
|
||||
pass
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""Use as a decorator to trace function execution.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
The wrapped function
|
||||
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Measure execution time
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Measure overhead
|
||||
overhead_start_time = time.time()
|
||||
overhead_time = 0
|
||||
|
||||
try:
|
||||
# Pickle the arguments
|
||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Get benchmark info from environment
|
||||
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
|
||||
benchmark_file_name = os.environ.get("CODEFLASH_BENCHMARK_FILE_NAME", "")
|
||||
|
||||
# Calculate overhead time
|
||||
overhead_end_time = time.time()
|
||||
overhead_time = overhead_end_time - overhead_start_time
|
||||
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_file_name, execution_time,
|
||||
overhead_time, pickled_args, pickled_kwargs)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in codeflash_trace: {e}")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
def write_to_db(self, output_file: str) -> None:
|
||||
"""Write all collected function call data to the SQLite database.
|
||||
|
||||
Args:
|
||||
output_file: Path to the SQLite database file where results will be stored
|
||||
|
||||
"""
|
||||
if not self.function_calls_data:
|
||||
print("No function call data to write")
|
||||
return
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
con = sqlite3.connect(output_file)
|
||||
cur = con.cursor()
|
||||
cur.execute("PRAGMA synchronous = OFF")
|
||||
|
||||
# Check if table exists and create it if it doesn't
|
||||
cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS function_calls("
|
||||
"function_name TEXT, class_name TEXT, file_name TEXT, "
|
||||
"benchmark_function_name TEXT, benchmark_file_name TEXT, "
|
||||
"time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
|
||||
)
|
||||
|
||||
# Insert all data at once
|
||||
cur.executemany(
|
||||
"INSERT INTO function_calls "
|
||||
"(function_name, class_name, file_name, benchmark_function_name, "
|
||||
"benchmark_file_name, time_ns, overhead_time_ns, args, kwargs) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
self.function_calls_data
|
||||
)
|
||||
|
||||
# Commit and close
|
||||
con.commit()
|
||||
con.close()
|
||||
|
||||
print(f"Successfully wrote {len(self.function_calls_data)} function call records to {output_file}")
|
||||
|
||||
# Clear the data after writing
|
||||
self.function_calls_data.clear()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error writing function calls to database: {e}")
|
||||
|
||||
# Create a singleton instance
|
||||
codeflash_trace = CodeflashTrace()
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
import functools
|
||||
import pickle
|
||||
import sqlite3
|
||||
import time
|
||||
import os
|
||||
|
||||
def codeflash_trace(output_file: str):
|
||||
"""A decorator factory that returns a decorator that measures the execution time
|
||||
of a function and pickles its arguments using the highest protocol available.
|
||||
|
||||
Args:
|
||||
output_file: Path to the SQLite database file where results will be stored
|
||||
|
||||
Returns:
|
||||
The decorator function
|
||||
|
||||
"""
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Measure execution time
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = end_time - start_time
|
||||
|
||||
# Measure overhead
|
||||
overhead_start_time = time.time()
|
||||
|
||||
try:
|
||||
# Connect to the database
|
||||
con = sqlite3.connect(output_file)
|
||||
cur = con.cursor()
|
||||
cur.execute("PRAGMA synchronous = OFF")
|
||||
|
||||
# Check if table exists and create it if it doesn't
|
||||
cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS function_calls(function_name TEXT, class_name TEXT, file_name TEXT, benchmark_function_name TEXT, benchmark_file_name TEXT,"
|
||||
"time_ns INTEGER, args BLOB, kwargs BLOB)"
|
||||
)
|
||||
|
||||
# Pickle the arguments
|
||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# Get benchmark info from environment
|
||||
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME")
|
||||
benchmark_file_name = os.environ.get("CODEFLASH_BENCHMARK_FILE_NAME")
|
||||
# Insert the data
|
||||
cur.execute(
|
||||
"INSERT INTO function_calls (function_name, classname, filename, benchmark_function_name, benchmark_file_name, time_ns, args, kwargs) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(func.__name__, func.__module__, func.__code__.co_filename,
|
||||
execution_time, pickled_args, pickled_kwargs)
|
||||
)
|
||||
|
||||
# Commit and close
|
||||
con.commit()
|
||||
con.close()
|
||||
|
||||
overhead_end_time = time.time()
|
||||
|
||||
print(f"Function '{func.__name__}' took {execution_time:.6f} seconds to execute")
|
||||
print(f"Function '{func.__name__}' overhead took {overhead_end_time - overhead_start_time:.6f} seconds to execute")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in codeflash_trace: {e}")
|
||||
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import pytest
|
||||
import time
|
||||
|
||||
import os
|
||||
class CodeFlashPlugin:
|
||||
@staticmethod
|
||||
def pytest_addoption(parser):
|
||||
|
|
@ -35,9 +35,11 @@ class CodeFlashPlugin:
|
|||
|
||||
class Benchmark:
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
start = time.time_ns()
|
||||
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = request.node.name
|
||||
os.environ["CODEFLASH_BENCHMARK_FILE_NAME"] = request.node.fspath.basename
|
||||
start = time.process_time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time_ns()
|
||||
end = time.process_time_ns()
|
||||
print(f"Benchmark: {func.__name__} took {end - start} ns")
|
||||
return result
|
||||
|
||||
|
|
|
|||
15
tests/test_codeflash_trace_decorator.py
Normal file
15
tests/test_codeflash_trace_decorator.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
from pathlib import Path
|
||||
|
||||
@codeflash_trace("test_codeflash_trace.trace")
|
||||
def example_function(arr):
|
||||
arr.sort()
|
||||
return arr
|
||||
|
||||
|
||||
def test_codeflash_trace_decorator():
|
||||
arr = [3, 1, 2]
|
||||
result = example_function(arr)
|
||||
# cleanup test trace file using Path
|
||||
assert result == [1, 2, 3]
|
||||
Path("test_codeflash_trace.trace").unlink()
|
||||
Loading…
Reference in a new issue