mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
end to end test that proves picklepatcher works. example shown is a socket (which is unpickleable) that's used or not used
This commit is contained in:
parent
40e416e0d0
commit
3158f9cc1c
13 changed files with 351 additions and 193 deletions
|
|
@ -0,0 +1,18 @@
|
||||||
|
|
||||||
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||||
|
|
||||||
|
|
||||||
|
@codeflash_trace
|
||||||
|
def bubble_sort_with_unused_socket(data_container):
|
||||||
|
# Extract the list to sort, leaving the socket untouched
|
||||||
|
numbers = data_container.get('numbers', []).copy()
|
||||||
|
|
||||||
|
return sorted(numbers)
|
||||||
|
|
||||||
|
@codeflash_trace
|
||||||
|
def bubble_sort_with_used_socket(data_container):
|
||||||
|
# Extract the list to sort, leaving the socket untouched
|
||||||
|
numbers = data_container.get('numbers', []).copy()
|
||||||
|
socket = data_container.get('socket')
|
||||||
|
socket.send("Hello from the optimized function!")
|
||||||
|
return sorted(numbers)
|
||||||
|
|
@ -1,38 +1,6 @@
|
||||||
def bubble_sort_with_unused_socket(data_container):
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||||
"""
|
|
||||||
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
|
||||||
- 'numbers' (list): The list to be sorted.
|
|
||||||
- 'socket' (socket): A socket
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: The sorted list of numbers
|
|
||||||
"""
|
|
||||||
# Extract the list to sort, leaving the socket untouched
|
|
||||||
numbers = data_container.get('numbers', []).copy()
|
|
||||||
|
|
||||||
# Classic bubble sort implementation
|
|
||||||
n = len(numbers)
|
|
||||||
for i in range(n):
|
|
||||||
# Flag to optimize by detecting if no swaps occurred
|
|
||||||
swapped = False
|
|
||||||
|
|
||||||
# Last i elements are already in place
|
|
||||||
for j in range(0, n - i - 1):
|
|
||||||
# Swap if the element is greater than the next element
|
|
||||||
if numbers[j] > numbers[j + 1]:
|
|
||||||
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
|
|
||||||
swapped = True
|
|
||||||
|
|
||||||
# If no swapping occurred in this pass, the list is sorted
|
|
||||||
if not swapped:
|
|
||||||
break
|
|
||||||
|
|
||||||
return numbers
|
|
||||||
|
|
||||||
|
|
||||||
|
@codeflash_trace
|
||||||
def bubble_sort_with_used_socket(data_container):
|
def bubble_sort_with_used_socket(data_container):
|
||||||
"""
|
"""
|
||||||
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
Performs a bubble sort on a list within the data_container. The data container has the following schema:
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
|
||||||
|
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
|
||||||
|
|
||||||
|
def test_socket_picklepatch(benchmark):
|
||||||
|
s1, s2 = socket.socketpair()
|
||||||
|
data = {
|
||||||
|
"numbers": list(reversed(range(500))),
|
||||||
|
"socket": s1
|
||||||
|
}
|
||||||
|
benchmark(bubble_sort_with_unused_socket, data)
|
||||||
|
|
||||||
|
def test_used_socket_picklepatch(benchmark):
|
||||||
|
s1, s2 = socket.socketpair()
|
||||||
|
data = {
|
||||||
|
"numbers": list(reversed(range(500))),
|
||||||
|
"socket": s1
|
||||||
|
}
|
||||||
|
benchmark(bubble_sort_with_used_socket, data)
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
import socket
|
|
||||||
from unittest.mock import Mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from code_to_optimize.bubble_sort_picklepatch import bubble_sort_with_unused_socket, bubble_sort_with_used_socket
|
|
||||||
|
|
||||||
|
|
||||||
def test_bubble_sort_with_unused_socket():
|
|
||||||
mock_socket = Mock()
|
|
||||||
# Test case 1: Regular unsorted list
|
|
||||||
data_container = {
|
|
||||||
'numbers': [5, 2, 9, 1, 5, 6],
|
|
||||||
'socket': mock_socket
|
|
||||||
}
|
|
||||||
|
|
||||||
result = bubble_sort_with_unused_socket(data_container)
|
|
||||||
|
|
||||||
# Check that the result is correctly sorted
|
|
||||||
assert result == [1, 2, 5, 5, 6, 9]
|
|
||||||
|
|
||||||
def test_bubble_sort_with_used_socket():
|
|
||||||
mock_socket = Mock()
|
|
||||||
# Test case 1: Regular unsorted list
|
|
||||||
data_container = {
|
|
||||||
'numbers': [5, 2, 9, 1, 5, 6],
|
|
||||||
'socket': mock_socket
|
|
||||||
}
|
|
||||||
|
|
||||||
result = bubble_sort_with_used_socket(data_container)
|
|
||||||
|
|
||||||
# Check that the result is correctly sorted
|
|
||||||
assert result == [1, 2, 5, 5, 6, 9]
|
|
||||||
|
|
||||||
|
|
@ -2,12 +2,11 @@ import functools
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import dill
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||||
|
|
||||||
|
|
||||||
class CodeflashTrace:
|
class CodeflashTrace:
|
||||||
|
|
@ -147,34 +146,20 @@ class CodeflashTrace:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
original_recursion_limit = sys.getrecursionlimit()
|
|
||||||
sys.setrecursionlimit(10000)
|
|
||||||
# args = dict(args.items())
|
|
||||||
# if class_name and func.__name__ == "__init__" and "self" in args:
|
|
||||||
# del args["self"]
|
|
||||||
# Pickle the arguments
|
# Pickle the arguments
|
||||||
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
sys.setrecursionlimit(original_recursion_limit)
|
except Exception as e:
|
||||||
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
|
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
||||||
# Retry with dill if pickle fails. It's slower but more comprehensive
|
# Add to the list of function calls without pickled args. Used for timing info only
|
||||||
try:
|
self._thread_local.active_functions.remove(func_id)
|
||||||
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
|
overhead_time = time.thread_time_ns() - end_time
|
||||||
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
|
self.function_calls_data.append(
|
||||||
sys.setrecursionlimit(original_recursion_limit)
|
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
||||||
|
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
||||||
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
|
overhead_time, None, None)
|
||||||
print(f"Error pickling arguments for function {func.__name__}: {e}")
|
)
|
||||||
# Add to the list of function calls without pickled args. Used for timing info only
|
return result
|
||||||
self._thread_local.active_functions.remove(func_id)
|
|
||||||
overhead_time = time.thread_time_ns() - end_time
|
|
||||||
self.function_calls_data.append(
|
|
||||||
(func.__name__, class_name, func.__module__, func.__code__.co_filename,
|
|
||||||
benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time,
|
|
||||||
overhead_time, None, None)
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Flush to database every 1000 calls
|
# Flush to database every 1000 calls
|
||||||
if len(self.function_calls_data) > 1000:
|
if len(self.function_calls_data) > 1000:
|
||||||
self.write_function_timings()
|
self.write_function_timings()
|
||||||
|
|
|
||||||
|
|
@ -175,7 +175,6 @@ class CodeFlashBenchmarkPlugin:
|
||||||
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
|
||||||
# Subtract overhead from total time
|
# Subtract overhead from total time
|
||||||
overhead = overhead_by_benchmark.get(benchmark_key, 0)
|
overhead = overhead_by_benchmark.get(benchmark_key, 0)
|
||||||
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
|
|
||||||
result[benchmark_key] = time_ns - overhead
|
result[benchmark_key] = time_ns - overhead
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -267,9 +266,9 @@ class CodeFlashBenchmarkPlugin:
|
||||||
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
|
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
|
||||||
os.environ["CODEFLASH_BENCHMARKING"] = "True"
|
os.environ["CODEFLASH_BENCHMARKING"] = "True"
|
||||||
# Run the function
|
# Run the function
|
||||||
start = time.thread_time_ns()
|
start = time.time_ns()
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
end = time.thread_time_ns()
|
end = time.time_ns()
|
||||||
# Reset the environment variable
|
# Reset the environment variable
|
||||||
os.environ["CODEFLASH_BENCHMARKING"] = "False"
|
os.environ["CODEFLASH_BENCHMARKING"] = "False"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ def create_trace_replay_test_code(
|
||||||
assert test_framework in ["pytest", "unittest"]
|
assert test_framework in ["pytest", "unittest"]
|
||||||
|
|
||||||
# Create Imports
|
# Create Imports
|
||||||
imports = f"""import dill as pickle
|
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||||
{"import unittest" if test_framework == "unittest" else ""}
|
{"import unittest" if test_framework == "unittest" else ""}
|
||||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from collections.abc import Collection
|
||||||
from enum import Enum, IntEnum
|
from enum import Enum, IntEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from re import Pattern
|
from re import Pattern
|
||||||
from typing import Annotated, Any, Optional, Union, cast
|
from typing import Annotated, Optional, cast
|
||||||
|
|
||||||
from jedi.api.classes import Name
|
from jedi.api.classes import Name
|
||||||
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
|
||||||
|
|
@ -362,6 +362,7 @@ class FunctionCoverage:
|
||||||
class TestingMode(enum.Enum):
|
class TestingMode(enum.Enum):
|
||||||
BEHAVIOR = "behavior"
|
BEHAVIOR = "behavior"
|
||||||
PERFORMANCE = "performance"
|
PERFORMANCE = "performance"
|
||||||
|
LINE_PROFILE = "line_profile"
|
||||||
|
|
||||||
|
|
||||||
class VerificationType(str, Enum):
|
class VerificationType(str, Enum):
|
||||||
|
|
@ -533,7 +534,7 @@ class TestResults(BaseModel):
|
||||||
tree.add(
|
tree.add(
|
||||||
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
|
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
|
||||||
)
|
)
|
||||||
return
|
return tree
|
||||||
|
|
||||||
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
|
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
|
||||||
|
|
||||||
|
|
@ -606,4 +607,4 @@ class TestResults(BaseModel):
|
||||||
sys.setrecursionlimit(original_recursion_limit)
|
sys.setrecursionlimit(original_recursion_limit)
|
||||||
return False
|
return False
|
||||||
sys.setrecursionlimit(original_recursion_limit)
|
sys.setrecursionlimit(original_recursion_limit)
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,8 @@
|
||||||
|
class PicklePlaceholderAccessError(Exception):
|
||||||
|
"""Custom exception raised when attempting to access an unpicklable object."""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PicklePlaceholder:
|
class PicklePlaceholder:
|
||||||
"""A placeholder for an object that couldn't be pickled.
|
"""A placeholder for an object that couldn't be pickled.
|
||||||
|
|
||||||
|
|
@ -22,22 +27,22 @@ class PicklePlaceholder:
|
||||||
self.__dict__["path"] = path if path is not None else []
|
self.__dict__["path"] = path if path is not None else []
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
"""Raise an error when any attribute is accessed."""
|
"""Raise a custom error when any attribute is accessed."""
|
||||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||||
raise AttributeError(
|
raise PicklePlaceholderAccessError(
|
||||||
f"Cannot access attribute '{name}' on unpicklable object at {path_str}. "
|
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
|
||||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
"""Prevent setting attributes."""
|
"""Prevent setting attributes."""
|
||||||
self.__getattr__(name) # This will raise an AttributeError
|
self.__getattr__(name) # This will raise our custom error
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
"""Raise an error when the object is called."""
|
"""Raise a custom error when the object is called."""
|
||||||
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
|
||||||
raise TypeError(
|
raise PicklePlaceholderAccessError(
|
||||||
f"Cannot call unpicklable object at {path_str}. "
|
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
|
||||||
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from typing import Any
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
|
|
||||||
from codeflash.cli_cmds.console import logger
|
from codeflash.cli_cmds.console import logger
|
||||||
|
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -64,7 +65,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
|
||||||
if len(orig) != len(new):
|
if len(orig) != len(new):
|
||||||
return False
|
return False
|
||||||
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
|
||||||
|
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
|
||||||
|
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
|
||||||
|
# The test results should be rejected as the behavior of the unpickleable object is unknown.
|
||||||
|
logger.debug("Unable to verify behavior of unpickleable object in replay test")
|
||||||
|
return False
|
||||||
if isinstance(
|
if isinstance(
|
||||||
orig,
|
orig,
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import dill as pickle
|
||||||
from junitparser.xunit2 import JUnitXml
|
from junitparser.xunit2 import JUnitXml
|
||||||
from lxml.etree import XMLParser, parse
|
from lxml.etree import XMLParser, parse
|
||||||
|
|
||||||
|
|
@ -20,7 +21,6 @@ from codeflash.code_utils.code_utils import (
|
||||||
)
|
)
|
||||||
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
|
||||||
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
|
||||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
|
||||||
from codeflash.verification.coverage_utils import CoverageUtils
|
from codeflash.verification.coverage_utils import CoverageUtils
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -75,7 +75,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
|
||||||
|
|
||||||
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
|
||||||
try:
|
try:
|
||||||
test_pickle = PicklePatcher.loads(test_pickle_bin) if loop_index == 1 else None
|
test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if DEBUG_MODE:
|
if DEBUG_MODE:
|
||||||
logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}")
|
logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}")
|
||||||
|
|
@ -133,7 +133,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
|
||||||
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
# TODO : this is because sqlite writes original file module path. Should make it consistent
|
||||||
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
|
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
|
||||||
try:
|
try:
|
||||||
ret_val = (PicklePatcher.loads(val[7]) if loop_index == 1 else None,)
|
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
test_results.add(
|
test_results.add(
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,40 @@
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import shutil
|
||||||
import socket
|
import socket
|
||||||
|
import sqlite3
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import dill
|
import dill
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
|
||||||
import sqlite3
|
|
||||||
|
|
||||||
|
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||||
|
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||||
|
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||||
|
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||||
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
|
||||||
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
|
||||||
from codeflash.models.models import CodePosition, TestingMode, TestType, TestFiles, TestFile
|
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
|
||||||
from codeflash.optimization.optimizer import Optimizer
|
from codeflash.optimization.optimizer import Optimizer
|
||||||
|
from codeflash.verification.equivalence import compare_test_results
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy import Column, Integer, String, create_engine
|
||||||
from sqlalchemy import create_engine, Column, Integer, String
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
HAS_SQLALCHEMY = True
|
HAS_SQLALCHEMY = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_SQLALCHEMY = False
|
HAS_SQLALCHEMY = False
|
||||||
|
|
||||||
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher
|
||||||
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder
|
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_simple_nested():
|
def test_picklepatch_simple_nested():
|
||||||
"""
|
"""Test that a simple nested data structure pickles and unpickles correctly.
|
||||||
Test that a simple nested data structure pickles and unpickles correctly.
|
|
||||||
"""
|
"""
|
||||||
original_data = {
|
original_data = {
|
||||||
"numbers": [1, 2, 3],
|
"numbers": [1, 2, 3],
|
||||||
|
|
@ -41,17 +47,24 @@ def test_picklepatch_simple_nested():
|
||||||
assert reloaded == original_data
|
assert reloaded == original_data
|
||||||
# Everything was pickleable, so no placeholders should appear.
|
# Everything was pickleable, so no placeholders should appear.
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_with_socket():
|
def test_picklepatch_with_socket():
|
||||||
"""
|
"""Test that a data structure containing a raw socket is replaced by
|
||||||
Test that a data structure containing a raw socket is replaced by
|
|
||||||
PicklePlaceholder rather than raising an error.
|
PicklePlaceholder rather than raising an error.
|
||||||
"""
|
"""
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
# Create a pair of connected sockets instead of a single socket
|
||||||
|
sock1, sock2 = socket.socketpair()
|
||||||
|
|
||||||
data_with_socket = {
|
data_with_socket = {
|
||||||
"safe_value": 123,
|
"safe_value": 123,
|
||||||
"raw_socket": s,
|
"raw_socket": sock1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Send a message through sock1, which can be received by sock2
|
||||||
|
sock1.send(b"Hello, world!")
|
||||||
|
received = sock2.recv(1024)
|
||||||
|
assert received == b"Hello, world!"
|
||||||
|
# Pickle the data structure containing the socket
|
||||||
dumped = PicklePatcher.dumps(data_with_socket)
|
dumped = PicklePatcher.dumps(data_with_socket)
|
||||||
reloaded = PicklePatcher.loads(dumped)
|
reloaded = PicklePatcher.loads(dumped)
|
||||||
|
|
||||||
|
|
@ -60,15 +73,18 @@ def test_picklepatch_with_socket():
|
||||||
assert reloaded["safe_value"] == 123
|
assert reloaded["safe_value"] == 123
|
||||||
assert isinstance(reloaded["raw_socket"], PicklePlaceholder)
|
assert isinstance(reloaded["raw_socket"], PicklePlaceholder)
|
||||||
|
|
||||||
# Attempting to use or access attributes => AttributeError
|
# Attempting to use or access attributes => AttributeError
|
||||||
# (not RuntimeError as in original tests, our implementation uses AttributeError)
|
# (not RuntimeError as in original tests, our implementation uses AttributeError)
|
||||||
with pytest.raises(AttributeError) :
|
with pytest.raises(PicklePlaceholderAccessError):
|
||||||
reloaded["raw_socket"].recv(1024)
|
reloaded["raw_socket"].recv(1024)
|
||||||
|
|
||||||
|
# Clean up by closing both sockets
|
||||||
|
sock1.close()
|
||||||
|
sock2.close()
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_deeply_nested():
|
def test_picklepatch_deeply_nested():
|
||||||
"""
|
"""Test that deep nesting with unpicklable objects works correctly.
|
||||||
Test that deep nesting with unpicklable objects works correctly.
|
|
||||||
"""
|
"""
|
||||||
# Create a deeply nested structure with an unpicklable object
|
# Create a deeply nested structure with an unpicklable object
|
||||||
deep_nested = {
|
deep_nested = {
|
||||||
|
|
@ -92,8 +108,7 @@ def test_picklepatch_deeply_nested():
|
||||||
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
|
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
|
||||||
|
|
||||||
def test_picklepatch_class_with_unpicklable_attr():
|
def test_picklepatch_class_with_unpicklable_attr():
|
||||||
"""
|
"""Test that a class with an unpicklable attribute works correctly.
|
||||||
Test that a class with an unpicklable attribute works correctly.
|
|
||||||
"""
|
"""
|
||||||
class TestClass:
|
class TestClass:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -115,12 +130,11 @@ def test_picklepatch_class_with_unpicklable_attr():
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_with_database_connection():
|
def test_picklepatch_with_database_connection():
|
||||||
"""
|
"""Test that a data structure containing a database connection is replaced
|
||||||
Test that a data structure containing a database connection is replaced
|
|
||||||
by PicklePlaceholder rather than raising an error.
|
by PicklePlaceholder rather than raising an error.
|
||||||
"""
|
"""
|
||||||
# SQLite connection - not pickleable
|
# SQLite connection - not pickleable
|
||||||
conn = sqlite3.connect(':memory:')
|
conn = sqlite3.connect(":memory:")
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
data_with_db = {
|
data_with_db = {
|
||||||
|
|
@ -139,13 +153,12 @@ def test_picklepatch_with_database_connection():
|
||||||
assert isinstance(reloaded["cursor"], PicklePlaceholder)
|
assert isinstance(reloaded["cursor"], PicklePlaceholder)
|
||||||
|
|
||||||
# Attempting to use attributes => AttributeError
|
# Attempting to use attributes => AttributeError
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(PicklePlaceholderAccessError):
|
||||||
reloaded["connection"].execute("SELECT 1")
|
reloaded["connection"].execute("SELECT 1")
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_with_generator():
|
def test_picklepatch_with_generator():
|
||||||
"""
|
"""Test that a data structure containing a generator is replaced by
|
||||||
Test that a data structure containing a generator is replaced by
|
|
||||||
PicklePlaceholder rather than raising an error.
|
PicklePlaceholder rather than raising an error.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -178,13 +191,12 @@ def test_picklepatch_with_generator():
|
||||||
next(reloaded["generator"])
|
next(reloaded["generator"])
|
||||||
|
|
||||||
# Attempting to call methods on the generator => AttributeError
|
# Attempting to call methods on the generator => AttributeError
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(PicklePlaceholderAccessError):
|
||||||
reloaded["generator"].send(None)
|
reloaded["generator"].send(None)
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_loads_standard_pickle():
|
def test_picklepatch_loads_standard_pickle():
|
||||||
"""
|
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||||
Test that PicklePatcher.loads can correctly load data that was pickled
|
|
||||||
using the standard pickle module.
|
using the standard pickle module.
|
||||||
"""
|
"""
|
||||||
# Create a simple data structure
|
# Create a simple data structure
|
||||||
|
|
@ -209,12 +221,10 @@ def test_picklepatch_loads_standard_pickle():
|
||||||
|
|
||||||
|
|
||||||
def test_picklepatch_loads_dill_pickle():
|
def test_picklepatch_loads_dill_pickle():
|
||||||
"""
|
"""Test that PicklePatcher.loads can correctly load data that was pickled
|
||||||
Test that PicklePatcher.loads can correctly load data that was pickled
|
|
||||||
using the dill module, which can pickle more complex objects than the
|
using the dill module, which can pickle more complex objects than the
|
||||||
standard pickle module.
|
standard pickle module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a more complex data structure that includes a lambda function
|
# Create a more complex data structure that includes a lambda function
|
||||||
# which dill can handle but standard pickle cannot
|
# which dill can handle but standard pickle cannot
|
||||||
original_data = {
|
original_data = {
|
||||||
|
|
@ -240,80 +250,264 @@ def test_picklepatch_loads_dill_pickle():
|
||||||
assert reloaded["nested"]["another_function"](4) == 16
|
assert reloaded["nested"]["another_function"](4) == 16
|
||||||
|
|
||||||
def test_run_and_parse_picklepatch() -> None:
|
def test_run_and_parse_picklepatch() -> None:
|
||||||
|
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
|
||||||
|
|
||||||
test_path = (
|
The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results.
|
||||||
Path(__file__).parent.resolve()
|
Here, we are simply 'ignoring' the unused unpickleable object.
|
||||||
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py"
|
|
||||||
).resolve()
|
The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object.
|
||||||
test_path_perf = (
|
Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results,
|
||||||
Path(__file__).parent.resolve()
|
since we were not able to reuse the unpickleable object in the replay test.
|
||||||
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch_perf.py"
|
"""
|
||||||
).resolve()
|
# Init paths
|
||||||
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_picklepatch.py").resolve()
|
project_root = Path(__file__).parent.parent.resolve()
|
||||||
original_test =test_path.read_text("utf-8")
|
tests_root = project_root / "code_to_optimize" / "tests" / "pytest"
|
||||||
|
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
|
||||||
|
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
|
||||||
|
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
|
||||||
|
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
|
||||||
|
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
|
||||||
|
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
|
||||||
|
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
|
||||||
|
# Trace benchmarks
|
||||||
|
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
|
||||||
|
assert output_file.exists()
|
||||||
try:
|
try:
|
||||||
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
|
# Check contents
|
||||||
project_root_path = (Path(__file__).parent / "..").resolve()
|
conn = sqlite3.connect(output_file.as_posix())
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
|
||||||
|
function_calls = cursor.fetchall()
|
||||||
|
|
||||||
|
# Assert the length of function calls
|
||||||
|
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
|
||||||
|
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
|
||||||
|
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
|
||||||
|
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
|
||||||
|
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
|
||||||
|
|
||||||
|
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||||
|
assert total_time > 0.0
|
||||||
|
assert function_time > 0.0
|
||||||
|
assert percent > 0.0
|
||||||
|
|
||||||
|
test_name, total_time, function_time, percent = \
|
||||||
|
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
|
||||||
|
assert total_time > 0.0
|
||||||
|
assert function_time > 0.0
|
||||||
|
assert percent > 0.0
|
||||||
|
|
||||||
|
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
|
||||||
|
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
|
||||||
|
# Expected function calls
|
||||||
|
expected_calls = [
|
||||||
|
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
|
||||||
|
f"{bubble_sort_unused_socket_path}",
|
||||||
|
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
|
||||||
|
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
|
||||||
|
f"{bubble_sort_used_socket_path}",
|
||||||
|
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
|
||||||
|
]
|
||||||
|
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
|
||||||
|
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
|
||||||
|
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
|
||||||
|
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
|
||||||
|
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
|
||||||
|
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
|
||||||
|
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
|
||||||
|
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Generate replay test
|
||||||
|
generate_replay_test(output_file, replay_tests_dir)
|
||||||
|
replay_test_path = replay_tests_dir / Path(
|
||||||
|
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
|
||||||
|
replay_test_perf_path = replay_tests_dir / Path(
|
||||||
|
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
|
||||||
|
assert replay_test_path.exists()
|
||||||
|
original_replay_test_code = replay_test_path.read_text("utf-8")
|
||||||
|
|
||||||
|
# Instrument the replay test
|
||||||
|
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
|
||||||
original_cwd = Path.cwd()
|
original_cwd = Path.cwd()
|
||||||
run_cwd = Path(__file__).parent.parent.resolve()
|
run_cwd = project_root
|
||||||
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_path))
|
|
||||||
os.chdir(run_cwd)
|
os.chdir(run_cwd)
|
||||||
success, new_test = inject_profiling_into_existing_test(
|
success, new_test = inject_profiling_into_existing_test(
|
||||||
test_path,
|
replay_test_path,
|
||||||
[CodePosition(13,14), CodePosition(31,14)],
|
[CodePosition(17, 15)],
|
||||||
func,
|
func,
|
||||||
project_root_path,
|
project_root,
|
||||||
"pytest",
|
"pytest",
|
||||||
mode=TestingMode.BEHAVIOR,
|
mode=TestingMode.BEHAVIOR,
|
||||||
)
|
)
|
||||||
os.chdir(original_cwd)
|
os.chdir(original_cwd)
|
||||||
assert success
|
assert success
|
||||||
assert new_test is not None
|
assert new_test is not None
|
||||||
|
replay_test_path.write_text(new_test)
|
||||||
with test_path.open("w") as f:
|
|
||||||
f.write(new_test)
|
|
||||||
|
|
||||||
opt = Optimizer(
|
opt = Optimizer(
|
||||||
Namespace(
|
Namespace(
|
||||||
project_root=project_root_path,
|
project_root=project_root,
|
||||||
disable_telemetry=True,
|
disable_telemetry=True,
|
||||||
tests_root=tests_root,
|
tests_root=tests_root,
|
||||||
test_framework="pytest",
|
test_framework="pytest",
|
||||||
pytest_cmd="pytest",
|
pytest_cmd="pytest",
|
||||||
experiment_id=None,
|
experiment_id=None,
|
||||||
test_project_root=project_root_path,
|
test_project_root=project_root,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run the replay test for the original code that does not use the socket
|
||||||
test_env = os.environ.copy()
|
test_env = os.environ.copy()
|
||||||
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||||
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||||
test_type = TestType.EXISTING_UNIT_TEST
|
test_type = TestType.REPLAY_TEST
|
||||||
|
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||||
func_optimizer = opt.create_function_optimizer(func)
|
func_optimizer = opt.create_function_optimizer(func)
|
||||||
func_optimizer.test_files = TestFiles(
|
func_optimizer.test_files = TestFiles(
|
||||||
test_files=[
|
test_files=[
|
||||||
TestFile(
|
TestFile(
|
||||||
instrumented_behavior_file_path=test_path,
|
instrumented_behavior_file_path=replay_test_path,
|
||||||
test_type=test_type,
|
test_type=test_type,
|
||||||
original_file_path=test_path,
|
original_file_path=replay_test_path,
|
||||||
benchmarking_file_path=test_path_perf,
|
benchmarking_file_path=replay_test_perf_path,
|
||||||
|
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
test_results, coverage_data = func_optimizer.run_and_parse_tests(
|
test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||||
testing_type=TestingMode.BEHAVIOR,
|
testing_type=TestingMode.BEHAVIOR,
|
||||||
test_env=test_env,
|
test_env=test_env,
|
||||||
test_files=func_optimizer.test_files,
|
test_files=func_optimizer.test_files,
|
||||||
optimization_iteration=0,
|
optimization_iteration=0,
|
||||||
pytest_min_loops=1,
|
pytest_min_loops=1,
|
||||||
pytest_max_loops=1,
|
pytest_max_loops=1,
|
||||||
testing_time=0.1,
|
testing_time=1.0,
|
||||||
)
|
)
|
||||||
assert test_results.test_results[0].id.test_function_name =="test_bubble_sort_with_unused_socket"
|
assert len(test_results_unused_socket) == 1
|
||||||
assert test_results.test_results[0].did_pass ==True
|
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
assert test_results.test_results[1].id.test_function_name =="test_bubble_sort_with_used_socket"
|
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
|
||||||
assert test_results.test_results[1].did_pass ==False
|
assert test_results_unused_socket.test_results[0].did_pass == True
|
||||||
# assert pickle placeholder problem
|
|
||||||
print(test_results)
|
# Replace with optimized candidate
|
||||||
|
fto_unused_socket_path.write_text("""
|
||||||
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||||
|
|
||||||
|
@codeflash_trace
|
||||||
|
def bubble_sort_with_unused_socket(data_container):
|
||||||
|
# Extract the list to sort, leaving the socket untouched
|
||||||
|
numbers = data_container.get('numbers', []).copy()
|
||||||
|
return sorted(numbers)
|
||||||
|
""")
|
||||||
|
# Run optimized code for unused socket
|
||||||
|
optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||||
|
testing_type=TestingMode.BEHAVIOR,
|
||||||
|
test_env=test_env,
|
||||||
|
test_files=func_optimizer.test_files,
|
||||||
|
optimization_iteration=0,
|
||||||
|
pytest_min_loops=1,
|
||||||
|
pytest_max_loops=1,
|
||||||
|
testing_time=1.0,
|
||||||
|
)
|
||||||
|
assert len(optimized_test_results_unused_socket) == 1
|
||||||
|
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
|
||||||
|
assert verification_result is True
|
||||||
|
|
||||||
|
# Remove the previous instrumentation
|
||||||
|
replay_test_path.write_text(original_replay_test_code)
|
||||||
|
# Instrument the replay test
|
||||||
|
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
|
||||||
|
success, new_test = inject_profiling_into_existing_test(
|
||||||
|
replay_test_path,
|
||||||
|
[CodePosition(23,15)],
|
||||||
|
func,
|
||||||
|
project_root,
|
||||||
|
"pytest",
|
||||||
|
mode=TestingMode.BEHAVIOR,
|
||||||
|
)
|
||||||
|
os.chdir(original_cwd)
|
||||||
|
assert success
|
||||||
|
assert new_test is not None
|
||||||
|
replay_test_path.write_text(new_test)
|
||||||
|
|
||||||
|
# Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||||
|
test_env = os.environ.copy()
|
||||||
|
test_env["CODEFLASH_TEST_ITERATION"] = "0"
|
||||||
|
test_env["CODEFLASH_LOOP_INDEX"] = "1"
|
||||||
|
test_type = TestType.REPLAY_TEST
|
||||||
|
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
|
||||||
|
file_path=Path(fto_used_socket_path))
|
||||||
|
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||||
|
func_optimizer = opt.create_function_optimizer(func)
|
||||||
|
func_optimizer.test_files = TestFiles(
|
||||||
|
test_files=[
|
||||||
|
TestFile(
|
||||||
|
instrumented_behavior_file_path=replay_test_path,
|
||||||
|
test_type=test_type,
|
||||||
|
original_file_path=replay_test_path,
|
||||||
|
benchmarking_file_path=replay_test_perf_path,
|
||||||
|
tests_in_file=[
|
||||||
|
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
|
||||||
|
test_type=test_type)],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||||
|
testing_type=TestingMode.BEHAVIOR,
|
||||||
|
test_env=test_env,
|
||||||
|
test_files=func_optimizer.test_files,
|
||||||
|
optimization_iteration=0,
|
||||||
|
pytest_min_loops=1,
|
||||||
|
pytest_max_loops=1,
|
||||||
|
testing_time=1.0,
|
||||||
|
)
|
||||||
|
assert len(test_results_used_socket) == 1
|
||||||
|
assert test_results_used_socket.test_results[
|
||||||
|
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
|
assert test_results_used_socket.test_results[
|
||||||
|
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||||
|
assert test_results_used_socket.test_results[0].did_pass is False
|
||||||
|
print("test results used socket")
|
||||||
|
print(test_results_used_socket)
|
||||||
|
# Replace with optimized candidate
|
||||||
|
fto_used_socket_path.write_text("""
|
||||||
|
from codeflash.benchmarking.codeflash_trace import codeflash_trace
|
||||||
|
|
||||||
|
@codeflash_trace
|
||||||
|
def bubble_sort_with_used_socket(data_container):
|
||||||
|
# Extract the list to sort, leaving the socket untouched
|
||||||
|
numbers = data_container.get('numbers', []).copy()
|
||||||
|
socket = data_container.get('socket')
|
||||||
|
socket.send("Hello from the optimized function!")
|
||||||
|
return sorted(numbers)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
|
||||||
|
optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
|
||||||
|
testing_type=TestingMode.BEHAVIOR,
|
||||||
|
test_env=test_env,
|
||||||
|
test_files=func_optimizer.test_files,
|
||||||
|
optimization_iteration=0,
|
||||||
|
pytest_min_loops=1,
|
||||||
|
pytest_max_loops=1,
|
||||||
|
testing_time=1.0,
|
||||||
|
)
|
||||||
|
assert len(test_results_used_socket) == 1
|
||||||
|
assert test_results_used_socket.test_results[
|
||||||
|
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
|
||||||
|
assert test_results_used_socket.test_results[
|
||||||
|
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
|
||||||
|
assert test_results_used_socket.test_results[0].did_pass is False
|
||||||
|
|
||||||
|
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
|
||||||
|
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
test_path.write_text(original_test)
|
# cleanup
|
||||||
|
output_file.unlink(missing_ok=True)
|
||||||
|
shutil.rmtree(replay_tests_dir, ignore_errors=True)
|
||||||
|
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
|
||||||
|
fto_used_socket_path.write_text(original_fto_used_socket_code)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
|
import shutil
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
|
||||||
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
|
||||||
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
|
||||||
from codeflash.benchmarking.replay_test import generate_replay_test
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
|
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
|
||||||
import shutil
|
from codeflash.benchmarking.replay_test import generate_replay_test
|
||||||
|
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
|
||||||
|
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
|
||||||
|
|
||||||
|
|
||||||
def test_trace_benchmarks():
|
def test_trace_benchmarks() -> None:
|
||||||
# Test the trace_benchmarks function
|
# Test the trace_benchmarks function
|
||||||
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
project_root = Path(__file__).parent.parent / "code_to_optimize"
|
||||||
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
|
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
|
||||||
|
|
@ -83,13 +82,12 @@ def test_trace_benchmarks():
|
||||||
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
|
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
|
||||||
assert test_class_sort_path.exists()
|
assert test_class_sort_path.exists()
|
||||||
test_class_sort_code = f"""
|
test_class_sort_code = f"""
|
||||||
import dill as pickle
|
|
||||||
|
|
||||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||||
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
|
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
|
||||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||||
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||||
|
|
||||||
functions = ['sort_class', 'sort_static', 'sorter']
|
functions = ['sort_class', 'sort_static', 'sorter']
|
||||||
trace_file_path = r"{output_file.as_posix()}"
|
trace_file_path = r"{output_file.as_posix()}"
|
||||||
|
|
@ -146,14 +144,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__():
|
||||||
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
|
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
|
||||||
assert test_sort_path.exists()
|
assert test_sort_path.exists()
|
||||||
test_sort_code = f"""
|
test_sort_code = f"""
|
||||||
import dill as pickle
|
|
||||||
|
|
||||||
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
from code_to_optimize.bubble_sort_codeflash_trace import \\
|
||||||
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
|
||||||
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
|
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
|
||||||
compute_and_sort as \\
|
compute_and_sort as \\
|
||||||
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
|
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
|
||||||
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
from codeflash.benchmarking.replay_test import get_next_arg_and_return
|
||||||
|
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
|
||||||
|
|
||||||
functions = ['compute_and_sort', 'sorter']
|
functions = ['compute_and_sort', 'sorter']
|
||||||
trace_file_path = r"{output_file}"
|
trace_file_path = r"{output_file}"
|
||||||
|
|
@ -284,4 +281,4 @@ def test_trace_benchmark_decorator() -> None:
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# cleanup
|
# cleanup
|
||||||
output_file.unlink(missing_ok=True)
|
output_file.unlink(missing_ok=True)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue