Merge branch 'main' into context-import-bug

This commit is contained in:
Alvin Ryanputra 2025-04-17 18:51:18 -04:00
commit 04c19bfd6a
53 changed files with 3754 additions and 70 deletions

View file

@ -68,4 +68,4 @@ jobs:
id: optimize_code
run: |
source .venv/bin/activate
poetry run codeflash
poetry run codeflash --benchmark

View file

@ -0,0 +1,41 @@
name: end-to-end-test
on:
pull_request:
workflow_dispatch:
jobs:
benchmark-bubble-sort-optimization:
runs-on: ubuntu-latest
env:
CODEFLASH_AIS_SERVER: prod
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
COLUMNS: 110
MAX_RETRIES: 3
RETRY_DELAY: 5
EXPECTED_IMPROVEMENT_PCT: 5
CODEFLASH_END_TO_END: 1
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Python 3.11 for CLI
uses: astral-sh/setup-uv@v5
with:
python-version: 3.11.6
- name: Install dependencies (CLI)
run: |
uv tool install poetry
uv venv
source .venv/bin/activate
poetry install --with dev
- name: Run Codeflash to optimize code
id: optimize_code_with_benchmarks
run: |
source .venv/bin/activate
poetry run python tests/scripts/end_to_end_test_benchmark_sort.py

View file

@ -32,7 +32,7 @@ jobs:
run: uvx poetry install --with dev
- name: Unit tests
run: uvx poetry run pytest tests/ --cov --cov-report=xml
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5

View file

@ -7,4 +7,4 @@ def sorter(arr):
arr[j] = arr[j + 1]
arr[j + 1] = temp
print(f"result: {arr}")
return arr
return arr

View file

@ -0,0 +1,64 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def sorter(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
@codeflash_trace
def recursive_bubble_sort(arr, n=None):
# Initialize n if not provided
if n is None:
n = len(arr)
# Base case: if n is 1, the array is already sorted
if n == 1:
return arr
# One pass of bubble sort - move the largest element to the end
for i in range(n - 1):
if arr[i] > arr[i + 1]:
arr[i], arr[i + 1] = arr[i + 1], arr[i]
# Recursively sort the remaining n-1 elements
return recursive_bubble_sort(arr, n - 1)
class Sorter:
@codeflash_trace
def __init__(self, arr):
self.arr = arr
@codeflash_trace
def sorter(self, multiplier):
for i in range(len(self.arr)):
for j in range(len(self.arr) - 1):
if self.arr[j] > self.arr[j + 1]:
temp = self.arr[j]
self.arr[j] = self.arr[j + 1]
self.arr[j + 1] = temp
return self.arr * multiplier
@staticmethod
@codeflash_trace
def sort_static(arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
@classmethod
@codeflash_trace
def sort_class(cls, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr

View file

@ -0,0 +1,23 @@
# from code_to_optimize.bubble_sort_codeflash_trace import sorter
from code_to_optimize.bubble_sort_codeflash_trace import sorter
import concurrent.futures
def multithreaded_sorter(unsorted_lists: list[list[int]]) -> list[list[int]]:
# Create a list to store results in the correct order
sorted_lists = [None] * len(unsorted_lists)
# Use ThreadPoolExecutor to manage threads
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
# Submit all sorting tasks and map them to their original indices
future_to_index = {
executor.submit(sorter, unsorted_list): i
for i, unsorted_list in enumerate(unsorted_lists)
}
# Collect results as they complete
for future in concurrent.futures.as_completed(future_to_index):
index = future_to_index[future]
sorted_lists[index] = future.result()
return sorted_lists

View file

@ -0,0 +1,18 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_unused_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
return sorted(numbers)
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
socket = data_container.get('socket')
socket.send("Hello from the optimized function!")
return sorted(numbers)

View file

@ -0,0 +1,46 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
"""
Performs a bubble sort on a list within the data_container. The data container has the following schema:
- 'numbers' (list): The list to be sorted.
- 'socket' (socket): A socket
Args:
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
Returns:
list: The sorted list of numbers
"""
# Extract the list to sort and socket
numbers = data_container.get('numbers', []).copy()
socket = data_container.get('socket')
# Track swap count
swap_count = 0
# Classic bubble sort implementation
n = len(numbers)
for i in range(n):
# Flag to optimize by detecting if no swaps occurred
swapped = False
# Last i elements are already in place
for j in range(0, n - i - 1):
# Swap if the element is greater than the next element
if numbers[j] > numbers[j + 1]:
# Perform the swap
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
swapped = True
swap_count += 1
# If no swapping occurred in this pass, the list is sorted
if not swapped:
break
# Send final summary
summary = f"Bubble sort completed with {swap_count} swaps"
socket.send(summary.encode())
return numbers

View file

@ -0,0 +1,28 @@
from code_to_optimize.bubble_sort import sorter
def calculate_pairwise_products(arr):
"""
Calculate the average of all pairwise products in the array.
"""
sum_of_products = 0
count = 0
for i in range(len(arr)):
for j in range(len(arr)):
if i != j:
sum_of_products += arr[i] * arr[j]
count += 1
# The average of all pairwise products
return sum_of_products / count if count > 0 else 0
def compute_and_sort(arr):
# Compute pairwise sums average
pairwise_average = calculate_pairwise_products(arr)
# Call sorter function
sorter(arr.copy())
return pairwise_average

View file

@ -0,0 +1,28 @@
from code_to_optimize.bubble_sort import sorter
from codeflash.benchmarking.codeflash_trace import codeflash_trace
def calculate_pairwise_products(arr):
"""
Calculate the average of all pairwise products in the array.
"""
sum_of_products = 0
count = 0
for i in range(len(arr)):
for j in range(len(arr)):
if i != j:
sum_of_products += arr[i] * arr[j]
count += 1
# The average of all pairwise products
return sum_of_products / count if count > 0 else 0
@codeflash_trace
def compute_and_sort(arr):
# Compute pairwise sums average
pairwise_average = calculate_pairwise_products(arr)
# Call sorter function
sorter(arr.copy())
return pairwise_average

View file

@ -0,0 +1,13 @@
import pytest
from code_to_optimize.bubble_sort import sorter
def test_sort(benchmark):
result = benchmark(sorter, list(reversed(range(500))))
assert result == list(range(500))
# This should not be picked up as a benchmark test
def test_sort2():
result = sorter(list(reversed(range(500))))
assert result == list(range(500))

View file

@ -0,0 +1,8 @@
from code_to_optimize.process_and_bubble_sort import compute_and_sort
from code_to_optimize.bubble_sort import sorter
def test_compute_and_sort(benchmark):
result = benchmark(compute_and_sort, list(reversed(range(500))))
assert result == 62208.5
def test_no_func(benchmark):
benchmark(sorter, list(reversed(range(500))))

View file

@ -0,0 +1,4 @@
from code_to_optimize.bubble_sort_multithread import multithreaded_sorter
def test_benchmark_sort(benchmark):
benchmark(multithreaded_sorter, [list(range(1000)) for i in range (10)])

View file

@ -0,0 +1,20 @@
import socket
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
def test_socket_picklepatch(benchmark):
s1, s2 = socket.socketpair()
data = {
"numbers": list(reversed(range(500))),
"socket": s1
}
benchmark(bubble_sort_with_unused_socket, data)
def test_used_socket_picklepatch(benchmark):
s1, s2 = socket.socketpair()
data = {
"numbers": list(reversed(range(500))),
"socket": s1
}
benchmark(bubble_sort_with_used_socket, data)

View file

@ -0,0 +1,26 @@
import pytest
from code_to_optimize.bubble_sort_codeflash_trace import sorter, Sorter
def test_sort(benchmark):
result = benchmark(sorter, list(reversed(range(500))))
assert result == list(range(500))
# This should not be picked up as a benchmark test
def test_sort2():
result = sorter(list(reversed(range(500))))
assert result == list(range(500))
def test_class_sort(benchmark):
obj = Sorter(list(reversed(range(100))))
result1 = benchmark(obj.sorter, 2)
def test_class_sort2(benchmark):
result2 = benchmark(Sorter.sort_class, list(reversed(range(100))))
def test_class_sort3(benchmark):
result3 = benchmark(Sorter.sort_static, list(reversed(range(100))))
def test_class_sort4(benchmark):
result4 = benchmark(Sorter, [1,2,3])

View file

@ -0,0 +1,8 @@
from code_to_optimize.process_and_bubble_sort_codeflash_trace import compute_and_sort
from code_to_optimize.bubble_sort_codeflash_trace import sorter
def test_compute_and_sort(benchmark):
result = benchmark(compute_and_sort, list(reversed(range(500))))
assert result == 62208.5
def test_no_func(benchmark):
benchmark(sorter, list(reversed(range(500))))

View file

@ -0,0 +1,6 @@
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort
def test_recursive_sort(benchmark):
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
assert result == list(range(500))

View file

@ -0,0 +1,11 @@
import pytest
from code_to_optimize.bubble_sort_codeflash_trace import sorter
def test_benchmark_sort(benchmark):
@benchmark
def do_sort():
sorter(list(reversed(range(500))))
@pytest.mark.benchmark(group="benchmark_decorator")
def test_pytest_mark(benchmark):
benchmark(sorter, list(reversed(range(500))))

View file

View file

@ -0,0 +1,179 @@
import functools
import os
import pickle
import sqlite3
import threading
import time
from typing import Callable
from codeflash.picklepatch.pickle_patcher import PicklePatcher
class CodeflashTrace:
"""Decorator class that traces and profiles function execution."""
def __init__(self) -> None:
self.function_calls_data = []
self.function_call_count = 0
self.pickle_count_limit = 1000
self._connection = None
self._trace_path = None
self._thread_local = threading.local()
self._thread_local.active_functions = set()
def setup(self, trace_path: str) -> None:
"""Set up the database connection for direct writing.
Args:
trace_path: Path to the trace database file
"""
try:
self._trace_path = trace_path
self._connection = sqlite3.connect(self._trace_path)
cur = self._connection.cursor()
cur.execute("PRAGMA synchronous = OFF")
cur.execute("PRAGMA journal_mode = MEMORY")
cur.execute(
"CREATE TABLE IF NOT EXISTS benchmark_function_timings("
"function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT,"
"benchmark_function_name TEXT, benchmark_module_path TEXT, benchmark_line_number INTEGER,"
"function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
)
self._connection.commit()
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
self._connection.close()
self._connection = None
raise
def write_function_timings(self) -> None:
"""Write function call data directly to the database.
Args:
data: List of function call data tuples to write
"""
if not self.function_calls_data:
return # No data to write
if self._connection is None and self._trace_path is not None:
self._connection = sqlite3.connect(self._trace_path)
try:
cur = self._connection.cursor()
# Insert data into the benchmark_function_timings table
cur.executemany(
"INSERT INTO benchmark_function_timings"
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
self.function_calls_data
)
self._connection.commit()
self.function_calls_data = []
except Exception as e:
print(f"Error writing to function timings database: {e}")
if self._connection:
self._connection.rollback()
raise
def open(self) -> None:
"""Open the database connection."""
if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
def close(self) -> None:
"""Close the database connection."""
if self._connection:
self._connection.close()
self._connection = None
def __call__(self, func: Callable) -> Callable:
"""Use as a decorator to trace function execution.
Args:
func: The function to be decorated
Returns:
The wrapped function
"""
func_id = (func.__module__,func.__name__)
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Initialize thread-local active functions set if it doesn't exist
if not hasattr(self._thread_local, "active_functions"):
self._thread_local.active_functions = set()
# If it's in a recursive function, just return the result
if func_id in self._thread_local.active_functions:
return func(*args, **kwargs)
# Track active functions so we can detect recursive functions
self._thread_local.active_functions.add(func_id)
# Measure execution time
start_time = time.thread_time_ns()
result = func(*args, **kwargs)
end_time = time.thread_time_ns()
# Calculate execution time
execution_time = end_time - start_time
self.function_call_count += 1
# Check if currently in pytest benchmark fixture
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
self._thread_local.active_functions.remove(func_id)
return result
# Get benchmark info from environment
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
# Get class name
class_name = ""
qualname = func.__qualname__
if "." in qualname:
class_name = qualname.split(".")[0]
# Limit pickle count so memory does not explode
if self.function_call_count > self.pickle_count_limit:
print("Pickle limit reached")
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result
try:
# Pickle the arguments
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, None, None)
)
return result
# Flush to database every 100 calls
if len(self.function_calls_data) > 100:
self.write_function_timings()
# Add to the list of function calls with pickled args, to be used for replay tests
self._thread_local.active_functions.remove(func_id)
overhead_time = time.thread_time_ns() - end_time
self.function_calls_data.append(
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
overhead_time, pickled_args, pickled_kwargs)
)
return result
return wrapper
# Create a singleton instance
codeflash_trace = CodeflashTrace()

View file

@ -0,0 +1,117 @@
from pathlib import Path
import isort
import libcst as cst
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
class AddDecoratorTransformer(cst.CSTTransformer):
def __init__(self, target_functions: set[tuple[str, str]]) -> None:
super().__init__()
self.target_functions = target_functions
self.added_codeflash_trace = False
self.class_name = ""
self.function_name = ""
self.decorator = cst.Decorator(
decorator=cst.Name(value="codeflash_trace")
)
def leave_ClassDef(self, original_node, updated_node):
if self.class_name == original_node.name.value:
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
return updated_node
def visit_ClassDef(self, node):
if self.class_name: # Don't go into nested class
return False
self.class_name = node.name.value
def visit_FunctionDef(self, node):
if self.function_name: # Don't go into nested function
return False
self.function_name = node.name.value
def leave_FunctionDef(self, original_node, updated_node):
if self.function_name == original_node.name.value:
self.function_name = ""
if (self.class_name, original_node.name.value) in self.target_functions:
# Add the new decorator after any existing decorators, so it gets executed first
updated_decorators = list(updated_node.decorators) + [self.decorator]
self.added_codeflash_trace = True
return updated_node.with_changes(
decorators=updated_decorators
)
return updated_node
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
# Create import statement for codeflash_trace
if not self.added_codeflash_trace:
return updated_node
import_stmt = cst.SimpleStatementLine(
body=[
cst.ImportFrom(
module=cst.Attribute(
value=cst.Attribute(
value=cst.Name(value="codeflash"),
attr=cst.Name(value="benchmarking")
),
attr=cst.Name(value="codeflash_trace")
),
names=[
cst.ImportAlias(
name=cst.Name(value="codeflash_trace")
)
]
)
]
)
# Insert at the beginning of the file. We'll use isort later to sort the imports.
new_body = [import_stmt, *list(updated_node.body)]
return updated_node.with_changes(body=new_body)
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
"""Add codeflash_trace to a function.
Args:
code: The source code as a string
function_to_optimize: The FunctionToOptimize instance containing function details
Returns:
The modified source code as a string
"""
target_functions = set()
for function_to_optimize in functions_to_optimize:
class_name = ""
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
class_name = function_to_optimize.parents[0].name
target_functions.add((class_name, function_to_optimize.function_name))
transformer = AddDecoratorTransformer(
target_functions = target_functions,
)
module = cst.parse_module(code)
modified_module = module.visit(transformer)
return modified_module.code
def instrument_codeflash_trace_decorator(
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
) -> None:
"""Instrument codeflash_trace decorator to functions to optimize."""
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
original_code = file_path.read_text(encoding="utf-8")
new_code = add_codeflash_decorator_to_code(
original_code,
functions_to_optimize
)
# Modify the code
modified_code = isort.code(code=new_code, float_to_top=True)
# Write the modified code back to the file
file_path.write_text(modified_code, encoding="utf-8")

View file

@ -0,0 +1,293 @@
from __future__ import annotations
import os
import sqlite3
import sys
import time
from pathlib import Path
import pytest
from codeflash.benchmarking.codeflash_trace import codeflash_trace
from codeflash.code_utils.code_utils import module_name_from_file_path
from codeflash.models.models import BenchmarkKey
class CodeFlashBenchmarkPlugin:
def __init__(self) -> None:
self._trace_path = None
self._connection = None
self.project_root = None
self.benchmark_timings = []
def setup(self, trace_path:str, project_root:str) -> None:
try:
# Open connection
self.project_root = project_root
self._trace_path = trace_path
self._connection = sqlite3.connect(self._trace_path)
cur = self._connection.cursor()
cur.execute("PRAGMA synchronous = OFF")
cur.execute("PRAGMA journal_mode = MEMORY")
cur.execute(
"CREATE TABLE IF NOT EXISTS benchmark_timings("
"benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
"benchmark_time_ns INTEGER)"
)
self._connection.commit()
self.close() # Reopen only at the end of pytest session
except Exception as e:
print(f"Database setup error: {e}")
if self._connection:
self._connection.close()
self._connection = None
raise
def write_benchmark_timings(self) -> None:
if not self.benchmark_timings:
return # No data to write
if self._connection is None:
self._connection = sqlite3.connect(self._trace_path)
try:
cur = self._connection.cursor()
# Insert data into the benchmark_timings table
cur.executemany(
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
self.benchmark_timings
)
self._connection.commit()
self.benchmark_timings = [] # Clear the benchmark timings list
except Exception as e:
print(f"Error writing to benchmark timings database: {e}")
self._connection.rollback()
raise
def close(self) -> None:
if self._connection:
self._connection.close()
self._connection = None
@staticmethod
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
"""Process the trace file and extract timing data for all functions.
Args:
trace_path: Path to the trace file
Returns:
A nested dictionary where:
- Outer keys are module_name.qualified_name (module.class.function)
- Inner keys are of type BenchmarkKey
- Values are function timing in milliseconds
"""
# Initialize the result dictionary
result = {}
# Connect to the SQLite database
connection = sqlite3.connect(trace_path)
cursor = connection.cursor()
try:
# Query the function_calls table for all function calls
cursor.execute(
"SELECT module_name, class_name, function_name, "
"benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
"FROM benchmark_function_timings"
)
# Process each row
for row in cursor.fetchall():
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
# Create the function key (module_name.class_name.function_name)
if class_name:
qualified_name = f"{module_name}.{class_name}.{function_name}"
else:
qualified_name = f"{module_name}.{function_name}"
# Create the benchmark key (file::function::line)
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Initialize the inner dictionary if needed
if qualified_name not in result:
result[qualified_name] = {}
# If multiple calls to the same function in the same benchmark,
# add the times together
if benchmark_key in result[qualified_name]:
result[qualified_name][benchmark_key] += time_ns
else:
result[qualified_name][benchmark_key] = time_ns
finally:
# Close the connection
connection.close()
return result
@staticmethod
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
"""Extract total benchmark timings from trace files.
Args:
trace_path: Path to the trace file
Returns:
A dictionary mapping where:
- Keys are of type BenchmarkKey
- Values are total benchmark timing in milliseconds (with overhead subtracted)
"""
# Initialize the result dictionary
result = {}
overhead_by_benchmark = {}
# Connect to the SQLite database
connection = sqlite3.connect(trace_path)
cursor = connection.cursor()
try:
# Query the benchmark_function_timings table to get total overhead for each benchmark
cursor.execute(
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
"FROM benchmark_function_timings "
"GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
)
# Process overhead information
for row in cursor.fetchall():
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
# Query the benchmark_timings table for total times
cursor.execute(
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
"FROM benchmark_timings"
)
# Process each row and subtract overhead
for row in cursor.fetchall():
benchmark_file, benchmark_func, benchmark_line, time_ns = row
# Create the benchmark key (file::function::line)
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Subtract overhead from total time
overhead = overhead_by_benchmark.get(benchmark_key, 0)
result[benchmark_key] = time_ns - overhead
finally:
# Close the connection
connection.close()
return result
# Pytest hooks
@pytest.hookimpl
def pytest_sessionfinish(self, session, exitstatus):
"""Execute after whole test run is completed."""
# Write any remaining benchmark timings to the database
codeflash_trace.close()
if self.benchmark_timings:
self.write_benchmark_timings()
# Close the database connection
self.close()
@staticmethod
def pytest_addoption(parser):
parser.addoption(
"--codeflash-trace",
action="store_true",
default=False,
help="Enable CodeFlash tracing"
)
@staticmethod
def pytest_plugin_registered(plugin, manager):
# Not necessary since run with -p no:benchmark, but just in case
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
manager.unregister(plugin)
@staticmethod
def pytest_configure(config):
"""Register the benchmark marker."""
config.addinivalue_line(
"markers",
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
)
@staticmethod
def pytest_collection_modifyitems(config, items):
# Skip tests that don't have the benchmark fixture
if not config.getoption("--codeflash-trace"):
return
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
for item in items:
# Check for direct benchmark fixture usage
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
# Check for @pytest.mark.benchmark marker
has_marker = False
if hasattr(item, "get_closest_marker"):
marker = item.get_closest_marker("benchmark")
if marker is not None:
has_marker = True
# Skip if neither fixture nor marker is present
if not (has_fixture or has_marker):
item.add_marker(skip_no_benchmark)
# Benchmark fixture
class Benchmark:
def __init__(self, request):
self.request = request
def __call__(self, func, *args, **kwargs):
"""Handle both direct function calls and decorator usage."""
if args or kwargs:
# Used as benchmark(func, *args, **kwargs)
return self._run_benchmark(func, *args, **kwargs)
# Used as @benchmark decorator
def wrapped_func(*args, **kwargs):
return func(*args, **kwargs)
result = self._run_benchmark(func)
return wrapped_func
def _run_benchmark(self, func, *args, **kwargs):
"""Actual benchmark implementation."""
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
Path(codeflash_benchmark_plugin.project_root))
benchmark_function_name = self.request.node.name
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
# Set env vars
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"
# Run the function
start = time.time_ns()
result = func(*args, **kwargs)
end = time.time_ns()
# Reset the environment variable
os.environ["CODEFLASH_BENCHMARKING"] = "False"
# Write function calls
codeflash_trace.write_function_timings()
# Reset function call count
codeflash_trace.function_call_count = 0
# Add to the benchmark timings buffer
codeflash_benchmark_plugin.benchmark_timings.append(
(benchmark_module_path, benchmark_function_name, line_number, end - start))
return result
@staticmethod
@pytest.fixture
def benchmark(request):
if not request.config.getoption("--codeflash-trace"):
return None
return CodeFlashBenchmarkPlugin.Benchmark(request)
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()

View file

@ -0,0 +1,25 @@
import sys
from pathlib import Path
from codeflash.benchmarking.codeflash_trace import codeflash_trace
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
benchmarks_root = sys.argv[1]
tests_root = sys.argv[2]
trace_file = sys.argv[3]
# current working directory
project_root = Path.cwd()
if __name__ == "__main__":
import pytest
try:
codeflash_benchmark_plugin.setup(trace_file, project_root)
codeflash_trace.setup(trace_file)
exitcode = pytest.main(
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
) # Errors will be printed to stdout, not stderr
except Exception as e:
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
exitcode = -1
sys.exit(exitcode)

View file

@ -0,0 +1,286 @@
from __future__ import annotations
import sqlite3
import textwrap
from pathlib import Path
from typing import TYPE_CHECKING, Any
import isort
from codeflash.cli_cmds.console import logger
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
from codeflash.verification.verification_utils import get_test_file_path
if TYPE_CHECKING:
from collections.abc import Generator
def get_next_arg_and_return(
trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256
) -> Generator[Any]:
db = sqlite3.connect(trace_file)
cur = db.cursor()
limit = num_to_get
if class_name is not None:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
(benchmark_function_name, function_name, file_path, class_name, limit),
)
else:
cursor = cur.execute(
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
(benchmark_function_name, function_name, file_path, limit),
)
while (val := cursor.fetchone()) is not None:
yield val[9], val[10] # pickled_args, pickled_kwargs
def get_function_alias(module: str, function_name: str) -> str:
return "_".join(module.split(".")) + "_" + function_name
def create_trace_replay_test_code(
trace_file: str,
functions_data: list[dict[str, Any]],
test_framework: str = "pytest",
max_run_count=256
) -> str:
"""Create a replay test for functions based on trace data.
Args:
trace_file: Path to the SQLite database file
functions_data: List of dictionaries with function info extracted from DB
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include in the test
Returns:
A string containing the test code
"""
assert test_framework in ["pytest", "unittest"]
# Create Imports
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""
function_imports = []
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name", "")
if class_name:
function_imports.append(
f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}"
)
else:
function_imports.append(
f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}"
)
imports += "\n".join(function_imports)
functions_to_optimize = sorted({func.get("function_name") for func in functions_data
if func.get("function_name") != "__init__"})
metadata = f"""functions = {functions_to_optimize}
trace_file_path = r"{trace_file}"
"""
# Templates for different types of tests
test_function_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = {function_name}(*args, **kwargs)
"""
)
test_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
function_name = "{orig_function_name}"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = {class_name_alias}(*args[1:], **kwargs)
else:
instance = args[0] # self
ret = instance{method_name}(*args[1:], **kwargs)
""")
test_class_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
if not args:
raise ValueError("No arguments provided for the method.")
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
"""
)
test_static_method_body = textwrap.dedent(
"""\
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl){filter_variables}
ret = {class_name_alias}{method_name}(*args, **kwargs)
"""
)
# Create main body
if test_framework == "unittest":
self = "self"
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
else:
test_template = ""
self = ""
for func in functions_data:
module_name = func.get("module_name")
function_name = func.get("function_name")
class_name = func.get("class_name")
file_path = func.get("file_path")
benchmark_function_name = func.get("benchmark_function_name")
function_properties = func.get("function_properties")
if not class_name:
alias = get_function_alias(module_name, function_name)
test_body = test_function_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
function_name=alias,
file_path=file_path,
max_run_count=max_run_count,
)
else:
class_name_alias = get_function_alias(module_name, class_name)
alias = get_function_alias(module_name, class_name + "_" + function_name)
filter_variables = ""
# filter_variables = '\n args.pop("cls", None)'
method_name = "." + function_name if function_name != "__init__" else ""
if function_properties.is_classmethod:
test_body = test_class_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
elif function_properties.is_staticmethod:
test_body = test_static_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
else:
test_body = test_method_body.format(
benchmark_function_name=benchmark_function_name,
orig_function_name=function_name,
file_path=file_path,
class_name_alias=class_name_alias,
class_name=class_name,
method_name=method_name,
max_run_count=max_run_count,
filter_variables=filter_variables,
)
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
test_template += " " if test_framework == "unittest" else ""
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
return imports + "\n" + metadata + "\n" + test_template
def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int:
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
Args:
trace_file_path: Path to the SQLite database file
output_dir: Directory to write the generated tests (if None, only returns the code)
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include per function
Returns:
Dictionary mapping benchmark names to generated test code
"""
count = 0
try:
# Connect to the database
conn = sqlite3.connect(trace_file_path.as_posix())
cursor = conn.cursor()
# Get distinct benchmark file paths
cursor.execute(
"SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings"
)
benchmark_files = cursor.fetchall()
# Generate a test for each benchmark file
for benchmark_file in benchmark_files:
benchmark_module_path = benchmark_file[0]
# Get all benchmarks and functions associated with this file path
cursor.execute(
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
"WHERE benchmark_module_path = ?",
(benchmark_module_path,)
)
functions_data = []
for row in cursor.fetchall():
benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row
# Add this function to our list
functions_data.append({
"function_name": function_name,
"class_name": class_name,
"file_path": file_path,
"module_name": module_name,
"benchmark_function_name": benchmark_function_name,
"benchmark_module_path": benchmark_module_path,
"benchmark_line_number": benchmark_line_number,
"function_properties": inspect_top_level_functions_or_methods(
file_name=Path(file_path),
function_or_method_name=function_name,
class_name=class_name,
)
})
if not functions_data:
logger.info(f"No benchmark test functions found in {benchmark_module_path}")
continue
# Generate the test code for this benchmark
test_code = create_trace_replay_test_code(
trace_file=trace_file_path.as_posix(),
functions_data=functions_data,
test_framework=test_framework,
max_run_count=max_run_count,
)
test_code = isort.code(test_code)
output_file = get_test_file_path(
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
)
# Write test code to file, parents = true
output_dir.mkdir(parents=True, exist_ok=True)
output_file.write_text(test_code, "utf-8")
count += 1
conn.close()
except Exception as e:
logger.info(f"Error generating replay tests: {e}")
return count

View file

@ -0,0 +1,42 @@
from __future__ import annotations
import re
import subprocess
from pathlib import Path
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None:
result = subprocess.run(
[
SAFE_SYS_EXECUTABLE,
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
benchmarks_root,
tests_root,
trace_file,
],
cwd=project_root,
check=False,
capture_output=True,
text=True,
env={"PYTHONPATH": str(project_root)},
timeout=timeout,
)
if result.returncode != 0:
if "ERROR collecting" in result.stdout:
# Pattern matches "===== ERRORS =====" (any number of =) and captures everything after
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
match = re.search(error_pattern, result.stdout)
error_section = match.group(1) if match else result.stdout
elif "FAILURES" in result.stdout:
# Pattern matches "===== FAILURES =====" (any number of =) and captures everything after
error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
match = re.search(error_pattern, result.stdout)
error_section = match.group(1) if match else result.stdout
else:
error_section = result.stdout
logger.warning(
f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}"
)

View file

@ -0,0 +1,136 @@
from __future__ import annotations
import shutil
from typing import TYPE_CHECKING, Optional
from rich.console import Console
from rich.table import Table
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo
from codeflash.result.critic import performance_gain
if TYPE_CHECKING:
from codeflash.models.models import BenchmarkKey
def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[BenchmarkKey, int]],
total_benchmark_timings: dict[BenchmarkKey, int]) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
function_to_result = {}
# Process each function's benchmark data
for func_path, test_times in function_benchmark_timings.items():
# Sort by percentage (highest first)
sorted_tests = []
for benchmark_key, func_time in test_times.items():
total_time = total_benchmark_timings.get(benchmark_key, 0)
if func_time > total_time:
logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}")
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
# Do not try to project the optimization impact for this function.
sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0))
elif total_time > 0:
percentage = (func_time / total_time) * 100
# Convert nanoseconds to milliseconds
func_time_ms = func_time / 1_000_000
total_time_ms = total_time / 1_000_000
sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage))
sorted_tests.sort(key=lambda x: x[3], reverse=True)
function_to_result[func_path] = sorted_tests
return function_to_result
def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None:
try:
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
except Exception:
terminal_width = 120 # Fallback width
console = Console(width = terminal_width)
for func_path, sorted_tests in function_to_results.items():
console.print()
function_name = func_path.split(":")[-1]
# Create a table for this function
table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True)
benchmark_col_width = max(int(terminal_width * 0.4), 40)
# Add columns - split the benchmark test into two columns
table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold")
table.add_column("Test Function", style="magenta", overflow="fold")
table.add_column("Total Time (ms)", justify="right", style="green")
table.add_column("Function Time (ms)", justify="right", style="yellow")
table.add_column("Percentage (%)", justify="right", style="red")
for benchmark_key, total_time, func_time, percentage in sorted_tests:
# Split the benchmark test into module path and function name
module_path = benchmark_key.module_path
test_function = benchmark_key.function_name
if total_time == 0.0:
table.add_row(
module_path,
test_function,
"N/A",
"N/A",
"N/A"
)
else:
table.add_row(
module_path,
test_function,
f"{total_time:.3f}",
f"{func_time:.3f}",
f"{percentage:.2f}"
)
# Print the table
console.print(table)
def process_benchmark_data(
replay_performance_gain: dict[BenchmarkKey, float],
fto_benchmark_timings: dict[BenchmarkKey, int],
total_benchmark_timings: dict[BenchmarkKey, int]
) -> Optional[ProcessedBenchmarkInfo]:
"""Process benchmark data and generate detailed benchmark information.
Args:
replay_performance_gain: The performance gain from replay
fto_benchmark_timings: Function to optimize benchmark timings
total_benchmark_timings: Total benchmark timings
Returns:
ProcessedBenchmarkInfo containing processed benchmark details
"""
if not replay_performance_gain or not fto_benchmark_timings or not total_benchmark_timings:
return None
benchmark_details = []
for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items():
total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0)
if total_benchmark_timing == 0:
continue # Skip benchmarks with zero timing
# Calculate expected new benchmark timing
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + (
1 / (replay_performance_gain[benchmark_key] + 1)
) * og_benchmark_timing
# Calculate speedup
benchmark_speedup_percent = performance_gain(original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)) * 100
benchmark_details.append(
BenchmarkDetail(
benchmark_name=benchmark_key.module_path,
test_function=benchmark_key.function_name,
original_timing=humanize_runtime(int(total_benchmark_timing)),
expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)),
speedup_percent=benchmark_speedup_percent
)
)
return ProcessedBenchmarkInfo(benchmark_details=benchmark_details)

View file

@ -62,6 +62,10 @@ def parse_args() -> Namespace:
)
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
parser.add_argument(
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
)
args: Namespace = parser.parse_args()
return process_and_validate_cmd_args(args)
@ -109,6 +113,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
supported_keys = [
"module_root",
"tests_root",
"benchmarks_root",
"test_framework",
"ignore_paths",
"pytest_cmd",
@ -127,7 +132,12 @@ def process_pyproject_config(args: Namespace) -> Namespace:
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
assert args.tests_root is not None, "--tests-root must be specified"
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
if args.benchmark:
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), (
f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}"
)
if env_utils.get_pr_number() is not None:
assert env_utils.ensure_codeflash_api_key(), (
"Codeflash API key not found. When running in a Github Actions Context, provide the "

View file

@ -49,6 +49,7 @@ CODEFLASH_LOGO: str = (
class SetupInfo:
module_root: str
tests_root: str
benchmarks_root: str | None
test_framework: str
ignore_paths: list[str]
formatter: str
@ -125,8 +126,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
def should_modify_pyproject_toml() -> bool:
"""
Check if the current directory contains a valid pyproject.toml file with codeflash config
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
If it does, ask the user if they want to re-configure it.
"""
from rich.prompt import Confirm
@ -135,7 +135,7 @@ def should_modify_pyproject_toml() -> bool:
return True
try:
config, config_file_path = parse_config_file(pyproject_toml_path)
except Exception as e:
except Exception:
return True
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
@ -144,7 +144,7 @@ def should_modify_pyproject_toml() -> bool:
return True
create_toml = Confirm.ask(
f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
)
return create_toml
@ -244,6 +244,66 @@ def collect_setup_info() -> SetupInfo:
ph("cli-test-framework-provided", {"test_framework": test_framework})
# Get benchmarks root directory
default_benchmarks_subdir = "benchmarks"
create_benchmarks_option = f"okay, create a {default_benchmarks_subdir}{os.path.sep} directory for me!"
no_benchmarks_option = "I don't need benchmarks"
# Check if benchmarks directory exists inside tests directory
tests_subdirs = []
if tests_root.exists():
tests_subdirs = [d.name for d in tests_root.iterdir() if d.is_dir() and not d.name.startswith(".")]
benchmarks_options = []
if default_benchmarks_subdir in tests_subdirs:
benchmarks_options.append(default_benchmarks_subdir)
benchmarks_options.extend([d for d in tests_subdirs if d != default_benchmarks_subdir])
benchmarks_options.append(create_benchmarks_option)
benchmarks_options.append(custom_dir_option)
benchmarks_options.append(no_benchmarks_option)
benchmarks_answer = inquirer_wrapper(
inquirer.list_input,
message="Where are your benchmarks located? (benchmarks must be a sub directory of your tests root directory)",
choices=benchmarks_options,
default=(
default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]),
)
if benchmarks_answer == create_benchmarks_option:
benchmarks_root = tests_root / default_benchmarks_subdir
benchmarks_root.mkdir(exist_ok=True)
click.echo(f"✅ Created directory {benchmarks_root}{os.path.sep}{LF}")
elif benchmarks_answer == custom_dir_option:
custom_benchmarks_answer = inquirer_wrapper_path(
"path",
message=f"Enter the path to your benchmarks directory inside {tests_root}{os.path.sep} ",
path_type=inquirer.Path.DIRECTORY,
)
if custom_benchmarks_answer:
benchmarks_root = tests_root / Path(custom_benchmarks_answer["path"])
else:
apologize_and_exit()
elif benchmarks_answer == no_benchmarks_option:
benchmarks_root = None
else:
benchmarks_root = tests_root / Path(cast(str, benchmarks_answer))
# TODO: Implement other benchmark framework options
# if benchmarks_root:
# benchmarks_root = benchmarks_root.relative_to(curdir)
#
# # Ask about benchmark framework
# benchmark_framework_options = ["pytest-benchmark", "asv (Airspeed Velocity)", "custom/other"]
# benchmark_framework = inquirer_wrapper(
# inquirer.list_input,
# message="Which benchmark framework do you use?",
# choices=benchmark_framework_options,
# default=benchmark_framework_options[0],
# carousel=True,
# )
formatter = inquirer_wrapper(
inquirer.list_input,
message="Which code formatter do you use?",
@ -279,6 +339,7 @@ def collect_setup_info() -> SetupInfo:
return SetupInfo(
module_root=str(module_root),
tests_root=str(tests_root),
benchmarks_root = str(benchmarks_root) if benchmarks_root else None,
test_framework=cast(str, test_framework),
ignore_paths=ignore_paths,
formatter=cast(str, formatter),
@ -437,11 +498,19 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
return
workflows_path.mkdir(parents=True, exist_ok=True)
from importlib.resources import files
benchmark_mode = False
if "benchmarks_root" in config:
benchmark_mode = inquirer_wrapper(
inquirer.confirm,
message="It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
" This will show the impact of Codeflash's suggested optimizations on your benchmarks",
default=True,
)
optimize_yml_content = (
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
)
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root)
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
optimize_yml_file.write(materialized_optimize_yml_content)
click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}")
@ -556,7 +625,7 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
def customize_codeflash_yaml_content(
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
) -> str:
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
@ -587,6 +656,9 @@ def customize_codeflash_yaml_content(
# Add codeflash command
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
if benchmark_mode:
codeflash_cmd += " --benchmark"
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
@ -608,6 +680,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
codeflash_section["module-root"] = setup_info.module_root
codeflash_section["tests-root"] = setup_info.tests_root
codeflash_section["test-framework"] = setup_info.test_framework
codeflash_section["benchmarks-root"] = setup_info.benchmarks_root if setup_info.benchmarks_root else ""
codeflash_section["ignore-paths"] = setup_info.ignore_paths
if setup_info.git_remote not in ["", "origin"]:
codeflash_section["git-remote"] = setup_info.git_remote

View file

@ -52,10 +52,10 @@ def parse_config_file(
assert isinstance(config, dict)
# default values:
path_keys = {"module-root", "tests-root"}
path_list_keys = {"ignore-paths", }
path_keys = ["module-root", "tests-root", "benchmarks-root"]
path_list_keys = ["ignore-paths"]
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False}
list_str_keys = {"formatter-cmds": ["black $file"]}
for key in str_keys:

View file

@ -106,7 +106,7 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
class FunctionToOptimize:
"""Represents a function that is a candidate for optimization.
"""Represent a function that is a candidate for optimization.
Attributes
----------
@ -145,7 +145,6 @@ class FunctionToOptimize:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
def get_functions_to_optimize(
optimize_all: str | None,
replay_test: str | None,
@ -359,9 +358,15 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
for decorator in body_node.decorator_list
):
self.is_classmethod = True
elif any(
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
for decorator in body_node.decorator_list
):
self.is_staticmethod = True
return
else:
# search if the class has a staticmethod with the same name and on the same line number
elif self.line_no:
# If we have line number info, check if class has a static method with the same line number
# This way, if we don't have the class name, we can still find the static method
for body_node in node.body:
if (
isinstance(body_node, ast.FunctionDef)

View file

@ -1,9 +1,11 @@
from typing import Union
from __future__ import annotations
from typing import Union, Optional
from pydantic import BaseModel
from pydantic.dataclasses import dataclass
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.models import BenchmarkDetail
from codeflash.models.models import TestResults
@ -18,8 +20,9 @@ class PrComment:
speedup_pct: str
winning_behavioral_test_results: TestResults
winning_benchmarking_test_results: TestResults
benchmark_details: Optional[list[BenchmarkDetail]] = None
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]:
report_table = {
test_type.to_name(): result
for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items()
@ -36,6 +39,7 @@ class PrComment:
"speedup_pct": self.speedup_pct,
"loop_count": self.winning_benchmarking_test_results.number_of_loops(),
"report_table": report_table,
"benchmark_details": self.benchmark_details if self.benchmark_details else None,
}

View file

@ -1,5 +1,6 @@
from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING
from rich.tree import Tree
@ -11,7 +12,7 @@ if TYPE_CHECKING:
import enum
import re
import sys
from collections.abc import Collection, Iterator
from collections.abc import Collection
from enum import Enum, IntEnum
from pathlib import Path
from re import Pattern
@ -22,7 +23,7 @@ from pydantic import AfterValidator, BaseModel, ConfigDict, Field
from pydantic.dataclasses import dataclass
from codeflash.cli_cmds.console import console, logger
from codeflash.code_utils.code_utils import validate_python_code
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
from codeflash.code_utils.env_utils import is_end_to_end
from codeflash.verification.comparator import comparator
@ -58,28 +59,74 @@ class FunctionSource:
def __eq__(self, other: object) -> bool:
if not isinstance(other, FunctionSource):
return False
return (
self.file_path == other.file_path
and self.qualified_name == other.qualified_name
and self.fully_qualified_name == other.fully_qualified_name
and self.only_function_name == other.only_function_name
and self.source_code == other.source_code
)
return (self.file_path == other.file_path and
self.qualified_name == other.qualified_name and
self.fully_qualified_name == other.fully_qualified_name and
self.only_function_name == other.only_function_name and
self.source_code == other.source_code)
def __hash__(self) -> int:
return hash(
(self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code)
)
return hash((self.file_path, self.qualified_name, self.fully_qualified_name,
self.only_function_name, self.source_code))
class BestOptimization(BaseModel):
candidate: OptimizedCandidate
helper_functions: list[FunctionSource]
runtime: int
replay_performance_gain: Optional[dict[BenchmarkKey,float]] = None
winning_behavioral_test_results: TestResults
winning_benchmarking_test_results: TestResults
winning_replay_benchmarking_test_results : Optional[TestResults] = None
@dataclass(frozen=True)
class BenchmarkKey:
module_path: str
function_name: str
def __str__(self) -> str:
return f"{self.module_path}::{self.function_name}"
@dataclass
class BenchmarkDetail:
benchmark_name: str
test_function: str
original_timing: str
expected_new_timing: str
speedup_percent: float
def to_string(self) -> str:
return (
f"Original timing for {self.benchmark_name}::{self.test_function}: {self.original_timing}\n"
f"Expected new timing for {self.benchmark_name}::{self.test_function}: {self.expected_new_timing}\n"
f"Benchmark speedup for {self.benchmark_name}::{self.test_function}: {self.speedup_percent:.2f}%\n"
)
def to_dict(self) -> dict[str, any]:
return {
"benchmark_name": self.benchmark_name,
"test_function": self.test_function,
"original_timing": self.original_timing,
"expected_new_timing": self.expected_new_timing,
"speedup_percent": self.speedup_percent
}
@dataclass
class ProcessedBenchmarkInfo:
benchmark_details: list[BenchmarkDetail]
def to_string(self) -> str:
if not self.benchmark_details:
return ""
result = "Benchmark Performance Details:\n"
for detail in self.benchmark_details:
result += detail.to_string() + "\n"
return result
def to_dict(self) -> dict[str, list[dict[str, any]]]:
return {
"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]
}
class CodeString(BaseModel):
code: Annotated[str, AfterValidator(validate_python_code)]
file_path: Optional[Path] = None
@ -104,8 +151,7 @@ class CodeOptimizationContext(BaseModel):
read_writable_code: str = Field(min_length=1)
read_only_context_code: str = ""
helper_functions: list[FunctionSource]
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
class CodeContextType(str, Enum):
READ_WRITABLE = "READ_WRITABLE"
@ -118,6 +164,7 @@ class OptimizedCandidateResult(BaseModel):
best_test_runtime: int
behavior_test_results: TestResults
benchmarking_test_results: TestResults
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
optimization_candidate_index: int
total_candidate_timing: int
@ -222,6 +269,7 @@ class FunctionParent:
class OriginalCodeBaseline(BaseModel):
behavioral_test_results: TestResults
benchmarking_test_results: TestResults
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
line_profile_results: dict
runtime: int
coverage_results: Optional[CoverageData]
@ -299,7 +347,6 @@ class CoverageData:
status=CoverageStatus.NOT_FOUND,
)
@dataclass
class FunctionCoverage:
"""Represents the coverage data for a specific function in a source file."""
@ -426,6 +473,20 @@ class TestResults(BaseModel):
raise ValueError(msg)
self.test_result_idx[k] = v + original_len
def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path) -> dict[BenchmarkKey, TestResults]:
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
test_results_by_benchmark = defaultdict(TestResults)
benchmark_module_path = {}
for benchmark_key in benchmark_keys:
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root)
for test_result in self.test_results:
if (test_result.test_type == TestType.REPLAY_TEST):
for benchmark_key, module_path in benchmark_module_path.items():
if test_result.id.test_module_path.startswith(module_path):
test_results_by_benchmark[benchmark_key].add(test_result)
return test_results_by_benchmark
def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None:
try:
return self.test_results[self.test_result_idx[unique_invocation_loop_id]]

View file

@ -19,6 +19,7 @@ from rich.syntax import Syntax
from rich.tree import Tree
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
@ -42,7 +43,6 @@ from codeflash.code_utils.remove_generated_tests import remove_functions_from_ge
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Failure, Success, is_successful
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
@ -76,8 +76,9 @@ from codeflash.verification.verifier import generate_tests
if TYPE_CHECKING:
from argparse import Namespace
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.either import Result
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
from codeflash.verification.verification_utils import TestConfig
@ -90,7 +91,10 @@ class FunctionOptimizer:
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_ast: ast.FunctionDef | None = None,
aiservice_client: AiServiceClient | None = None,
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
args: Namespace | None = None,
replay_tests_dir: Path|None = None
) -> None:
self.project_root = test_cfg.project_root_path
self.test_cfg = test_cfg
@ -113,11 +117,14 @@ class FunctionOptimizer:
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
self.test_files = TestFiles(test_files=[])
self.args = args # Check defaults for these
self.function_trace_id: str = str(uuid.uuid4())
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
def optimize_function(self) -> Result[BestOptimization, str]:
should_run_experiment = self.experiment_id is not None
logger.debug(f"Function Trace ID: {self.function_trace_id}")
@ -136,8 +143,8 @@ class FunctionOptimizer:
original_helper_code[helper_function_path] = helper_code
if has_any_async_functions(code_context.read_writable_code):
return Failure("Codeflash does not support async functions in the code to optimize.")
code_print(code_context.read_writable_code)
code_print(code_context.read_writable_code)
generated_test_paths = [
get_test_file_path(
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
@ -261,6 +268,13 @@ class FunctionOptimizer:
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
)
)
processed_benchmark_info = None
if self.args.benchmark:
processed_benchmark_info = process_benchmark_data(
replay_performance_gain=best_optimization.replay_performance_gain,
fto_benchmark_timings=self.function_benchmark_timings,
total_benchmark_timings=self.total_benchmark_timings
)
explanation = Explanation(
raw_explanation_message=best_optimization.candidate.explanation,
winning_behavioral_test_results=best_optimization.winning_behavioral_test_results,
@ -269,6 +283,7 @@ class FunctionOptimizer:
best_runtime_ns=best_optimization.runtime,
function_name=function_to_optimize_qualified_name,
file_path=self.function_to_optimize.file_path,
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None
)
self.log_successful_optimization(explanation, generated_tests)
@ -362,7 +377,7 @@ class FunctionOptimizer:
candidates = deque(candidates)
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future_line_profile_results = executor.submit(
self.aiservice_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code,
@ -382,8 +397,8 @@ class FunctionOptimizer:
if done and (future_line_profile_results is not None):
line_profile_results = future_line_profile_results.result()
candidates.extend(line_profile_results)
original_len+= len(line_profile_results)
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
original_len+= len(candidates)
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
future_line_profile_results = None
candidate_index += 1
candidate = candidates.popleft()
@ -410,7 +425,7 @@ class FunctionOptimizer:
)
continue
# Instrument codeflash capture
run_results = self.run_optimized_candidate(
optimization_candidate_index=candidate_index,
baseline_results=original_code_baseline,
@ -434,6 +449,7 @@ class FunctionOptimizer:
speedup_ratios[candidate.optimization_id] = perf_gain
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
benchmark_tree = None
if speedup_critic(
candidate_result, original_code_baseline.runtime, best_runtime_until_now
) and quantity_of_tests_critic(candidate_result):
@ -446,13 +462,29 @@ class FunctionOptimizer:
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
replay_perf_gain = {}
if self.args.benchmark:
test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
if len(test_results_by_benchmark) > 0:
benchmark_tree = Tree("Speedup percentage on benchmarks:")
for benchmark_key, candidate_test_results in test_results_by_benchmark.items():
original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime()
candidate_replay_runtime = candidate_test_results.total_passed_runtime()
replay_perf_gain[benchmark_key] = performance_gain(
original_runtime_ns=original_code_replay_runtime,
optimized_runtime_ns=candidate_replay_runtime,
)
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")
best_optimization = BestOptimization(
candidate=candidate,
helper_functions=code_context.helper_functions,
runtime=best_test_runtime,
winning_behavioral_test_results=candidate_result.behavior_test_results,
replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
)
best_runtime_until_now = best_test_runtime
else:
@ -464,6 +496,8 @@ class FunctionOptimizer:
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
console.print(tree)
if self.args.benchmark and benchmark_tree:
console.print(benchmark_tree)
console.rule()
self.write_code_and_helpers(
@ -507,7 +541,8 @@ class FunctionOptimizer:
)
console.print(Group(explanation_panel, tests_panel))
console.print(explanation_panel)
else:
console.print(explanation_panel)
ph(
"cli-optimize-success",
@ -664,6 +699,7 @@ class FunctionOptimizer:
unique_instrumented_test_files.add(new_behavioral_test_path)
unique_instrumented_test_files.add(new_perf_test_path)
if not self.test_files.get_by_original_file_path(path_obj_test_file):
self.test_files.add(
TestFile(
@ -675,6 +711,7 @@ class FunctionOptimizer:
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
)
)
logger.info(
f"Discovered {existing_test_files_count} existing unit test file"
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
@ -865,7 +902,6 @@ class FunctionOptimizer:
enable_coverage=False,
code_context=code_context,
)
else:
benchmarking_results = TestResults()
start_time: float = time.time()
@ -920,11 +956,15 @@ class FunctionOptimizer:
)
console.rule()
logger.debug(f"Total original code runtime (ns): {total_timing}")
if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
return Success(
(
OriginalCodeBaseline(
behavioral_test_results=behavioral_results,
benchmarking_test_results=benchmarking_results,
replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None,
runtime=total_timing,
coverage_results=coverage_results,
line_profile_results=line_profile_results,
@ -954,8 +994,6 @@ class FunctionOptimizer:
test_env["PYTHONPATH"] += os.pathsep + str(self.project_root)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
# Instrument codeflash capture
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
candidate_helper_code = {}
@ -986,7 +1024,6 @@ class FunctionOptimizer:
)
)
console.rule()
if compare_test_results(baseline_results.behavioral_test_results, candidate_behavior_results):
logger.info("Test results matched!")
console.rule()
@ -1039,12 +1076,17 @@ class FunctionOptimizer:
console.rule()
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
if self.args.benchmark:
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items():
logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}")
return Success(
OptimizedCandidateResult(
max_loop_count=loop_count,
best_test_runtime=total_candidate_timing,
behavior_test_results=candidate_behavior_results,
benchmarking_test_results=candidate_benchmarking_results,
replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None,
optimization_candidate_index=optimization_candidate_index,
total_candidate_timing=total_candidate_timing,
)
@ -1086,8 +1128,8 @@ class FunctionOptimizer:
pytest_cmd=self.test_cfg.pytest_cmd,
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
pytest_target_runtime_seconds=testing_time,
pytest_min_loops=pytest_min_loops,
pytest_max_loops=pytest_min_loops,
pytest_min_loops=1,
pytest_max_loops=1,
test_framework=self.test_cfg.test_framework,
line_profiler_output_file=line_profiler_output_file,
)

View file

@ -5,10 +5,16 @@ import os
import time
import shutil
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
from codeflash.cli_cmds.console import console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
@ -17,7 +23,7 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
from codeflash.either import is_successful
from codeflash.models.models import TestType, ValidCode
from codeflash.models.models import BenchmarkKey, TestType, ValidCode
from codeflash.optimization.function_optimizer import FunctionOptimizer
from codeflash.telemetry.posthog_cf import ph
from codeflash.verification.verification_utils import TestConfig
@ -39,18 +45,21 @@ class Optimizer:
project_root_path=args.project_root,
test_framework=args.test_framework,
pytest_cmd=args.pytest_cmd,
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
)
self.aiservice_client = AiServiceClient()
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
self.replay_tests_dir = None
def create_function_optimizer(
self,
function_to_optimize: FunctionToOptimize,
function_to_optimize_ast: ast.FunctionDef | None = None,
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
function_to_optimize_source_code: str | None = "",
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
) -> FunctionOptimizer:
return FunctionOptimizer(
function_to_optimize=function_to_optimize,
@ -60,6 +69,9 @@ class Optimizer:
function_to_optimize_ast=function_to_optimize_ast,
aiservice_client=self.aiservice_client,
args=self.args,
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
replay_tests_dir = self.replay_tests_dir
)
def run(self) -> None:
@ -71,6 +83,8 @@ class Optimizer:
function_optimizer = None
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
num_optimizable_functions: int
# discover functions
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
optimize_all=self.args.all,
replay_test=self.args.replay_test,
@ -81,7 +95,42 @@ class Optimizer:
project_root=self.args.project_root,
module_root=self.args.module_root,
)
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
total_benchmark_timings: dict[BenchmarkKey, int] = {}
if self.args.benchmark:
with progress_bar(
f"Running benchmarks in {self.args.benchmarks_root}",
transient=True,
):
# Insert decorator
file_path_to_source_code = defaultdict(str)
for file in file_to_funcs_to_optimize:
with file.open("r", encoding="utf8") as f:
file_path_to_source_code[file] = f.read()
try:
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
if trace_file.exists():
trace_file.unlink()
self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root))
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
if replay_count == 0:
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
else:
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
print_benchmark_table(function_to_results)
except Exception as e:
logger.info(f"Error while tracing existing benchmarks: {e}")
logger.info("Information on existing benchmarks will not be available for this run.")
finally:
# Restore original source code
for file in file_path_to_source_code:
with file.open("w", encoding="utf8") as f:
f.write(file_path_to_source_code[file])
optimizations_found: int = 0
function_iterator_count: int = 0
if self.args.test_framework == "pytest":
@ -103,6 +152,7 @@ class Optimizer:
console.rule()
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
for original_module_path in file_to_funcs_to_optimize:
logger.info(f"Examining file {original_module_path!s}")
console.rule()
@ -159,12 +209,19 @@ class Optimizer:
f"Skipping optimization."
)
continue
function_optimizer = self.create_function_optimizer(
function_to_optimize,
function_to_optimize_ast,
function_to_tests,
validated_original_code[original_module_path].source_code,
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
self.args.project_root
)
if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings:
function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings
)
else:
function_optimizer = self.create_function_optimizer(
function_to_optimize, function_to_optimize_ast, function_to_tests,
validated_original_code[original_module_path].source_code
)
best_optimization = function_optimizer.optimize_function()
if is_successful(best_optimization):
optimizations_found += 1
@ -189,6 +246,10 @@ class Optimizer:
test_file.instrumented_behavior_file_path.unlink(missing_ok=True)
if function_optimizer.test_cfg.concolic_test_root_dir:
shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True)
if self.args.benchmark:
if self.replay_tests_dir.exists():
shutil.rmtree(self.replay_tests_dir, ignore_errors=True)
trace_file.unlink(missing_ok=True)
if hasattr(get_run_tmp_file, "tmpdir"):
get_run_tmp_file.tmpdir.cleanup()

View file

View file

@ -0,0 +1,346 @@
"""PicklePatcher - A utility for safely pickling objects with unpicklable components.
This module provides functions to recursively pickle objects, replacing unpicklable
components with placeholders that provide informative errors when accessed.
"""
import pickle
import types
import dill
from .pickle_placeholder import PicklePlaceholder
class PicklePatcher:
"""A utility class for safely pickling objects with unpicklable components.
This class provides methods to recursively pickle objects, replacing any
components that can't be pickled with placeholder objects.
"""
# Class-level cache of unpicklable types
_unpicklable_types = set()
@staticmethod
def dumps(obj, protocol=None, max_depth=100, **kwargs):
"""Safely pickle an object, replacing unpicklable parts with placeholders.
Args:
obj: The object to pickle
protocol: The pickle protocol version to use
max_depth: Maximum recursion depth
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
bytes: Pickled data with placeholders for unpicklable objects
"""
return PicklePatcher._recursive_pickle(obj, max_depth, path=[], protocol=protocol, **kwargs)
@staticmethod
def loads(pickled_data):
"""Unpickle data that may contain placeholders.
Args:
pickled_data: Pickled data with possible placeholders
Returns:
The unpickled object with placeholders for unpicklable parts
"""
try:
# We use dill for loading since it can handle everything pickle can
return dill.loads(pickled_data)
except Exception as e:
raise
@staticmethod
def _create_placeholder(obj, error_msg, path):
"""Create a placeholder for an unpicklable object.
Args:
obj: The original unpicklable object
error_msg: Error message explaining why it couldn't be pickled
path: Path to this object in the object graph
Returns:
PicklePlaceholder: A placeholder object
"""
obj_type = type(obj)
try:
obj_str = str(obj)[:100] if hasattr(obj, "__str__") else f"<unprintable object of type {obj_type.__name__}>"
except:
obj_str = f"<unprintable object of type {obj_type.__name__}>"
print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}")
placeholder = PicklePlaceholder(
obj_type.__name__,
obj_str,
error_msg,
path
)
# Add this type to our known unpicklable types cache
PicklePatcher._unpicklable_types.add(obj_type)
return placeholder
@staticmethod
def _pickle(obj, path=None, protocol=None, **kwargs):
"""Try to pickle an object using pickle first, then dill. If both fail, create a placeholder.
Args:
obj: The object to pickle
path: Path to this object in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
tuple: (success, result) where success is a boolean and result is either:
- Pickled bytes if successful
- Error message if not successful
"""
# Try standard pickle first
try:
return True, pickle.dumps(obj, protocol=protocol, **kwargs)
except (pickle.PickleError, TypeError, AttributeError, ValueError) as e:
# Then try dill (which is more powerful)
try:
return True, dill.dumps(obj, protocol=protocol, **kwargs)
except (dill.PicklingError, TypeError, AttributeError, ValueError) as e:
return False, str(e)
@staticmethod
def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs):
"""Recursively try to pickle an object, replacing unpicklable parts with placeholders.
Args:
obj: The object to pickle
max_depth: Maximum recursion depth
path: Current path in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
bytes: Pickled data with placeholders for unpicklable objects
"""
if path is None:
path = []
obj_type = type(obj)
# Check if this type is known to be unpicklable
if obj_type in PicklePatcher._unpicklable_types:
placeholder = PicklePatcher._create_placeholder(
obj,
"Known unpicklable type",
path
)
return dill.dumps(placeholder, protocol=protocol, **kwargs)
# Check for max depth
if max_depth <= 0:
placeholder = PicklePatcher._create_placeholder(
obj,
"Max recursion depth exceeded",
path
)
return dill.dumps(placeholder, protocol=protocol, **kwargs)
# Try standard pickling
success, result = PicklePatcher._pickle(obj, path, protocol, **kwargs)
if success:
return result
error_msg = result # Error message from pickling attempt
# Handle different container types
if isinstance(obj, dict):
return PicklePatcher._handle_dict(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
elif isinstance(obj, (list, tuple, set)):
return PicklePatcher._handle_sequence(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
elif hasattr(obj, "__dict__"):
result = PicklePatcher._handle_object(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
# If this was a failure, add the type to the cache
unpickled = dill.loads(result)
if isinstance(unpickled, PicklePlaceholder):
PicklePatcher._unpicklable_types.add(obj_type)
return result
# For other unpicklable objects, use a placeholder
placeholder = PicklePatcher._create_placeholder(obj, error_msg, path)
return dill.dumps(placeholder, protocol=protocol, **kwargs)
@staticmethod
def _handle_dict(obj_dict, max_depth, error_msg, path, protocol=None, **kwargs):
"""Handle pickling for dictionary objects.
Args:
obj_dict: The dictionary to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
path: Current path in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
bytes: Pickled data with placeholders for unpicklable objects
"""
if not isinstance(obj_dict, dict):
placeholder = PicklePatcher._create_placeholder(
obj_dict,
f"Expected a dictionary, got {type(obj_dict).__name__}",
path
)
return dill.dumps(placeholder, protocol=protocol, **kwargs)
result = {}
for key, value in obj_dict.items():
# Process the key
key_success, key_result = PicklePatcher._pickle(key, path, protocol, **kwargs)
if key_success:
key_result = key
else:
# If the key can't be pickled, use a string representation
try:
key_str = str(key)[:50]
except:
key_str = f"<unprintable key of type {type(key).__name__}>"
key_result = f"<unpicklable_key:{key_str}>"
# Process the value
value_path = path + [f"[{repr(key)[:20]}]"]
value_success, value_bytes = PicklePatcher._pickle(value, value_path, protocol, **kwargs)
if value_success:
value_result = value
else:
# Try recursive pickling for the value
try:
value_bytes = PicklePatcher._recursive_pickle(
value, max_depth - 1, value_path, protocol=protocol, **kwargs
)
value_result = dill.loads(value_bytes)
except Exception as inner_e:
value_result = PicklePatcher._create_placeholder(
value,
str(inner_e),
value_path
)
result[key_result] = value_result
return dill.dumps(result, protocol=protocol, **kwargs)
@staticmethod
def _handle_sequence(obj_seq, max_depth, error_msg, path, protocol=None, **kwargs):
"""Handle pickling for sequence types (list, tuple, set).
Args:
obj_seq: The sequence to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
path: Current path in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
bytes: Pickled data with placeholders for unpicklable objects
"""
result = []
for i, item in enumerate(obj_seq):
item_path = path + [f"[{i}]"]
# Try to pickle the item directly
success, _ = PicklePatcher._pickle(item, item_path, protocol, **kwargs)
if success:
result.append(item)
continue
# If we couldn't pickle directly, try recursively
try:
item_bytes = PicklePatcher._recursive_pickle(
item, max_depth - 1, item_path, protocol=protocol, **kwargs
)
result.append(dill.loads(item_bytes))
except Exception as inner_e:
# If recursive pickling fails, use a placeholder
placeholder = PicklePatcher._create_placeholder(
item,
str(inner_e),
item_path
)
result.append(placeholder)
# Convert back to the original type
if isinstance(obj_seq, tuple):
result = tuple(result)
elif isinstance(obj_seq, set):
# Try to create a set from the result
try:
result = set(result)
except Exception:
# If we can't create a set (unhashable items), keep it as a list
pass
return dill.dumps(result, protocol=protocol, **kwargs)
@staticmethod
def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs):
"""Handle pickling for custom objects with __dict__.
Args:
obj: The object to pickle
max_depth: Maximum recursion depth
error_msg: Error message from the original pickling attempt
path: Current path in the object graph
protocol: The pickle protocol version to use
**kwargs: Additional arguments for pickle/dill.dumps
Returns:
bytes: Pickled data with placeholders for unpicklable objects
"""
# Try to create a new instance of the same class
try:
# First try to create an empty instance
new_obj = object.__new__(type(obj))
# Handle __dict__ attributes if they exist
if hasattr(obj, "__dict__"):
for attr_name, attr_value in obj.__dict__.items():
attr_path = path + [attr_name]
# Try to pickle directly first
success, _ = PicklePatcher._pickle(attr_value, attr_path, protocol, **kwargs)
if success:
setattr(new_obj, attr_name, attr_value)
continue
# If direct pickling fails, try recursive pickling
try:
attr_bytes = PicklePatcher._recursive_pickle(
attr_value, max_depth - 1, attr_path, protocol=protocol, **kwargs
)
setattr(new_obj, attr_name, dill.loads(attr_bytes))
except Exception as inner_e:
# Use placeholder for unpicklable attribute
placeholder = PicklePatcher._create_placeholder(
attr_value,
str(inner_e),
attr_path
)
setattr(new_obj, attr_name, placeholder)
# Try to pickle the patched object
success, result = PicklePatcher._pickle(new_obj, path, protocol, **kwargs)
if success:
return result
# Fall through to placeholder creation
except Exception:
pass # Fall through to placeholder creation
# If we get here, just use a placeholder
placeholder = PicklePatcher._create_placeholder(obj, error_msg, path)
return dill.dumps(placeholder, protocol=protocol, **kwargs)

View file

@ -0,0 +1,71 @@
class PicklePlaceholderAccessError(Exception):
"""Custom exception raised when attempting to access an unpicklable object."""
class PicklePlaceholder:
"""A placeholder for an object that couldn't be pickled.
When unpickled, any attempt to access attributes or call methods on this
placeholder will raise a PicklePlaceholderAccessError.
"""
def __init__(self, obj_type, obj_str, error_msg, path=None):
"""Initialize a placeholder for an unpicklable object.
Args:
obj_type (str): The type name of the original object
obj_str (str): String representation of the original object
error_msg (str): The error message that occurred during pickling
path (list, optional): Path to this object in the original object graph
"""
# Store these directly in __dict__ to avoid __getattr__ recursion
self.__dict__["obj_type"] = obj_type
self.__dict__["obj_str"] = obj_str
self.__dict__["error_msg"] = error_msg
self.__dict__["path"] = path if path is not None else []
def __getattr__(self, name):
"""Raise a custom error when any attribute is accessed."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
raise PicklePlaceholderAccessError(
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
)
def __setattr__(self, name, value):
"""Prevent setting attributes."""
self.__getattr__(name) # This will raise our custom error
def __call__(self, *args, **kwargs):
"""Raise a custom error when the object is called."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
raise PicklePlaceholderAccessError(
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
)
def __repr__(self):
"""Return a string representation of the placeholder."""
try:
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root"
return f"<PicklePlaceholder at {path_str}: {self.__dict__['obj_type']} {self.__dict__['obj_str']}>"
except:
return "<PicklePlaceholder: (error displaying details)>"
def __str__(self):
"""Return a string representation of the placeholder."""
return self.__repr__()
def __reduce__(self):
"""Make sure pickling of the placeholder itself works correctly."""
return (
PicklePlaceholder,
(
self.__dict__["obj_type"],
self.__dict__["obj_str"],
self.__dict__["error_msg"],
self.__dict__["path"]
)
)

View file

@ -77,6 +77,7 @@ def check_create_pr(
speedup_pct=explanation.speedup_pct,
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
benchmark_details=explanation.benchmark_details
),
existing_tests=existing_tests_source,
generated_tests=generated_original_test_source,
@ -123,6 +124,7 @@ def check_create_pr(
speedup_pct=explanation.speedup_pct,
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
benchmark_details=explanation.benchmark_details
),
existing_tests=existing_tests_source,
generated_tests=generated_original_test_source,

View file

@ -1,9 +1,16 @@
from __future__ import annotations
import shutil
from io import StringIO
from pathlib import Path
from typing import Optional, cast
from pydantic.dataclasses import dataclass
from rich.console import Console
from rich.table import Table
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.models.models import TestResults
from codeflash.models.models import BenchmarkDetail, TestResults
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
@ -15,6 +22,7 @@ class Explanation:
best_runtime_ns: int
function_name: str
file_path: Path
benchmark_details: Optional[list[BenchmarkDetail]] = None
@property
def perf_improvement_line(self) -> str:
@ -37,16 +45,55 @@ class Explanation:
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
original_runtime_human = humanize_runtime(self.original_runtime_ns)
best_runtime_human = humanize_runtime(self.best_runtime_ns)
benchmark_info = ""
if self.benchmark_details:
# Get terminal width (or use a reasonable default if detection fails)
try:
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
except Exception:
terminal_width = 200 # Fallback width
# Create a rich table for better formatting
table = Table(title="Benchmark Performance Details", width=terminal_width, show_lines=True)
# Add columns - split Benchmark File and Function into separate columns
# Using proportional width for benchmark file column (40% of terminal width)
benchmark_col_width = max(int(terminal_width * 0.4), 40)
table.add_column("Benchmark Module Path", style="cyan", width=benchmark_col_width, overflow="fold")
table.add_column("Test Function", style="cyan", overflow="fold")
table.add_column("Original Runtime", style="magenta", justify="right")
table.add_column("Expected New Runtime", style="green", justify="right")
table.add_column("Speedup", style="red", justify="right")
# Add rows with split data
for detail in self.benchmark_details:
# Split the benchmark name and test function
benchmark_name = detail.benchmark_name
test_function = detail.test_function
table.add_row(
benchmark_name,
test_function,
f"{detail.original_timing}",
f"{detail.expected_new_timing}",
f"{detail.speedup_percent:.2f}%"
)
# Convert table to string
string_buffer = StringIO()
console = Console(file=string_buffer, width=terminal_width)
console.print(table)
benchmark_info = cast(StringIO, console.file).getvalue() + "\n" # Cast for mypy
return (
f"Optimized {self.function_name} in {self.file_path}\n"
f"{self.perf_improvement_line}\n"
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
+ "Explanation:\n"
+ self.raw_explanation_message
+ " \n\n"
+ "The new optimized code was tested for correctness. The results are listed below.\n"
+ f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n"
f"Optimized {self.function_name} in {self.file_path}\n"
f"{self.perf_improvement_line}\n"
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
+ (benchmark_info if benchmark_info else "")
+ self.raw_explanation_message
+ " \n\n"
+ "The new optimized code was tested for correctness. The results are listed below.\n"
+ f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n"
)
def explanation_message(self) -> str:

View file

@ -10,6 +10,7 @@ from typing import Any
import sentry_sdk
from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
try:
import numpy as np
@ -90,6 +91,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
return True
return math.isclose(orig, new)
if isinstance(orig, BaseException):
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
# The test results should be rejected as the behavior of the unpickleable object is unknown.
logger.debug("Unable to verify behavior of unpickleable object in replay test")
return False
# if str(orig) != str(new):
# return False
# compare the attributes of the two exception objects to determine if they are equivalent.

View file

@ -75,3 +75,4 @@ class TestConfig:
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
concolic_test_root_dir: Optional[Path] = None
pytest_cmd: str = "pytest"
benchmark_tests_root: Optional[Path] = None

View file

@ -116,6 +116,7 @@ types-openpyxl = ">=3.1.5.20241020"
types-regex = ">=2024.9.11.20240912"
types-python-dateutil = ">=2.9.0.20241003"
pytest-cov = "^6.0.0"
pytest-benchmark = ">=5.1.0"
types-gevent = "^24.11.0.20241230"
types-greenlet = "^3.1.0.20241221"
types-pexpect = "^4.9.0.20241208"
@ -219,6 +220,7 @@ initial-content = """
[tool.codeflash]
module-root = "codeflash"
tests-root = "tests"
benchmarks-root = "tests/benchmarks"
test-framework = "pytest"
formatter-cmds = [
"uvx ruff check --exit-zero --fix $file",

View file

@ -0,0 +1,31 @@
from argparse import Namespace
from pathlib import Path
from codeflash.context.code_context_extractor import get_code_optimization_context
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent
from codeflash.optimization.optimizer import Optimizer
def test_benchmark_extract(benchmark)->None:
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
opt = Optimizer(
Namespace(
project_root=file_path.resolve(),
disable_telemetry=True,
tests_root=(file_path / "tests").resolve(),
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path.cwd(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="replace_function_and_helpers_with_optimized_code",
file_path=file_path / "optimization" / "function_optimizer.py",
parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")],
starting_line=None,
ending_line=None,
)
benchmark(get_code_optimization_context,function_to_optimize, opt.args.project_root)

View file

@ -0,0 +1,26 @@
from pathlib import Path
from codeflash.discovery.discover_unit_tests import discover_unit_tests
from codeflash.verification.verification_utils import TestConfig
def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None:
project_path = Path(__file__).parent.parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest"
test_config = TestConfig(
tests_root=tests_path,
project_root_path=project_path,
test_framework="pytest",
tests_project_rootdir=tests_path.parent,
)
benchmark(discover_unit_tests, test_config)
def test_benchmark_codeflash_test_discovery(benchmark) -> None:
project_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
tests_path = project_path / "tests"
test_config = TestConfig(
tests_root=tests_path,
project_root_path=project_path,
test_framework="pytest",
tests_project_rootdir=tests_path.parent,
)
benchmark(discover_unit_tests, test_config)

View file

@ -0,0 +1,71 @@
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
from codeflash.verification.parse_test_output import merge_test_results
def generate_test_invocations(count=100):
"""Generate a set number of test invocations for benchmarking."""
test_results_xml = TestResults()
test_results_bin = TestResults()
# Generate test invocations in a loop
for i in range(count):
iteration_id = str(i * 3 + 5) # Generate unique iteration IDs
# XML results - some with None runtime
test_results_xml.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="code_to_optimize.tests.unittest.test_bubble_sort",
test_class_name="TestPigLatin",
test_function_name="test_sort",
function_getting_tested="sorter",
iteration_id=iteration_id,
),
file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py",
did_pass=True,
runtime=None if i % 3 == 0 else i * 100, # Vary runtime values
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=None,
timed_out=False,
loop_index=i,
)
)
# Binary results - with actual runtime values
test_results_bin.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="code_to_optimize.tests.unittest.test_bubble_sort",
test_class_name="TestPigLatin",
test_function_name="test_sort",
function_getting_tested="sorter",
iteration_id=iteration_id,
),
file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py",
did_pass=True,
runtime=500 + i * 20, # Generate varying runtime values
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=None,
timed_out=False,
loop_index=i,
)
)
return test_results_xml, test_results_bin
def run_merge_benchmark(count=100):
test_results_xml, test_results_bin = generate_test_invocations(count)
# Perform the merge operation that will be benchmarked
merge_test_results(
xml_test_results=test_results_xml,
bin_test_results=test_results_bin,
test_framework="unittest"
)
def test_benchmark_merge_test_results(benchmark):
benchmark(run_merge_benchmark, 1000) # Default to 100 test invocations

View file

@ -0,0 +1,26 @@
import os
import pathlib
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
def run_test(expected_improvement_pct: int) -> bool:
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
config = TestConfig(
file_path=pathlib.Path("bubble_sort.py"),
function_name="sorter",
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
test_framework="pytest",
min_improvement_x=1.0,
coverage_expectations=[
CoverageExpectation(
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
)
],
)
return run_codeflash_command(cwd, config, expected_improvement_pct)
if __name__ == "__main__":
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5))))

View file

@ -26,6 +26,7 @@ class TestConfig:
min_improvement_x: float = 0.1
trace_mode: bool = False
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
benchmarks_root: Optional[pathlib.Path] = None
def clear_directory(directory_path: str | pathlib.Path) -> None:
@ -85,8 +86,8 @@ def run_codeflash_command(
path_to_file = cwd / config.file_path
file_contents = path_to_file.read_text("utf-8")
test_root = cwd / "tests" / (config.test_framework or "")
command = build_command(cwd, config, test_root)
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
process = subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
)
@ -116,7 +117,7 @@ def run_codeflash_command(
return validated
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]:
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root:pathlib.Path|None = None) -> list[str]:
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
base_command = ["python", python_path, "--file", config.file_path, "--no-pr"]
@ -127,7 +128,8 @@ def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path
base_command.extend(
["--test-framework", config.test_framework, "--tests-root", str(test_root), "--module-root", str(cwd)]
)
if benchmarks_root:
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
return base_command

View file

@ -0,0 +1,15 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
from pathlib import Path
from codeflash.code_utils.code_utils import get_run_tmp_file
@codeflash_trace
def example_function(arr):
arr.sort()
return arr
def test_codeflash_trace_decorator():
arr = [3, 1, 2]
result = example_function(arr)
# cleanup test trace file using Path
assert result == [1, 2, 3]

View file

@ -0,0 +1,547 @@
from __future__ import annotations
import tempfile
from pathlib import Path
from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \
instrument_codeflash_trace_decorator
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
def test_add_decorator_to_normal_function() -> None:
"""Test adding decorator to a normal function."""
code = """
def normal_function():
return "Hello, World!"
"""
fto = FunctionToOptimize(
function_name="normal_function",
file_path=Path("dummy_path.py"),
parents=[]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def normal_function():
return "Hello, World!"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_normal_method() -> None:
"""Test adding decorator to a normal method."""
code = """
class TestClass:
def normal_method(self):
return "Hello from method"
"""
fto = FunctionToOptimize(
function_name="normal_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@codeflash_trace
def normal_method(self):
return "Hello from method"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_classmethod() -> None:
"""Test adding decorator to a classmethod."""
code = """
class TestClass:
@classmethod
def class_method(cls):
return "Hello from classmethod"
"""
fto = FunctionToOptimize(
function_name="class_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@classmethod
@codeflash_trace
def class_method(cls):
return "Hello from classmethod"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_staticmethod() -> None:
"""Test adding decorator to a staticmethod."""
code = """
class TestClass:
@staticmethod
def static_method():
return "Hello from staticmethod"
"""
fto = FunctionToOptimize(
function_name="static_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@staticmethod
@codeflash_trace
def static_method():
return "Hello from staticmethod"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_init_function() -> None:
"""Test adding decorator to an __init__ function."""
code = """
class TestClass:
def __init__(self, value):
self.value = value
"""
fto = FunctionToOptimize(
function_name="__init__",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@codeflash_trace
def __init__(self, value):
self.value = value
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_with_multiple_decorators() -> None:
"""Test adding decorator to a function with multiple existing decorators."""
code = """
class TestClass:
@property
@other_decorator
def property_method(self):
return self._value
"""
fto = FunctionToOptimize(
function_name="property_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@property
@other_decorator
@codeflash_trace
def property_method(self):
return self._value
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_function_in_multiple_classes() -> None:
"""Test that only the right class's method gets the decorator."""
code = """
class TestClass:
def test_method(self):
return "This should get decorated"
class OtherClass:
def test_method(self):
return "This should NOT get decorated"
"""
fto = FunctionToOptimize(
function_name="test_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class TestClass:
@codeflash_trace
def test_method(self):
return "This should get decorated"
class OtherClass:
def test_method(self):
return "This should NOT get decorated"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_nonexistent_function() -> None:
"""Test that code remains unchanged when function doesn't exist."""
code = """
def existing_function():
return "This exists"
"""
fto = FunctionToOptimize(
function_name="nonexistent_function",
file_path=Path("dummy_path.py"),
parents=[]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
# Code should remain unchanged
assert modified_code.strip() == code.strip()
def test_add_decorator_to_multiple_functions() -> None:
"""Test adding decorator to multiple functions."""
code = """
def function_one():
return "First function"
class TestClass:
def method_one(self):
return "First method"
def method_two(self):
return "Second method"
def function_two():
return "Second function"
"""
functions_to_optimize = [
FunctionToOptimize(
function_name="function_one",
file_path=Path("dummy_path.py"),
parents=[]
),
FunctionToOptimize(
function_name="method_two",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="TestClass", type="ClassDef")]
),
FunctionToOptimize(
function_name="function_two",
file_path=Path("dummy_path.py"),
parents=[]
)
]
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=functions_to_optimize
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def function_one():
return "First function"
class TestClass:
def method_one(self):
return "First method"
@codeflash_trace
def method_two(self):
return "Second method"
@codeflash_trace
def function_two():
return "Second function"
"""
assert modified_code.strip() == expected_code.strip()
def test_instrument_codeflash_trace_decorator_single_file() -> None:
"""Test instrumenting codeflash trace decorator on a single file."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create a test Python file
test_file_path = Path(temp_dir) / "test_module.py"
test_file_content = """
def function_one():
return "First function"
class TestClass:
def method_one(self):
return "First method"
def method_two(self):
return "Second method"
def function_two():
return "Second function"
"""
test_file_path.write_text(test_file_content, encoding="utf-8")
# Define functions to optimize
functions_to_optimize = [
FunctionToOptimize(
function_name="function_one",
file_path=test_file_path,
parents=[]
),
FunctionToOptimize(
function_name="method_two",
file_path=test_file_path,
parents=[FunctionParent(name="TestClass", type="ClassDef")]
)
]
# Execute the function being tested
instrument_codeflash_trace_decorator({test_file_path: functions_to_optimize})
# Read the modified file
modified_content = test_file_path.read_text(encoding="utf-8")
# Define expected content (with isort applied)
expected_content = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def function_one():
return "First function"
class TestClass:
def method_one(self):
return "First method"
@codeflash_trace
def method_two(self):
return "Second method"
def function_two():
return "Second function"
"""
# Compare the modified content with expected content
assert modified_content.strip() == expected_content.strip()
def test_instrument_codeflash_trace_decorator_multiple_files() -> None:
"""Test instrumenting codeflash trace decorator on multiple files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Create first test Python file
test_file_1_path = Path(temp_dir) / "module_a.py"
test_file_1_content = """
def function_a():
return "Function in module A"
class ClassA:
def method_a(self):
return "Method in ClassA"
"""
test_file_1_path.write_text(test_file_1_content, encoding="utf-8")
# Create second test Python file
test_file_2_path = Path(temp_dir) / "module_b.py"
test_file_2_content ="""
def function_b():
return "Function in module B"
class ClassB:
@staticmethod
def static_method_b():
return "Static method in ClassB"
"""
test_file_2_path.write_text(test_file_2_content, encoding="utf-8")
# Define functions to optimize
file_to_funcs_to_optimize = {
test_file_1_path: [
FunctionToOptimize(
function_name="function_a",
file_path=test_file_1_path,
parents=[]
)
],
test_file_2_path: [
FunctionToOptimize(
function_name="static_method_b",
file_path=test_file_2_path,
parents=[FunctionParent(name="ClassB", type="ClassDef")]
)
]
}
# Execute the function being tested
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
# Read the modified files
modified_content_1 = test_file_1_path.read_text(encoding="utf-8")
modified_content_2 = test_file_2_path.read_text(encoding="utf-8")
# Define expected content for first file (with isort applied)
expected_content_1 = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def function_a():
return "Function in module A"
class ClassA:
def method_a(self):
return "Method in ClassA"
"""
# Define expected content for second file (with isort applied)
expected_content_2 = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
def function_b():
return "Function in module B"
class ClassB:
@staticmethod
@codeflash_trace
def static_method_b():
return "Static method in ClassB"
"""
# Compare the modified content with expected content
assert modified_content_1.strip() == expected_content_1.strip()
assert modified_content_2.strip() == expected_content_2.strip()
def test_add_decorator_to_method_after_nested_class() -> None:
"""Test adding decorator to a method that appears after a nested class definition."""
code = """
class OuterClass:
class NestedClass:
def nested_method(self):
return "Hello from nested class method"
def target_method(self):
return "Hello from target method after nested class"
"""
fto = FunctionToOptimize(
function_name="target_method",
file_path=Path("dummy_path.py"),
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
class OuterClass:
class NestedClass:
def nested_method(self):
return "Hello from nested class method"
@codeflash_trace
def target_method(self):
return "Hello from target method after nested class"
"""
assert modified_code.strip() == expected_code.strip()
def test_add_decorator_to_function_after_nested_function() -> None:
"""Test adding decorator to a function that appears after a function with a nested function."""
code = """
def function_with_nested():
def inner_function():
return "Hello from inner function"
return inner_function()
def target_function():
return "Hello from target function after nested function"
"""
fto = FunctionToOptimize(
function_name="target_function",
file_path=Path("dummy_path.py"),
parents=[]
)
modified_code = add_codeflash_decorator_to_code(
code=code,
functions_to_optimize=[fto]
)
expected_code = """
from codeflash.benchmarking.codeflash_trace import codeflash_trace
def function_with_nested():
def inner_function():
return "Hello from inner function"
return inner_function()
@codeflash_trace
def target_function():
return "Hello from target function after nested function"
"""
assert modified_code.strip() == expected_code.strip()

View file

@ -0,0 +1,513 @@
import os
import pickle
import shutil
import socket
import sqlite3
from argparse import Namespace
from pathlib import Path
import dill
import pytest
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
try:
import sqlalchemy
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
from codeflash.picklepatch.pickle_patcher import PicklePatcher
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError
def test_picklepatch_simple_nested():
"""Test that a simple nested data structure pickles and unpickles correctly.
"""
original_data = {
"numbers": [1, 2, 3],
"nested_dict": {"key": "value", "another": 42},
}
dumped = PicklePatcher.dumps(original_data)
reloaded = PicklePatcher.loads(dumped)
assert reloaded == original_data
# Everything was pickleable, so no placeholders should appear.
def test_picklepatch_with_socket():
"""Test that a data structure containing a raw socket is replaced by
PicklePlaceholder rather than raising an error.
"""
# Create a pair of connected sockets instead of a single socket
sock1, sock2 = socket.socketpair()
data_with_socket = {
"safe_value": 123,
"raw_socket": sock1,
}
# Send a message through sock1, which can be received by sock2
sock1.send(b"Hello, world!")
received = sock2.recv(1024)
assert received == b"Hello, world!"
# Pickle the data structure containing the socket
dumped = PicklePatcher.dumps(data_with_socket)
reloaded = PicklePatcher.loads(dumped)
# We expect "raw_socket" to be replaced by a placeholder
assert isinstance(reloaded, dict)
assert reloaded["safe_value"] == 123
assert isinstance(reloaded["raw_socket"], PicklePlaceholder)
# Attempting to use or access attributes => AttributeError
# (not RuntimeError as in original tests, our implementation uses AttributeError)
with pytest.raises(PicklePlaceholderAccessError):
reloaded["raw_socket"].recv(1024)
# Clean up by closing both sockets
sock1.close()
sock2.close()
def test_picklepatch_deeply_nested():
"""Test that deep nesting with unpicklable objects works correctly.
"""
# Create a deeply nested structure with an unpicklable object
deep_nested = {
"level1": {
"level2": {
"level3": {
"normal": "value",
"socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)
}
}
}
}
dumped = PicklePatcher.dumps(deep_nested)
reloaded = PicklePatcher.loads(dumped)
# We should be able to access the normal value
assert reloaded["level1"]["level2"]["level3"]["normal"] == "value"
# The socket should be replaced with a placeholder
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
def test_picklepatch_class_with_unpicklable_attr():
"""Test that a class with an unpicklable attribute works correctly.
"""
class TestClass:
def __init__(self):
self.normal = "normal value"
self.unpicklable = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
obj = TestClass()
dumped = PicklePatcher.dumps(obj)
reloaded = PicklePatcher.loads(dumped)
# Normal attribute should be preserved
assert reloaded.normal == "normal value"
# Unpicklable attribute should be replaced with a placeholder
assert isinstance(reloaded.unpicklable, PicklePlaceholder)
def test_picklepatch_with_database_connection():
"""Test that a data structure containing a database connection is replaced
by PicklePlaceholder rather than raising an error.
"""
# SQLite connection - not pickleable
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
data_with_db = {
"description": "Database connection",
"connection": conn,
"cursor": cursor,
}
dumped = PicklePatcher.dumps(data_with_db)
reloaded = PicklePatcher.loads(dumped)
# Both connection and cursor should become placeholders
assert isinstance(reloaded, dict)
assert reloaded["description"] == "Database connection"
assert isinstance(reloaded["connection"], PicklePlaceholder)
assert isinstance(reloaded["cursor"], PicklePlaceholder)
# Attempting to use attributes => AttributeError
with pytest.raises(PicklePlaceholderAccessError):
reloaded["connection"].execute("SELECT 1")
def test_picklepatch_with_generator():
"""Test that a data structure containing a generator is replaced by
PicklePlaceholder rather than raising an error.
"""
def simple_generator():
yield 1
yield 2
yield 3
# Create a generator
gen = simple_generator()
# Put it in a data structure
data_with_generator = {
"description": "Contains a generator",
"generator": gen,
"normal_list": [1, 2, 3]
}
dumped = PicklePatcher.dumps(data_with_generator)
reloaded = PicklePatcher.loads(dumped)
# Generator should be replaced with a placeholder
assert isinstance(reloaded, dict)
assert reloaded["description"] == "Contains a generator"
assert reloaded["normal_list"] == [1, 2, 3]
assert isinstance(reloaded["generator"], PicklePlaceholder)
# Attempting to use the generator => AttributeError
with pytest.raises(TypeError):
next(reloaded["generator"])
# Attempting to call methods on the generator => AttributeError
with pytest.raises(PicklePlaceholderAccessError):
reloaded["generator"].send(None)
def test_picklepatch_loads_standard_pickle():
"""Test that PicklePatcher.loads can correctly load data that was pickled
using the standard pickle module.
"""
# Create a simple data structure
original_data = {
"numbers": [1, 2, 3],
"nested_dict": {"key": "value", "another": 42},
"tuple": (1, "two", 3.0),
}
# Pickle it with standard pickle
pickled_data = pickle.dumps(original_data)
# Load with PicklePatcher
reloaded = PicklePatcher.loads(pickled_data)
# Verify the data is correctly loaded
assert reloaded == original_data
assert isinstance(reloaded, dict)
assert reloaded["numbers"] == [1, 2, 3]
assert reloaded["nested_dict"]["key"] == "value"
assert reloaded["tuple"] == (1, "two", 3.0)
def test_picklepatch_loads_dill_pickle():
"""Test that PicklePatcher.loads can correctly load data that was pickled
using the dill module, which can pickle more complex objects than the
standard pickle module.
"""
# Create a more complex data structure that includes a lambda function
# which dill can handle but standard pickle cannot
original_data = {
"numbers": [1, 2, 3],
"function": lambda x: x * 2,
"nested": {
"another_function": lambda y: y ** 2
}
}
# Pickle it with dill
dilled_data = dill.dumps(original_data)
# Load with PicklePatcher
reloaded = PicklePatcher.loads(dilled_data)
# Verify the data structure
assert isinstance(reloaded, dict)
assert reloaded["numbers"] == [1, 2, 3]
# Test that the functions actually work
assert reloaded["function"](5) == 10
assert reloaded["nested"]["another_function"](4) == 16
def test_run_and_parse_picklepatch() -> None:
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results.
Here, we are simply 'ignoring' the unused unpickleable object.
The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object.
Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results,
since we were not able to reuse the unpickleable object in the replay test.
"""
# Init paths
project_root = Path(__file__).parent.parent.resolve()
tests_root = project_root / "code_to_optimize" / "tests" / "pytest"
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
# Trace benchmarks
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# Check contents
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
test_name, total_time, function_time, percent = \
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
# Expected function calls
expected_calls = [
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
f"{bubble_sort_unused_socket_path}",
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
f"{bubble_sort_used_socket_path}",
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
conn.close()
# Generate replay test
generate_replay_test(output_file, replay_tests_dir)
replay_test_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
replay_test_perf_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
assert replay_test_path.exists()
original_replay_test_code = replay_test_path.read_text("utf-8")
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
original_cwd = Path.cwd()
run_cwd = project_root
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
replay_test_path,
[CodePosition(17, 15)],
func,
project_root,
"pytest",
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
replay_test_path.write_text(new_test)
opt = Optimizer(
Namespace(
project_root=project_root,
disable_telemetry=True,
tests_root=tests_root,
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root,
)
)
# Run the replay test for the original code that does not use the socket
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.REPLAY_TEST
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=replay_test_path,
test_type=test_type,
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
)
]
)
test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_unused_socket) == 1
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
assert test_results_unused_socket.test_results[0].did_pass == True
# Replace with optimized candidate
fto_unused_socket_path.write_text("""
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_unused_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
return sorted(numbers)
""")
# Run optimized code for unused socket
optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(optimized_test_results_unused_socket) == 1
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
assert verification_result is True
# Remove the previous instrumentation
replay_test_path.write_text(original_replay_test_code)
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
success, new_test = inject_profiling_into_existing_test(
replay_test_path,
[CodePosition(23,15)],
func,
project_root,
"pytest",
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
replay_test_path.write_text(new_test)
# Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.REPLAY_TEST
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
file_path=Path(fto_used_socket_path))
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=replay_test_path,
test_type=test_type,
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
test_type=test_type)],
)
]
)
test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
assert test_results_used_socket.test_results[0].did_pass is False
print("test results used socket")
print(test_results_used_socket)
# Replace with optimized candidate
fto_used_socket_path.write_text("""
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
socket = data_container.get('socket')
socket.send("Hello from the optimized function!")
return sorted(numbers)
""")
# Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
assert test_results_used_socket.test_results[0].did_pass is False
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
finally:
# cleanup
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir, ignore_errors=True)
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
fto_used_socket_path.write_text(original_fto_used_socket_code)

View file

@ -0,0 +1,288 @@
import multiprocessing
import shutil
import sqlite3
from pathlib import Path
import pytest
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
def test_trace_benchmarks() -> None:
# Test the trace_benchmarks function
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# check contents of trace file
# connect to database
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
expected_calls = [
("sorter", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_class_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 17),
("sort_class", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_class_sort2", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 20),
("sort_static", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_class_sort3", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 23),
("__init__", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_class_sort4", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 26),
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 7),
("compute_and_sort", "", "code_to_optimize.process_and_bubble_sort_codeflash_trace",
f"{process_and_bubble_sort_path}",
"test_compute_and_sort", "tests.pytest.benchmarks_test.test_process_and_sort_example", 4),
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
generate_replay_test(output_file, replay_tests_dir)
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
assert test_class_sort_path.exists()
test_class_sort_code = f"""
from code_to_optimize.bubble_sort_codeflash_trace import \\
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
from code_to_optimize.bubble_sort_codeflash_trace import \\
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
from codeflash.benchmarking.replay_test import get_next_arg_and_return
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
functions = ['sort_class', 'sort_static', 'sorter']
trace_file_path = r"{output_file.as_posix()}"
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_sort", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort", function_name="sorter", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
function_name = "sorter"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
else:
instance = args[0] # self
ret = instance.sorter(*args[1:], **kwargs)
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort2", function_name="sort_class", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
if not args:
raise ValueError("No arguments provided for the method.")
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_class(*args[1:], **kwargs)
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort3", function_name="sort_static", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_static(*args, **kwargs)
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort4", function_name="__init__", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
function_name = "__init__"
if not args:
raise ValueError("No arguments provided for the method.")
if function_name == "__init__":
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
else:
instance = args[0] # self
ret = instance(*args[1:], **kwargs)
"""
assert test_class_sort_path.read_text("utf-8").strip()==test_class_sort_code.strip()
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
assert test_sort_path.exists()
test_sort_code = f"""
from code_to_optimize.bubble_sort_codeflash_trace import \\
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
compute_and_sort as \\
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
from codeflash.benchmarking.replay_test import get_next_arg_and_return
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
functions = ['compute_and_sort', 'sorter']
trace_file_path = r"{output_file}"
def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(*args, **kwargs)
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter():
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_no_func", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
args = pickle.loads(args_pkl)
kwargs = pickle.loads(kwargs_pkl)
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
"""
assert test_sort_path.read_text("utf-8").strip()==test_sort_code.strip()
finally:
# cleanup
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir)
# Skip the test in CI as the machine may not be multithreaded
@pytest.mark.ci_skip
def test_trace_multithreaded_benchmark() -> None:
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_multithread"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# check contents of trace file
# connect to database
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
expected_calls = [
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_benchmark_sort", "tests.pytest.benchmarks_multithread.test_multithread_sort", 4),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
# Close connection
conn.close()
finally:
# cleanup
output_file.unlink(missing_ok=True)
def test_trace_benchmark_decorator() -> None:
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
tests_root = project_root / "tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
# check contents of trace file
# connect to database
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
# Get the count of records
# Get all records
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
# Expected function calls
expected_calls = [
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
f"{bubble_sort_path}",
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
# Close connection
conn.close()
finally:
# cleanup
output_file.unlink(missing_ok=True)

View file

@ -17,7 +17,19 @@ def test_unit_test_discovery_pytest():
)
tests = discover_unit_tests(test_config)
assert len(tests) > 0
# print(tests)
def test_benchmark_test_discovery_pytest():
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
tests_path = project_path / "tests" / "pytest" / "benchmarks"
test_config = TestConfig(
tests_root=tests_path,
project_root_path=project_path,
test_framework="pytest",
tests_project_rootdir=tests_path.parent,
)
tests = discover_unit_tests(test_config)
assert len(tests) == 1 # Should not discover benchmark tests
def test_unit_test_discovery_unittest():