end to end test that proves picklepatcher works. example shown is a socket (which is unpickleable) that's used or not used

This commit is contained in:
Alvin Ryanputra 2025-04-10 21:43:56 -04:00
parent 40e416e0d0
commit 3158f9cc1c
13 changed files with 351 additions and 193 deletions

View file

@ -0,0 +1,18 @@
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_unused_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
return sorted(numbers)
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
socket = data_container.get('socket')
socket.send("Hello from the optimized function!")
return sorted(numbers)

View file

@ -1,38 +1,6 @@
def bubble_sort_with_unused_socket(data_container):
"""
Performs a bubble sort on a list within the data_container. The data container has the following schema:
- 'numbers' (list): The list to be sorted.
- 'socket' (socket): A socket
Args:
data_container: A dictionary with at least 'numbers' (list) and 'socket' keys
Returns:
list: The sorted list of numbers
"""
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
# Classic bubble sort implementation
n = len(numbers)
for i in range(n):
# Flag to optimize by detecting if no swaps occurred
swapped = False
# Last i elements are already in place
for j in range(0, n - i - 1):
# Swap if the element is greater than the next element
if numbers[j] > numbers[j + 1]:
numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j]
swapped = True
# If no swapping occurred in this pass, the list is sorted
if not swapped:
break
return numbers
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
"""
Performs a bubble sort on a list within the data_container. The data container has the following schema:

View file

@ -0,0 +1,20 @@
import socket
from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket
from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket
def test_socket_picklepatch(benchmark):
s1, s2 = socket.socketpair()
data = {
"numbers": list(reversed(range(500))),
"socket": s1
}
benchmark(bubble_sort_with_unused_socket, data)
def test_used_socket_picklepatch(benchmark):
s1, s2 = socket.socketpair()
data = {
"numbers": list(reversed(range(500))),
"socket": s1
}
benchmark(bubble_sort_with_used_socket, data)

View file

@ -1,34 +0,0 @@
import socket
from unittest.mock import Mock
import pytest
from code_to_optimize.bubble_sort_picklepatch import bubble_sort_with_unused_socket, bubble_sort_with_used_socket
def test_bubble_sort_with_unused_socket():
mock_socket = Mock()
# Test case 1: Regular unsorted list
data_container = {
'numbers': [5, 2, 9, 1, 5, 6],
'socket': mock_socket
}
result = bubble_sort_with_unused_socket(data_container)
# Check that the result is correctly sorted
assert result == [1, 2, 5, 5, 6, 9]
def test_bubble_sort_with_used_socket():
mock_socket = Mock()
# Test case 1: Regular unsorted list
data_container = {
'numbers': [5, 2, 9, 1, 5, 6],
'socket': mock_socket
}
result = bubble_sort_with_used_socket(data_container)
# Check that the result is correctly sorted
assert result == [1, 2, 5, 5, 6, 9]

View file

@ -2,12 +2,11 @@ import functools
import os
import pickle
import sqlite3
import sys
import threading
import time
from typing import Callable
import dill
from codeflash.picklepatch.pickle_patcher import PicklePatcher
class CodeflashTrace:
@ -147,23 +146,10 @@ class CodeflashTrace:
return result
try:
original_recursion_limit = sys.getrecursionlimit()
sys.setrecursionlimit(10000)
# args = dict(args.items())
# if class_name and func.__name__ == "__init__" and "self" in args:
# del args["self"]
# Pickle the arguments
pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError):
# Retry with dill if pickle fails. It's slower but more comprehensive
try:
pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
sys.setrecursionlimit(original_recursion_limit)
except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e:
pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL)
pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL)
except Exception as e:
print(f"Error pickling arguments for function {func.__name__}: {e}")
# Add to the list of function calls without pickled args. Used for timing info only
self._thread_local.active_functions.remove(func_id)
@ -174,7 +160,6 @@ class CodeflashTrace:
overhead_time, None, None)
)
return result
# Flush to database every 1000 calls
if len(self.function_calls_data) > 1000:
self.write_function_timings()

View file

@ -175,7 +175,6 @@ class CodeFlashBenchmarkPlugin:
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
# Subtract overhead from total time
overhead = overhead_by_benchmark.get(benchmark_key, 0)
print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead)
result[benchmark_key] = time_ns - overhead
finally:
@ -267,9 +266,9 @@ class CodeFlashBenchmarkPlugin:
os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number)
os.environ["CODEFLASH_BENCHMARKING"] = "True"
# Run the function
start = time.thread_time_ns()
start = time.time_ns()
result = func(*args, **kwargs)
end = time.thread_time_ns()
end = time.time_ns()
# Reset the environment variable
os.environ["CODEFLASH_BENCHMARKING"] = "False"

View file

@ -62,7 +62,7 @@ def create_trace_replay_test_code(
assert test_framework in ["pytest", "unittest"]
# Create Imports
imports = f"""import dill as pickle
imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
{"import unittest" if test_framework == "unittest" else ""}
from codeflash.benchmarking.replay_test import get_next_arg_and_return
"""

View file

@ -16,7 +16,7 @@ from collections.abc import Collection
from enum import Enum, IntEnum
from pathlib import Path
from re import Pattern
from typing import Annotated, Any, Optional, Union, cast
from typing import Annotated, Optional, cast
from jedi.api.classes import Name
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
@ -362,6 +362,7 @@ class FunctionCoverage:
class TestingMode(enum.Enum):
BEHAVIOR = "behavior"
PERFORMANCE = "performance"
LINE_PROFILE = "line_profile"
class VerificationType(str, Enum):
@ -533,7 +534,7 @@ class TestResults(BaseModel):
tree.add(
f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}"
)
return
return tree
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:

View file

@ -1,3 +1,8 @@
class PicklePlaceholderAccessError(Exception):
"""Custom exception raised when attempting to access an unpicklable object."""
class PicklePlaceholder:
"""A placeholder for an object that couldn't be pickled.
@ -22,22 +27,22 @@ class PicklePlaceholder:
self.__dict__["path"] = path if path is not None else []
def __getattr__(self, name):
"""Raise an error when any attribute is accessed."""
"""Raise a custom error when any attribute is accessed."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
raise AttributeError(
f"Cannot access attribute '{name}' on unpicklable object at {path_str}. "
raise PicklePlaceholderAccessError(
f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. "
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
)
def __setattr__(self, name, value):
"""Prevent setting attributes."""
self.__getattr__(name) # This will raise an AttributeError
self.__getattr__(name) # This will raise our custom error
def __call__(self, *args, **kwargs):
"""Raise an error when the object is called."""
"""Raise a custom error when the object is called."""
path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object"
raise TypeError(
f"Cannot call unpicklable object at {path_str}. "
raise PicklePlaceholderAccessError(
f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. "
f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}"
)

View file

@ -10,6 +10,7 @@ from typing import Any
import sentry_sdk
from codeflash.cli_cmds.console import logger
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError
try:
import numpy as np
@ -64,7 +65,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
if len(orig) != len(new):
return False
return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new))
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
# The test results should be rejected as the behavior of the unpickleable object is unknown.
logger.debug("Unable to verify behavior of unpickleable object in replay test")
return False
if isinstance(
orig,
(

View file

@ -8,6 +8,7 @@ from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING
import dill as pickle
from junitparser.xunit2 import JUnitXml
from lxml.etree import XMLParser, parse
@ -20,7 +21,6 @@ from codeflash.code_utils.code_utils import (
)
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType
from codeflash.picklepatch.pickle_patcher import PicklePatcher
from codeflash.verification.coverage_utils import CoverageUtils
if TYPE_CHECKING:
@ -75,7 +75,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path)
try:
test_pickle = PicklePatcher.loads(test_pickle_bin) if loop_index == 1 else None
test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None
except Exception as e:
if DEBUG_MODE:
logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}")
@ -133,7 +133,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
# TODO : this is because sqlite writes original file module path. Should make it consistent
test_type = test_files.get_test_type_by_original_file_path(test_file_path)
try:
ret_val = (PicklePatcher.loads(val[7]) if loop_index == 1 else None,)
ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,)
except Exception:
continue
test_results.add(

View file

@ -1,34 +1,40 @@
import os
import pickle
import shutil
import socket
import sqlite3
from argparse import Namespace
from pathlib import Path
import dill
import pytest
import requests
import sqlite3
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode, TestType, TestFiles, TestFile
from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType
from codeflash.optimization.optimizer import Optimizer
from codeflash.verification.equivalence import compare_test_results
try:
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy import create_engine, Column, Integer, String
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
HAS_SQLALCHEMY = True
except ImportError:
HAS_SQLALCHEMY = False
from codeflash.picklepatch.pickle_patcher import PicklePatcher
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder
from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError
def test_picklepatch_simple_nested():
"""
Test that a simple nested data structure pickles and unpickles correctly.
"""Test that a simple nested data structure pickles and unpickles correctly.
"""
original_data = {
"numbers": [1, 2, 3],
@ -41,17 +47,24 @@ def test_picklepatch_simple_nested():
assert reloaded == original_data
# Everything was pickleable, so no placeholders should appear.
def test_picklepatch_with_socket():
"""
Test that a data structure containing a raw socket is replaced by
"""Test that a data structure containing a raw socket is replaced by
PicklePlaceholder rather than raising an error.
"""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Create a pair of connected sockets instead of a single socket
sock1, sock2 = socket.socketpair()
data_with_socket = {
"safe_value": 123,
"raw_socket": s,
"raw_socket": sock1,
}
# Send a message through sock1, which can be received by sock2
sock1.send(b"Hello, world!")
received = sock2.recv(1024)
assert received == b"Hello, world!"
# Pickle the data structure containing the socket
dumped = PicklePatcher.dumps(data_with_socket)
reloaded = PicklePatcher.loads(dumped)
@ -62,13 +75,16 @@ def test_picklepatch_with_socket():
# Attempting to use or access attributes => AttributeError
# (not RuntimeError as in original tests, our implementation uses AttributeError)
with pytest.raises(AttributeError) :
with pytest.raises(PicklePlaceholderAccessError):
reloaded["raw_socket"].recv(1024)
# Clean up by closing both sockets
sock1.close()
sock2.close()
def test_picklepatch_deeply_nested():
"""
Test that deep nesting with unpicklable objects works correctly.
"""Test that deep nesting with unpicklable objects works correctly.
"""
# Create a deeply nested structure with an unpicklable object
deep_nested = {
@ -92,8 +108,7 @@ def test_picklepatch_deeply_nested():
assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder)
def test_picklepatch_class_with_unpicklable_attr():
"""
Test that a class with an unpicklable attribute works correctly.
"""Test that a class with an unpicklable attribute works correctly.
"""
class TestClass:
def __init__(self):
@ -115,12 +130,11 @@ def test_picklepatch_class_with_unpicklable_attr():
def test_picklepatch_with_database_connection():
"""
Test that a data structure containing a database connection is replaced
"""Test that a data structure containing a database connection is replaced
by PicklePlaceholder rather than raising an error.
"""
# SQLite connection - not pickleable
conn = sqlite3.connect(':memory:')
conn = sqlite3.connect(":memory:")
cursor = conn.cursor()
data_with_db = {
@ -139,13 +153,12 @@ def test_picklepatch_with_database_connection():
assert isinstance(reloaded["cursor"], PicklePlaceholder)
# Attempting to use attributes => AttributeError
with pytest.raises(AttributeError):
with pytest.raises(PicklePlaceholderAccessError):
reloaded["connection"].execute("SELECT 1")
def test_picklepatch_with_generator():
"""
Test that a data structure containing a generator is replaced by
"""Test that a data structure containing a generator is replaced by
PicklePlaceholder rather than raising an error.
"""
@ -178,13 +191,12 @@ def test_picklepatch_with_generator():
next(reloaded["generator"])
# Attempting to call methods on the generator => AttributeError
with pytest.raises(AttributeError):
with pytest.raises(PicklePlaceholderAccessError):
reloaded["generator"].send(None)
def test_picklepatch_loads_standard_pickle():
"""
Test that PicklePatcher.loads can correctly load data that was pickled
"""Test that PicklePatcher.loads can correctly load data that was pickled
using the standard pickle module.
"""
# Create a simple data structure
@ -209,12 +221,10 @@ def test_picklepatch_loads_standard_pickle():
def test_picklepatch_loads_dill_pickle():
"""
Test that PicklePatcher.loads can correctly load data that was pickled
"""Test that PicklePatcher.loads can correctly load data that was pickled
using the dill module, which can pickle more complex objects than the
standard pickle module.
"""
# Create a more complex data structure that includes a lambda function
# which dill can handle but standard pickle cannot
original_data = {
@ -240,80 +250,264 @@ def test_picklepatch_loads_dill_pickle():
assert reloaded["nested"]["another_function"](4) == 16
def test_run_and_parse_picklepatch() -> None:
"""Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests.
test_path = (
Path(__file__).parent.resolve()
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py"
).resolve()
test_path_perf = (
Path(__file__).parent.resolve()
/ "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch_perf.py"
).resolve()
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_picklepatch.py").resolve()
original_test =test_path.read_text("utf-8")
The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results.
Here, we are simply 'ignoring' the unused unpickleable object.
The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object.
Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results,
since we were not able to reuse the unpickleable object in the replay test.
"""
# Init paths
project_root = Path(__file__).parent.parent.resolve()
tests_root = project_root / "code_to_optimize" / "tests" / "pytest"
benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test"
replay_tests_dir = benchmarks_root / "codeflash_replay_tests"
output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve()
fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve()
fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve()
original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8")
original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8")
# Trace benchmarks
trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file)
assert output_file.exists()
try:
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
project_root_path = (Path(__file__).parent / "..").resolve()
# Check contents
conn = sqlite3.connect(output_file.as_posix())
cursor = conn.cursor()
cursor.execute(
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
function_calls = cursor.fetchall()
# Assert the length of function calls
assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}"
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file)
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
test_name, total_time, function_time, percent = \
function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0]
assert total_time > 0.0
assert function_time > 0.0
assert percent > 0.0
bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix()
bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix()
# Expected function calls
expected_calls = [
("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket",
f"{bubble_sort_unused_socket_path}",
"test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12),
("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket",
f"{bubble_sort_used_socket_path}",
"test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20)
]
for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)):
assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name"
assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name"
assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name"
assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path"
assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name"
assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path"
assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number"
conn.close()
# Generate replay test
generate_replay_test(output_file, replay_tests_dir)
replay_test_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py")
replay_test_perf_path = replay_tests_dir / Path(
"test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py")
assert replay_test_path.exists()
original_replay_test_code = replay_test_path.read_text("utf-8")
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path))
original_cwd = Path.cwd()
run_cwd = Path(__file__).parent.parent.resolve()
func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_path))
run_cwd = project_root
os.chdir(run_cwd)
success, new_test = inject_profiling_into_existing_test(
test_path,
[CodePosition(13,14), CodePosition(31,14)],
replay_test_path,
[CodePosition(17, 15)],
func,
project_root_path,
project_root,
"pytest",
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
with test_path.open("w") as f:
f.write(new_test)
replay_test_path.write_text(new_test)
opt = Optimizer(
Namespace(
project_root=project_root_path,
project_root=project_root,
disable_telemetry=True,
tests_root=tests_root,
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=project_root_path,
test_project_root=project_root,
)
)
# Run the replay test for the original code that does not use the socket
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.EXISTING_UNIT_TEST
test_type = TestType.REPLAY_TEST
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=test_path,
instrumented_behavior_file_path=replay_test_path,
test_type=test_type,
original_file_path=test_path,
benchmarking_file_path=test_path_perf,
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)],
)
]
)
test_results, coverage_data = func_optimizer.run_and_parse_tests(
test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=0.1,
testing_time=1.0,
)
assert test_results.test_results[0].id.test_function_name =="test_bubble_sort_with_unused_socket"
assert test_results.test_results[0].did_pass ==True
assert test_results.test_results[1].id.test_function_name =="test_bubble_sort_with_used_socket"
assert test_results.test_results[1].did_pass ==False
# assert pickle placeholder problem
print(test_results)
assert len(test_results_unused_socket) == 1
assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket"
assert test_results_unused_socket.test_results[0].did_pass == True
# Replace with optimized candidate
fto_unused_socket_path.write_text("""
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_unused_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
return sorted(numbers)
""")
# Run optimized code for unused socket
optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(optimized_test_results_unused_socket) == 1
verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket)
assert verification_result is True
# Remove the previous instrumentation
replay_test_path.write_text(original_replay_test_code)
# Instrument the replay test
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path))
success, new_test = inject_profiling_into_existing_test(
replay_test_path,
[CodePosition(23,15)],
func,
project_root,
"pytest",
mode=TestingMode.BEHAVIOR,
)
os.chdir(original_cwd)
assert success
assert new_test is not None
replay_test_path.write_text(new_test)
# Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
test_env = os.environ.copy()
test_env["CODEFLASH_TEST_ITERATION"] = "0"
test_env["CODEFLASH_LOOP_INDEX"] = "1"
test_type = TestType.REPLAY_TEST
func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[],
file_path=Path(fto_used_socket_path))
replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
func_optimizer = opt.create_function_optimizer(func)
func_optimizer.test_files = TestFiles(
test_files=[
TestFile(
instrumented_behavior_file_path=replay_test_path,
test_type=test_type,
original_file_path=replay_test_path,
benchmarking_file_path=replay_test_perf_path,
tests_in_file=[
TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function,
test_type=test_type)],
)
]
)
test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
assert test_results_used_socket.test_results[0].did_pass is False
print("test results used socket")
print(test_results_used_socket)
# Replace with optimized candidate
fto_used_socket_path.write_text("""
from codeflash.benchmarking.codeflash_trace import codeflash_trace
@codeflash_trace
def bubble_sort_with_used_socket(data_container):
# Extract the list to sort, leaving the socket untouched
numbers = data_container.get('numbers', []).copy()
socket = data_container.get('socket')
socket.send("Hello from the optimized function!")
return sorted(numbers)
""")
# Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed.
optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests(
testing_type=TestingMode.BEHAVIOR,
test_env=test_env,
test_files=func_optimizer.test_files,
optimization_iteration=0,
pytest_min_loops=1,
pytest_max_loops=1,
testing_time=1.0,
)
assert len(test_results_used_socket) == 1
assert test_results_used_socket.test_results[
0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0"
assert test_results_used_socket.test_results[
0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket"
assert test_results_used_socket.test_results[0].did_pass is False
# Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined.
assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False
finally:
test_path.write_text(original_test)
# cleanup
output_file.unlink(missing_ok=True)
shutil.rmtree(replay_tests_dir, ignore_errors=True)
fto_unused_socket_path.write_text(original_fto_unused_socket_code)
fto_used_socket_path.write_text(original_fto_used_socket_code)

View file

@ -1,15 +1,14 @@
import shutil
import sqlite3
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.replay_test import generate_replay_test
from pathlib import Path
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
import shutil
from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin
from codeflash.benchmarking.replay_test import generate_replay_test
from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest
from codeflash.benchmarking.utils import validate_and_format_benchmark_table
def test_trace_benchmarks():
def test_trace_benchmarks() -> None:
# Test the trace_benchmarks function
project_root = Path(__file__).parent.parent / "code_to_optimize"
benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test"
@ -83,13 +82,12 @@ def test_trace_benchmarks():
test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py")
assert test_class_sort_path.exists()
test_class_sort_code = f"""
import dill as pickle
from code_to_optimize.bubble_sort_codeflash_trace import \\
Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter
from code_to_optimize.bubble_sort_codeflash_trace import \\
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
from codeflash.benchmarking.replay_test import get_next_arg_and_return
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
functions = ['sort_class', 'sort_static', 'sorter']
trace_file_path = r"{output_file.as_posix()}"
@ -146,14 +144,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__():
test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py")
assert test_sort_path.exists()
test_sort_code = f"""
import dill as pickle
from code_to_optimize.bubble_sort_codeflash_trace import \\
sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter
from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\
compute_and_sort as \\
code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort
from codeflash.benchmarking.replay_test import get_next_arg_and_return
from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle
functions = ['compute_and_sort', 'sorter']
trace_file_path = r"{output_file}"