Support recursive functions, and @benchmark / @pytest.mark.benchmark ways of using benchmark. created tests for all of them

This commit is contained in:
Alvin Ryanputra 2025-04-07 14:52:18 -07:00
parent c997b90394
commit d6ed1c33c4
8 changed files with 192 additions and 55 deletions

View file

@ -9,6 +9,24 @@ def sorter(arr):
arr[j + 1] = temp
return arr
@codeflash_trace
def recursive_bubble_sort(arr, n=None):
# Initialize n if not provided
if n is None:
n = len(arr)
# Base case: if n is 1, the array is already sorted
if n == 1:
return arr
# One pass of bubble sort - move the largest element to the end
for i in range(n - 1):
if arr[i] > arr[i + 1]:
arr[i], arr[i + 1] = arr[i + 1], arr[i]
# Recursively sort the remaining n-1 elements
return recursive_bubble_sort(arr, n - 1)
class Sorter:
@codeflash_trace
def __init__(self, arr):

View file

@ -0,0 +1,6 @@
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort
def test_recursive_sort(benchmark):
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
assert result == list(range(500))

View file

@ -0,0 +1,11 @@
import pytest
from code_to_optimize.bubble_sort_codeflash_trace import sorter
def test_benchmark_sort(benchmark):
@benchmark
def do_sort():
sorter(list(reversed(range(500))))
@pytest.mark.benchmark(group="benchmark_decorator")
def test_pytest_mark(benchmark):
benchmark(sorter, list(reversed(range(500))))

View file

@ -3,6 +3,7 @@ import os
import pickle
import sqlite3
import sys
import threading
import time
from typing import Callable
@ -18,6 +19,8 @@ class CodeflashTrace:
self.pickle_count_limit = 1000
self._connection = None
self._trace_path = None
self._thread_local = threading.local()
self._thread_local.active_functions = set()
def setup(self, trace_path: str) -> None:
"""Set up the database connection for direct writing.
@ -98,23 +101,29 @@ class CodeflashTrace:
The wrapped function
"""
func_id = (func.__module__,func.__name__)
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()
# If it's in a recursive function, just return the result
if func_id in self._thread_local.active_functions:
return func(*args, **kwargs)
# Track active functions so we can detect recursive functions
self._thread_local.active_functions.add(func_id)
# Measure execution time
start_time = time.thread_time_ns()
result = func(*args, **kwargs)
end_time = time.thread_time_ns()
# Calculate execution time
execution_time = end_time - start_time
self.function_call_count += 1
# Measure overhead
original_recursion_limit = sys.getrecursionlimit()
# Check if currently in pytest benchmark fixture
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
self._thread_local.active_functions.remove(func_id)
return result
# Get benchmark info from environment
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
@ -125,32 +134,54 @@ class CodeflashTrace:
if "." in qualname:
class_name = qualname.split(".")[0]
if self.function_call_count <= self.pickle_count_limit:
# Limit pickle count so memory does not explode
if self.function_call_count > self.pickle_count_limit:
print("Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result
try:
original_recursion_limit = sys.getrecursionlimit()
sys.setrecursionlimit(10000)
# args = dict(args.items())
# if class_name and func.__name__ == "__init__" and "self" in args:
# del args["self"]
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# Retry with dill if pickle fails. It's slower but more comprehensive
try:
sys.setrecursionlimit(1000000)
args = dict(args.items())
if class_name and func.__name__ == "__init__" and "self" in args:
del args["self"]
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# we retry with dill if pickle fails. It's slower but more comprehensive
try:
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
return result
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result
# Flush to database every 1000 calls
if len(self.function_calls_data) > 1000:
self.write_function_timings()
# Calculate overhead time
overhead_time = time.thread_time_ns() - end_time
# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,

View file

@ -175,6 +175,7 @@ class CodeFlashBenchmarkPlugin:
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Subtract overhead from total time
overhead = overhead_by_benchmark.get(benchmark_key, 0)
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
result[benchmark_key] = time_ns - overhead
finally:
@ -210,6 +211,13 @@ class CodeFlashBenchmarkPlugin:
manager.unregister(plugin)
@staticmethod
def pytest_configure(config):
"""Register the benchmark marker."""
config.addinivalue_line(
"markers",
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
)
@staticmethod
def pytest_collection_modifyitems(config, items):
# Skip tests that don't have the benchmark fixture
if not config.getoption("--codeflash-trace"):
@ -217,9 +225,19 @@ class CodeFlashBenchmarkPlugin:
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
continue
item.add_marker(skip_no_benchmark)
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
# Check for @pytest.mark.benchmark marker
has_marker = False
if hasattr(item, "get_closest_marker"):
marker = item.get_closest_marker("benchmark")
if marker is not None:
has_marker = True
# Skip if neither fixture nor marker is present
if not (has_fixture or has_marker):
item.add_marker(skip_no_benchmark)
# Benchmark fixture
class Benchmark:
@ -227,44 +245,37 @@ class CodeFlashBenchmarkPlugin:
self.request = request
def __call__(self, func, *args, **kwargs):
"""Handle behaviour for the benchmark fixture in pytest.
"""Handle both direct function calls and decorator usage."""
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)
# Used as @benchmark decorator
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
result = self._run_benchmark(func)
return wrapped_func
For example,
def test_something(benchmark):
benchmark(sorter, [3,2,1])
Args:
func: The function to benchmark (e.g. sorter)
args: The arguments to pass to the function (e.g. [3,2,1])
kwargs: The keyword arguments to pass to the function
Returns:
The return value of the function
a
"""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
def _run_benchmark(self, func, *args, **kwargs):
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
Path(codeflash_benchmark_plugin.project_root))
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
# Set env vars so codeflash decorator can identify what benchmark its being run in
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
# Set env vars
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"
# Run the function
start = time.perf_counter_ns()
# Run the function
start = time.thread_time_ns()
result = func(*args, **kwargs)
end = time.perf_counter_ns()
end = time.thread_time_ns()
# Reset the environment variable
os.environ["CODEFLASH_BENCHMARKING"] = "False"
# Write function calls
codeflash_trace.write_function_timings()
# Reset function call count after a benchmark is run
# Reset function call count
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(

View file

@ -16,7 +16,7 @@ if __name__ == "__main__":
codeflash_benchmark_plugin.setup(trace_file, project_root)
codeflash_trace.setup(trace_file)
exitcode = pytest.main(
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
) # Errors will be printed to stdout, not stderr
except Exception as e:

View file

@ -34,7 +34,7 @@ def get_next_arg_and_return(
)
while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # args and kwargs are at indices 7 and 8
yield val[9], val[10] # pickled_args, pickled_kwargs
def get_function_alias(module: str, function_name: str) -> str:

View file

@ -31,7 +31,7 @@ def test_trace_benchmarks():
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 7, f"Expected 6 function calls, but got {len(function_calls)}"
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
@ -64,6 +64,10 @@ def test_trace_benchmarks():
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
@ -222,6 +226,62 @@ def test_trace_multithreaded_benchmark() -> None:
# Close connection
conn.close()
finally:
# cleanup
output_file.unlink(missing_ok=True)
def test_trace_benchmark_decorator() -> None:
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# check contents of trace file
# connect to database
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
expected_calls = [
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
finally:
# cleanup
output_file.unlink(missing_ok=True)