improvements

This commit is contained in:
Alvin Ryanputra 2025-03-11 17:35:34 -07:00
parent 1f3fd4d2db
commit 5faccd821c
5 changed files with 150 additions and 78 deletions

View file

@ -1,4 +1,10 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace("bubble_sort.trace")
def sorter(arr):
# Utilizing Python's built-in Timsort algorithm for better performance
arr.sort()
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr

View file

@ -0,0 +1,122 @@
import functools
import os
import pickle
import sqlite3
import time
from typing import Callable
class CodeflashTrace:
"""A class that provides both a decorator for tracing function calls
and a context manager for managing the tracing data lifecycle.
"""
def __init__(self) -> None:
self.function_calls_data = []
def __enter__(self) -> None:
# Initialize for context manager use
self.function_calls_data = []
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
# Cleanup is optional here
pass
def __call__(self, func: Callable) -> Callable:
"""Use as a decorator to trace function execution.
Args:
func: The function to be decorated
Returns:
The wrapped function
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Measure execution time
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
# Calculate execution time
execution_time = end_time - start_time
# Measure overhead
overhead_start_time = time.time()
overhead_time = 0
try:
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
# Get benchmark info from environment
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
benchmark_file_name = os.environ.get("CODEFLASH_BENCHMARK_FILE_NAME", "")
# Calculate overhead time
overhead_end_time = time.time()
overhead_time = overhead_end_time - overhead_start_time
self.function_calls_data.append(
(func.__name__, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_file_name, execution_time,
overhead_time, pickled_args, pickled_kwargs)
)
except Exception as e:
print(f"Error in codeflash_trace: {e}")
return result
return wrapper
def write_to_db(self, output_file: str) -> None:
"""Write all collected function call data to the SQLite database.
Args:
output_file: Path to the SQLite database file where results will be stored
"""
if not self.function_calls_data:
print("No function call data to write")
return
try:
# Connect to the database
con = sqlite3.connect(output_file)
cur = con.cursor()
cur.execute("PRAGMA synchronous = OFF")
# Check if table exists and create it if it doesn't
cur.execute(
"CREATE TABLE IF NOT EXISTS function_calls("
"function_name TEXT, class_name TEXT, file_name TEXT, "
"benchmark_function_name TEXT, benchmark_file_name TEXT, "
"time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
)
# Insert all data at once
cur.executemany(
"INSERT INTO function_calls "
"(function_name, class_name, file_name, benchmark_function_name, "
"benchmark_file_name, time_ns, overhead_time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.function_calls_data
)
# Commit and close
con.commit()
con.close()
print(f"Successfully wrote {len(self.function_calls_data)} function call records to {output_file}")
# Clear the data after writing
self.function_calls_data.clear()
except Exception as e:
print(f"Error writing function calls to database: {e}")
# Create a singleton instance
codeflash_trace = CodeflashTrace()

View file

@ -1,73 +0,0 @@
import functools
import pickle
import sqlite3
import time
import os
def codeflash_trace(output_file: str):
"""A decorator factory that returns a decorator that measures the execution time
of a function and pickles its arguments using the highest protocol available.
Args:
output_file: Path to the SQLite database file where results will be stored
Returns:
The decorator function
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Measure execution time
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
# Calculate execution time
execution_time = end_time - start_time
# Measure overhead
overhead_start_time = time.time()
try:
# Connect to the database
con = sqlite3.connect(output_file)
cur = con.cursor()
cur.execute("PRAGMA synchronous = OFF")
# Check if table exists and create it if it doesn't
cur.execute(
"CREATE TABLE IF NOT EXISTS function_calls(function_name TEXT, class_name TEXT, file_name TEXT, benchmark_function_name TEXT, benchmark_file_name TEXT,"
"time_ns INTEGER, args BLOB, kwargs BLOB)"
)
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
# Get benchmark info from environment
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME")
benchmark_file_name = os.environ.get("CODEFLASH_BENCHMARK_FILE_NAME")
# Insert the data
cur.execute(
"INSERT INTO function_calls (function_name, classname, filename, benchmark_function_name, benchmark_file_name, time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?)",
(func.__name__, func.__module__, func.__code__.co_filename,
execution_time, pickled_args, pickled_kwargs)
)
# Commit and close
con.commit()
con.close()
overhead_end_time = time.time()
print(f"Function '{func.__name__}' took {execution_time:.6f} seconds to execute")
print(f"Function '{func.__name__}' overhead took {overhead_end_time - overhead_start_time:.6f} seconds to execute")
except Exception as e:
print(f"Error in codeflash_trace: {e}")
return result
return wrapper
return decorator

View file

@ -1,6 +1,6 @@
import pytest
import time
import os
class CodeFlashPlugin:
@staticmethod
def pytest_addoption(parser):
@ -35,9 +35,11 @@ class CodeFlashPlugin:
class Benchmark:
def __call__(self, func, *args, **kwargs):
start = time.time_ns()
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = request.node.name
os.environ["CODEFLASH_BENCHMARK_FILE_NAME"] = request.node.fspath.basename
start = time.process_time_ns()
result = func(*args, **kwargs)
end = time.time_ns()
end = time.process_time_ns()
print(f"Benchmark: {func.__name__} took {end - start} ns")
return result

View file

@ -0,0 +1,15 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
from pathlib import Path
@codeflash_trace("test_codeflash_trace.trace")
def example_function(arr):
arr.sort()
return arr
def test_codeflash_trace_decorator():
arr = [3, 1, 2]
result = example_function(arr)
# cleanup test trace file using Path
assert result == [1, 2, 3]
Path("test_codeflash_trace.trace").unlink()