mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Support recursive functions, and @benchmark / @pytest.mark.benchmark ways of using benchmark. created tests for all of them
This commit is contained in:
parent
c997b90394
commit
d6ed1c33c4
8 changed files with 192 additions and 55 deletions
|
|
@ -9,6 +9,24 @@ def sorter(arr):
|
|||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
@codeflash_trace
|
||||
def recursive_bubble_sort(arr, n=None):
|
||||
# Initialize n if not provided
|
||||
if n is None:
|
||||
n = len(arr)
|
||||
|
||||
# Base case: if n is 1, the array is already sorted
|
||||
if n == 1:
|
||||
return arr
|
||||
|
||||
# One pass of bubble sort - move the largest element to the end
|
||||
for i in range(n - 1):
|
||||
if arr[i] > arr[i + 1]:
|
||||
arr[i], arr[i + 1] = arr[i + 1], arr[i]
|
||||
|
||||
# Recursively sort the remaining n-1 elements
|
||||
return recursive_bubble_sort(arr, n - 1)
|
||||
|
||||
class Sorter:
|
||||
@codeflash_trace
|
||||
def __init__(self, arr):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort
|
||||
|
||||
|
||||
def test_recursive_sort(benchmark):
|
||||
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import sorter
|
||||
|
||||
def test_benchmark_sort(benchmark):
|
||||
@benchmark
|
||||
def do_sort():
|
||||
sorter(list(reversed(range(500))))
|
||||
|
||||
@pytest.mark.benchmark(group="benchmark_decorator")
|
||||
def test_pytest_mark(benchmark):
|
||||
benchmark(sorter, list(reversed(range(500))))
|
||||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import pickle
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
|
|
@ -18,6 +19,8 @@ class CodeflashTrace:
|
|||
self.pickle_count_limit = 1000
|
||||
self._connection = None
|
||||
self._trace_path = None
|
||||
self._thread_local = threading.local()
|
||||
self._thread_local.active_functions = set()
|
||||
|
||||
def setup(self, trace_path: str) -> None:
|
||||
"""Set up the database connection for direct writing.
|
||||
|
|
@ -98,23 +101,29 @@ class CodeflashTrace:
|
|||
The wrapped function
|
||||
|
||||
"""
|
||||
func_id = (func.__module__,func.__name__)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Initialize thread-local active functions set if it doesn't exist
|
||||
if not hasattr(self._thread_local, "active_functions"):
|
||||
self._thread_local.active_functions = set()
|
||||
# If it's in a recursive function, just return the result
|
||||
if func_id in self._thread_local.active_functions:
|
||||
return func(*args, **kwargs)
|
||||
# Track active functions so we can detect recursive functions
|
||||
self._thread_local.active_functions.add(func_id)
|
||||
# Measure execution time
|
||||
start_time = time.thread_time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.thread_time_ns()
|
||||
# Calculate execution time
|
||||
execution_time = end_time - start_time
|
||||
|
||||
self.function_call_count += 1
|
||||
|
||||
# Measure overhead
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
# Check if currently in pytest benchmark fixture
|
||||
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
return result
|
||||
|
||||
# Get benchmark info from environment
|
||||
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
|
||||
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
|
||||
|
|
@ -125,32 +134,54 @@ class CodeflashTrace:
|
|||
if "." in qualname:
|
||||
class_name = qualname.split(".")[0]
|
||||
|
||||
if self.function_call_count <= self.pickle_count_limit:
|
||||
# Limit pickle count so memory does not explode
|
||||
if self.function_call_count > self.pickle_count_limit:
|
||||
print("Pickle limit reached")
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
overhead_time, None, None)
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
sys.setrecursionlimit(10000)
|
||||
# args = dict(args.items())
|
||||
# if class_name and func.__name__ == "__init__" and "self" in args:
|
||||
# del args["self"]
|
||||
# Pickle the arguments
|
||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
|
||||
# Retry with dill if pickle fails. It's slower but more comprehensive
|
||||
try:
|
||||
sys.setrecursionlimit(1000000)
|
||||
args = dict(args.items())
|
||||
if class_name and func.__name__ == "__init__" and "self" in args:
|
||||
del args["self"]
|
||||
# Pickle the arguments
|
||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
|
||||
# we retry with dill if pickle fails. It's slower but more comprehensive
|
||||
try:
|
||||
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
|
||||
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
|
||||
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
||||
return result
|
||||
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
|
||||
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
||||
# Add to the list of function calls without pickled args. Used for timing info only
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
overhead_time, None, None)
|
||||
)
|
||||
return result
|
||||
|
||||
# Flush to database every 1000 calls
|
||||
if len(self.function_calls_data) > 1000:
|
||||
self.write_function_timings()
|
||||
# Calculate overhead time
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
|
||||
# Add to the list of function calls with pickled args, to be used for replay tests
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ class CodeFlashBenchmarkPlugin:
|
|||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
# Subtract overhead from total time
|
||||
overhead = overhead_by_benchmark.get(benchmark_key, 0)
|
||||
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
|
||||
result[benchmark_key] = time_ns - overhead
|
||||
|
||||
finally:
|
||||
|
|
@ -210,6 +211,13 @@ class CodeFlashBenchmarkPlugin:
|
|||
manager.unregister(plugin)
|
||||
|
||||
@staticmethod
|
||||
def pytest_configure(config):
|
||||
"""Register the benchmark marker."""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
|
||||
)
|
||||
@staticmethod
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Skip tests that don't have the benchmark fixture
|
||||
if not config.getoption("--codeflash-trace"):
|
||||
|
|
@ -217,9 +225,19 @@ class CodeFlashBenchmarkPlugin:
|
|||
|
||||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames:
|
||||
continue
|
||||
item.add_marker(skip_no_benchmark)
|
||||
# Check for direct benchmark fixture usage
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
|
||||
|
||||
# Check for @pytest.mark.benchmark marker
|
||||
has_marker = False
|
||||
if hasattr(item, "get_closest_marker"):
|
||||
marker = item.get_closest_marker("benchmark")
|
||||
if marker is not None:
|
||||
has_marker = True
|
||||
|
||||
# Skip if neither fixture nor marker is present
|
||||
if not (has_fixture or has_marker):
|
||||
item.add_marker(skip_no_benchmark)
|
||||
|
||||
# Benchmark fixture
|
||||
class Benchmark:
|
||||
|
|
@ -227,44 +245,37 @@ class CodeFlashBenchmarkPlugin:
|
|||
self.request = request
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
"""Handle behaviour for the benchmark fixture in pytest.
|
||||
"""Handle both direct function calls and decorator usage."""
|
||||
if args or kwargs:
|
||||
# Used as benchmark(func, *args, **kwargs)
|
||||
return self._run_benchmark(func, *args, **kwargs)
|
||||
# Used as @benchmark decorator
|
||||
def wrapped_func(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
result = self._run_benchmark(func)
|
||||
return wrapped_func
|
||||
|
||||
For example,
|
||||
|
||||
def test_something(benchmark):
|
||||
benchmark(sorter, [3,2,1])
|
||||
|
||||
Args:
|
||||
func: The function to benchmark (e.g. sorter)
|
||||
args: The arguments to pass to the function (e.g. [3,2,1])
|
||||
kwargs: The keyword arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
The return value of the function
|
||||
a
|
||||
|
||||
"""
|
||||
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root))
|
||||
def _run_benchmark(self, func, *args, **kwargs):
|
||||
"""Actual benchmark implementation."""
|
||||
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
|
||||
Path(codeflash_benchmark_plugin.project_root))
|
||||
benchmark_function_name = self.request.node.name
|
||||
line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack
|
||||
|
||||
# Set env vars so codeflash decorator can identify what benchmark its being run in
|
||||
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
|
||||
# Set env vars
|
||||
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
|
||||
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
|
||||
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "True"
|
||||
|
||||
# Run the function
|
||||
start = time.perf_counter_ns()
|
||||
# Run the function
|
||||
start = time.thread_time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.perf_counter_ns()
|
||||
|
||||
end = time.thread_time_ns()
|
||||
# Reset the environment variable
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "False"
|
||||
|
||||
# Write function calls
|
||||
codeflash_trace.write_function_timings()
|
||||
# Reset function call count after a benchmark is run
|
||||
# Reset function call count
|
||||
codeflash_trace.function_call_count = 0
|
||||
# Add to the benchmark timings buffer
|
||||
codeflash_benchmark_plugin.benchmark_timings.append(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ if __name__ == "__main__":
|
|||
codeflash_benchmark_plugin.setup(trace_file, project_root)
|
||||
codeflash_trace.setup(trace_file)
|
||||
exitcode = pytest.main(
|
||||
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
|
||||
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
|
||||
) # Errors will be printed to stdout, not stderr
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def get_next_arg_and_return(
|
|||
)
|
||||
|
||||
while (val := cursor.fetchone()) is not None:
|
||||
yield val[9], val[10] # args and kwargs are at indices 7 and 8
|
||||
yield val[9], val[10] # pickled_args, pickled_kwargs
|
||||
|
||||
|
||||
def get_function_alias(module: str, function_name: str) -> str:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def test_trace_benchmarks():
|
|||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 7, f"Expected 6 function calls, but got {len(function_calls)}"
|
||||
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
|
||||
|
|
@ -64,6 +64,10 @@ def test_trace_benchmarks():
|
|||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
|
||||
|
||||
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
|
|
@ -225,3 +229,59 @@ def test_trace_multithreaded_benchmark() -> None:
|
|||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
|
||||
def test_trace_benchmark_decorator() -> None:
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
|
||||
tests_root = project_root / "tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
# check contents of trace file
|
||||
# connect to database
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
# Close connection
|
||||
conn.close()
|
||||
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
Loading…
Reference in a new issue