diff --git a/code_to_optimize/bubble_sort_codeflash_trace.py b/code_to_optimize/bubble_sort_codeflash_trace.py index ee4dbd999..48e9a412b 100644 --- a/code_to_optimize/bubble_sort_codeflash_trace.py +++ b/code_to_optimize/bubble_sort_codeflash_trace.py @@ -9,6 +9,24 @@ def sorter(arr): arr[j + 1] = temp return arr +@codeflash_trace +def recursive_bubble_sort(arr, n=None): + # Initialize n if not provided + if n is None: + n = len(arr) + + # Base case: if n is 1, the array is already sorted + if n == 1: + return arr + + # One pass of bubble sort - move the largest element to the end + for i in range(n - 1): + if arr[i] > arr[i + 1]: + arr[i], arr[i + 1] = arr[i + 1], arr[i] + + # Recursively sort the remaining n-1 elements + return recursive_bubble_sort(arr, n - 1) + class Sorter: @codeflash_trace def __init__(self, arr): diff --git a/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py b/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py new file mode 100644 index 000000000..689b1f9ff --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test/test_recursive_example.py @@ -0,0 +1,6 @@ +from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort + + +def test_recursive_sort(benchmark): + result = benchmark(recursive_bubble_sort, list(reversed(range(500)))) + assert result == list(range(500)) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py b/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py new file mode 100644 index 000000000..b924bee7f --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_test_decorator/test_benchmark_decorator.py @@ -0,0 +1,11 @@ +import pytest +from code_to_optimize.bubble_sort_codeflash_trace import sorter + +def test_benchmark_sort(benchmark): + @benchmark + def do_sort(): + sorter(list(reversed(range(500)))) + +@pytest.mark.benchmark(group="benchmark_decorator") +def test_pytest_mark(benchmark): + benchmark(sorter, list(reversed(range(500)))) \ No newline at end of file diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 776e0e635..95318a38a 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -3,6 +3,7 @@ import os import pickle import sqlite3 import sys +import threading import time from typing import Callable @@ -18,6 +19,8 @@ class CodeflashTrace: self.pickle_count_limit = 1000 self._connection = None self._trace_path = None + self._thread_local = threading.local() + self._thread_local.active_functions = set() def setup(self, trace_path: str) -> None: """Set up the database connection for direct writing. @@ -98,23 +101,29 @@ class CodeflashTrace: The wrapped function """ + func_id = (func.__module__,func.__name__) @functools.wraps(func) def wrapper(*args, **kwargs): + # Initialize thread-local active functions set if it doesn't exist + if not hasattr(self._thread_local, "active_functions"): + self._thread_local.active_functions = set() + # If it's in a recursive function, just return the result + if func_id in self._thread_local.active_functions: + return func(*args, **kwargs) + # Track active functions so we can detect recursive functions + self._thread_local.active_functions.add(func_id) # Measure execution time start_time = time.thread_time_ns() result = func(*args, **kwargs) end_time = time.thread_time_ns() # Calculate execution time execution_time = end_time - start_time - self.function_call_count += 1 - # Measure overhead - original_recursion_limit = sys.getrecursionlimit() # Check if currently in pytest benchmark fixture if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False": + self._thread_local.active_functions.remove(func_id) return result - # Get benchmark info from environment benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "") benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "") @@ -125,32 +134,54 @@ class CodeflashTrace: if "." in qualname: class_name = qualname.split(".")[0] - if self.function_call_count <= self.pickle_count_limit: + # Limit pickle count so memory does not explode + if self.function_call_count > self.pickle_count_limit: + print("Pickle limit reached") + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, None, None) + ) + return result + + try: + original_recursion_limit = sys.getrecursionlimit() + sys.setrecursionlimit(10000) + # args = dict(args.items()) + # if class_name and func.__name__ == "__init__" and "self" in args: + # del args["self"] + # Pickle the arguments + pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + sys.setrecursionlimit(original_recursion_limit) + except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): + # Retry with dill if pickle fails. It's slower but more comprehensive try: - sys.setrecursionlimit(1000000) - args = dict(args.items()) - if class_name and func.__name__ == "__init__" and "self" in args: - del args["self"] - # Pickle the arguments - pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) - pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) sys.setrecursionlimit(original_recursion_limit) - except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): - # we retry with dill if pickle fails. It's slower but more comprehensive - try: - pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) - pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) - sys.setrecursionlimit(original_recursion_limit) - except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e: - print(f"Error pickling arguments for function {func.__name__}: {e}") - return result + except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e: + print(f"Error pickling arguments for function {func.__name__}: {e}") + # Add to the list of function calls without pickled args. Used for timing info only + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, None, None) + ) + return result + # Flush to database every 1000 calls if len(self.function_calls_data) > 1000: self.write_function_timings() - # Calculate overhead time - overhead_time = time.thread_time_ns() - end_time + # Add to the list of function calls with pickled args, to be used for replay tests + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time self.function_calls_data.append( (func.__name__, class_name, func.__module__, func.__code__.co_filename, benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index c7c11c6d4..f1614b5c8 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -175,6 +175,7 @@ class CodeFlashBenchmarkPlugin: benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) # Subtract overhead from total time overhead = overhead_by_benchmark.get(benchmark_key, 0) + print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead) result[benchmark_key] = time_ns - overhead finally: @@ -210,6 +211,13 @@ class CodeFlashBenchmarkPlugin: manager.unregister(plugin) @staticmethod + def pytest_configure(config): + """Register the benchmark marker.""" + config.addinivalue_line( + "markers", + "benchmark: mark test as a benchmark that should be run with codeflash tracing" + ) + @staticmethod def pytest_collection_modifyitems(config, items): # Skip tests that don't have the benchmark fixture if not config.getoption("--codeflash-trace"): @@ -217,9 +225,19 @@ class CodeFlashBenchmarkPlugin: skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture") for item in items: - if hasattr(item, "fixturenames") and "benchmark" in item.fixturenames: - continue - item.add_marker(skip_no_benchmark) + # Check for direct benchmark fixture usage + has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames + + # Check for @pytest.mark.benchmark marker + has_marker = False + if hasattr(item, "get_closest_marker"): + marker = item.get_closest_marker("benchmark") + if marker is not None: + has_marker = True + + # Skip if neither fixture nor marker is present + if not (has_fixture or has_marker): + item.add_marker(skip_no_benchmark) # Benchmark fixture class Benchmark: @@ -227,44 +245,37 @@ class CodeFlashBenchmarkPlugin: self.request = request def __call__(self, func, *args, **kwargs): - """Handle behaviour for the benchmark fixture in pytest. + """Handle both direct function calls and decorator usage.""" + if args or kwargs: + # Used as benchmark(func, *args, **kwargs) + return self._run_benchmark(func, *args, **kwargs) + # Used as @benchmark decorator + def wrapped_func(*args, **kwargs): + return func(*args, **kwargs) + result = self._run_benchmark(func) + return wrapped_func - For example, - - def test_something(benchmark): - benchmark(sorter, [3,2,1]) - - Args: - func: The function to benchmark (e.g. sorter) - args: The arguments to pass to the function (e.g. [3,2,1]) - kwargs: The keyword arguments to pass to the function - - Returns: - The return value of the function - a - - """ - benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), Path(codeflash_benchmark_plugin.project_root)) + def _run_benchmark(self, func, *args, **kwargs): + """Actual benchmark implementation.""" + benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)), + Path(codeflash_benchmark_plugin.project_root)) benchmark_function_name = self.request.node.name - line_number = int(str(sys._getframe(1).f_lineno)) # 1 frame up in the call stack - - # Set env vars so codeflash decorator can identify what benchmark its being run in + line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack + # Set env vars os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" - - # Run the function - start = time.perf_counter_ns() + # Run the function + start = time.thread_time_ns() result = func(*args, **kwargs) - end = time.perf_counter_ns() - + end = time.thread_time_ns() # Reset the environment variable os.environ["CODEFLASH_BENCHMARKING"] = "False" # Write function calls codeflash_trace.write_function_timings() - # Reset function call count after a benchmark is run + # Reset function call count codeflash_trace.function_call_count = 0 # Add to the benchmark timings buffer codeflash_benchmark_plugin.benchmark_timings.append( diff --git a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py index 1bb7bbfa4..232c39fa7 100644 --- a/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py +++ b/codeflash/benchmarking/pytest_new_process_trace_benchmarks.py @@ -16,7 +16,7 @@ if __name__ == "__main__": codeflash_benchmark_plugin.setup(trace_file, project_root) codeflash_trace.setup(trace_file) exitcode = pytest.main( - [benchmarks_root, "--codeflash-trace", "-p", "no:benchmark", "-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin] + [benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin] ) # Errors will be printed to stdout, not stderr except Exception as e: diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index 63a330774..445957505 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -34,7 +34,7 @@ def get_next_arg_and_return( ) while (val := cursor.fetchone()) is not None: - yield val[9], val[10] # args and kwargs are at indices 7 and 8 + yield val[9], val[10] # pickled_args, pickled_kwargs def get_function_alias(module: str, function_name: str) -> str: diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py index e953d1e81..715955063 100644 --- a/tests/test_trace_benchmarks.py +++ b/tests/test_trace_benchmarks.py @@ -31,7 +31,7 @@ def test_trace_benchmarks(): function_calls = cursor.fetchall() # Assert the length of function calls - assert len(function_calls) == 7, f"Expected 6 function calls, but got {len(function_calls)}" + assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}" bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix() @@ -64,6 +64,10 @@ def test_trace_benchmarks(): ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", f"{bubble_sort_path}", "test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8), + + ("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5), ] for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" @@ -222,6 +226,62 @@ def test_trace_multithreaded_benchmark() -> None: # Close connection conn.close() + finally: + # cleanup + output_file.unlink(missing_ok=True) + +def test_trace_benchmark_decorator() -> None: + project_root = Path(__file__).parent.parent / "code_to_optimize" + benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator" + tests_root = project_root / "tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() + try: + # check contents of trace file + # connect to database + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + # Get the count of records + # Get all records + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" + function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) + total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results + + test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix() + # Expected function calls + expected_calls = [ + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5), + ("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace", + f"{bubble_sort_path}", + "test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11), + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number" + # Close connection + conn.close() + finally: # cleanup output_file.unlink(missing_ok=True) \ No newline at end of file