mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
end to end test that proves picklepatcher works. example shown is a socket (which is unpickleable) that's used or not used
This commit is contained in:
parent
40e416e0d0
commit
3158f9cc1c
13 changed files with 351 additions and 193 deletions
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_unused_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
|
||||
return sorted(numbers)
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
socket = data_container.get('socket')
|
||||
socket.send("Hello from the optimized function!")
|
||||
return sorted(numbers)
|
||||
|
|
@ -1,38 +1,6 @@
|
|||
def bubble_sort_with_unused_socket(data_container):
|
||||
"""
|
||||
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
||||
- 'numbers' (list): The list to be sorted.
|
||||
- 'socket' (socket): A socket
|
||||
|
||||
Args:
|
||||
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
|
||||
|
||||
Returns:
|
||||
list: The sorted list of numbers
|
||||
"""
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
|
||||
# Classic bubble sort implementation
|
||||
n = len(numbers)
|
||||
for i in range(n):
|
||||
# Flag to optimize by detecting if no swaps occurred
|
||||
swapped = False
|
||||
|
||||
# Last i elements are already in place
|
||||
for j in range(0, n - i - 1):
|
||||
# Swap if the element is greater than the next element
|
||||
if numbers[j] > numbers[j + 1]:
|
||||
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
|
||||
swapped = True
|
||||
|
||||
# If no swapping occurred in this pass, the list is sorted
|
||||
if not swapped:
|
||||
break
|
||||
|
||||
return numbers
|
||||
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
"""
|
||||
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
import socket
|
||||
|
||||
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
|
||||
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
|
||||
|
||||
def test_socket_picklepatch(benchmark):
|
||||
s1, s2 = socket.socketpair()
|
||||
data = {
|
||||
"numbers": list(reversed(range(500))),
|
||||
"socket": s1
|
||||
}
|
||||
benchmark(bubble_sort_with_unused_socket, data)
|
||||
|
||||
def test_used_socket_picklepatch(benchmark):
|
||||
s1, s2 = socket.socketpair()
|
||||
data = {
|
||||
"numbers": list(reversed(range(500))),
|
||||
"socket": s1
|
||||
}
|
||||
benchmark(bubble_sort_with_used_socket, data)
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
import socket
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from code_to_optimize.bubble_sort_picklepatch import bubble_sort_with_unused_socket, bubble_sort_with_used_socket
|
||||
|
||||
|
||||
def test_bubble_sort_with_unused_socket():
|
||||
mock_socket = Mock()
|
||||
# Test case 1: Regular unsorted list
|
||||
data_container = {
|
||||
'numbers': [5, 2, 9, 1, 5, 6],
|
||||
'socket': mock_socket
|
||||
}
|
||||
|
||||
result = bubble_sort_with_unused_socket(data_container)
|
||||
|
||||
# Check that the result is correctly sorted
|
||||
assert result == [1, 2, 5, 5, 6, 9]
|
||||
|
||||
def test_bubble_sort_with_used_socket():
|
||||
mock_socket = Mock()
|
||||
# Test case 1: Regular unsorted list
|
||||
data_container = {
|
||||
'numbers': [5, 2, 9, 1, 5, 6],
|
||||
'socket': mock_socket
|
||||
}
|
||||
|
||||
result = bubble_sort_with_used_socket(data_container)
|
||||
|
||||
# Check that the result is correctly sorted
|
||||
assert result == [1, 2, 5, 5, 6, 9]
|
||||
|
||||
|
|
@ -2,12 +2,11 @@ import functools
|
|||
import os
|
||||
import pickle
|
||||
import sqlite3
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable
|
||||
|
||||
import dill
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
|
||||
|
||||
class CodeflashTrace:
|
||||
|
|
@ -147,23 +146,10 @@ class CodeflashTrace:
|
|||
return result
|
||||
|
||||
try:
|
||||
original_recursion_limit = sys.getrecursionlimit()
|
||||
sys.setrecursionlimit(10000)
|
||||
# args = dict(args.items())
|
||||
# if class_name and func.__name__ == "__init__" and "self" in args:
|
||||
# del args["self"]
|
||||
# Pickle the arguments
|
||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
|
||||
# Retry with dill if pickle fails. It's slower but more comprehensive
|
||||
try:
|
||||
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
sys.setrecursionlimit(original_recursion_limit)
|
||||
|
||||
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
|
||||
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
except Exception as e:
|
||||
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
||||
# Add to the list of function calls without pickled args. Used for timing info only
|
||||
self._thread_local.active_functions.remove(func_id)
|
||||
|
|
@ -174,7 +160,6 @@ class CodeflashTrace:
|
|||
overhead_time, None, None)
|
||||
)
|
||||
return result
|
||||
|
||||
# Flush to database every 1000 calls
|
||||
if len(self.function_calls_data) > 1000:
|
||||
self.write_function_timings()
|
||||
|
|
|
|||
|
|
@ -175,7 +175,6 @@ class CodeFlashBenchmarkPlugin:
|
|||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||
# Subtract overhead from total time
|
||||
overhead = overhead_by_benchmark.get(benchmark_key, 0)
|
||||
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
|
||||
result[benchmark_key] = time_ns - overhead
|
||||
|
||||
finally:
|
||||
|
|
@ -267,9 +266,9 @@ class CodeFlashBenchmarkPlugin:
|
|||
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "True"
|
||||
# Run the function
|
||||
start = time.thread_time_ns()
|
||||
start = time.time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end = time.thread_time_ns()
|
||||
end = time.time_ns()
|
||||
# Reset the environment variable
|
||||
os.environ["CODEFLASH_BENCHMARKING"] = "False"
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def create_trace_replay_test_code(
|
|||
assert test_framework in ["pytest", "unittest"]
|
||||
|
||||
# Create Imports
|
||||
imports = f"""import dill as pickle
|
||||
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
{"import unittest" if test_framework == "unittest" else ""}
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from collections.abc import Collection
|
|||
from enum import Enum, IntEnum
|
||||
from pathlib import Path
|
||||
from re import Pattern
|
||||
from typing import Annotated, Any, Optional, Union, cast
|
||||
from typing import Annotated, Optional, cast
|
||||
|
||||
from jedi.api.classes import Name
|
||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
||||
|
|
@ -362,6 +362,7 @@ class FunctionCoverage:
|
|||
class TestingMode(enum.Enum):
|
||||
BEHAVIOR = "behavior"
|
||||
PERFORMANCE = "performance"
|
||||
LINE_PROFILE = "line_profile"
|
||||
|
||||
|
||||
class VerificationType(str, Enum):
|
||||
|
|
@ -533,7 +534,7 @@ class TestResults(BaseModel):
|
|||
tree.add(
|
||||
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
|
||||
)
|
||||
return
|
||||
return tree
|
||||
|
||||
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,8 @@
|
|||
class PicklePlaceholderAccessError(Exception):
|
||||
"""Custom exception raised when attempting to access an unpicklable object."""
|
||||
|
||||
|
||||
|
||||
class PicklePlaceholder:
|
||||
"""A placeholder for an object that couldn't be pickled.
|
||||
|
||||
|
|
@ -22,22 +27,22 @@ class PicklePlaceholder:
|
|||
self.__dict__["path"] = path if path is not None else []
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Raise an error when any attribute is accessed."""
|
||||
"""Raise a custom error when any attribute is accessed."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
raise AttributeError(
|
||||
f"Cannot access attribute '{name}' on unpicklable object at {path_str}. "
|
||||
raise PicklePlaceholderAccessError(
|
||||
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
|
||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Prevent setting attributes."""
|
||||
self.__getattr__(name) # This will raise an AttributeError
|
||||
self.__getattr__(name) # This will raise our custom error
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Raise an error when the object is called."""
|
||||
"""Raise a custom error when the object is called."""
|
||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||
raise TypeError(
|
||||
f"Cannot call unpicklable object at {path_str}. "
|
||||
raise PicklePlaceholderAccessError(
|
||||
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
|
||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
|||
import sentry_sdk
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
|
@ -64,7 +65,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
|||
if len(orig) != len(new):
|
||||
return False
|
||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||
|
||||
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
|
||||
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
|
||||
# The test results should be rejected as the behavior of the unpickleable object is unknown.
|
||||
logger.debug("Unable to verify behavior of unpickleable object in replay test")
|
||||
return False
|
||||
if isinstance(
|
||||
orig,
|
||||
(
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import dill as pickle
|
||||
from junitparser.xunit2 import JUnitXml
|
||||
from lxml.etree import XMLParser, parse
|
||||
|
||||
|
|
@ -20,7 +21,6 @@ from codeflash.code_utils.code_utils import (
|
|||
)
|
||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
from codeflash.verification.coverage_utils import CoverageUtils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -75,7 +75,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
|
|||
|
||||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||
try:
|
||||
test_pickle = PicklePatcher.loads(test_pickle_bin) if loop_index == 1 else None
|
||||
test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None
|
||||
except Exception as e:
|
||||
if DEBUG_MODE:
|
||||
logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}")
|
||||
|
|
@ -133,7 +133,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
|||
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
||||
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
|
||||
try:
|
||||
ret_val = (PicklePatcher.loads(val[7]) if loop_index == 1 else None,)
|
||||
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
|
||||
except Exception:
|
||||
continue
|
||||
test_results.add(
|
||||
|
|
|
|||
|
|
@ -1,34 +1,40 @@
|
|||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import socket
|
||||
import sqlite3
|
||||
from argparse import Namespace
|
||||
from pathlib import Path
|
||||
|
||||
import dill
|
||||
import pytest
|
||||
import requests
|
||||
import sqlite3
|
||||
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||
from codeflash.models.models import CodePosition, TestingMode, TestType, TestFiles, TestFile
|
||||
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
|
||||
from codeflash.optimization.optimizer import Optimizer
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine, Column, Integer, String
|
||||
from sqlalchemy import Column, Integer, String, create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
HAS_SQLALCHEMY = True
|
||||
except ImportError:
|
||||
HAS_SQLALCHEMY = False
|
||||
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder
|
||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError
|
||||
|
||||
|
||||
def test_picklepatch_simple_nested():
|
||||
"""
|
||||
Test that a simple nested data structure pickles and unpickles correctly.
|
||||
"""Test that a simple nested data structure pickles and unpickles correctly.
|
||||
"""
|
||||
original_data = {
|
||||
"numbers": [1, 2, 3],
|
||||
|
|
@ -41,17 +47,24 @@ def test_picklepatch_simple_nested():
|
|||
assert reloaded == original_data
|
||||
# Everything was pickleable, so no placeholders should appear.
|
||||
|
||||
|
||||
def test_picklepatch_with_socket():
|
||||
"""
|
||||
Test that a data structure containing a raw socket is replaced by
|
||||
"""Test that a data structure containing a raw socket is replaced by
|
||||
PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
# Create a pair of connected sockets instead of a single socket
|
||||
sock1, sock2 = socket.socketpair()
|
||||
|
||||
data_with_socket = {
|
||||
"safe_value": 123,
|
||||
"raw_socket": s,
|
||||
"raw_socket": sock1,
|
||||
}
|
||||
|
||||
# Send a message through sock1, which can be received by sock2
|
||||
sock1.send(b"Hello, world!")
|
||||
received = sock2.recv(1024)
|
||||
assert received == b"Hello, world!"
|
||||
# Pickle the data structure containing the socket
|
||||
dumped = PicklePatcher.dumps(data_with_socket)
|
||||
reloaded = PicklePatcher.loads(dumped)
|
||||
|
||||
|
|
@ -62,13 +75,16 @@ def test_picklepatch_with_socket():
|
|||
|
||||
# Attempting to use or access attributes => AttributeError
|
||||
# (not RuntimeError as in original tests, our implementation uses AttributeError)
|
||||
with pytest.raises(AttributeError) :
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["raw_socket"].recv(1024)
|
||||
|
||||
# Clean up by closing both sockets
|
||||
sock1.close()
|
||||
sock2.close()
|
||||
|
||||
|
||||
def test_picklepatch_deeply_nested():
|
||||
"""
|
||||
Test that deep nesting with unpicklable objects works correctly.
|
||||
"""Test that deep nesting with unpicklable objects works correctly.
|
||||
"""
|
||||
# Create a deeply nested structure with an unpicklable object
|
||||
deep_nested = {
|
||||
|
|
@ -92,8 +108,7 @@ def test_picklepatch_deeply_nested():
|
|||
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
|
||||
|
||||
def test_picklepatch_class_with_unpicklable_attr():
|
||||
"""
|
||||
Test that a class with an unpicklable attribute works correctly.
|
||||
"""Test that a class with an unpicklable attribute works correctly.
|
||||
"""
|
||||
class TestClass:
|
||||
def __init__(self):
|
||||
|
|
@ -115,12 +130,11 @@ def test_picklepatch_class_with_unpicklable_attr():
|
|||
|
||||
|
||||
def test_picklepatch_with_database_connection():
|
||||
"""
|
||||
Test that a data structure containing a database connection is replaced
|
||||
"""Test that a data structure containing a database connection is replaced
|
||||
by PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
# SQLite connection - not pickleable
|
||||
conn = sqlite3.connect(':memory:')
|
||||
conn = sqlite3.connect(":memory:")
|
||||
cursor = conn.cursor()
|
||||
|
||||
data_with_db = {
|
||||
|
|
@ -139,13 +153,12 @@ def test_picklepatch_with_database_connection():
|
|||
assert isinstance(reloaded["cursor"], PicklePlaceholder)
|
||||
|
||||
# Attempting to use attributes => AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["connection"].execute("SELECT 1")
|
||||
|
||||
|
||||
def test_picklepatch_with_generator():
|
||||
"""
|
||||
Test that a data structure containing a generator is replaced by
|
||||
"""Test that a data structure containing a generator is replaced by
|
||||
PicklePlaceholder rather than raising an error.
|
||||
"""
|
||||
|
||||
|
|
@ -178,13 +191,12 @@ def test_picklepatch_with_generator():
|
|||
next(reloaded["generator"])
|
||||
|
||||
# Attempting to call methods on the generator => AttributeError
|
||||
with pytest.raises(AttributeError):
|
||||
with pytest.raises(PicklePlaceholderAccessError):
|
||||
reloaded["generator"].send(None)
|
||||
|
||||
|
||||
def test_picklepatch_loads_standard_pickle():
|
||||
"""
|
||||
Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
using the standard pickle module.
|
||||
"""
|
||||
# Create a simple data structure
|
||||
|
|
@ -209,12 +221,10 @@ def test_picklepatch_loads_standard_pickle():
|
|||
|
||||
|
||||
def test_picklepatch_loads_dill_pickle():
|
||||
"""
|
||||
Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||
using the dill module, which can pickle more complex objects than the
|
||||
standard pickle module.
|
||||
"""
|
||||
|
||||
# Create a more complex data structure that includes a lambda function
|
||||
# which dill can handle but standard pickle cannot
|
||||
original_data = {
|
||||
|
|
@ -240,80 +250,264 @@ def test_picklepatch_loads_dill_pickle():
|
|||
assert reloaded["nested"]["another_function"](4) == 16
|
||||
|
||||
def test_run_and_parse_picklepatch() -> None:
|
||||
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
|
||||
|
||||
test_path = (
|
||||
Path(__file__).parent.resolve()
|
||||
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py"
|
||||
).resolve()
|
||||
test_path_perf = (
|
||||
Path(__file__).parent.resolve()
|
||||
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch_perf.py"
|
||||
).resolve()
|
||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_picklepatch.py").resolve()
|
||||
original_test =test_path.read_text("utf-8")
|
||||
The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results.
|
||||
Here, we are simply 'ignoring' the unused unpickleable object.
|
||||
|
||||
The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object.
|
||||
Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results,
|
||||
since we were not able to reuse the unpickleable object in the replay test.
|
||||
"""
|
||||
# Init paths
|
||||
project_root = Path(__file__).parent.parent.resolve()
|
||||
tests_root = project_root / "code_to_optimize" / "tests" / "pytest"
|
||||
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
|
||||
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
||||
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
|
||||
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
|
||||
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
|
||||
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
|
||||
# Trace benchmarks
|
||||
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||
assert output_file.exists()
|
||||
try:
|
||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
||||
# Check contents
|
||||
conn = sqlite3.connect(output_file.as_posix())
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||
function_calls = cursor.fetchall()
|
||||
|
||||
# Assert the length of function calls
|
||||
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
|
||||
|
||||
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
test_name, total_time, function_time, percent = \
|
||||
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||
assert total_time > 0.0
|
||||
assert function_time > 0.0
|
||||
assert percent > 0.0
|
||||
|
||||
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
|
||||
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
|
||||
# Expected function calls
|
||||
expected_calls = [
|
||||
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
|
||||
f"{bubble_sort_unused_socket_path}",
|
||||
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
|
||||
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
|
||||
f"{bubble_sort_used_socket_path}",
|
||||
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
|
||||
]
|
||||
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||
conn.close()
|
||||
|
||||
# Generate replay test
|
||||
generate_replay_test(output_file, replay_tests_dir)
|
||||
replay_test_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
|
||||
replay_test_perf_path = replay_tests_dir / Path(
|
||||
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
|
||||
assert replay_test_path.exists()
|
||||
original_replay_test_code = replay_test_path.read_text("utf-8")
|
||||
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
|
||||
original_cwd = Path.cwd()
|
||||
run_cwd = Path(__file__).parent.parent.resolve()
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_path))
|
||||
run_cwd = project_root
|
||||
os.chdir(run_cwd)
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
test_path,
|
||||
[CodePosition(13,14), CodePosition(31,14)],
|
||||
replay_test_path,
|
||||
[CodePosition(17, 15)],
|
||||
func,
|
||||
project_root_path,
|
||||
project_root,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test is not None
|
||||
|
||||
with test_path.open("w") as f:
|
||||
f.write(new_test)
|
||||
replay_test_path.write_text(new_test)
|
||||
|
||||
opt = Optimizer(
|
||||
Namespace(
|
||||
project_root=project_root_path,
|
||||
project_root=project_root,
|
||||
disable_telemetry=True,
|
||||
tests_root=tests_root,
|
||||
test_framework="pytest",
|
||||
pytest_cmd="pytest",
|
||||
experiment_id=None,
|
||||
test_project_root=project_root_path,
|
||||
test_project_root=project_root,
|
||||
)
|
||||
)
|
||||
|
||||
# Run the replay test for the original code that does not use the socket
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.EXISTING_UNIT_TEST
|
||||
|
||||
test_type = TestType.REPLAY_TEST
|
||||
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||
func_optimizer = opt.create_function_optimizer(func)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=test_path,
|
||||
instrumented_behavior_file_path=replay_test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=test_path,
|
||||
benchmarking_file_path=test_path_perf,
|
||||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
|
||||
)
|
||||
]
|
||||
)
|
||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=0.1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert test_results.test_results[0].id.test_function_name =="test_bubble_sort_with_unused_socket"
|
||||
assert test_results.test_results[0].did_pass ==True
|
||||
assert test_results.test_results[1].id.test_function_name =="test_bubble_sort_with_used_socket"
|
||||
assert test_results.test_results[1].did_pass ==False
|
||||
# assert pickle placeholder problem
|
||||
print(test_results)
|
||||
assert len(test_results_unused_socket) == 1
|
||||
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||
assert test_results_unused_socket.test_results[0].did_pass == True
|
||||
|
||||
# Replace with optimized candidate
|
||||
fto_unused_socket_path.write_text("""
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_unused_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
return sorted(numbers)
|
||||
""")
|
||||
# Run optimized code for unused socket
|
||||
optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(optimized_test_results_unused_socket) == 1
|
||||
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
||||
assert verification_result is True
|
||||
|
||||
# Remove the previous instrumentation
|
||||
replay_test_path.write_text(original_replay_test_code)
|
||||
# Instrument the replay test
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
|
||||
success, new_test = inject_profiling_into_existing_test(
|
||||
replay_test_path,
|
||||
[CodePosition(23,15)],
|
||||
func,
|
||||
project_root,
|
||||
"pytest",
|
||||
mode=TestingMode.BEHAVIOR,
|
||||
)
|
||||
os.chdir(original_cwd)
|
||||
assert success
|
||||
assert new_test is not None
|
||||
replay_test_path.write_text(new_test)
|
||||
|
||||
# Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||
test_env = os.environ.copy()
|
||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||
test_type = TestType.REPLAY_TEST
|
||||
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
|
||||
file_path=Path(fto_used_socket_path))
|
||||
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
func_optimizer = opt.create_function_optimizer(func)
|
||||
func_optimizer.test_files = TestFiles(
|
||||
test_files=[
|
||||
TestFile(
|
||||
instrumented_behavior_file_path=replay_test_path,
|
||||
test_type=test_type,
|
||||
original_file_path=replay_test_path,
|
||||
benchmarking_file_path=replay_test_perf_path,
|
||||
tests_in_file=[
|
||||
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
|
||||
test_type=test_type)],
|
||||
)
|
||||
]
|
||||
)
|
||||
test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
print("test results used socket")
|
||||
print(test_results_used_socket)
|
||||
# Replace with optimized candidate
|
||||
fto_used_socket_path.write_text("""
|
||||
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||
|
||||
@codeflash_trace
|
||||
def bubble_sort_with_used_socket(data_container):
|
||||
# Extract the list to sort, leaving the socket untouched
|
||||
numbers = data_container.get('numbers', []).copy()
|
||||
socket = data_container.get('socket')
|
||||
socket.send("Hello from the optimized function!")
|
||||
return sorted(numbers)
|
||||
""")
|
||||
|
||||
# Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||
optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||
testing_type=TestingMode.BEHAVIOR,
|
||||
test_env=test_env,
|
||||
test_files=func_optimizer.test_files,
|
||||
optimization_iteration=0,
|
||||
pytest_min_loops=1,
|
||||
pytest_max_loops=1,
|
||||
testing_time=1.0,
|
||||
)
|
||||
assert len(test_results_used_socket) == 1
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||
assert test_results_used_socket.test_results[
|
||||
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||
assert test_results_used_socket.test_results[0].did_pass is False
|
||||
|
||||
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
|
||||
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
|
||||
|
||||
finally:
|
||||
test_path.write_text(original_test)
|
||||
# cleanup
|
||||
output_file.unlink(missing_ok=True)
|
||||
shutil.rmtree(replay_tests_dir, ignore_errors=True)
|
||||
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
|
||||
fto_used_socket_path.write_text(original_fto_used_socket_code)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
import shutil
|
||||
import sqlite3
|
||||
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
|
||||
import shutil
|
||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||
|
||||
|
||||
def test_trace_benchmarks():
|
||||
def test_trace_benchmarks() -> None:
|
||||
# Test the trace_benchmarks function
|
||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
|
||||
|
|
@ -83,13 +82,12 @@ def test_trace_benchmarks():
|
|||
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
|
||||
assert test_class_sort_path.exists()
|
||||
test_class_sort_code = f"""
|
||||
import dill as pickle
|
||||
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
|
||||
functions = ['sort_class', 'sort_static', 'sorter']
|
||||
trace_file_path = r"{output_file.as_posix()}"
|
||||
|
|
@ -146,14 +144,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__():
|
|||
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
|
||||
assert test_sort_path.exists()
|
||||
test_sort_code = f"""
|
||||
import dill as pickle
|
||||
|
||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
|
||||
compute_and_sort as \\
|
||||
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
|
||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||
|
||||
functions = ['compute_and_sort', 'sorter']
|
||||
trace_file_path = r"{output_file}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue