mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Merge branch 'main' into context-import-bug
This commit is contained in:
commit
04c19bfd6a
53 changed files with 3754 additions and 70 deletions
2
.github/workflows/codeflash-optimize.yaml
vendored
2
.github/workflows/codeflash-optimize.yaml
vendored
|
|
@ -68,4 +68,4 @@ jobs:
|
|||
id: optimize_code
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
poetry run codeflash
|
||||
poetry run codeflash --benchmark
|
||||
|
|
|
|||
41
.github/workflows/end-to-end-test-benchmark-bubblesort.yaml
vendored
Normal file
41
.github/workflows/end-to-end-test-benchmark-bubblesort.yaml
vendored
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
name: end-to-end-test
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
benchmark-bubble-sort-optimization:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
CODEFLASH_AIS_SERVER: prod
|
||||
POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }}
|
||||
CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }}
|
||||
COLUMNS: 110
|
||||
MAX_RETRIES: 3
|
||||
RETRY_DELAY: 5
|
||||
EXPECTED_IMPROVEMENT_PCT: 5
|
||||
CODEFLASH_END_TO_END: 1
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Python 3.11 for CLI
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: 3.11.6
|
||||
|
||||
- name: Install dependencies (CLI)
|
||||
run: |
|
||||
uv tool install poetry
|
||||
uv venv
|
||||
source .venv/bin/activate
|
||||
poetry install --with dev
|
||||
|
||||
- name: Run Codeflash to optimize code
|
||||
id: optimize_code_with_benchmarks
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
poetry run python tests/scripts/end_to_end_test_benchmark_sort.py
|
||||
2
.github/workflows/unit-tests.yaml
vendored
2
.github/workflows/unit-tests.yaml
vendored
|
|
@ -32,7 +32,7 @@ jobs:
|
|||
run: uvx poetry install --with dev
|
||||
|
||||
- name: Unit tests
|
||||
run: uvx poetry run pytest tests/ --cov --cov-report=xml
|
||||
run: uvx poetry run pytest tests/ --cov --cov-report=xml --benchmark-skip -m "not ci_skip"
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ def sorter(arr):
|
|||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
print(f"result: {arr}")
|
||||
return arr
|
||||
return arr
|
||||
|
|
|
|||
64
code_to_optimize/bubble_sort_codeflash_trace.py
Normal file
64
code_to_optimize/bubble_sort_codeflash_trace.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
@codeflash_trace
|
||||
def sorter(arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
@codeflash_trace
|
||||
def recursive_bubble_sort(arr, n=None):
|
||||
# Initialize n if not provided
|
||||
if n is None:
|
||||
n = len(arr)
|
||||
|
||||
# Base case: if n is 1, the array is already sorted
|
||||
if n == 1:
|
||||
return arr
|
||||
|
||||
# One pass of bubble sort - move the largest element to the end
|
||||
for i in range(n - 1):
|
||||
if arr[i] > arr[i + 1]:
|
||||
arr[i], arr[i + 1] = arr[i + 1], arr[i]
|
||||
|
||||
# Recursively sort the remaining n-1 elements
|
||||
return recursive_bubble_sort(arr, n - 1)
|
||||
|
||||
class Sorter:
|
||||
@codeflash_trace
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
@codeflash_trace
|
||||
def sorter(self, multiplier):
|
||||
for i in range(len(self.arr)):
|
||||
for j in range(len(self.arr) - 1):
|
||||
if self.arr[j] > self.arr[j + 1]:
|
||||
temp = self.arr[j]
|
||||
self.arr[j] = self.arr[j + 1]
|
||||
self.arr[j + 1] = temp
|
||||
return self.arr * multiplier
|
||||
|
||||
@staticmethod
|
||||
@codeflash_trace
|
||||
def sort_static(arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
|
||||
@classmethod
|
||||
@codeflash_trace
|
||||
def sort_class(cls, arr):
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr) - 1):
|
||||
if arr[j] > arr[j + 1]:
|
||||
temp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = temp
|
||||
return arr
|
||||
23
code_to_optimize/bubble_sort_multithread.py
Normal file
23
code_to_optimize/bubble_sort_multithread.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# from code_to_optimize.bubble_sort_codeflash_trace import sorter
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import sorter
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def multithreaded_sorter(unsorted_lists: list[list[int]]) -> list[list[int]]:
|
||||
# Create a list to store results in the correct order
|
||||
sorted_lists = [None] * len(unsorted_lists)
|
||||
|
||||
# Use ThreadPoolExecutor to manage threads
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
# Submit all sorting tasks and map them to their original indices
|
||||
future_to_index = {
|
||||
executor.submit(sorter, unsorted_list): i
|
||||
for i, unsorted_list in enumerate(unsorted_lists)
|
||||
}
|
||||
|
||||
# Collect results as they complete
|
||||
for future in concurrent.futures.as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
sorted_lists[index] = future.result()
|
||||
|
||||
return sorted_lists
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_unused_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
|
||||
return sorted(numbers)
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
socket = data_container.get('socket')
|
||||
socket.send("Hello from the optimized function!")
|
||||
return sorted(numbers)
|
||||
46
code_to_optimize/bubble_sort_picklepatch_test_used_socket.py
Normal file
46
code_to_optimize/bubble_sort_picklepatch_test_used_socket.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
"""
|
||||
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
||||
- 'numbers' (list): The list to be sorted.
|
||||
- 'socket' (socket): A socket
|
||||
|
||||
Args:
|
||||
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
|
||||
|
||||
Returns:
|
||||
list: The sorted list of numbers
|
||||
"""
|
||||
# Extract the list to sort and socket
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
socket = data_container.get('socket')
|
||||
|
||||
# Track swap count
|
||||
swap_count = 0
|
||||
|
||||
# Classic bubble sort implementation
|
||||
n = len(numbers)
|
||||
for i in range(n):
|
||||
# Flag to optimize by detecting if no swaps occurred
|
||||
swapped = False
|
||||
|
||||
# Last i elements are already in place
|
||||
for j in range(0, n - i - 1):
|
||||
# Swap if the element is greater than the next element
|
||||
if numbers[j] > numbers[j + 1]:
|
||||
# Perform the swap
|
||||
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
|
||||
swapped = True
|
||||
swap_count += 1
|
||||
|
||||
# If no swapping occurred in this pass, the list is sorted
|
||||
if not swapped:
|
||||
break
|
||||
|
||||
# Send final summary
|
||||
summary = f"Bubble sort completed with {swap_count} swaps"
|
||||
socket.send(summary.encode())
|
||||
|
||||
return numbers
|
||||
28
code_to_optimize/process_and_bubble_sort.py
Normal file
28
code_to_optimize/process_and_bubble_sort.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def calculate_pairwise_products(arr):
|
||||
"""
|
||||
Calculate the average of all pairwise products in the array.
|
||||
"""
|
||||
sum_of_products = 0
|
||||
count = 0
|
||||
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr)):
|
||||
if i != j:
|
||||
sum_of_products += arr[i] * arr[j]
|
||||
count += 1
|
||||
|
||||
# The average of all pairwise products
|
||||
return sum_of_products / count if count > 0 else 0
|
||||
|
||||
|
||||
def compute_and_sort(arr):
|
||||
# Compute pairwise sums average
|
||||
pairwise_average = calculate_pairwise_products(arr)
|
||||
|
||||
# Call sorter function
|
||||
sorter(arr.copy())
|
||||
|
||||
return pairwise_average
|
||||
28
code_to_optimize/process_and_bubble_sort_codeflash_trace.py
Normal file
28
code_to_optimize/process_and_bubble_sort_codeflash_trace.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from code_to_optimize.bubble_sort import sorter
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
def calculate_pairwise_products(arr):
|
||||
"""
|
||||
Calculate the average of all pairwise products in the array.
|
||||
"""
|
||||
sum_of_products = 0
|
||||
count = 0
|
||||
|
||||
for i in range(len(arr)):
|
||||
for j in range(len(arr)):
|
||||
if i != j:
|
||||
sum_of_products += arr[i] * arr[j]
|
||||
count += 1
|
||||
|
||||
# The average of all pairwise products
|
||||
return sum_of_products / count if count > 0 else 0
|
||||
|
||||
@codeflash_trace
|
||||
def compute_and_sort(arr):
|
||||
# Compute pairwise sums average
|
||||
pairwise_average = calculate_pairwise_products(arr)
|
||||
|
||||
# Call sorter function
|
||||
sorter(arr.copy())
|
||||
|
||||
return pairwise_average
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import pytest
|
||||
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
|
||||
|
||||
def test_sort(benchmark):
|
||||
result = benchmark(sorter, list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
||||
# This should not be picked up as a benchmark test
|
||||
def test_sort2():
|
||||
result = sorter(list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from code_to_optimize.process_and_bubble_sort import compute_and_sort
|
||||
from code_to_optimize.bubble_sort import sorter
|
||||
def test_compute_and_sort(benchmark):
|
||||
result = benchmark(compute_and_sort, list(reversed(range(500))))
|
||||
assert result == 62208.5
|
||||
|
||||
def test_no_func(benchmark):
|
||||
benchmark(sorter, list(reversed(range(500))))
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from code_to_optimize.bubble_sort_multithread import multithreaded_sorter
|
||||
|
||||
def test_benchmark_sort(benchmark):
|
||||
benchmark(multithreaded_sorter, [list(range(1000)) for i in range (10)])
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import socket
|
||||
|
||||
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
|
||||
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
|
||||
|
||||
def test_socket_picklepatch(benchmark):
|
||||
s1, s2 = socket.socketpair()
|
||||
data = {
|
||||
"numbers": list(reversed(range(500))),
|
||||
"socket": s1
|
||||
}
|
||||
benchmark(bubble_sort_with_unused_socket, data)
|
||||
|
||||
def test_used_socket_picklepatch(benchmark):
|
||||
s1, s2 = socket.socketpair()
|
||||
data = {
|
||||
"numbers": list(reversed(range(500))),
|
||||
"socket": s1
|
||||
}
|
||||
benchmark(bubble_sort_with_used_socket, data)
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import pytest
|
||||
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import sorter, Sorter
|
||||
|
||||
|
||||
def test_sort(benchmark):
|
||||
result = benchmark(sorter, list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
||||
# This should not be picked up as a benchmark test
|
||||
def test_sort2():
|
||||
result = sorter(list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
||||
def test_class_sort(benchmark):
|
||||
obj = Sorter(list(reversed(range(100))))
|
||||
result1 = benchmark(obj.sorter, 2)
|
||||
|
||||
def test_class_sort2(benchmark):
|
||||
result2 = benchmark(Sorter.sort_class, list(reversed(range(100))))
|
||||
|
||||
def test_class_sort3(benchmark):
|
||||
result3 = benchmark(Sorter.sort_static, list(reversed(range(100))))
|
||||
|
||||
def test_class_sort4(benchmark):
|
||||
result4 = benchmark(Sorter, [1,2,3])
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from code_to_optimize.process_and_bubble_sort_codeflash_trace import compute_and_sort
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import sorter
|
||||
def test_compute_and_sort(benchmark):
|
||||
result = benchmark(compute_and_sort, list(reversed(range(500))))
|
||||
assert result == 62208.5
|
||||
|
||||
def test_no_func(benchmark):
|
||||
benchmark(sorter, list(reversed(range(500))))
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
from code_to_optimize.bubble_sort_codeflash_trace import recursive_bubble_sort
|
||||
|
||||
|
||||
def test_recursive_sort(benchmark):
|
||||
result = benchmark(recursive_bubble_sort, list(reversed(range(500))))
|
||||
assert result == list(range(500))
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
import pytest
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import sorter
|
||||
|
||||
def test_benchmark_sort(benchmark):
|
||||
@benchmark
|
||||
def do_sort():
|
||||
sorter(list(reversed(range(500))))
|
||||
|
||||
@pytest.mark.benchmark(group="benchmark_decorator")
|
||||
def test_pytest_mark(benchmark):
|
||||
benchmark(sorter, list(reversed(range(500))))
|
||||
0
codeflash/benchmarking/__init__.py
Normal file
0
codeflash/benchmarking/__init__.py
Normal file
179
codeflash/benchmarking/codeflash_trace.py
Normal file
179
codeflash/benchmarking/codeflash_trace.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
import functools
|
||||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
|
||||
|
||||
class CodeflashTrace:
|
||||
"""Decorator class that traces and profiles function execution."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.function_calls_data = []
|
||||
self.function_call_count = 0
|
||||
self.pickle_count_limit = 1000
|
||||
self._connection = None
|
||||
self._trace_path = None
|
||||
self._thread_local = threading.local()
|
||||
self._thread_local.active_functions = set()
|
||||
|
||||
def setup(self, trace_path: str) -> None:
|
||||
"""Set up the database connection for direct writing.
|
||||
|
||||
Args:
|
||||
trace_path: Path to the trace database file
|
||||
|
||||
"""
|
||||
try:
|
||||
self._trace_path = trace_path
|
||||
self._connection = sqlite3.connect(self._trace_path)
|
||||
cur = self._connection.cursor()
|
||||
cur.execute("PRAGMA synchronous = OFF")
|
||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
||||
cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS benchmark_function_timings("
|
||||
"function_name TEXT, class_name TEXT, module_name TEXT, file_path TEXT,"
|
||||
"benchmark_function_name TEXT, benchmark_module_path TEXT, benchmark_line_number INTEGER,"
|
||||
"function_time_ns INTEGER, overhead_time_ns INTEGER, args BLOB, kwargs BLOB)"
|
||||
)
|
||||
self._connection.commit()
|
||||
except Exception as e:
|
||||
print(f"Database setup error: {e}")
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
raise
|
||||
|
||||
def write_function_timings(self) -> None:
|
||||
"""Write function call data directly to the database.
|
||||
|
||||
Args:
|
||||
data: List of function call data tuples to write
|
||||
|
||||
"""
|
||||
if not self.function_calls_data:
|
||||
return # No data to write
|
||||
|
||||
if self._connection is None and self._trace_path is not None:
|
||||
self._connection = sqlite3.connect(self._trace_path)
|
||||
|
||||
try:
|
||||
cur = self._connection.cursor()
|
||||
# Insert data into the benchmark_function_timings table
|
||||
cur.executemany(
|
||||
"INSERT INTO benchmark_function_timings"
|
||||
"(function_name, class_name, module_name, file_path, benchmark_function_name, "
|
||||
"benchmark_module_path, benchmark_line_number, function_time_ns, overhead_time_ns, args, kwargs) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||
self.function_calls_data
|
||||
)
|
||||
self._connection.commit()
|
||||
self.function_calls_data = []
|
||||
except Exception as e:
|
||||
print(f"Error writing to function timings database: {e}")
|
||||
if self._connection:
|
||||
self._connection.rollback()
|
||||
raise
|
||||
|
||||
def open(self) -> None:
|
||||
"""Open the database connection."""
|
||||
if self._connection is None:
|
||||
self._connection = sqlite3.connect(self._trace_path)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
def __call__(self, func: Callable) -> Callable:
|
||||
"""Use as a decorator to trace function execution.
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
The wrapped function
|
||||
|
||||
"""
|
||||
func_id = (func.__module__,func.__name__)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Initialize thread-local active functions set if it doesn't exist
|
||||
if not hasattr(self._thread_local, "active_functions"):
|
||||
self._thread_local.active_functions = set()
|
||||
# If it's in a recursive function, just return the result
|
||||
if func_id in self._thread_local.active_functions:
|
||||
return func(*args, **kwargs)
|
||||
# Track active functions so we can detect recursive functions
|
||||
self._thread_local.active_functions.add(func_id)
|
||||
# Measure execution time
|
||||
start_time = time.thread_time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.thread_time_ns()
|
||||
# Calculate execution time
|
||||
execution_time = end_time - start_time
|
||||
self.function_call_count += 1
|
||||
|
||||
# Check if currently in pytest benchmark fixture
|
||||
if os.environ.get("CODEFLASH_BENCHMARKING", "False") == "False":
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
return result
|
||||
# Get benchmark info from environment
|
||||
benchmark_function_name = os.environ.get("CODEFLASH_BENCHMARK_FUNCTION_NAME", "")
|
||||
benchmark_module_path = os.environ.get("CODEFLASH_BENCHMARK_MODULE_PATH", "")
|
||||
benchmark_line_number = os.environ.get("CODEFLASH_BENCHMARK_LINE_NUMBER", "")
|
||||
# Get class name
|
||||
class_name = ""
|
||||
qualname = func.__qualname__
|
||||
if "." in qualname:
|
||||
class_name = qualname.split(".")[0]
|
||||
|
||||
# Limit pickle count so memory does not explode
|
||||
if self.function_call_count > self.pickle_count_limit:
|
||||
print("Pickle limit reached")
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
overhead_time, None, None)
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
# Pickle the arguments
|
||||
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
except Exception as e:
|
||||
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
||||
# Add to the list of function calls without pickled args. Used for timing info only
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
overhead_time, None, None)
|
||||
)
|
||||
return result
|
||||
# Flush to database every 100 calls
|
||||
if len(self.function_calls_data) > 100:
|
||||
self.write_function_timings()
|
||||
|
||||
# Add to the list of function calls with pickled args, to be used for replay tests
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
overhead_time = time.thread_time_ns() - end_time
|
||||
self.function_calls_data.append(
|
||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||
overhead_time, pickled_args, pickled_kwargs)
|
||||
)
|
||||
return result
|
||||
return wrapper
|
||||
|
||||
# Create a singleton instance
|
||||
codeflash_trace = CodeflashTrace()
|
||||
117
codeflash/benchmarking/instrument_codeflash_trace.py
Normal file
117
codeflash/benchmarking/instrument_codeflash_trace.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
from pathlib import Path
|
||||
|
||||
import isort
|
||||
import libcst as cst
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
|
||||
|
||||
class AddDecoratorTransformer(cst.CSTTransformer):
|
||||
def __init__(self, target_functions: set[tuple[str, str]]) -> None:
|
||||
super().__init__()
|
||||
self.target_functions = target_functions
|
||||
self.added_codeflash_trace = False
|
||||
self.class_name = ""
|
||||
self.function_name = ""
|
||||
self.decorator = cst.Decorator(
|
||||
decorator=cst.Name(value="codeflash_trace")
|
||||
)
|
||||
|
||||
def leave_ClassDef(self, original_node, updated_node):
|
||||
if self.class_name == original_node.name.value:
|
||||
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
|
||||
return updated_node
|
||||
|
||||
def visit_ClassDef(self, node):
|
||||
if self.class_name: # Don't go into nested class
|
||||
return False
|
||||
self.class_name = node.name.value
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
if self.function_name: # Don't go into nested function
|
||||
return False
|
||||
self.function_name = node.name.value
|
||||
|
||||
def leave_FunctionDef(self, original_node, updated_node):
|
||||
if self.function_name == original_node.name.value:
|
||||
self.function_name = ""
|
||||
if (self.class_name, original_node.name.value) in self.target_functions:
|
||||
# Add the new decorator after any existing decorators, so it gets executed first
|
||||
updated_decorators = list(updated_node.decorators) + [self.decorator]
|
||||
self.added_codeflash_trace = True
|
||||
return updated_node.with_changes(
|
||||
decorators=updated_decorators
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
|
||||
# Create import statement for codeflash_trace
|
||||
if not self.added_codeflash_trace:
|
||||
return updated_node
|
||||
import_stmt = cst.SimpleStatementLine(
|
||||
body=[
|
||||
cst.ImportFrom(
|
||||
module=cst.Attribute(
|
||||
value=cst.Attribute(
|
||||
value=cst.Name(value="codeflash"),
|
||||
attr=cst.Name(value="benchmarking")
|
||||
),
|
||||
attr=cst.Name(value="codeflash_trace")
|
||||
),
|
||||
names=[
|
||||
cst.ImportAlias(
|
||||
name=cst.Name(value="codeflash_trace")
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Insert at the beginning of the file. We'll use isort later to sort the imports.
|
||||
new_body = [import_stmt, *list(updated_node.body)]
|
||||
|
||||
return updated_node.with_changes(body=new_body)
|
||||
|
||||
def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[FunctionToOptimize]) -> str:
|
||||
"""Add codeflash_trace to a function.
|
||||
|
||||
Args:
|
||||
code: The source code as a string
|
||||
function_to_optimize: The FunctionToOptimize instance containing function details
|
||||
|
||||
Returns:
|
||||
The modified source code as a string
|
||||
|
||||
"""
|
||||
target_functions = set()
|
||||
for function_to_optimize in functions_to_optimize:
|
||||
class_name = ""
|
||||
if len(function_to_optimize.parents) == 1 and function_to_optimize.parents[0].type == "ClassDef":
|
||||
class_name = function_to_optimize.parents[0].name
|
||||
target_functions.add((class_name, function_to_optimize.function_name))
|
||||
|
||||
transformer = AddDecoratorTransformer(
|
||||
target_functions = target_functions,
|
||||
)
|
||||
|
||||
module = cst.parse_module(code)
|
||||
modified_module = module.visit(transformer)
|
||||
return modified_module.code
|
||||
|
||||
|
||||
def instrument_codeflash_trace_decorator(
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
|
||||
) -> None:
|
||||
"""Instrument codeflash_trace decorator to functions to optimize."""
|
||||
for file_path, functions_to_optimize in file_to_funcs_to_optimize.items():
|
||||
original_code = file_path.read_text(encoding="utf-8")
|
||||
new_code = add_codeflash_decorator_to_code(
|
||||
original_code,
|
||||
functions_to_optimize
|
||||
)
|
||||
# Modify the code
|
||||
modified_code = isort.code(code=new_code, float_to_top=True)
|
||||
|
||||
# Write the modified code back to the file
|
||||
file_path.write_text(modified_code, encoding="utf-8")
|
||||
0
codeflash/benchmarking/plugin/__init__.py
Normal file
0
codeflash/benchmarking/plugin/__init__.py
Normal file
293
codeflash/benchmarking/plugin/plugin.py
Normal file
293
codeflash/benchmarking/plugin/plugin.py
Normal file
|
|
@ -0,0 +1,293 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path
|
||||
from codeflash.models.models import BenchmarkKey
|
||||
|
||||
|
||||
class CodeFlashBenchmarkPlugin:
|
||||
def __init__(self) -> None:
|
||||
self._trace_path = None
|
||||
self._connection = None
|
||||
self.project_root = None
|
||||
self.benchmark_timings = []
|
||||
|
||||
def setup(self, trace_path:str, project_root:str) -> None:
|
||||
try:
|
||||
# Open connection
|
||||
self.project_root = project_root
|
||||
self._trace_path = trace_path
|
||||
self._connection = sqlite3.connect(self._trace_path)
|
||||
cur = self._connection.cursor()
|
||||
cur.execute("PRAGMA synchronous = OFF")
|
||||
cur.execute("PRAGMA journal_mode = MEMORY")
|
||||
cur.execute(
|
||||
"CREATE TABLE IF NOT EXISTS benchmark_timings("
|
||||
"benchmark_module_path TEXT, benchmark_function_name TEXT, benchmark_line_number INTEGER,"
|
||||
"benchmark_time_ns INTEGER)"
|
||||
)
|
||||
self._connection.commit()
|
||||
self.close() # Reopen only at the end of pytest session
|
||||
except Exception as e:
|
||||
print(f"Database setup error: {e}")
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
raise
|
||||
|
||||
def write_benchmark_timings(self) -> None:
|
||||
if not self.benchmark_timings:
|
||||
return # No data to write
|
||||
|
||||
if self._connection is None:
|
||||
self._connection = sqlite3.connect(self._trace_path)
|
||||
|
||||
try:
|
||||
cur = self._connection.cursor()
|
||||
# Insert data into the benchmark_timings table
|
||||
cur.executemany(
|
||||
"INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)",
|
||||
self.benchmark_timings
|
||||
)
|
||||
self._connection.commit()
|
||||
self.benchmark_timings = [] # Clear the benchmark timings list
|
||||
except Exception as e:
|
||||
print(f"Error writing to benchmark timings database: {e}")
|
||||
self._connection.rollback()
|
||||
raise
|
||||
def close(self) -> None:
|
||||
if self._connection:
|
||||
self._connection.close()
|
||||
self._connection = None
|
||||
|
||||
@staticmethod
|
||||
def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[BenchmarkKey, int]]:
|
||||
"""Process the trace file and extract timing data for all functions.
|
||||
|
||||
Args:
|
||||
trace_path: Path to the trace file
|
||||
|
||||
Returns:
|
||||
A nested dictionary where:
|
||||
- Outer keys are module_name.qualified_name (module.class.function)
|
||||
- Inner keys are of type BenchmarkKey
|
||||
- Values are function timing in milliseconds
|
||||
|
||||
"""
|
||||
# Initialize the result dictionary
|
||||
result = {}
|
||||
|
||||
# Connect to the SQLite database
|
||||
connection = sqlite3.connect(trace_path)
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
# Query the function_calls table for all function calls
|
||||
cursor.execute(
|
||||
"SELECT module_name, class_name, function_name, "
|
||||
"benchmark_module_path, benchmark_function_name, benchmark_line_number, function_time_ns "
|
||||
"FROM benchmark_function_timings"
|
||||
)
|
||||
|
||||
# Process each row
|
||||
for row in cursor.fetchall():
|
||||
module_name, class_name, function_name, benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
|
||||
# Create the function key (module_name.class_name.function_name)
|
||||
if class_name:
|
||||
qualified_name = f"{module_name}.{class_name}.{function_name}"
|
||||
else:
|
||||
qualified_name = f"{module_name}.{function_name}"
|
||||
|
||||
# Create the benchmark key (file::function::line)
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
# Initialize the inner dictionary if needed
|
||||
if qualified_name not in result:
|
||||
result[qualified_name] = {}
|
||||
|
||||
# If multiple calls to the same function in the same benchmark,
|
||||
# add the times together
|
||||
if benchmark_key in result[qualified_name]:
|
||||
result[qualified_name][benchmark_key] += time_ns
|
||||
else:
|
||||
result[qualified_name][benchmark_key] = time_ns
|
||||
|
||||
finally:
|
||||
# Close the connection
|
||||
connection.close()
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
|
||||
"""Extract total benchmark timings from trace files.
|
||||
|
||||
Args:
|
||||
trace_path: Path to the trace file
|
||||
|
||||
Returns:
|
||||
A dictionary mapping where:
|
||||
- Keys are of type BenchmarkKey
|
||||
- Values are total benchmark timing in milliseconds (with overhead subtracted)
|
||||
|
||||
"""
|
||||
# Initialize the result dictionary
|
||||
result = {}
|
||||
overhead_by_benchmark = {}
|
||||
|
||||
# Connect to the SQLite database
|
||||
connection = sqlite3.connect(trace_path)
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
# Query the benchmark_function_timings table to get total overhead for each benchmark
|
||||
cursor.execute(
|
||||
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, SUM(overhead_time_ns) "
|
||||
"FROM benchmark_function_timings "
|
||||
"GROUP BY benchmark_module_path, benchmark_function_name, benchmark_line_number"
|
||||
)
|
||||
|
||||
# Process overhead information
|
||||
for row in cursor.fetchall():
|
||||
benchmark_file, benchmark_func, benchmark_line, total_overhead_ns = row
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
overhead_by_benchmark[benchmark_key] = total_overhead_ns or 0 # Handle NULL sum case
|
||||
|
||||
# Query the benchmark_timings table for total times
|
||||
cursor.execute(
|
||||
"SELECT benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns "
|
||||
"FROM benchmark_timings"
|
||||
)
|
||||
|
||||
# Process each row and subtract overhead
|
||||
for row in cursor.fetchall():
|
||||
benchmark_file, benchmark_func, benchmark_line, time_ns = row
|
||||
|
||||
# Create the benchmark key (file::function::line)
|
||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
# Subtract overhead from total time
|
||||
overhead = overhead_by_benchmark.get(benchmark_key, 0)
|
||||
result[benchmark_key] = time_ns - overhead
|
||||
|
||||
finally:
|
||||
# Close the connection
|
||||
connection.close()
|
||||
|
||||
return result
|
||||
|
||||
# Pytest hooks
|
||||
@pytest.hookimpl
|
||||
def pytest_sessionfinish(self, session, exitstatus):
|
||||
"""Execute after whole test run is completed."""
|
||||
# Write any remaining benchmark timings to the database
|
||||
codeflash_trace.close()
|
||||
if self.benchmark_timings:
|
||||
self.write_benchmark_timings()
|
||||
# Close the database connection
|
||||
self.close()
|
||||
|
||||
@staticmethod
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption(
|
||||
"--codeflash-trace",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable CodeFlash tracing"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pytest_plugin_registered(plugin, manager):
|
||||
# Not necessary since run with -p no:benchmark, but just in case
|
||||
if hasattr(plugin, "name") and plugin.name == "pytest-benchmark":
|
||||
manager.unregister(plugin)
|
||||
|
||||
@staticmethod
|
||||
def pytest_configure(config):
|
||||
"""Register the benchmark marker."""
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
"benchmark: mark test as a benchmark that should be run with codeflash tracing"
|
||||
)
|
||||
@staticmethod
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Skip tests that don't have the benchmark fixture
|
||||
if not config.getoption("--codeflash-trace"):
|
||||
return
|
||||
|
||||
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
|
||||
for item in items:
|
||||
# Check for direct benchmark fixture usage
|
||||
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
|
||||
|
||||
# Check for @pytest.mark.benchmark marker
|
||||
has_marker = False
|
||||
if hasattr(item, "get_closest_marker"):
|
||||
marker = item.get_closest_marker("benchmark")
|
||||
if marker is not None:
|
||||
has_marker = True
|
||||
|
||||
# Skip if neither fixture nor marker is present
|
||||
if not (has_fixture or has_marker):
|
||||
item.add_marker(skip_no_benchmark)
|
||||
|
||||
# Benchmark fixture
|
||||
class Benchmark:
|
||||
def __init__(self, request):
|
||||
self.request = request
|
||||
|
||||
def __call__(self, func, *args, **kwargs):
|
||||
"""Handle both direct function calls and decorator usage."""
|
||||
if args or kwargs:
|
||||
# Used as benchmark(func, *args, **kwargs)
|
||||
return self._run_benchmark(func, *args, **kwargs)
|
||||
# Used as @benchmark decorator
|
||||
def wrapped_func(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
result = self._run_benchmark(func)
|
||||
return wrapped_func
|
||||
|
||||
def _run_benchmark(self, func, *args, **kwargs):
|
||||
"""Actual benchmark implementation."""
|
||||
benchmark_module_path = module_name_from_file_path(Path(str(self.request.node.fspath)),
|
||||
Path(codeflash_benchmark_plugin.project_root))
|
||||
benchmark_function_name = self.request.node.name
|
||||
line_number = int(str(sys._getframe(2).f_lineno)) # 2 frames up in the call stack
|
||||
# Set env vars
|
||||
os.environ["CODEFLASH_BENCHMARK_FUNCTION_NAME"] = benchmark_function_name
|
||||
os.environ["CODEFLASH_BENCHMARK_MODULE_PATH"] = benchmark_module_path
|
||||
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "True"
|
||||
# Run the function
|
||||
start = time.time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.time_ns()
|
||||
# Reset the environment variable
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "False"
|
||||
|
||||
# Write function calls
|
||||
codeflash_trace.write_function_timings()
|
||||
# Reset function call count
|
||||
codeflash_trace.function_call_count = 0
|
||||
# Add to the benchmark timings buffer
|
||||
codeflash_benchmark_plugin.benchmark_timings.append(
|
||||
(benchmark_module_path, benchmark_function_name, line_number, end - start))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
@pytest.fixture
|
||||
def benchmark(request):
|
||||
if not request.config.getoption("--codeflash-trace"):
|
||||
return None
|
||||
|
||||
return CodeFlashBenchmarkPlugin.Benchmark(request)
|
||||
|
||||
codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin()
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
|
||||
benchmarks_root = sys.argv[1]
|
||||
tests_root = sys.argv[2]
|
||||
trace_file = sys.argv[3]
|
||||
# current working directory
|
||||
project_root = Path.cwd()
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
|
||||
try:
|
||||
codeflash_benchmark_plugin.setup(trace_file, project_root)
|
||||
codeflash_trace.setup(trace_file)
|
||||
exitcode = pytest.main(
|
||||
[benchmarks_root, "--codeflash-trace", "-p", "no:benchmark","-p", "no:codspeed","-p", "no:cov-s", "-o", "addopts="], plugins=[codeflash_benchmark_plugin]
|
||||
) # Errors will be printed to stdout, not stderr
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
|
||||
exitcode = -1
|
||||
sys.exit(exitcode)
|
||||
286
codeflash/benchmarking/replay_test.py
Normal file
286
codeflash/benchmarking/replay_test.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import isort
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.discovery.functions_to_optimize import inspect_top_level_functions_or_methods
|
||||
from codeflash.verification.verification_utils import get_test_file_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
def get_next_arg_and_return(
|
||||
trace_file: str, benchmark_function_name:str, function_name: str, file_path: str, class_name: str | None = None, num_to_get: int = 256
|
||||
) -> Generator[Any]:
|
||||
db = sqlite3.connect(trace_file)
|
||||
cur = db.cursor()
|
||||
limit = num_to_get
|
||||
|
||||
if class_name is not None:
|
||||
cursor = cur.execute(
|
||||
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
|
||||
(benchmark_function_name, function_name, file_path, class_name, limit),
|
||||
)
|
||||
else:
|
||||
cursor = cur.execute(
|
||||
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
|
||||
(benchmark_function_name, function_name, file_path, limit),
|
||||
)
|
||||
|
||||
while (val := cursor.fetchone()) is not None:
|
||||
yield val[9], val[10] # pickled_args, pickled_kwargs
|
||||
|
||||
|
||||
def get_function_alias(module: str, function_name: str) -> str:
|
||||
return "_".join(module.split(".")) + "_" + function_name
|
||||
|
||||
|
||||
def create_trace_replay_test_code(
|
||||
trace_file: str,
|
||||
functions_data: list[dict[str, Any]],
|
||||
test_framework: str = "pytest",
|
||||
max_run_count=256
|
||||
) -> str:
|
||||
"""Create a replay test for functions based on trace data.
|
||||
|
||||
Args:
|
||||
trace_file: Path to the SQLite database file
|
||||
functions_data: List of dictionaries with function info extracted from DB
|
||||
test_framework: 'pytest' or 'unittest'
|
||||
max_run_count: Maximum number of runs to include in the test
|
||||
|
||||
Returns:
|
||||
A string containing the test code
|
||||
|
||||
"""
|
||||
assert test_framework in ["pytest", "unittest"]
|
||||
|
||||
# Create Imports
|
||||
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
{"import unittest" if test_framework == "unittest" else ""}
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
"""
|
||||
|
||||
function_imports = []
|
||||
for func in functions_data:
|
||||
module_name = func.get("module_name")
|
||||
function_name = func.get("function_name")
|
||||
class_name = func.get("class_name", "")
|
||||
if class_name:
|
||||
function_imports.append(
|
||||
f"from {module_name} import {class_name} as {get_function_alias(module_name, class_name)}"
|
||||
)
|
||||
else:
|
||||
function_imports.append(
|
||||
f"from {module_name} import {function_name} as {get_function_alias(module_name, function_name)}"
|
||||
)
|
||||
|
||||
imports += "\n".join(function_imports)
|
||||
|
||||
functions_to_optimize = sorted({func.get("function_name") for func in functions_data
|
||||
if func.get("function_name") != "__init__"})
|
||||
metadata = f"""functions = {functions_to_optimize}
|
||||
trace_file_path = r"{trace_file}"
|
||||
"""
|
||||
# Templates for different types of tests
|
||||
test_function_body = textwrap.dedent(
|
||||
"""\
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
ret = {function_name}(*args, **kwargs)
|
||||
"""
|
||||
)
|
||||
|
||||
test_method_body = textwrap.dedent(
|
||||
"""\
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl){filter_variables}
|
||||
function_name = "{orig_function_name}"
|
||||
if not args:
|
||||
raise ValueError("No arguments provided for the method.")
|
||||
if function_name == "__init__":
|
||||
ret = {class_name_alias}(*args[1:], **kwargs)
|
||||
else:
|
||||
instance = args[0] # self
|
||||
ret = instance{method_name}(*args[1:], **kwargs)
|
||||
""")
|
||||
|
||||
test_class_method_body = textwrap.dedent(
|
||||
"""\
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl){filter_variables}
|
||||
if not args:
|
||||
raise ValueError("No arguments provided for the method.")
|
||||
ret = {class_name_alias}{method_name}(*args[1:], **kwargs)
|
||||
"""
|
||||
)
|
||||
test_static_method_body = textwrap.dedent(
|
||||
"""\
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl){filter_variables}
|
||||
ret = {class_name_alias}{method_name}(*args, **kwargs)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create main body
|
||||
|
||||
if test_framework == "unittest":
|
||||
self = "self"
|
||||
test_template = "\nclass TestTracedFunctions(unittest.TestCase):\n"
|
||||
else:
|
||||
test_template = ""
|
||||
self = ""
|
||||
|
||||
for func in functions_data:
|
||||
|
||||
module_name = func.get("module_name")
|
||||
function_name = func.get("function_name")
|
||||
class_name = func.get("class_name")
|
||||
file_path = func.get("file_path")
|
||||
benchmark_function_name = func.get("benchmark_function_name")
|
||||
function_properties = func.get("function_properties")
|
||||
if not class_name:
|
||||
alias = get_function_alias(module_name, function_name)
|
||||
test_body = test_function_body.format(
|
||||
benchmark_function_name=benchmark_function_name,
|
||||
orig_function_name=function_name,
|
||||
function_name=alias,
|
||||
file_path=file_path,
|
||||
max_run_count=max_run_count,
|
||||
)
|
||||
else:
|
||||
class_name_alias = get_function_alias(module_name, class_name)
|
||||
alias = get_function_alias(module_name, class_name + "_" + function_name)
|
||||
|
||||
filter_variables = ""
|
||||
# filter_variables = '\n args.pop("cls", None)'
|
||||
method_name = "." + function_name if function_name != "__init__" else ""
|
||||
if function_properties.is_classmethod:
|
||||
test_body = test_class_method_body.format(
|
||||
benchmark_function_name=benchmark_function_name,
|
||||
orig_function_name=function_name,
|
||||
file_path=file_path,
|
||||
class_name_alias=class_name_alias,
|
||||
class_name=class_name,
|
||||
method_name=method_name,
|
||||
max_run_count=max_run_count,
|
||||
filter_variables=filter_variables,
|
||||
)
|
||||
elif function_properties.is_staticmethod:
|
||||
test_body = test_static_method_body.format(
|
||||
benchmark_function_name=benchmark_function_name,
|
||||
orig_function_name=function_name,
|
||||
file_path=file_path,
|
||||
class_name_alias=class_name_alias,
|
||||
class_name=class_name,
|
||||
method_name=method_name,
|
||||
max_run_count=max_run_count,
|
||||
filter_variables=filter_variables,
|
||||
)
|
||||
else:
|
||||
test_body = test_method_body.format(
|
||||
benchmark_function_name=benchmark_function_name,
|
||||
orig_function_name=function_name,
|
||||
file_path=file_path,
|
||||
class_name_alias=class_name_alias,
|
||||
class_name=class_name,
|
||||
method_name=method_name,
|
||||
max_run_count=max_run_count,
|
||||
filter_variables=filter_variables,
|
||||
)
|
||||
|
||||
formatted_test_body = textwrap.indent(test_body, " " if test_framework == "unittest" else " ")
|
||||
|
||||
test_template += " " if test_framework == "unittest" else ""
|
||||
test_template += f"def test_{alias}({self}):\n{formatted_test_body}\n"
|
||||
|
||||
return imports + "\n" + metadata + "\n" + test_template
|
||||
|
||||
def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework: str = "pytest", max_run_count: int = 100) -> int:
|
||||
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
|
||||
|
||||
Args:
|
||||
trace_file_path: Path to the SQLite database file
|
||||
output_dir: Directory to write the generated tests (if None, only returns the code)
|
||||
test_framework: 'pytest' or 'unittest'
|
||||
max_run_count: Maximum number of runs to include per function
|
||||
|
||||
Returns:
|
||||
Dictionary mapping benchmark names to generated test code
|
||||
|
||||
"""
|
||||
count = 0
|
||||
try:
|
||||
# Connect to the database
|
||||
conn = sqlite3.connect(trace_file_path.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get distinct benchmark file paths
|
||||
cursor.execute(
|
||||
"SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings"
|
||||
)
|
||||
benchmark_files = cursor.fetchall()
|
||||
|
||||
# Generate a test for each benchmark file
|
||||
for benchmark_file in benchmark_files:
|
||||
benchmark_module_path = benchmark_file[0]
|
||||
# Get all benchmarks and functions associated with this file path
|
||||
cursor.execute(
|
||||
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
|
||||
"WHERE benchmark_module_path = ?",
|
||||
(benchmark_module_path,)
|
||||
)
|
||||
|
||||
functions_data = []
|
||||
for row in cursor.fetchall():
|
||||
benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row
|
||||
# Add this function to our list
|
||||
functions_data.append({
|
||||
"function_name": function_name,
|
||||
"class_name": class_name,
|
||||
"file_path": file_path,
|
||||
"module_name": module_name,
|
||||
"benchmark_function_name": benchmark_function_name,
|
||||
"benchmark_module_path": benchmark_module_path,
|
||||
"benchmark_line_number": benchmark_line_number,
|
||||
"function_properties": inspect_top_level_functions_or_methods(
|
||||
file_name=Path(file_path),
|
||||
function_or_method_name=function_name,
|
||||
class_name=class_name,
|
||||
)
|
||||
})
|
||||
|
||||
if not functions_data:
|
||||
logger.info(f"No benchmark test functions found in {benchmark_module_path}")
|
||||
continue
|
||||
# Generate the test code for this benchmark
|
||||
test_code = create_trace_replay_test_code(
|
||||
trace_file=trace_file_path.as_posix(),
|
||||
functions_data=functions_data,
|
||||
test_framework=test_framework,
|
||||
max_run_count=max_run_count,
|
||||
)
|
||||
test_code = isort.code(test_code)
|
||||
output_file = get_test_file_path(
|
||||
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
|
||||
)
|
||||
# Write test code to file, parents = true
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_file.write_text(test_code, "utf-8")
|
||||
count += 1
|
||||
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
logger.info(f"Error generating replay tests: {e}")
|
||||
|
||||
return count
|
||||
42
codeflash/benchmarking/trace_benchmarks.py
Normal file
42
codeflash/benchmarking/trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
|
||||
|
||||
|
||||
def trace_benchmarks_pytest(benchmarks_root: Path, tests_root:Path, project_root: Path, trace_file: Path, timeout:int = 300) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
SAFE_SYS_EXECUTABLE,
|
||||
Path(__file__).parent / "pytest_new_process_trace_benchmarks.py",
|
||||
benchmarks_root,
|
||||
tests_root,
|
||||
trace_file,
|
||||
],
|
||||
cwd=project_root,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env={"PYTHONPATH": str(project_root)},
|
||||
timeout=timeout,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
if "ERROR collecting" in result.stdout:
|
||||
# Pattern matches "===== ERRORS =====" (any number of =) and captures everything after
|
||||
error_pattern = r"={3,}\s*ERRORS\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
|
||||
match = re.search(error_pattern, result.stdout)
|
||||
error_section = match.group(1) if match else result.stdout
|
||||
elif "FAILURES" in result.stdout:
|
||||
# Pattern matches "===== FAILURES =====" (any number of =) and captures everything after
|
||||
error_pattern = r"={3,}\s*FAILURES\s*={3,}\n([\s\S]*?)(?:={3,}|$)"
|
||||
match = re.search(error_pattern, result.stdout)
|
||||
error_section = match.group(1) if match else result.stdout
|
||||
else:
|
||||
error_section = result.stdout
|
||||
logger.warning(
|
||||
f"Error collecting benchmarks - Pytest Exit code: {result.returncode}, {error_section}"
|
||||
)
|
||||
136
codeflash/benchmarking/utils.py
Normal file
136
codeflash/benchmarking/utils.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.models.models import BenchmarkDetail, ProcessedBenchmarkInfo
|
||||
from codeflash.result.critic import performance_gain
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from codeflash.models.models import BenchmarkKey
|
||||
|
||||
|
||||
def validate_and_format_benchmark_table(function_benchmark_timings: dict[str, dict[BenchmarkKey, int]],
|
||||
total_benchmark_timings: dict[BenchmarkKey, int]) -> dict[str, list[tuple[BenchmarkKey, float, float, float]]]:
|
||||
function_to_result = {}
|
||||
# Process each function's benchmark data
|
||||
for func_path, test_times in function_benchmark_timings.items():
|
||||
# Sort by percentage (highest first)
|
||||
sorted_tests = []
|
||||
for benchmark_key, func_time in test_times.items():
|
||||
total_time = total_benchmark_timings.get(benchmark_key, 0)
|
||||
if func_time > total_time:
|
||||
logger.debug(f"Skipping test {benchmark_key} due to func_time {func_time} > total_time {total_time}")
|
||||
# If the function time is greater than total time, likely to have multithreading / multiprocessing issues.
|
||||
# Do not try to project the optimization impact for this function.
|
||||
sorted_tests.append((benchmark_key, 0.0, 0.0, 0.0))
|
||||
elif total_time > 0:
|
||||
percentage = (func_time / total_time) * 100
|
||||
# Convert nanoseconds to milliseconds
|
||||
func_time_ms = func_time / 1_000_000
|
||||
total_time_ms = total_time / 1_000_000
|
||||
sorted_tests.append((benchmark_key, total_time_ms, func_time_ms, percentage))
|
||||
sorted_tests.sort(key=lambda x: x[3], reverse=True)
|
||||
function_to_result[func_path] = sorted_tests
|
||||
return function_to_result
|
||||
|
||||
|
||||
def print_benchmark_table(function_to_results: dict[str, list[tuple[BenchmarkKey, float, float, float]]]) -> None:
|
||||
|
||||
try:
|
||||
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
|
||||
except Exception:
|
||||
terminal_width = 120 # Fallback width
|
||||
console = Console(width = terminal_width)
|
||||
for func_path, sorted_tests in function_to_results.items():
|
||||
console.print()
|
||||
function_name = func_path.split(":")[-1]
|
||||
|
||||
# Create a table for this function
|
||||
table = Table(title=f"Function: {function_name}", width=terminal_width, border_style="blue", show_lines=True)
|
||||
benchmark_col_width = max(int(terminal_width * 0.4), 40)
|
||||
# Add columns - split the benchmark test into two columns
|
||||
table.add_column("Benchmark Module Path", width=benchmark_col_width, style="cyan", overflow="fold")
|
||||
table.add_column("Test Function", style="magenta", overflow="fold")
|
||||
table.add_column("Total Time (ms)", justify="right", style="green")
|
||||
table.add_column("Function Time (ms)", justify="right", style="yellow")
|
||||
table.add_column("Percentage (%)", justify="right", style="red")
|
||||
|
||||
for benchmark_key, total_time, func_time, percentage in sorted_tests:
|
||||
# Split the benchmark test into module path and function name
|
||||
module_path = benchmark_key.module_path
|
||||
test_function = benchmark_key.function_name
|
||||
|
||||
if total_time == 0.0:
|
||||
table.add_row(
|
||||
module_path,
|
||||
test_function,
|
||||
"N/A",
|
||||
"N/A",
|
||||
"N/A"
|
||||
)
|
||||
else:
|
||||
table.add_row(
|
||||
module_path,
|
||||
test_function,
|
||||
f"{total_time:.3f}",
|
||||
f"{func_time:.3f}",
|
||||
f"{percentage:.2f}"
|
||||
)
|
||||
|
||||
# Print the table
|
||||
console.print(table)
|
||||
|
||||
|
||||
def process_benchmark_data(
|
||||
replay_performance_gain: dict[BenchmarkKey, float],
|
||||
fto_benchmark_timings: dict[BenchmarkKey, int],
|
||||
total_benchmark_timings: dict[BenchmarkKey, int]
|
||||
) -> Optional[ProcessedBenchmarkInfo]:
|
||||
"""Process benchmark data and generate detailed benchmark information.
|
||||
|
||||
Args:
|
||||
replay_performance_gain: The performance gain from replay
|
||||
fto_benchmark_timings: Function to optimize benchmark timings
|
||||
total_benchmark_timings: Total benchmark timings
|
||||
|
||||
Returns:
|
||||
ProcessedBenchmarkInfo containing processed benchmark details
|
||||
|
||||
"""
|
||||
if not replay_performance_gain or not fto_benchmark_timings or not total_benchmark_timings:
|
||||
return None
|
||||
|
||||
benchmark_details = []
|
||||
|
||||
for benchmark_key, og_benchmark_timing in fto_benchmark_timings.items():
|
||||
|
||||
total_benchmark_timing = total_benchmark_timings.get(benchmark_key, 0)
|
||||
|
||||
if total_benchmark_timing == 0:
|
||||
continue # Skip benchmarks with zero timing
|
||||
|
||||
# Calculate expected new benchmark timing
|
||||
expected_new_benchmark_timing = total_benchmark_timing - og_benchmark_timing + (
|
||||
1 / (replay_performance_gain[benchmark_key] + 1)
|
||||
) * og_benchmark_timing
|
||||
|
||||
# Calculate speedup
|
||||
benchmark_speedup_percent = performance_gain(original_runtime_ns=total_benchmark_timing, optimized_runtime_ns=int(expected_new_benchmark_timing)) * 100
|
||||
|
||||
benchmark_details.append(
|
||||
BenchmarkDetail(
|
||||
benchmark_name=benchmark_key.module_path,
|
||||
test_function=benchmark_key.function_name,
|
||||
original_timing=humanize_runtime(int(total_benchmark_timing)),
|
||||
expected_new_timing=humanize_runtime(int(expected_new_benchmark_timing)),
|
||||
speedup_percent=benchmark_speedup_percent
|
||||
)
|
||||
)
|
||||
|
||||
return ProcessedBenchmarkInfo(benchmark_details=benchmark_details)
|
||||
|
|
@ -62,6 +62,10 @@ def parse_args() -> Namespace:
|
|||
)
|
||||
parser.add_argument("-v", "--verbose", action="store_true", help="Print verbose debug logs")
|
||||
parser.add_argument("--version", action="store_true", help="Print the version of codeflash")
|
||||
parser.add_argument("--benchmark", action="store_true", help="Trace benchmark tests and calculate optimization impact on benchmarks")
|
||||
parser.add_argument(
|
||||
"--benchmarks-root", type=str, help="Path to the directory of the project, where all the pytest-benchmark tests are located."
|
||||
)
|
||||
args: Namespace = parser.parse_args()
|
||||
return process_and_validate_cmd_args(args)
|
||||
|
||||
|
|
@ -109,6 +113,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
supported_keys = [
|
||||
"module_root",
|
||||
"tests_root",
|
||||
"benchmarks_root",
|
||||
"test_framework",
|
||||
"ignore_paths",
|
||||
"pytest_cmd",
|
||||
|
|
@ -127,7 +132,12 @@ def process_pyproject_config(args: Namespace) -> Namespace:
|
|||
assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory"
|
||||
assert args.tests_root is not None, "--tests-root must be specified"
|
||||
assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory"
|
||||
|
||||
if args.benchmark:
|
||||
assert args.benchmarks_root is not None, "--benchmarks-root must be specified when running with --benchmark"
|
||||
assert Path(args.benchmarks_root).is_dir(), f"--benchmarks-root {args.benchmarks_root} must be a valid directory"
|
||||
assert Path(args.benchmarks_root).resolve().is_relative_to(Path(args.tests_root).resolve()), (
|
||||
f"--benchmarks-root {args.benchmarks_root} must be a subdirectory of --tests-root {args.tests_root}"
|
||||
)
|
||||
if env_utils.get_pr_number() is not None:
|
||||
assert env_utils.ensure_codeflash_api_key(), (
|
||||
"Codeflash API key not found. When running in a Github Actions Context, provide the "
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ CODEFLASH_LOGO: str = (
|
|||
class SetupInfo:
|
||||
module_root: str
|
||||
tests_root: str
|
||||
benchmarks_root: str | None
|
||||
test_framework: str
|
||||
ignore_paths: list[str]
|
||||
formatter: str
|
||||
|
|
@ -125,8 +126,7 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
|
|||
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
|
||||
|
||||
def should_modify_pyproject_toml() -> bool:
|
||||
"""
|
||||
Check if the current directory contains a valid pyproject.toml file with codeflash config
|
||||
"""Check if the current directory contains a valid pyproject.toml file with codeflash config
|
||||
If it does, ask the user if they want to re-configure it.
|
||||
"""
|
||||
from rich.prompt import Confirm
|
||||
|
|
@ -135,7 +135,7 @@ def should_modify_pyproject_toml() -> bool:
|
|||
return True
|
||||
try:
|
||||
config, config_file_path = parse_config_file(pyproject_toml_path)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return True
|
||||
|
||||
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
|
||||
|
|
@ -144,7 +144,7 @@ def should_modify_pyproject_toml() -> bool:
|
|||
return True
|
||||
|
||||
create_toml = Confirm.ask(
|
||||
f"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
|
||||
"✅ A valid Codeflash config already exists in this project. Do you want to re-configure it?", default=False, show_default=True
|
||||
)
|
||||
return create_toml
|
||||
|
||||
|
|
@ -244,6 +244,66 @@ def collect_setup_info() -> SetupInfo:
|
|||
|
||||
ph("cli-test-framework-provided", {"test_framework": test_framework})
|
||||
|
||||
# Get benchmarks root directory
|
||||
default_benchmarks_subdir = "benchmarks"
|
||||
create_benchmarks_option = f"okay, create a {default_benchmarks_subdir}{os.path.sep} directory for me!"
|
||||
no_benchmarks_option = "I don't need benchmarks"
|
||||
|
||||
# Check if benchmarks directory exists inside tests directory
|
||||
tests_subdirs = []
|
||||
if tests_root.exists():
|
||||
tests_subdirs = [d.name for d in tests_root.iterdir() if d.is_dir() and not d.name.startswith(".")]
|
||||
|
||||
benchmarks_options = []
|
||||
if default_benchmarks_subdir in tests_subdirs:
|
||||
benchmarks_options.append(default_benchmarks_subdir)
|
||||
benchmarks_options.extend([d for d in tests_subdirs if d != default_benchmarks_subdir])
|
||||
benchmarks_options.append(create_benchmarks_option)
|
||||
benchmarks_options.append(custom_dir_option)
|
||||
benchmarks_options.append(no_benchmarks_option)
|
||||
|
||||
benchmarks_answer = inquirer_wrapper(
|
||||
inquirer.list_input,
|
||||
message="Where are your benchmarks located? (benchmarks must be a sub directory of your tests root directory)",
|
||||
choices=benchmarks_options,
|
||||
default=(
|
||||
default_benchmarks_subdir if default_benchmarks_subdir in benchmarks_options else benchmarks_options[0]),
|
||||
)
|
||||
|
||||
if benchmarks_answer == create_benchmarks_option:
|
||||
benchmarks_root = tests_root / default_benchmarks_subdir
|
||||
benchmarks_root.mkdir(exist_ok=True)
|
||||
click.echo(f"✅ Created directory {benchmarks_root}{os.path.sep}{LF}")
|
||||
elif benchmarks_answer == custom_dir_option:
|
||||
custom_benchmarks_answer = inquirer_wrapper_path(
|
||||
"path",
|
||||
message=f"Enter the path to your benchmarks directory inside {tests_root}{os.path.sep} ",
|
||||
path_type=inquirer.Path.DIRECTORY,
|
||||
)
|
||||
if custom_benchmarks_answer:
|
||||
benchmarks_root = tests_root / Path(custom_benchmarks_answer["path"])
|
||||
else:
|
||||
apologize_and_exit()
|
||||
elif benchmarks_answer == no_benchmarks_option:
|
||||
benchmarks_root = None
|
||||
else:
|
||||
benchmarks_root = tests_root / Path(cast(str, benchmarks_answer))
|
||||
|
||||
# TODO: Implement other benchmark framework options
|
||||
# if benchmarks_root:
|
||||
# benchmarks_root = benchmarks_root.relative_to(curdir)
|
||||
#
|
||||
# # Ask about benchmark framework
|
||||
# benchmark_framework_options = ["pytest-benchmark", "asv (Airspeed Velocity)", "custom/other"]
|
||||
# benchmark_framework = inquirer_wrapper(
|
||||
# inquirer.list_input,
|
||||
# message="Which benchmark framework do you use?",
|
||||
# choices=benchmark_framework_options,
|
||||
# default=benchmark_framework_options[0],
|
||||
# carousel=True,
|
||||
# )
|
||||
|
||||
|
||||
formatter = inquirer_wrapper(
|
||||
inquirer.list_input,
|
||||
message="Which code formatter do you use?",
|
||||
|
|
@ -279,6 +339,7 @@ def collect_setup_info() -> SetupInfo:
|
|||
return SetupInfo(
|
||||
module_root=str(module_root),
|
||||
tests_root=str(tests_root),
|
||||
benchmarks_root = str(benchmarks_root) if benchmarks_root else None,
|
||||
test_framework=cast(str, test_framework),
|
||||
ignore_paths=ignore_paths,
|
||||
formatter=cast(str, formatter),
|
||||
|
|
@ -437,11 +498,19 @@ def install_github_actions(override_formatter_check: bool = False) -> None:
|
|||
return
|
||||
workflows_path.mkdir(parents=True, exist_ok=True)
|
||||
from importlib.resources import files
|
||||
benchmark_mode = False
|
||||
if "benchmarks_root" in config:
|
||||
benchmark_mode = inquirer_wrapper(
|
||||
inquirer.confirm,
|
||||
message="⚡️It looks like you've configured a benchmarks_root in your config. Would you like to run the Github action in benchmark mode? "
|
||||
" This will show the impact of Codeflash's suggested optimizations on your benchmarks",
|
||||
default=True,
|
||||
)
|
||||
|
||||
optimize_yml_content = (
|
||||
files("codeflash").joinpath("cli_cmds", "workflows", "codeflash-optimize.yaml").read_text(encoding="utf-8")
|
||||
)
|
||||
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root)
|
||||
materialized_optimize_yml_content = customize_codeflash_yaml_content(optimize_yml_content, config, git_root, benchmark_mode)
|
||||
with optimize_yaml_path.open("w", encoding="utf8") as optimize_yml_file:
|
||||
optimize_yml_file.write(materialized_optimize_yml_content)
|
||||
click.echo(f"{LF}✅ Created GitHub action workflow at {optimize_yaml_path}{LF}")
|
||||
|
|
@ -556,7 +625,7 @@ def get_github_action_working_directory(toml_path: Path, git_root: Path) -> str:
|
|||
|
||||
|
||||
def customize_codeflash_yaml_content(
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path
|
||||
optimize_yml_content: str, config: tuple[dict[str, Any], Path], git_root: Path, benchmark_mode: bool = False
|
||||
) -> str:
|
||||
module_path = str(Path(config["module_root"]).relative_to(git_root) / "**")
|
||||
optimize_yml_content = optimize_yml_content.replace("{{ codeflash_module_path }}", module_path)
|
||||
|
|
@ -587,6 +656,9 @@ def customize_codeflash_yaml_content(
|
|||
|
||||
# Add codeflash command
|
||||
codeflash_cmd = get_codeflash_github_action_command(dep_manager)
|
||||
|
||||
if benchmark_mode:
|
||||
codeflash_cmd += " --benchmark"
|
||||
return optimize_yml_content.replace("{{ codeflash_command }}", codeflash_cmd)
|
||||
|
||||
|
||||
|
|
@ -608,6 +680,7 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
|
|||
codeflash_section["module-root"] = setup_info.module_root
|
||||
codeflash_section["tests-root"] = setup_info.tests_root
|
||||
codeflash_section["test-framework"] = setup_info.test_framework
|
||||
codeflash_section["benchmarks-root"] = setup_info.benchmarks_root if setup_info.benchmarks_root else ""
|
||||
codeflash_section["ignore-paths"] = setup_info.ignore_paths
|
||||
if setup_info.git_remote not in ["", "origin"]:
|
||||
codeflash_section["git-remote"] = setup_info.git_remote
|
||||
|
|
|
|||
|
|
@ -52,10 +52,10 @@ def parse_config_file(
|
|||
assert isinstance(config, dict)
|
||||
|
||||
# default values:
|
||||
path_keys = {"module-root", "tests-root"}
|
||||
path_list_keys = {"ignore-paths", }
|
||||
path_keys = ["module-root", "tests-root", "benchmarks-root"]
|
||||
path_list_keys = ["ignore-paths"]
|
||||
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
|
||||
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
|
||||
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False}
|
||||
list_str_keys = {"formatter-cmds": ["black $file"]}
|
||||
|
||||
for key in str_keys:
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ class FunctionWithReturnStatement(ast.NodeVisitor):
|
|||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
class FunctionToOptimize:
|
||||
"""Represents a function that is a candidate for optimization.
|
||||
"""Represent a function that is a candidate for optimization.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
|
|
@ -145,7 +145,6 @@ class FunctionToOptimize:
|
|||
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
|
||||
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
|
||||
|
||||
|
||||
def get_functions_to_optimize(
|
||||
optimize_all: str | None,
|
||||
replay_test: str | None,
|
||||
|
|
@ -359,9 +358,15 @@ class TopLevelFunctionOrMethodVisitor(ast.NodeVisitor):
|
|||
for decorator in body_node.decorator_list
|
||||
):
|
||||
self.is_classmethod = True
|
||||
elif any(
|
||||
isinstance(decorator, ast.Name) and decorator.id == "staticmethod"
|
||||
for decorator in body_node.decorator_list
|
||||
):
|
||||
self.is_staticmethod = True
|
||||
return
|
||||
else:
|
||||
# search if the class has a staticmethod with the same name and on the same line number
|
||||
elif self.line_no:
|
||||
# If we have line number info, check if class has a static method with the same line number
|
||||
# This way, if we don't have the class name, we can still find the static method
|
||||
for body_node in node.body:
|
||||
if (
|
||||
isinstance(body_node, ast.FunctionDef)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from typing import Union
|
||||
from __future__ import annotations
|
||||
from typing import Union, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.models.models import BenchmarkDetail
|
||||
from codeflash.models.models import TestResults
|
||||
|
||||
|
||||
|
|
@ -18,8 +20,9 @@ class PrComment:
|
|||
speedup_pct: str
|
||||
winning_behavioral_test_results: TestResults
|
||||
winning_benchmarking_test_results: TestResults
|
||||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
|
||||
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
|
||||
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]:
|
||||
report_table = {
|
||||
test_type.to_name(): result
|
||||
for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items()
|
||||
|
|
@ -36,6 +39,7 @@ class PrComment:
|
|||
"speedup_pct": self.speedup_pct,
|
||||
"loop_count": self.winning_benchmarking_test_results.number_of_loops(),
|
||||
"report_table": report_table,
|
||||
"benchmark_details": self.benchmark_details if self.benchmark_details else None,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from rich.tree import Tree
|
||||
|
|
@ -11,7 +12,7 @@ if TYPE_CHECKING:
|
|||
import enum
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Collection, Iterator
|
||||
from collections.abc import Collection
|
||||
from enum import Enum, IntEnum
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
|
|
@ -22,7 +23,7 @@ from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
|||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from codeflash.cli_cmds.console import console, logger
|
||||
from codeflash.code_utils.code_utils import validate_python_code
|
||||
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
|
||||
from codeflash.code_utils.env_utils import is_end_to_end
|
||||
from codeflash.verification.comparator import comparator
|
||||
|
||||
|
|
@ -58,28 +59,74 @@ class FunctionSource:
|
|||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, FunctionSource):
|
||||
return False
|
||||
return (
|
||||
self.file_path == other.file_path
|
||||
and self.qualified_name == other.qualified_name
|
||||
and self.fully_qualified_name == other.fully_qualified_name
|
||||
and self.only_function_name == other.only_function_name
|
||||
and self.source_code == other.source_code
|
||||
)
|
||||
return (self.file_path == other.file_path and
|
||||
self.qualified_name == other.qualified_name and
|
||||
self.fully_qualified_name == other.fully_qualified_name and
|
||||
self.only_function_name == other.only_function_name and
|
||||
self.source_code == other.source_code)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(
|
||||
(self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code)
|
||||
)
|
||||
|
||||
return hash((self.file_path, self.qualified_name, self.fully_qualified_name,
|
||||
self.only_function_name, self.source_code))
|
||||
|
||||
class BestOptimization(BaseModel):
|
||||
candidate: OptimizedCandidate
|
||||
helper_functions: list[FunctionSource]
|
||||
runtime: int
|
||||
replay_performance_gain: Optional[dict[BenchmarkKey,float]] = None
|
||||
winning_behavioral_test_results: TestResults
|
||||
winning_benchmarking_test_results: TestResults
|
||||
winning_replay_benchmarking_test_results : Optional[TestResults] = None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BenchmarkKey:
|
||||
module_path: str
|
||||
function_name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.module_path}::{self.function_name}"
|
||||
|
||||
@dataclass
|
||||
class BenchmarkDetail:
|
||||
benchmark_name: str
|
||||
test_function: str
|
||||
original_timing: str
|
||||
expected_new_timing: str
|
||||
speedup_percent: float
|
||||
|
||||
def to_string(self) -> str:
|
||||
return (
|
||||
f"Original timing for {self.benchmark_name}::{self.test_function}: {self.original_timing}\n"
|
||||
f"Expected new timing for {self.benchmark_name}::{self.test_function}: {self.expected_new_timing}\n"
|
||||
f"Benchmark speedup for {self.benchmark_name}::{self.test_function}: {self.speedup_percent:.2f}%\n"
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, any]:
|
||||
return {
|
||||
"benchmark_name": self.benchmark_name,
|
||||
"test_function": self.test_function,
|
||||
"original_timing": self.original_timing,
|
||||
"expected_new_timing": self.expected_new_timing,
|
||||
"speedup_percent": self.speedup_percent
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class ProcessedBenchmarkInfo:
|
||||
benchmark_details: list[BenchmarkDetail]
|
||||
|
||||
def to_string(self) -> str:
|
||||
if not self.benchmark_details:
|
||||
return ""
|
||||
|
||||
result = "Benchmark Performance Details:\n"
|
||||
for detail in self.benchmark_details:
|
||||
result += detail.to_string() + "\n"
|
||||
return result
|
||||
|
||||
def to_dict(self) -> dict[str, list[dict[str, any]]]:
|
||||
return {
|
||||
"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]
|
||||
}
|
||||
class CodeString(BaseModel):
|
||||
code: Annotated[str, AfterValidator(validate_python_code)]
|
||||
file_path: Optional[Path] = None
|
||||
|
|
@ -104,8 +151,7 @@ class CodeOptimizationContext(BaseModel):
|
|||
read_writable_code: str = Field(min_length=1)
|
||||
read_only_context_code: str = ""
|
||||
helper_functions: list[FunctionSource]
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
|
||||
|
||||
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
|
||||
|
||||
class CodeContextType(str, Enum):
|
||||
READ_WRITABLE = "READ_WRITABLE"
|
||||
|
|
@ -118,6 +164,7 @@ class OptimizedCandidateResult(BaseModel):
|
|||
best_test_runtime: int
|
||||
behavior_test_results: TestResults
|
||||
benchmarking_test_results: TestResults
|
||||
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
|
||||
optimization_candidate_index: int
|
||||
total_candidate_timing: int
|
||||
|
||||
|
|
@ -222,6 +269,7 @@ class FunctionParent:
|
|||
class OriginalCodeBaseline(BaseModel):
|
||||
behavioral_test_results: TestResults
|
||||
benchmarking_test_results: TestResults
|
||||
replay_benchmarking_test_results: Optional[dict[BenchmarkKey, TestResults]] = None
|
||||
line_profile_results: dict
|
||||
runtime: int
|
||||
coverage_results: Optional[CoverageData]
|
||||
|
|
@ -299,7 +347,6 @@ class CoverageData:
|
|||
status=CoverageStatus.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCoverage:
|
||||
"""Represents the coverage data for a specific function in a source file."""
|
||||
|
|
@ -426,6 +473,20 @@ class TestResults(BaseModel):
|
|||
raise ValueError(msg)
|
||||
self.test_result_idx[k] = v + original_len
|
||||
|
||||
def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path) -> dict[BenchmarkKey, TestResults]:
|
||||
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
|
||||
test_results_by_benchmark = defaultdict(TestResults)
|
||||
benchmark_module_path = {}
|
||||
for benchmark_key in benchmark_keys:
|
||||
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root)
|
||||
for test_result in self.test_results:
|
||||
if (test_result.test_type == TestType.REPLAY_TEST):
|
||||
for benchmark_key, module_path in benchmark_module_path.items():
|
||||
if test_result.id.test_module_path.startswith(module_path):
|
||||
test_results_by_benchmark[benchmark_key].add(test_result)
|
||||
|
||||
return test_results_by_benchmark
|
||||
|
||||
def get_by_unique_invocation_loop_id(self, unique_invocation_loop_id: str) -> FunctionTestInvocation | None:
|
||||
try:
|
||||
return self.test_results[self.test_result_idx[unique_invocation_loop_id]]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from rich.syntax import Syntax
|
|||
from rich.tree import Tree
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.benchmarking.utils import process_benchmark_data
|
||||
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
|
||||
|
|
@ -42,7 +43,6 @@ from codeflash.code_utils.remove_generated_tests import remove_functions_from_ge
|
|||
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.context import code_context_extractor
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import Failure, Success, is_successful
|
||||
from codeflash.models.ExperimentMetadata import ExperimentMetadata
|
||||
from codeflash.models.models import (
|
||||
|
|
@ -76,8 +76,9 @@ from codeflash.verification.verifier import generate_tests
|
|||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.either import Result
|
||||
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
|
||||
from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
|
|
@ -90,7 +91,10 @@ class FunctionOptimizer:
|
|||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
aiservice_client: AiServiceClient | None = None,
|
||||
function_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
total_benchmark_timings: dict[BenchmarkKey, int] | None = None,
|
||||
args: Namespace | None = None,
|
||||
replay_tests_dir: Path|None = None
|
||||
) -> None:
|
||||
self.project_root = test_cfg.project_root_path
|
||||
self.test_cfg = test_cfg
|
||||
|
|
@ -113,11 +117,14 @@ class FunctionOptimizer:
|
|||
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
|
||||
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
|
||||
self.test_files = TestFiles(test_files=[])
|
||||
|
||||
self.args = args # Check defaults for these
|
||||
self.function_trace_id: str = str(uuid.uuid4())
|
||||
self.original_module_path = module_name_from_file_path(self.function_to_optimize.file_path, self.project_root)
|
||||
|
||||
self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {}
|
||||
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
|
||||
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None
|
||||
|
||||
def optimize_function(self) -> Result[BestOptimization, str]:
|
||||
should_run_experiment = self.experiment_id is not None
|
||||
logger.debug(f"Function Trace ID: {self.function_trace_id}")
|
||||
|
|
@ -136,8 +143,8 @@ class FunctionOptimizer:
|
|||
original_helper_code[helper_function_path] = helper_code
|
||||
if has_any_async_functions(code_context.read_writable_code):
|
||||
return Failure("Codeflash does not support async functions in the code to optimize.")
|
||||
code_print(code_context.read_writable_code)
|
||||
|
||||
code_print(code_context.read_writable_code)
|
||||
generated_test_paths = [
|
||||
get_test_file_path(
|
||||
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
|
||||
|
|
@ -261,6 +268,13 @@ class FunctionOptimizer:
|
|||
best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue"
|
||||
)
|
||||
)
|
||||
processed_benchmark_info = None
|
||||
if self.args.benchmark:
|
||||
processed_benchmark_info = process_benchmark_data(
|
||||
replay_performance_gain=best_optimization.replay_performance_gain,
|
||||
fto_benchmark_timings=self.function_benchmark_timings,
|
||||
total_benchmark_timings=self.total_benchmark_timings
|
||||
)
|
||||
explanation = Explanation(
|
||||
raw_explanation_message=best_optimization.candidate.explanation,
|
||||
winning_behavioral_test_results=best_optimization.winning_behavioral_test_results,
|
||||
|
|
@ -269,6 +283,7 @@ class FunctionOptimizer:
|
|||
best_runtime_ns=best_optimization.runtime,
|
||||
function_name=function_to_optimize_qualified_name,
|
||||
file_path=self.function_to_optimize.file_path,
|
||||
benchmark_details=processed_benchmark_info.benchmark_details if processed_benchmark_info else None
|
||||
)
|
||||
|
||||
self.log_successful_optimization(explanation, generated_tests)
|
||||
|
|
@ -362,7 +377,7 @@ class FunctionOptimizer:
|
|||
candidates = deque(candidates)
|
||||
# Start a new thread for AI service request, start loop in main thread
|
||||
# check if aiservice request is complete, when it is complete, append result to the candidates list
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future_line_profile_results = executor.submit(
|
||||
self.aiservice_client.optimize_python_code_line_profiler,
|
||||
source_code=code_context.read_writable_code,
|
||||
|
|
@ -382,8 +397,8 @@ class FunctionOptimizer:
|
|||
if done and (future_line_profile_results is not None):
|
||||
line_profile_results = future_line_profile_results.result()
|
||||
candidates.extend(line_profile_results)
|
||||
original_len+= len(line_profile_results)
|
||||
logger.info(f"Added {len(line_profile_results)} results from line profiler to candidates, total candidates now: {original_len}")
|
||||
original_len+= len(candidates)
|
||||
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
|
||||
future_line_profile_results = None
|
||||
candidate_index += 1
|
||||
candidate = candidates.popleft()
|
||||
|
|
@ -410,7 +425,7 @@ class FunctionOptimizer:
|
|||
)
|
||||
continue
|
||||
|
||||
# Instrument codeflash capture
|
||||
|
||||
run_results = self.run_optimized_candidate(
|
||||
optimization_candidate_index=candidate_index,
|
||||
baseline_results=original_code_baseline,
|
||||
|
|
@ -434,6 +449,7 @@ class FunctionOptimizer:
|
|||
speedup_ratios[candidate.optimization_id] = perf_gain
|
||||
|
||||
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
|
||||
benchmark_tree = None
|
||||
if speedup_critic(
|
||||
candidate_result, original_code_baseline.runtime, best_runtime_until_now
|
||||
) and quantity_of_tests_critic(candidate_result):
|
||||
|
|
@ -446,13 +462,29 @@ class FunctionOptimizer:
|
|||
)
|
||||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
|
||||
replay_perf_gain = {}
|
||||
if self.args.benchmark:
|
||||
test_results_by_benchmark = candidate_result.benchmarking_test_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
|
||||
if len(test_results_by_benchmark) > 0:
|
||||
benchmark_tree = Tree("Speedup percentage on benchmarks:")
|
||||
for benchmark_key, candidate_test_results in test_results_by_benchmark.items():
|
||||
|
||||
original_code_replay_runtime = original_code_baseline.replay_benchmarking_test_results[benchmark_key].total_passed_runtime()
|
||||
candidate_replay_runtime = candidate_test_results.total_passed_runtime()
|
||||
replay_perf_gain[benchmark_key] = performance_gain(
|
||||
original_runtime_ns=original_code_replay_runtime,
|
||||
optimized_runtime_ns=candidate_replay_runtime,
|
||||
)
|
||||
benchmark_tree.add(f"{benchmark_key}: {replay_perf_gain[benchmark_key] * 100:.1f}%")
|
||||
|
||||
best_optimization = BestOptimization(
|
||||
candidate=candidate,
|
||||
helper_functions=code_context.helper_functions,
|
||||
runtime=best_test_runtime,
|
||||
winning_behavioral_test_results=candidate_result.behavior_test_results,
|
||||
replay_performance_gain=replay_perf_gain if self.args.benchmark else None,
|
||||
winning_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
winning_replay_benchmarking_test_results=candidate_result.benchmarking_test_results,
|
||||
)
|
||||
best_runtime_until_now = best_test_runtime
|
||||
else:
|
||||
|
|
@ -464,6 +496,8 @@ class FunctionOptimizer:
|
|||
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
|
||||
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
|
||||
console.print(tree)
|
||||
if self.args.benchmark and benchmark_tree:
|
||||
console.print(benchmark_tree)
|
||||
console.rule()
|
||||
|
||||
self.write_code_and_helpers(
|
||||
|
|
@ -507,7 +541,8 @@ class FunctionOptimizer:
|
|||
)
|
||||
|
||||
console.print(Group(explanation_panel, tests_panel))
|
||||
console.print(explanation_panel)
|
||||
else:
|
||||
console.print(explanation_panel)
|
||||
|
||||
ph(
|
||||
"cli-optimize-success",
|
||||
|
|
@ -664,6 +699,7 @@ class FunctionOptimizer:
|
|||
|
||||
unique_instrumented_test_files.add(new_behavioral_test_path)
|
||||
unique_instrumented_test_files.add(new_perf_test_path)
|
||||
|
||||
if not self.test_files.get_by_original_file_path(path_obj_test_file):
|
||||
self.test_files.add(
|
||||
TestFile(
|
||||
|
|
@ -675,6 +711,7 @@ class FunctionOptimizer:
|
|||
tests_in_file=[t.tests_in_file for t in tests_in_file_list],
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Discovered {existing_test_files_count} existing unit test file"
|
||||
f"{'s' if existing_test_files_count != 1 else ''}, {replay_test_files_count} replay test file"
|
||||
|
|
@ -865,7 +902,6 @@ class FunctionOptimizer:
|
|||
enable_coverage=False,
|
||||
code_context=code_context,
|
||||
)
|
||||
|
||||
else:
|
||||
benchmarking_results = TestResults()
|
||||
start_time: float = time.time()
|
||||
|
|
@ -920,11 +956,15 @@ class FunctionOptimizer:
|
|||
)
|
||||
console.rule()
|
||||
logger.debug(f"Total original code runtime (ns): {total_timing}")
|
||||
|
||||
if self.args.benchmark:
|
||||
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
|
||||
return Success(
|
||||
(
|
||||
OriginalCodeBaseline(
|
||||
behavioral_test_results=behavioral_results,
|
||||
benchmarking_test_results=benchmarking_results,
|
||||
replay_benchmarking_test_results = replay_benchmarking_test_results if self.args.benchmark else None,
|
||||
runtime=total_timing,
|
||||
coverage_results=coverage_results,
|
||||
line_profile_results=line_profile_results,
|
||||
|
|
@ -954,8 +994,6 @@ class FunctionOptimizer:
|
|||
test_env["PYTHONPATH"] += os.pathsep + str(self.project_root)
|
||||
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
get_run_tmp_file(Path(f"test_return_values_{optimization_candidate_index}.sqlite")).unlink(missing_ok=True)
|
||||
|
||||
# Instrument codeflash capture
|
||||
candidate_fto_code = Path(self.function_to_optimize.file_path).read_text("utf-8")
|
||||
candidate_helper_code = {}
|
||||
|
|
@ -986,7 +1024,6 @@ class FunctionOptimizer:
|
|||
)
|
||||
)
|
||||
console.rule()
|
||||
|
||||
if compare_test_results(baseline_results.behavioral_test_results, candidate_behavior_results):
|
||||
logger.info("Test results matched!")
|
||||
console.rule()
|
||||
|
|
@ -1039,12 +1076,17 @@ class FunctionOptimizer:
|
|||
console.rule()
|
||||
|
||||
logger.debug(f"Total optimized code {optimization_candidate_index} runtime (ns): {total_candidate_timing}")
|
||||
if self.args.benchmark:
|
||||
candidate_replay_benchmarking_results = candidate_benchmarking_results.group_by_benchmarks(self.total_benchmark_timings.keys(), self.replay_tests_dir, self.project_root)
|
||||
for benchmark_name, benchmark_results in candidate_replay_benchmarking_results.items():
|
||||
logger.debug(f"Benchmark {benchmark_name} runtime (ns): {humanize_runtime(benchmark_results.total_passed_runtime())}")
|
||||
return Success(
|
||||
OptimizedCandidateResult(
|
||||
max_loop_count=loop_count,
|
||||
best_test_runtime=total_candidate_timing,
|
||||
behavior_test_results=candidate_behavior_results,
|
||||
benchmarking_test_results=candidate_benchmarking_results,
|
||||
replay_benchmarking_test_results = candidate_replay_benchmarking_results if self.args.benchmark else None,
|
||||
optimization_candidate_index=optimization_candidate_index,
|
||||
total_candidate_timing=total_candidate_timing,
|
||||
)
|
||||
|
|
@ -1086,8 +1128,8 @@ class FunctionOptimizer:
|
|||
pytest_cmd=self.test_cfg.pytest_cmd,
|
||||
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
|
||||
pytest_target_runtime_seconds=testing_time,
|
||||
pytest_min_loops=pytest_min_loops,
|
||||
pytest_max_loops=pytest_min_loops,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
test_framework=self.test_cfg.test_framework,
|
||||
line_profiler_output_file=line_profiler_output_file,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,10 +5,16 @@ import os
|
|||
import time
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from codeflash.api.aiservice import AiServiceClient, LocalAiServiceClient
|
||||
from codeflash.benchmarking.instrument_codeflash_trace import instrument_codeflash_trace_decorator
|
||||
from codeflash.benchmarking.plugin.plugin import CodeFlashBenchmarkPlugin
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
|
||||
from codeflash.cli_cmds.console import console, logger, progress_bar
|
||||
from codeflash.code_utils import env_utils
|
||||
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
|
||||
|
|
@ -17,7 +23,7 @@ from codeflash.code_utils.static_analysis import analyze_imported_modules, get_f
|
|||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.discovery.functions_to_optimize import get_functions_to_optimize
|
||||
from codeflash.either import is_successful
|
||||
from codeflash.models.models import TestType, ValidCode
|
||||
from codeflash.models.models import BenchmarkKey, TestType, ValidCode
|
||||
from codeflash.optimization.function_optimizer import FunctionOptimizer
|
||||
from codeflash.telemetry.posthog_cf import ph
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
|
@ -39,18 +45,21 @@ class Optimizer:
|
|||
project_root_path=args.project_root,
|
||||
test_framework=args.test_framework,
|
||||
pytest_cmd=args.pytest_cmd,
|
||||
benchmark_tests_root=args.benchmarks_root if "benchmark" in args and "benchmarks_root" in args else None,
|
||||
)
|
||||
|
||||
self.aiservice_client = AiServiceClient()
|
||||
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
|
||||
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
|
||||
|
||||
self.replay_tests_dir = None
|
||||
def create_function_optimizer(
|
||||
self,
|
||||
function_to_optimize: FunctionToOptimize,
|
||||
function_to_optimize_ast: ast.FunctionDef | None = None,
|
||||
function_to_tests: dict[str, list[FunctionCalledInTest]] | None = None,
|
||||
function_to_optimize_source_code: str | None = "",
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, float]] | None = None,
|
||||
total_benchmark_timings: dict[BenchmarkKey, float] | None = None,
|
||||
) -> FunctionOptimizer:
|
||||
return FunctionOptimizer(
|
||||
function_to_optimize=function_to_optimize,
|
||||
|
|
@ -60,6 +69,9 @@ class Optimizer:
|
|||
function_to_optimize_ast=function_to_optimize_ast,
|
||||
aiservice_client=self.aiservice_client,
|
||||
args=self.args,
|
||||
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
|
||||
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
|
||||
replay_tests_dir = self.replay_tests_dir
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
|
|
@ -71,6 +83,8 @@ class Optimizer:
|
|||
function_optimizer = None
|
||||
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
|
||||
num_optimizable_functions: int
|
||||
|
||||
# discover functions
|
||||
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
|
||||
optimize_all=self.args.all,
|
||||
replay_test=self.args.replay_test,
|
||||
|
|
@ -81,7 +95,42 @@ class Optimizer:
|
|||
project_root=self.args.project_root,
|
||||
module_root=self.args.module_root,
|
||||
)
|
||||
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
|
||||
total_benchmark_timings: dict[BenchmarkKey, int] = {}
|
||||
if self.args.benchmark:
|
||||
with progress_bar(
|
||||
f"Running benchmarks in {self.args.benchmarks_root}",
|
||||
transient=True,
|
||||
):
|
||||
# Insert decorator
|
||||
file_path_to_source_code = defaultdict(str)
|
||||
for file in file_to_funcs_to_optimize:
|
||||
with file.open("r", encoding="utf8") as f:
|
||||
file_path_to_source_code[file] = f.read()
|
||||
try:
|
||||
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
|
||||
trace_file = Path(self.args.benchmarks_root) / "benchmarks.trace"
|
||||
if trace_file.exists():
|
||||
trace_file.unlink()
|
||||
|
||||
self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root))
|
||||
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
|
||||
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
|
||||
if replay_count == 0:
|
||||
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
|
||||
else:
|
||||
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
|
||||
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
print_benchmark_table(function_to_results)
|
||||
except Exception as e:
|
||||
logger.info(f"Error while tracing existing benchmarks: {e}")
|
||||
logger.info("Information on existing benchmarks will not be available for this run.")
|
||||
finally:
|
||||
# Restore original source code
|
||||
for file in file_path_to_source_code:
|
||||
with file.open("w", encoding="utf8") as f:
|
||||
f.write(file_path_to_source_code[file])
|
||||
optimizations_found: int = 0
|
||||
function_iterator_count: int = 0
|
||||
if self.args.test_framework == "pytest":
|
||||
|
|
@ -103,6 +152,7 @@ class Optimizer:
|
|||
console.rule()
|
||||
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
|
||||
|
||||
|
||||
for original_module_path in file_to_funcs_to_optimize:
|
||||
logger.info(f"Examining file {original_module_path!s}…")
|
||||
console.rule()
|
||||
|
|
@ -159,12 +209,19 @@ class Optimizer:
|
|||
f"Skipping optimization."
|
||||
)
|
||||
continue
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize,
|
||||
function_to_optimize_ast,
|
||||
function_to_tests,
|
||||
validated_original_code[original_module_path].source_code,
|
||||
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
|
||||
self.args.project_root
|
||||
)
|
||||
if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings:
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings
|
||||
)
|
||||
else:
|
||||
function_optimizer = self.create_function_optimizer(
|
||||
function_to_optimize, function_to_optimize_ast, function_to_tests,
|
||||
validated_original_code[original_module_path].source_code
|
||||
)
|
||||
|
||||
best_optimization = function_optimizer.optimize_function()
|
||||
if is_successful(best_optimization):
|
||||
optimizations_found += 1
|
||||
|
|
@ -189,6 +246,10 @@ class Optimizer:
|
|||
test_file.instrumented_behavior_file_path.unlink(missing_ok=True)
|
||||
if function_optimizer.test_cfg.concolic_test_root_dir:
|
||||
shutil.rmtree(function_optimizer.test_cfg.concolic_test_root_dir, ignore_errors=True)
|
||||
if self.args.benchmark:
|
||||
if self.replay_tests_dir.exists():
|
||||
shutil.rmtree(self.replay_tests_dir, ignore_errors=True)
|
||||
trace_file.unlink(missing_ok=True)
|
||||
if hasattr(get_run_tmp_file, "tmpdir"):
|
||||
get_run_tmp_file.tmpdir.cleanup()
|
||||
|
||||
|
|
|
|||
0
codeflash/picklepatch/__init__.py
Normal file
0
codeflash/picklepatch/__init__.py
Normal file
346
codeflash/picklepatch/pickle_patcher.py
Normal file
346
codeflash/picklepatch/pickle_patcher.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"""PicklePatcher - A utility for safely pickling objects with unpicklable components.
|
||||
|
||||
This module provides functions to recursively pickle objects, replacing unpicklable
|
||||
components with placeholders that provide informative errors when accessed.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import types
|
||||
|
||||
import dill
|
||||
|
||||
from .pickle_placeholder import PicklePlaceholder
|
||||
|
||||
|
||||
class PicklePatcher:
|
||||
"""A utility class for safely pickling objects with unpicklable components.
|
||||
|
||||
This class provides methods to recursively pickle objects, replacing any
|
||||
components that can't be pickled with placeholder objects.
|
||||
"""
|
||||
|
||||
# Class-level cache of unpicklable types
|
||||
_unpicklable_types = set()
|
||||
|
||||
@staticmethod
|
||||
def dumps(obj, protocol=None, max_depth=100, **kwargs):
|
||||
"""Safely pickle an object, replacing unpicklable parts with placeholders.
|
||||
|
||||
Args:
|
||||
obj: The object to pickle
|
||||
protocol: The pickle protocol version to use
|
||||
max_depth: Maximum recursion depth
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
bytes: Pickled data with placeholders for unpicklable objects
|
||||
"""
|
||||
return PicklePatcher._recursive_pickle(obj, max_depth, path=[], protocol=protocol, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def loads(pickled_data):
|
||||
"""Unpickle data that may contain placeholders.
|
||||
|
||||
Args:
|
||||
pickled_data: Pickled data with possible placeholders
|
||||
|
||||
Returns:
|
||||
The unpickled object with placeholders for unpicklable parts
|
||||
"""
|
||||
try:
|
||||
# We use dill for loading since it can handle everything pickle can
|
||||
return dill.loads(pickled_data)
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_placeholder(obj, error_msg, path):
|
||||
"""Create a placeholder for an unpicklable object.
|
||||
|
||||
Args:
|
||||
obj: The original unpicklable object
|
||||
error_msg: Error message explaining why it couldn't be pickled
|
||||
path: Path to this object in the object graph
|
||||
|
||||
Returns:
|
||||
PicklePlaceholder: A placeholder object
|
||||
"""
|
||||
obj_type = type(obj)
|
||||
try:
|
||||
obj_str = str(obj)[:100] if hasattr(obj, "__str__") else f"<unprintable object of type {obj_type.__name__}>"
|
||||
except:
|
||||
obj_str = f"<unprintable object of type {obj_type.__name__}>"
|
||||
|
||||
print(f"Creating placeholder for {obj_type.__name__} at path {'->'.join(path) or 'root'}: {error_msg}")
|
||||
|
||||
placeholder = PicklePlaceholder(
|
||||
obj_type.__name__,
|
||||
obj_str,
|
||||
error_msg,
|
||||
path
|
||||
)
|
||||
|
||||
# Add this type to our known unpicklable types cache
|
||||
PicklePatcher._unpicklable_types.add(obj_type)
|
||||
return placeholder
|
||||
|
||||
@staticmethod
|
||||
def _pickle(obj, path=None, protocol=None, **kwargs):
|
||||
"""Try to pickle an object using pickle first, then dill. If both fail, create a placeholder.
|
||||
|
||||
Args:
|
||||
obj: The object to pickle
|
||||
path: Path to this object in the object graph
|
||||
protocol: The pickle protocol version to use
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
tuple: (success, result) where success is a boolean and result is either:
|
||||
- Pickled bytes if successful
|
||||
- Error message if not successful
|
||||
"""
|
||||
# Try standard pickle first
|
||||
try:
|
||||
return True, pickle.dumps(obj, protocol=protocol, **kwargs)
|
||||
except (pickle.PickleError, TypeError, AttributeError, ValueError) as e:
|
||||
# Then try dill (which is more powerful)
|
||||
try:
|
||||
return True, dill.dumps(obj, protocol=protocol, **kwargs)
|
||||
except (dill.PicklingError, TypeError, AttributeError, ValueError) as e:
|
||||
return False, str(e)
|
||||
|
||||
@staticmethod
|
||||
def _recursive_pickle(obj, max_depth, path=None, protocol=None, **kwargs):
|
||||
"""Recursively try to pickle an object, replacing unpicklable parts with placeholders.
|
||||
|
||||
Args:
|
||||
obj: The object to pickle
|
||||
max_depth: Maximum recursion depth
|
||||
path: Current path in the object graph
|
||||
protocol: The pickle protocol version to use
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
bytes: Pickled data with placeholders for unpicklable objects
|
||||
"""
|
||||
if path is None:
|
||||
path = []
|
||||
|
||||
obj_type = type(obj)
|
||||
|
||||
# Check if this type is known to be unpicklable
|
||||
if obj_type in PicklePatcher._unpicklable_types:
|
||||
placeholder = PicklePatcher._create_placeholder(
|
||||
obj,
|
||||
"Known unpicklable type",
|
||||
path
|
||||
)
|
||||
return dill.dumps(placeholder, protocol=protocol, **kwargs)
|
||||
|
||||
# Check for max depth
|
||||
if max_depth <= 0:
|
||||
placeholder = PicklePatcher._create_placeholder(
|
||||
obj,
|
||||
"Max recursion depth exceeded",
|
||||
path
|
||||
)
|
||||
return dill.dumps(placeholder, protocol=protocol, **kwargs)
|
||||
|
||||
# Try standard pickling
|
||||
success, result = PicklePatcher._pickle(obj, path, protocol, **kwargs)
|
||||
if success:
|
||||
return result
|
||||
|
||||
error_msg = result # Error message from pickling attempt
|
||||
|
||||
# Handle different container types
|
||||
if isinstance(obj, dict):
|
||||
return PicklePatcher._handle_dict(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
|
||||
elif isinstance(obj, (list, tuple, set)):
|
||||
return PicklePatcher._handle_sequence(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
|
||||
elif hasattr(obj, "__dict__"):
|
||||
result = PicklePatcher._handle_object(obj, max_depth, error_msg, path, protocol=protocol, **kwargs)
|
||||
|
||||
# If this was a failure, add the type to the cache
|
||||
unpickled = dill.loads(result)
|
||||
if isinstance(unpickled, PicklePlaceholder):
|
||||
PicklePatcher._unpicklable_types.add(obj_type)
|
||||
return result
|
||||
|
||||
# For other unpicklable objects, use a placeholder
|
||||
placeholder = PicklePatcher._create_placeholder(obj, error_msg, path)
|
||||
return dill.dumps(placeholder, protocol=protocol, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _handle_dict(obj_dict, max_depth, error_msg, path, protocol=None, **kwargs):
|
||||
"""Handle pickling for dictionary objects.
|
||||
|
||||
Args:
|
||||
obj_dict: The dictionary to pickle
|
||||
max_depth: Maximum recursion depth
|
||||
error_msg: Error message from the original pickling attempt
|
||||
path: Current path in the object graph
|
||||
protocol: The pickle protocol version to use
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
bytes: Pickled data with placeholders for unpicklable objects
|
||||
"""
|
||||
if not isinstance(obj_dict, dict):
|
||||
placeholder = PicklePatcher._create_placeholder(
|
||||
obj_dict,
|
||||
f"Expected a dictionary, got {type(obj_dict).__name__}",
|
||||
path
|
||||
)
|
||||
return dill.dumps(placeholder, protocol=protocol, **kwargs)
|
||||
|
||||
result = {}
|
||||
|
||||
for key, value in obj_dict.items():
|
||||
# Process the key
|
||||
key_success, key_result = PicklePatcher._pickle(key, path, protocol, **kwargs)
|
||||
if key_success:
|
||||
key_result = key
|
||||
else:
|
||||
# If the key can't be pickled, use a string representation
|
||||
try:
|
||||
key_str = str(key)[:50]
|
||||
except:
|
||||
key_str = f"<unprintable key of type {type(key).__name__}>"
|
||||
key_result = f"<unpicklable_key:{key_str}>"
|
||||
|
||||
# Process the value
|
||||
value_path = path + [f"[{repr(key)[:20]}]"]
|
||||
value_success, value_bytes = PicklePatcher._pickle(value, value_path, protocol, **kwargs)
|
||||
|
||||
if value_success:
|
||||
value_result = value
|
||||
else:
|
||||
# Try recursive pickling for the value
|
||||
try:
|
||||
value_bytes = PicklePatcher._recursive_pickle(
|
||||
value, max_depth - 1, value_path, protocol=protocol, **kwargs
|
||||
)
|
||||
value_result = dill.loads(value_bytes)
|
||||
except Exception as inner_e:
|
||||
value_result = PicklePatcher._create_placeholder(
|
||||
value,
|
||||
str(inner_e),
|
||||
value_path
|
||||
)
|
||||
|
||||
result[key_result] = value_result
|
||||
|
||||
return dill.dumps(result, protocol=protocol, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _handle_sequence(obj_seq, max_depth, error_msg, path, protocol=None, **kwargs):
|
||||
"""Handle pickling for sequence types (list, tuple, set).
|
||||
|
||||
Args:
|
||||
obj_seq: The sequence to pickle
|
||||
max_depth: Maximum recursion depth
|
||||
error_msg: Error message from the original pickling attempt
|
||||
path: Current path in the object graph
|
||||
protocol: The pickle protocol version to use
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
bytes: Pickled data with placeholders for unpicklable objects
|
||||
"""
|
||||
result = []
|
||||
|
||||
for i, item in enumerate(obj_seq):
|
||||
item_path = path + [f"[{i}]"]
|
||||
|
||||
# Try to pickle the item directly
|
||||
success, _ = PicklePatcher._pickle(item, item_path, protocol, **kwargs)
|
||||
if success:
|
||||
result.append(item)
|
||||
continue
|
||||
|
||||
# If we couldn't pickle directly, try recursively
|
||||
try:
|
||||
item_bytes = PicklePatcher._recursive_pickle(
|
||||
item, max_depth - 1, item_path, protocol=protocol, **kwargs
|
||||
)
|
||||
result.append(dill.loads(item_bytes))
|
||||
except Exception as inner_e:
|
||||
# If recursive pickling fails, use a placeholder
|
||||
placeholder = PicklePatcher._create_placeholder(
|
||||
item,
|
||||
str(inner_e),
|
||||
item_path
|
||||
)
|
||||
result.append(placeholder)
|
||||
|
||||
# Convert back to the original type
|
||||
if isinstance(obj_seq, tuple):
|
||||
result = tuple(result)
|
||||
elif isinstance(obj_seq, set):
|
||||
# Try to create a set from the result
|
||||
try:
|
||||
result = set(result)
|
||||
except Exception:
|
||||
# If we can't create a set (unhashable items), keep it as a list
|
||||
pass
|
||||
|
||||
return dill.dumps(result, protocol=protocol, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _handle_object(obj, max_depth, error_msg, path, protocol=None, **kwargs):
|
||||
"""Handle pickling for custom objects with __dict__.
|
||||
|
||||
Args:
|
||||
obj: The object to pickle
|
||||
max_depth: Maximum recursion depth
|
||||
error_msg: Error message from the original pickling attempt
|
||||
path: Current path in the object graph
|
||||
protocol: The pickle protocol version to use
|
||||
**kwargs: Additional arguments for pickle/dill.dumps
|
||||
|
||||
Returns:
|
||||
bytes: Pickled data with placeholders for unpicklable objects
|
||||
"""
|
||||
# Try to create a new instance of the same class
|
||||
try:
|
||||
# First try to create an empty instance
|
||||
new_obj = object.__new__(type(obj))
|
||||
|
||||
# Handle __dict__ attributes if they exist
|
||||
if hasattr(obj, "__dict__"):
|
||||
for attr_name, attr_value in obj.__dict__.items():
|
||||
attr_path = path + [attr_name]
|
||||
|
||||
# Try to pickle directly first
|
||||
success, _ = PicklePatcher._pickle(attr_value, attr_path, protocol, **kwargs)
|
||||
if success:
|
||||
setattr(new_obj, attr_name, attr_value)
|
||||
continue
|
||||
|
||||
# If direct pickling fails, try recursive pickling
|
||||
try:
|
||||
attr_bytes = PicklePatcher._recursive_pickle(
|
||||
attr_value, max_depth - 1, attr_path, protocol=protocol, **kwargs
|
||||
)
|
||||
setattr(new_obj, attr_name, dill.loads(attr_bytes))
|
||||
except Exception as inner_e:
|
||||
# Use placeholder for unpicklable attribute
|
||||
placeholder = PicklePatcher._create_placeholder(
|
||||
attr_value,
|
||||
str(inner_e),
|
||||
attr_path
|
||||
)
|
||||
setattr(new_obj, attr_name, placeholder)
|
||||
|
||||
# Try to pickle the patched object
|
||||
success, result = PicklePatcher._pickle(new_obj, path, protocol, **kwargs)
|
||||
if success:
|
||||
return result
|
||||
# Fall through to placeholder creation
|
||||
except Exception:
|
||||
pass # Fall through to placeholder creation
|
||||
|
||||
# If we get here, just use a placeholder
|
||||
placeholder = PicklePatcher._create_placeholder(obj, error_msg, path)
|
||||
return dill.dumps(placeholder, protocol=protocol, **kwargs)
|
||||
71
codeflash/picklepatch/pickle_placeholder.py
Normal file
71
codeflash/picklepatch/pickle_placeholder.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
class PicklePlaceholderAccessError(Exception):
|
||||
"""Custom exception raised when attempting to access an unpicklable object."""
|
||||
|
||||
|
||||
|
||||
class PicklePlaceholder:
|
||||
"""A placeholder for an object that couldn't be pickled.
|
||||
|
||||
When unpickled, any attempt to access attributes or call methods on this
|
||||
placeholder will raise a PicklePlaceholderAccessError.
|
||||
"""
|
||||
|
||||
def __init__(self, obj_type, obj_str, error_msg, path=None):
|
||||
"""Initialize a placeholder for an unpicklable object.
|
||||
|
||||
Args:
|
||||
obj_type (str): The type name of the original object
|
||||
obj_str (str): String representation of the original object
|
||||
error_msg (str): The error message that occurred during pickling
|
||||
path (list, optional): Path to this object in the original object graph
|
||||
|
||||
"""
|
||||
# Store these directly in __dict__ to avoid __getattr__ recursion
|
||||
self.__dict__["obj_type"] = obj_type
|
||||
self.__dict__["obj_str"] = obj_str
|
||||
self.__dict__["error_msg"] = error_msg
|
||||
self.__dict__["path"] = path if path is not None else []
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Raise a custom error when any attribute is accessed."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
raise PicklePlaceholderAccessError(
|
||||
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
|
||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Prevent setting attributes."""
|
||||
self.__getattr__(name) # This will raise our custom error
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Raise a custom error when the object is called."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
raise PicklePlaceholderAccessError(
|
||||
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
|
||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""Return a string representation of the placeholder."""
|
||||
try:
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root"
|
||||
return f"<PicklePlaceholder at {path_str}: {self.__dict__['obj_type']} {self.__dict__['obj_str']}>"
|
||||
except:
|
||||
return "<PicklePlaceholder: (error displaying details)>"
|
||||
|
||||
def __str__(self):
|
||||
"""Return a string representation of the placeholder."""
|
||||
return self.__repr__()
|
||||
|
||||
def __reduce__(self):
|
||||
"""Make sure pickling of the placeholder itself works correctly."""
|
||||
return (
|
||||
PicklePlaceholder,
|
||||
(
|
||||
self.__dict__["obj_type"],
|
||||
self.__dict__["obj_str"],
|
||||
self.__dict__["error_msg"],
|
||||
self.__dict__["path"]
|
||||
)
|
||||
)
|
||||
|
|
@ -77,6 +77,7 @@ def check_create_pr(
|
|||
speedup_pct=explanation.speedup_pct,
|
||||
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
|
||||
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
|
||||
benchmark_details=explanation.benchmark_details
|
||||
),
|
||||
existing_tests=existing_tests_source,
|
||||
generated_tests=generated_original_test_source,
|
||||
|
|
@ -123,6 +124,7 @@ def check_create_pr(
|
|||
speedup_pct=explanation.speedup_pct,
|
||||
winning_behavioral_test_results=explanation.winning_behavioral_test_results,
|
||||
winning_benchmarking_test_results=explanation.winning_benchmarking_test_results,
|
||||
benchmark_details=explanation.benchmark_details
|
||||
),
|
||||
existing_tests=existing_tests_source,
|
||||
generated_tests=generated_original_test_source,
|
||||
|
|
|
|||
|
|
@ -1,9 +1,16 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from codeflash.code_utils.time_utils import humanize_runtime
|
||||
from codeflash.models.models import TestResults
|
||||
from codeflash.models.models import BenchmarkDetail, TestResults
|
||||
|
||||
|
||||
@dataclass(frozen=True, config={"arbitrary_types_allowed": True})
|
||||
|
|
@ -15,6 +22,7 @@ class Explanation:
|
|||
best_runtime_ns: int
|
||||
function_name: str
|
||||
file_path: Path
|
||||
benchmark_details: Optional[list[BenchmarkDetail]] = None
|
||||
|
||||
@property
|
||||
def perf_improvement_line(self) -> str:
|
||||
|
|
@ -37,16 +45,55 @@ class Explanation:
|
|||
# TODO: Sometimes the explanation says something similar to "This is the code that was optimized", remove such parts
|
||||
original_runtime_human = humanize_runtime(self.original_runtime_ns)
|
||||
best_runtime_human = humanize_runtime(self.best_runtime_ns)
|
||||
benchmark_info = ""
|
||||
|
||||
if self.benchmark_details:
|
||||
# Get terminal width (or use a reasonable default if detection fails)
|
||||
try:
|
||||
terminal_width = int(shutil.get_terminal_size().columns * 0.9)
|
||||
except Exception:
|
||||
terminal_width = 200 # Fallback width
|
||||
|
||||
# Create a rich table for better formatting
|
||||
table = Table(title="Benchmark Performance Details", width=terminal_width, show_lines=True)
|
||||
|
||||
# Add columns - split Benchmark File and Function into separate columns
|
||||
# Using proportional width for benchmark file column (40% of terminal width)
|
||||
benchmark_col_width = max(int(terminal_width * 0.4), 40)
|
||||
table.add_column("Benchmark Module Path", style="cyan", width=benchmark_col_width, overflow="fold")
|
||||
table.add_column("Test Function", style="cyan", overflow="fold")
|
||||
table.add_column("Original Runtime", style="magenta", justify="right")
|
||||
table.add_column("Expected New Runtime", style="green", justify="right")
|
||||
table.add_column("Speedup", style="red", justify="right")
|
||||
|
||||
# Add rows with split data
|
||||
for detail in self.benchmark_details:
|
||||
# Split the benchmark name and test function
|
||||
benchmark_name = detail.benchmark_name
|
||||
test_function = detail.test_function
|
||||
|
||||
table.add_row(
|
||||
benchmark_name,
|
||||
test_function,
|
||||
f"{detail.original_timing}",
|
||||
f"{detail.expected_new_timing}",
|
||||
f"{detail.speedup_percent:.2f}%"
|
||||
)
|
||||
# Convert table to string
|
||||
string_buffer = StringIO()
|
||||
console = Console(file=string_buffer, width=terminal_width)
|
||||
console.print(table)
|
||||
benchmark_info = cast(StringIO, console.file).getvalue() + "\n" # Cast for mypy
|
||||
|
||||
return (
|
||||
f"Optimized {self.function_name} in {self.file_path}\n"
|
||||
f"{self.perf_improvement_line}\n"
|
||||
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
+ "Explanation:\n"
|
||||
+ self.raw_explanation_message
|
||||
+ " \n\n"
|
||||
+ "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n"
|
||||
f"Optimized {self.function_name} in {self.file_path}\n"
|
||||
f"{self.perf_improvement_line}\n"
|
||||
f"Runtime went down from {original_runtime_human} to {best_runtime_human} \n\n"
|
||||
+ (benchmark_info if benchmark_info else "")
|
||||
+ self.raw_explanation_message
|
||||
+ " \n\n"
|
||||
+ "The new optimized code was tested for correctness. The results are listed below.\n"
|
||||
+ f"{TestResults.report_to_string(self.winning_behavioral_test_results.get_test_pass_fail_report_by_type())}\n"
|
||||
)
|
||||
|
||||
def explanation_message(self) -> str:
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
import sentry_sdk
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
|
@ -90,6 +91,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
return True
|
||||
return math.isclose(orig, new)
|
||||
if isinstance(orig, BaseException):
|
||||
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
|
||||
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
|
||||
# The test results should be rejected as the behavior of the unpickleable object is unknown.
|
||||
logger.debug("Unable to verify behavior of unpickleable object in replay test")
|
||||
return False
|
||||
# if str(orig) != str(new):
|
||||
# return False
|
||||
# compare the attributes of the two exception objects to determine if they are equivalent.
|
||||
|
|
|
|||
|
|
@ -75,3 +75,4 @@ class TestConfig:
|
|||
# or for unittest - project_root_from_module_root(args.tests_root, pyproject_file_path)
|
||||
concolic_test_root_dir: Optional[Path] = None
|
||||
pytest_cmd: str = "pytest"
|
||||
benchmark_tests_root: Optional[Path] = None
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ types-openpyxl = ">=3.1.5.20241020"
|
|||
types-regex = ">=2024.9.11.20240912"
|
||||
types-python-dateutil = ">=2.9.0.20241003"
|
||||
pytest-cov = "^6.0.0"
|
||||
pytest-benchmark = ">=5.1.0"
|
||||
types-gevent = "^24.11.0.20241230"
|
||||
types-greenlet = "^3.1.0.20241221"
|
||||
types-pexpect = "^4.9.0.20241208"
|
||||
|
|
@ -219,6 +220,7 @@ initial-content = """
|
|||
[tool.codeflash]
|
||||
module-root = "codeflash"
|
||||
tests-root = "tests"
|
||||
benchmarks-root = "tests/benchmarks"
|
||||
test-framework = "pytest"
|
||||
formatter-cmds = [
|
||||
"uvx ruff check --exit-zero --fix $file",
|
||||
|
|
|
|||
31
tests/benchmarks/test_benchmark_code_extract_code_context.py
Normal file
31
tests/benchmarks/test_benchmark_code_extract_code_context.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.context.code_context_extractor import get_code_optimization_context
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import FunctionParent
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
|
||||
|
||||
def test_benchmark_extract(benchmark)->None:
|
||||
file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=file_path.resolve(),
|
||||
disable_telemetry=True,
|
||||
tests_root=(file_path / "tests").resolve(),
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=Path.cwd(),
|
||||
)
|
||||
)
|
||||
function_to_optimize = FunctionToOptimize(
|
||||
function_name="replace_function_and_helpers_with_optimized_code",
|
||||
file_path=file_path / "optimization" / "function_optimizer.py",
|
||||
parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")],
|
||||
starting_line=None,
|
||||
ending_line=None,
|
||||
)
|
||||
|
||||
benchmark(get_code_optimization_context,function_to_optimize, opt.args.project_root)
|
||||
26
tests/benchmarks/test_benchmark_discover_unit_tests.py
Normal file
26
tests/benchmarks/test_benchmark_discover_unit_tests.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from pathlib import Path
|
||||
|
||||
from codeflash.discovery.discover_unit_tests import discover_unit_tests
|
||||
from codeflash.verification.verification_utils import TestConfig
|
||||
|
||||
|
||||
def test_benchmark_code_to_optimize_test_discovery(benchmark) -> None:
|
||||
project_path = Path(__file__).parent.parent.parent.resolve() / "code_to_optimize"
|
||||
tests_path = project_path / "tests" / "pytest"
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_path,
|
||||
project_root_path=project_path,
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
benchmark(discover_unit_tests, test_config)
|
||||
def test_benchmark_codeflash_test_discovery(benchmark) -> None:
|
||||
project_path = Path(__file__).parent.parent.parent.resolve() / "codeflash"
|
||||
tests_path = project_path / "tests"
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_path,
|
||||
project_root_path=project_path,
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
benchmark(discover_unit_tests, test_config)
|
||||
71
tests/benchmarks/test_benchmark_merge_test_results.py
Normal file
71
tests/benchmarks/test_benchmark_merge_test_results.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
from codeflash.verification.parse_test_output import merge_test_results
|
||||
|
||||
|
||||
def generate_test_invocations(count=100):
|
||||
"""Generate a set number of test invocations for benchmarking."""
|
||||
test_results_xml = TestResults()
|
||||
test_results_bin = TestResults()
|
||||
|
||||
# Generate test invocations in a loop
|
||||
for i in range(count):
|
||||
iteration_id = str(i * 3 + 5) # Generate unique iteration IDs
|
||||
|
||||
# XML results - some with None runtime
|
||||
test_results_xml.add(
|
||||
FunctionTestInvocation(
|
||||
id=InvocationId(
|
||||
test_module_path="code_to_optimize.tests.unittest.test_bubble_sort",
|
||||
test_class_name="TestPigLatin",
|
||||
test_function_name="test_sort",
|
||||
function_getting_tested="sorter",
|
||||
iteration_id=iteration_id,
|
||||
),
|
||||
file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py",
|
||||
did_pass=True,
|
||||
runtime=None if i % 3 == 0 else i * 100, # Vary runtime values
|
||||
test_framework="unittest",
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
return_value=None,
|
||||
timed_out=False,
|
||||
loop_index=i,
|
||||
)
|
||||
)
|
||||
|
||||
# Binary results - with actual runtime values
|
||||
test_results_bin.add(
|
||||
FunctionTestInvocation(
|
||||
id=InvocationId(
|
||||
test_module_path="code_to_optimize.tests.unittest.test_bubble_sort",
|
||||
test_class_name="TestPigLatin",
|
||||
test_function_name="test_sort",
|
||||
function_getting_tested="sorter",
|
||||
iteration_id=iteration_id,
|
||||
),
|
||||
file_name="/tmp/tests/unittest/test_bubble_sort__perfinstrumented.py",
|
||||
did_pass=True,
|
||||
runtime=500 + i * 20, # Generate varying runtime values
|
||||
test_framework="unittest",
|
||||
test_type=TestType.EXISTING_UNIT_TEST,
|
||||
return_value=None,
|
||||
timed_out=False,
|
||||
loop_index=i,
|
||||
)
|
||||
)
|
||||
|
||||
return test_results_xml, test_results_bin
|
||||
|
||||
|
||||
def run_merge_benchmark(count=100):
|
||||
test_results_xml, test_results_bin = generate_test_invocations(count)
|
||||
|
||||
# Perform the merge operation that will be benchmarked
|
||||
merge_test_results(
|
||||
xml_test_results=test_results_xml,
|
||||
bin_test_results=test_results_bin,
|
||||
test_framework="unittest"
|
||||
)
|
||||
|
||||
|
||||
def test_benchmark_merge_test_results(benchmark):
|
||||
benchmark(run_merge_benchmark, 1000) # Default to 100 test invocations
|
||||
26
tests/scripts/end_to_end_test_benchmark_sort.py
Normal file
26
tests/scripts/end_to_end_test_benchmark_sort.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import os
|
||||
import pathlib
|
||||
|
||||
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
|
||||
|
||||
|
||||
def run_test(expected_improvement_pct: int) -> bool:
|
||||
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
|
||||
config = TestConfig(
|
||||
file_path=pathlib.Path("bubble_sort.py"),
|
||||
function_name="sorter",
|
||||
benchmarks_root=cwd / "tests" / "pytest" / "benchmarks",
|
||||
test_framework="pytest",
|
||||
min_improvement_x=1.0,
|
||||
coverage_expectations=[
|
||||
CoverageExpectation(
|
||||
function_name="sorter", expected_coverage=100.0, expected_lines=[2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return run_codeflash_command(cwd, config, expected_improvement_pct)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 5))))
|
||||
|
|
@ -26,6 +26,7 @@ class TestConfig:
|
|||
min_improvement_x: float = 0.1
|
||||
trace_mode: bool = False
|
||||
coverage_expectations: list[CoverageExpectation] = field(default_factory=list)
|
||||
benchmarks_root: Optional[pathlib.Path] = None
|
||||
|
||||
|
||||
def clear_directory(directory_path: str | pathlib.Path) -> None:
|
||||
|
|
@ -85,8 +86,8 @@ def run_codeflash_command(
|
|||
path_to_file = cwd / config.file_path
|
||||
file_contents = path_to_file.read_text("utf-8")
|
||||
test_root = cwd / "tests" / (config.test_framework or "")
|
||||
command = build_command(cwd, config, test_root)
|
||||
|
||||
command = build_command(cwd, config, test_root, config.benchmarks_root if config.benchmarks_root else None)
|
||||
process = subprocess.Popen(
|
||||
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, cwd=str(cwd), env=os.environ.copy()
|
||||
)
|
||||
|
|
@ -116,7 +117,7 @@ def run_codeflash_command(
|
|||
return validated
|
||||
|
||||
|
||||
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path) -> list[str]:
|
||||
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root:pathlib.Path|None = None) -> list[str]:
|
||||
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
|
||||
|
||||
base_command = ["python", python_path, "--file", config.file_path, "--no-pr"]
|
||||
|
|
@ -127,7 +128,8 @@ def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path
|
|||
base_command.extend(
|
||||
["--test-framework", config.test_framework, "--tests-root", str(test_root), "--module-root", str(cwd)]
|
||||
)
|
||||
|
||||
if benchmarks_root:
|
||||
base_command.extend(["--benchmark", "--benchmarks-root", str(benchmarks_root)])
|
||||
return base_command
|
||||
|
||||
|
||||
|
|
|
|||
15
tests/test_codeflash_trace_decorator.py
Normal file
15
tests/test_codeflash_trace_decorator.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
from pathlib import Path
|
||||
from codeflash.code_utils.code_utils import get_run_tmp_file
|
||||
|
||||
@codeflash_trace
|
||||
def example_function(arr):
|
||||
arr.sort()
|
||||
return arr
|
||||
|
||||
|
||||
def test_codeflash_trace_decorator():
|
||||
arr = [3, 1, 2]
|
||||
result = example_function(arr)
|
||||
# cleanup test trace file using Path
|
||||
assert result == [1, 2, 3]
|
||||
547
tests/test_instrument_codeflash_trace.py
Normal file
547
tests/test_instrument_codeflash_trace.py
Normal file
|
|
@ -0,0 +1,547 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.benchmarking.instrument_codeflash_trace import add_codeflash_decorator_to_code, \
|
||||
instrument_codeflash_trace_decorator
|
||||
from codeflash.discovery.functions_to_optimize import FunctionParent, FunctionToOptimize
|
||||
|
||||
|
||||
def test_add_decorator_to_normal_function() -> None:
|
||||
"""Test adding decorator to a normal function."""
|
||||
code = """
|
||||
def normal_function():
|
||||
return "Hello, World!"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="normal_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
@codeflash_trace
|
||||
def normal_function():
|
||||
return "Hello, World!"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_normal_method() -> None:
|
||||
"""Test adding decorator to a normal method."""
|
||||
code = """
|
||||
class TestClass:
|
||||
def normal_method(self):
|
||||
return "Hello from method"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="normal_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@codeflash_trace
|
||||
def normal_method(self):
|
||||
return "Hello from method"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_classmethod() -> None:
|
||||
"""Test adding decorator to a classmethod."""
|
||||
code = """
|
||||
class TestClass:
|
||||
@classmethod
|
||||
def class_method(cls):
|
||||
return "Hello from classmethod"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="class_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@classmethod
|
||||
@codeflash_trace
|
||||
def class_method(cls):
|
||||
return "Hello from classmethod"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_staticmethod() -> None:
|
||||
"""Test adding decorator to a staticmethod."""
|
||||
code = """
|
||||
class TestClass:
|
||||
@staticmethod
|
||||
def static_method():
|
||||
return "Hello from staticmethod"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="static_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@staticmethod
|
||||
@codeflash_trace
|
||||
def static_method():
|
||||
return "Hello from staticmethod"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_init_function() -> None:
|
||||
"""Test adding decorator to an __init__ function."""
|
||||
code = """
|
||||
class TestClass:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="__init__",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@codeflash_trace
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_with_multiple_decorators() -> None:
|
||||
"""Test adding decorator to a function with multiple existing decorators."""
|
||||
code = """
|
||||
class TestClass:
|
||||
@property
|
||||
@other_decorator
|
||||
def property_method(self):
|
||||
return self._value
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="property_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@property
|
||||
@other_decorator
|
||||
@codeflash_trace
|
||||
def property_method(self):
|
||||
return self._value
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_function_in_multiple_classes() -> None:
|
||||
"""Test that only the right class's method gets the decorator."""
|
||||
code = """
|
||||
class TestClass:
|
||||
def test_method(self):
|
||||
return "This should get decorated"
|
||||
|
||||
class OtherClass:
|
||||
def test_method(self):
|
||||
return "This should NOT get decorated"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="test_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class TestClass:
|
||||
@codeflash_trace
|
||||
def test_method(self):
|
||||
return "This should get decorated"
|
||||
|
||||
class OtherClass:
|
||||
def test_method(self):
|
||||
return "This should NOT get decorated"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_nonexistent_function() -> None:
|
||||
"""Test that code remains unchanged when function doesn't exist."""
|
||||
code = """
|
||||
def existing_function():
|
||||
return "This exists"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="nonexistent_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
# Code should remain unchanged
|
||||
assert modified_code.strip() == code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_multiple_functions() -> None:
|
||||
"""Test adding decorator to multiple functions."""
|
||||
code = """
|
||||
def function_one():
|
||||
return "First function"
|
||||
|
||||
class TestClass:
|
||||
def method_one(self):
|
||||
return "First method"
|
||||
|
||||
def method_two(self):
|
||||
return "Second method"
|
||||
|
||||
def function_two():
|
||||
return "Second function"
|
||||
"""
|
||||
|
||||
functions_to_optimize = [
|
||||
FunctionToOptimize(
|
||||
function_name="function_one",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
),
|
||||
FunctionToOptimize(
|
||||
function_name="method_two",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
),
|
||||
FunctionToOptimize(
|
||||
function_name="function_two",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
]
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=functions_to_optimize
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
@codeflash_trace
|
||||
def function_one():
|
||||
return "First function"
|
||||
|
||||
class TestClass:
|
||||
def method_one(self):
|
||||
return "First method"
|
||||
|
||||
@codeflash_trace
|
||||
def method_two(self):
|
||||
return "Second method"
|
||||
|
||||
@codeflash_trace
|
||||
def function_two():
|
||||
return "Second function"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_instrument_codeflash_trace_decorator_single_file() -> None:
|
||||
"""Test instrumenting codeflash trace decorator on a single file."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create a test Python file
|
||||
test_file_path = Path(temp_dir) / "test_module.py"
|
||||
test_file_content = """
|
||||
def function_one():
|
||||
return "First function"
|
||||
|
||||
class TestClass:
|
||||
def method_one(self):
|
||||
return "First method"
|
||||
|
||||
def method_two(self):
|
||||
return "Second method"
|
||||
|
||||
def function_two():
|
||||
return "Second function"
|
||||
"""
|
||||
test_file_path.write_text(test_file_content, encoding="utf-8")
|
||||
|
||||
# Define functions to optimize
|
||||
functions_to_optimize = [
|
||||
FunctionToOptimize(
|
||||
function_name="function_one",
|
||||
file_path=test_file_path,
|
||||
parents=[]
|
||||
),
|
||||
FunctionToOptimize(
|
||||
function_name="method_two",
|
||||
file_path=test_file_path,
|
||||
parents=[FunctionParent(name="TestClass", type="ClassDef")]
|
||||
)
|
||||
]
|
||||
|
||||
# Execute the function being tested
|
||||
instrument_codeflash_trace_decorator({test_file_path: functions_to_optimize})
|
||||
|
||||
# Read the modified file
|
||||
modified_content = test_file_path.read_text(encoding="utf-8")
|
||||
|
||||
# Define expected content (with isort applied)
|
||||
expected_content = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
@codeflash_trace
|
||||
def function_one():
|
||||
return "First function"
|
||||
|
||||
class TestClass:
|
||||
def method_one(self):
|
||||
return "First method"
|
||||
|
||||
@codeflash_trace
|
||||
def method_two(self):
|
||||
return "Second method"
|
||||
|
||||
def function_two():
|
||||
return "Second function"
|
||||
"""
|
||||
|
||||
# Compare the modified content with expected content
|
||||
assert modified_content.strip() == expected_content.strip()
|
||||
|
||||
|
||||
def test_instrument_codeflash_trace_decorator_multiple_files() -> None:
|
||||
"""Test instrumenting codeflash trace decorator on multiple files."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Create first test Python file
|
||||
test_file_1_path = Path(temp_dir) / "module_a.py"
|
||||
test_file_1_content = """
|
||||
def function_a():
|
||||
return "Function in module A"
|
||||
|
||||
class ClassA:
|
||||
def method_a(self):
|
||||
return "Method in ClassA"
|
||||
"""
|
||||
test_file_1_path.write_text(test_file_1_content, encoding="utf-8")
|
||||
|
||||
# Create second test Python file
|
||||
test_file_2_path = Path(temp_dir) / "module_b.py"
|
||||
test_file_2_content ="""
|
||||
def function_b():
|
||||
return "Function in module B"
|
||||
|
||||
class ClassB:
|
||||
@staticmethod
|
||||
def static_method_b():
|
||||
return "Static method in ClassB"
|
||||
"""
|
||||
test_file_2_path.write_text(test_file_2_content, encoding="utf-8")
|
||||
|
||||
# Define functions to optimize
|
||||
file_to_funcs_to_optimize = {
|
||||
test_file_1_path: [
|
||||
FunctionToOptimize(
|
||||
function_name="function_a",
|
||||
file_path=test_file_1_path,
|
||||
parents=[]
|
||||
)
|
||||
],
|
||||
test_file_2_path: [
|
||||
FunctionToOptimize(
|
||||
function_name="static_method_b",
|
||||
file_path=test_file_2_path,
|
||||
parents=[FunctionParent(name="ClassB", type="ClassDef")]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
# Execute the function being tested
|
||||
instrument_codeflash_trace_decorator(file_to_funcs_to_optimize)
|
||||
|
||||
# Read the modified files
|
||||
modified_content_1 = test_file_1_path.read_text(encoding="utf-8")
|
||||
modified_content_2 = test_file_2_path.read_text(encoding="utf-8")
|
||||
|
||||
# Define expected content for first file (with isort applied)
|
||||
expected_content_1 = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
@codeflash_trace
|
||||
def function_a():
|
||||
return "Function in module A"
|
||||
|
||||
class ClassA:
|
||||
def method_a(self):
|
||||
return "Method in ClassA"
|
||||
"""
|
||||
|
||||
# Define expected content for second file (with isort applied)
|
||||
expected_content_2 = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
def function_b():
|
||||
return "Function in module B"
|
||||
|
||||
class ClassB:
|
||||
@staticmethod
|
||||
@codeflash_trace
|
||||
def static_method_b():
|
||||
return "Static method in ClassB"
|
||||
"""
|
||||
|
||||
# Compare the modified content with expected content
|
||||
assert modified_content_1.strip() == expected_content_1.strip()
|
||||
assert modified_content_2.strip() == expected_content_2.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_method_after_nested_class() -> None:
|
||||
"""Test adding decorator to a method that appears after a nested class definition."""
|
||||
code = """
|
||||
class OuterClass:
|
||||
class NestedClass:
|
||||
def nested_method(self):
|
||||
return "Hello from nested class method"
|
||||
|
||||
def target_method(self):
|
||||
return "Hello from target method after nested class"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="target_method",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
class OuterClass:
|
||||
class NestedClass:
|
||||
def nested_method(self):
|
||||
return "Hello from nested class method"
|
||||
|
||||
@codeflash_trace
|
||||
def target_method(self):
|
||||
return "Hello from target method after nested class"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
|
||||
|
||||
def test_add_decorator_to_function_after_nested_function() -> None:
|
||||
"""Test adding decorator to a function that appears after a function with a nested function."""
|
||||
code = """
|
||||
def function_with_nested():
|
||||
def inner_function():
|
||||
return "Hello from inner function"
|
||||
|
||||
return inner_function()
|
||||
|
||||
def target_function():
|
||||
return "Hello from target function after nested function"
|
||||
"""
|
||||
|
||||
fto = FunctionToOptimize(
|
||||
function_name="target_function",
|
||||
file_path=Path("dummy_path.py"),
|
||||
parents=[]
|
||||
)
|
||||
|
||||
modified_code = add_codeflash_decorator_to_code(
|
||||
code=code,
|
||||
functions_to_optimize=[fto]
|
||||
)
|
||||
|
||||
expected_code = """
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
def function_with_nested():
|
||||
def inner_function():
|
||||
return "Hello from inner function"
|
||||
|
||||
return inner_function()
|
||||
|
||||
@codeflash_trace
|
||||
def target_function():
|
||||
return "Hello from target function after nested function"
|
||||
"""
|
||||
|
||||
assert modified_code.strip() == expected_code.strip()
|
||||
513
tests/test_pickle_patcher.py
Normal file
513
tests/test_pickle_patcher.py
Normal file
|
|
@ -0,0 +1,513 @@
|
|||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import socket
|
||||
import sqlite3
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import dill
|
||||
import pytest
|
||||
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
from sqlalchemy import Column, Integer, String, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
HAS_SQLALCHEMY = True
|
||||
except ImportError:
|
||||
HAS_SQLALCHEMY = False
|
||||
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError
|
||||
|
||||
|
||||
def test_picklepatch_simple_nested():
|
||||
"""Test that a simple nested data structure pickles and unpickles correctly.
|
||||
"""
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"nested_dict": {"key": "value", "another": 42},
|
||||
}
|
||||
|
||||
dumped = PicklePatcher.dumps(original_data)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
assert reloaded == original_data
|
||||
# Everything was pickleable, so no placeholders should appear.
|
||||
|
||||
|
||||
def test_picklepatch_with_socket():
|
||||
"""Test that a data structure containing a raw socket is replaced by
|
||||
PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
# Create a pair of connected sockets instead of a single socket
|
||||
sock1, sock2 = socket.socketpair()
|
||||
|
||||
data_with_socket = {
|
||||
"safe_value": 123,
|
||||
"raw_socket": sock1,
|
||||
}
|
||||
|
||||
# Send a message through sock1, which can be received by sock2
|
||||
sock1.send(b"Hello, world!")
|
||||
received = sock2.recv(1024)
|
||||
assert received == b"Hello, world!"
|
||||
# Pickle the data structure containing the socket
|
||||
dumped = PicklePatcher.dumps(data_with_socket)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
# We expect "raw_socket" to be replaced by a placeholder
|
||||
assert isinstance(reloaded, dict)
|
||||
assert reloaded["safe_value"] == 123
|
||||
assert isinstance(reloaded["raw_socket"], PicklePlaceholder)
|
||||
|
||||
# Attempting to use or access attributes => AttributeError
|
||||
# (not RuntimeError as in original tests, our implementation uses AttributeError)
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["raw_socket"].recv(1024)
|
||||
|
||||
# Clean up by closing both sockets
|
||||
sock1.close()
|
||||
sock2.close()
|
||||
|
||||
|
||||
def test_picklepatch_deeply_nested():
|
||||
"""Test that deep nesting with unpicklable objects works correctly.
|
||||
"""
|
||||
# Create a deeply nested structure with an unpicklable object
|
||||
deep_nested = {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"normal": "value",
|
||||
"socket": socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dumped = PicklePatcher.dumps(deep_nested)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
# We should be able to access the normal value
|
||||
assert reloaded["level1"]["level2"]["level3"]["normal"] == "value"
|
||||
|
||||
# The socket should be replaced with a placeholder
|
||||
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
|
||||
|
||||
def test_picklepatch_class_with_unpicklable_attr():
|
||||
"""Test that a class with an unpicklable attribute works correctly.
|
||||
"""
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
self.normal = "normal value"
|
||||
self.unpicklable = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
|
||||
obj = TestClass()
|
||||
|
||||
dumped = PicklePatcher.dumps(obj)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
# Normal attribute should be preserved
|
||||
assert reloaded.normal == "normal value"
|
||||
|
||||
# Unpicklable attribute should be replaced with a placeholder
|
||||
assert isinstance(reloaded.unpicklable, PicklePlaceholder)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_picklepatch_with_database_connection():
|
||||
"""Test that a data structure containing a database connection is replaced
|
||||
by PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
# SQLite connection - not pickleable
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
|
||||
data_with_db = {
|
||||
"description": "Database connection",
|
||||
"connection": conn,
|
||||
"cursor": cursor,
|
||||
}
|
||||
|
||||
dumped = PicklePatcher.dumps(data_with_db)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
# Both connection and cursor should become placeholders
|
||||
assert isinstance(reloaded, dict)
|
||||
assert reloaded["description"] == "Database connection"
|
||||
assert isinstance(reloaded["connection"], PicklePlaceholder)
|
||||
assert isinstance(reloaded["cursor"], PicklePlaceholder)
|
||||
|
||||
# Attempting to use attributes => AttributeError
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["connection"].execute("SELECT 1")
|
||||
|
||||
|
||||
def test_picklepatch_with_generator():
|
||||
"""Test that a data structure containing a generator is replaced by
|
||||
PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
|
||||
def simple_generator():
|
||||
yield 1
|
||||
yield 2
|
||||
yield 3
|
||||
|
||||
# Create a generator
|
||||
gen = simple_generator()
|
||||
|
||||
# Put it in a data structure
|
||||
data_with_generator = {
|
||||
"description": "Contains a generator",
|
||||
"generator": gen,
|
||||
"normal_list": [1, 2, 3]
|
||||
}
|
||||
|
||||
dumped = PicklePatcher.dumps(data_with_generator)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
# Generator should be replaced with a placeholder
|
||||
assert isinstance(reloaded, dict)
|
||||
assert reloaded["description"] == "Contains a generator"
|
||||
assert reloaded["normal_list"] == [1, 2, 3]
|
||||
assert isinstance(reloaded["generator"], PicklePlaceholder)
|
||||
|
||||
# Attempting to use the generator => AttributeError
|
||||
with pytest.raises(TypeError):
|
||||
next(reloaded["generator"])
|
||||
|
||||
# Attempting to call methods on the generator => AttributeError
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["generator"].send(None)
|
||||
|
||||
|
||||
def test_picklepatch_loads_standard_pickle():
|
||||
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
using the standard pickle module.
|
||||
"""
|
||||
# Create a simple data structure
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"nested_dict": {"key": "value", "another": 42},
|
||||
"tuple": (1, "two", 3.0),
|
||||
}
|
||||
|
||||
# Pickle it with standard pickle
|
||||
pickled_data = pickle.dumps(original_data)
|
||||
|
||||
# Load with PicklePatcher
|
||||
reloaded = PicklePatcher.loads(pickled_data)
|
||||
|
||||
# Verify the data is correctly loaded
|
||||
assert reloaded == original_data
|
||||
assert isinstance(reloaded, dict)
|
||||
assert reloaded["numbers"] == [1, 2, 3]
|
||||
assert reloaded["nested_dict"]["key"] == "value"
|
||||
assert reloaded["tuple"] == (1, "two", 3.0)
|
||||
|
||||
|
||||
def test_picklepatch_loads_dill_pickle():
|
||||
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
using the dill module, which can pickle more complex objects than the
|
||||
standard pickle module.
|
||||
"""
|
||||
# Create a more complex data structure that includes a lambda function
|
||||
# which dill can handle but standard pickle cannot
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
"function": lambda x: x * 2,
|
||||
"nested": {
|
||||
"another_function": lambda y: y ** 2
|
||||
}
|
||||
}
|
||||
|
||||
# Pickle it with dill
|
||||
dilled_data = dill.dumps(original_data)
|
||||
|
||||
# Load with PicklePatcher
|
||||
reloaded = PicklePatcher.loads(dilled_data)
|
||||
|
||||
# Verify the data structure
|
||||
assert isinstance(reloaded, dict)
|
||||
assert reloaded["numbers"] == [1, 2, 3]
|
||||
|
||||
# Test that the functions actually work
|
||||
assert reloaded["function"](5) == 10
|
||||
assert reloaded["nested"]["another_function"](4) == 16
|
||||
|
||||
def test_run_and_parse_picklepatch() -> None:
|
||||
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
|
||||
|
||||
The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results.
|
||||
Here, we are simply 'ignoring' the unused unpickleable object.
|
||||
|
||||
The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object.
|
||||
Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results,
|
||||
since we were not able to reuse the unpickleable object in the replay test.
|
||||
"""
|
||||
# Init paths
|
||||
project_root = Path(__file__).parent.parent.resolve()
|
||||
tests_root = project_root / "code_to_optimize" / "tests" / "pytest"
|
||||
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
|
||||
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
|
||||
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
|
||||
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
|
||||
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
|
||||
# Trace benchmarks
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
# Check contents
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
test_name, total_time, function_time, percent = \
|
||||
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
|
||||
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
|
||||
f"{bubble_sort_unused_socket_path}",
|
||||
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
|
||||
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
|
||||
f"{bubble_sort_used_socket_path}",
|
||||
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
conn.close()
|
||||
|
||||
# Generate replay test
|
||||
generate_replay_test(output_file, replay_tests_dir)
|
||||
replay_test_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
|
||||
replay_test_perf_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
|
||||
assert replay_test_path.exists()
|
||||
original_replay_test_code = replay_test_path.read_text("utf-8")
|
||||
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = project_root
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
replay_test_path,
|
||||
[CodePosition(17, 15)],
|
||||
func,
|
||||
project_root,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test is not None
|
||||
replay_test_path.write_text(new_test)
|
||||
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root=tests_root,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root,
|
||||
)
|
||||
)
|
||||
|
||||
# Run the replay test for the original code that does not use the socket
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.REPLAY_TEST
|
||||
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||
func_optimizer = opt.create_function_optimizer(func)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=replay_test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
|
||||
)
|
||||
]
|
||||
)
|
||||
test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_unused_socket) == 1
|
||||
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||
assert test_results_unused_socket.test_results[0].did_pass == True
|
||||
|
||||
# Replace with optimized candidate
|
||||
fto_unused_socket_path.write_text("""
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_unused_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
return sorted(numbers)
|
||||
""")
|
||||
# Run optimized code for unused socket
|
||||
optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(optimized_test_results_unused_socket) == 1
|
||||
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
||||
assert verification_result is True
|
||||
|
||||
# Remove the previous instrumentation
|
||||
replay_test_path.write_text(original_replay_test_code)
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
replay_test_path,
|
||||
[CodePosition(23,15)],
|
||||
func,
|
||||
project_root,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test is not None
|
||||
replay_test_path.write_text(new_test)
|
||||
|
||||
# Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.REPLAY_TEST
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
|
||||
file_path=Path(fto_used_socket_path))
|
||||
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
func_optimizer = opt.create_function_optimizer(func)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=replay_test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[
|
||||
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
|
||||
test_type=test_type)],
|
||||
)
|
||||
]
|
||||
)
|
||||
test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
print("test results used socket")
|
||||
print(test_results_used_socket)
|
||||
# Replace with optimized candidate
|
||||
fto_used_socket_path.write_text("""
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
socket = data_container.get('socket')
|
||||
socket.send("Hello from the optimized function!")
|
||||
return sorted(numbers)
|
||||
""")
|
||||
|
||||
# Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||
optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
|
||||
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
|
||||
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
|
||||
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
shutil.rmtree(replay_tests_dir, ignore_errors=True)
|
||||
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
|
||||
fto_used_socket_path.write_text(original_fto_used_socket_code)
|
||||
|
||||
288
tests/test_trace_benchmarks.py
Normal file
288
tests/test_trace_benchmarks.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
import multiprocessing
|
||||
import shutil
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||
|
||||
|
||||
def test_trace_benchmarks() -> None:
|
||||
# Test the trace_benchmarks function
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
|
||||
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
||||
tests_root = project_root / "tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
# check contents of trace file
|
||||
# connect to database
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 8, f"Expected 8 function calls, but got {len(function_calls)}"
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
process_and_bubble_sort_path = (project_root / "process_and_bubble_sort_codeflash_trace.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("sorter", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_class_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 17),
|
||||
|
||||
("sort_class", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_class_sort2", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 20),
|
||||
|
||||
("sort_static", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_class_sort3", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 23),
|
||||
|
||||
("__init__", "Sorter", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_class_sort4", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 26),
|
||||
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_sort", "tests.pytest.benchmarks_test.test_benchmark_bubble_sort_example", 7),
|
||||
|
||||
("compute_and_sort", "", "code_to_optimize.process_and_bubble_sort_codeflash_trace",
|
||||
f"{process_and_bubble_sort_path}",
|
||||
"test_compute_and_sort", "tests.pytest.benchmarks_test.test_process_and_sort_example", 4),
|
||||
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_no_func", "tests.pytest.benchmarks_test.test_process_and_sort_example", 8),
|
||||
|
||||
("recursive_bubble_sort", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_recursive_sort", "tests.pytest.benchmarks_test.test_recursive_example", 5),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
# Close connection
|
||||
conn.close()
|
||||
generate_replay_test(output_file, replay_tests_dir)
|
||||
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
|
||||
assert test_class_sort_path.exists()
|
||||
test_class_sort_code = f"""
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
|
||||
functions = ['sort_class', 'sort_static', 'sorter']
|
||||
trace_file_path = r"{output_file.as_posix()}"
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_sort", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sorter():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort", function_name="sorter", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
function_name = "sorter"
|
||||
if not args:
|
||||
raise ValueError("No arguments provided for the method.")
|
||||
if function_name == "__init__":
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
|
||||
else:
|
||||
instance = args[0] # self
|
||||
ret = instance.sorter(*args[1:], **kwargs)
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_class():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort2", function_name="sort_class", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
if not args:
|
||||
raise ValueError("No arguments provided for the method.")
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_class(*args[1:], **kwargs)
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter_sort_static():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort3", function_name="sort_static", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter.sort_static(*args, **kwargs)
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_class_sort4", function_name="__init__", file_path=r"{bubble_sort_path}", class_name="Sorter", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
function_name = "__init__"
|
||||
if not args:
|
||||
raise ValueError("No arguments provided for the method.")
|
||||
if function_name == "__init__":
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_Sorter(*args[1:], **kwargs)
|
||||
else:
|
||||
instance = args[0] # self
|
||||
ret = instance(*args[1:], **kwargs)
|
||||
|
||||
"""
|
||||
assert test_class_sort_path.read_text("utf-8").strip()==test_class_sort_code.strip()
|
||||
|
||||
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
|
||||
assert test_sort_path.exists()
|
||||
test_sort_code = f"""
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
|
||||
compute_and_sort as \\
|
||||
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
|
||||
functions = ['compute_and_sort', 'sorter']
|
||||
trace_file_path = r"{output_file}"
|
||||
|
||||
def test_code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_compute_and_sort", function_name="compute_and_sort", file_path=r"{process_and_bubble_sort_path}", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
ret = code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort(*args, **kwargs)
|
||||
|
||||
def test_code_to_optimize_bubble_sort_codeflash_trace_sorter():
|
||||
for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="test_no_func", function_name="sorter", file_path=r"{bubble_sort_path}", num_to_get=100):
|
||||
args = pickle.loads(args_pkl)
|
||||
kwargs = pickle.loads(kwargs_pkl)
|
||||
ret = code_to_optimize_bubble_sort_codeflash_trace_sorter(*args, **kwargs)
|
||||
|
||||
"""
|
||||
assert test_sort_path.read_text("utf-8").strip()==test_sort_code.strip()
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
shutil.rmtree(replay_tests_dir)
|
||||
|
||||
# Skip the test in CI as the machine may not be multithreaded
|
||||
@pytest.mark.ci_skip
|
||||
def test_trace_multithreaded_benchmark() -> None:
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_multithread"
|
||||
tests_root = project_root / "tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
# check contents of trace file
|
||||
# connect to database
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_benchmark_sort", "tests.pytest.benchmarks_multithread.test_multithread_sort", 4),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
# Close connection
|
||||
conn.close()
|
||||
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
|
||||
def test_trace_benchmark_decorator() -> None:
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test_decorator"
|
||||
tests_root = project_root / "tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
# check contents of trace file
|
||||
# connect to database
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get the count of records
|
||||
# Get all records
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_benchmark_sort", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 5),
|
||||
("sorter", "", "code_to_optimize.bubble_sort_codeflash_trace",
|
||||
f"{bubble_sort_path}",
|
||||
"test_pytest_mark", "tests.pytest.benchmarks_test_decorator.test_benchmark_decorator", 11),
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
# Close connection
|
||||
conn.close()
|
||||
|
||||
finally:
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
|
|
@ -17,7 +17,19 @@ def test_unit_test_discovery_pytest():
|
|||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
assert len(tests) > 0
|
||||
# print(tests)
|
||||
|
||||
|
||||
def test_benchmark_test_discovery_pytest():
|
||||
project_path = Path(__file__).parent.parent.resolve() / "code_to_optimize"
|
||||
tests_path = project_path / "tests" / "pytest" / "benchmarks"
|
||||
test_config = TestConfig(
|
||||
tests_root=tests_path,
|
||||
project_root_path=project_path,
|
||||
test_framework="pytest",
|
||||
tests_project_rootdir=tests_path.parent,
|
||||
)
|
||||
tests = discover_unit_tests(test_config)
|
||||
assert len(tests) == 1 # Should not discover benchmark tests
|
||||
|
||||
|
||||
def test_unit_test_discovery_unittest():
|
||||
|
|
|
|||
Loading…
Reference in a new issue