diff --git a/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py b/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py new file mode 100644 index 000000000..2b75a8c34 --- /dev/null +++ b/code_to_optimize/bubble_sort_picklepatch_test_unused_socket.py @@ -0,0 +1,18 @@ + +from codeflash.benchmarking.codeflash_trace import codeflash_trace + + +@codeflash_trace +def bubble_sort_with_unused_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + + return sorted(numbers) + +@codeflash_trace +def bubble_sort_with_used_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + socket = data_container.get('socket') + socket.send("Hello from the optimized function!") + return sorted(numbers) diff --git a/code_to_optimize/bubble_sort_picklepatch.py b/code_to_optimize/bubble_sort_picklepatch_test_used_socket.py similarity index 55% rename from code_to_optimize/bubble_sort_picklepatch.py rename to code_to_optimize/bubble_sort_picklepatch_test_used_socket.py index 25cbe9628..390e090cd 100644 --- a/code_to_optimize/bubble_sort_picklepatch.py +++ b/code_to_optimize/bubble_sort_picklepatch_test_used_socket.py @@ -1,38 +1,6 @@ -def bubble_sort_with_unused_socket(data_container): - """ - Performs a bubble sort on a list within the data_container. The data container has the following schema: - - 'numbers' (list): The list to be sorted. - - 'socket' (socket): A socket - - Args: - data_container: A dictionary with at least 'numbers' (list) and 'socket' keys - - Returns: - list: The sorted list of numbers - """ - # Extract the list to sort, leaving the socket untouched - numbers = data_container.get('numbers', []).copy() - - # Classic bubble sort implementation - n = len(numbers) - for i in range(n): - # Flag to optimize by detecting if no swaps occurred - swapped = False - - # Last i elements are already in place - for j in range(0, n - i - 1): - # Swap if the element is greater than the next element - if numbers[j] > numbers[j + 1]: - numbers[j], numbers[j + 1] = numbers[j + 1], numbers[j] - swapped = True - - # If no swapping occurred in this pass, the list is sorted - if not swapped: - break - - return numbers - +from codeflash.benchmarking.codeflash_trace import codeflash_trace +@codeflash_trace def bubble_sort_with_used_socket(data_container): """ Performs a bubble sort on a list within the data_container. The data container has the following schema: diff --git a/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py b/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py new file mode 100644 index 000000000..bd05af487 --- /dev/null +++ b/code_to_optimize/tests/pytest/benchmarks_socket_test/test_socket.py @@ -0,0 +1,20 @@ +import socket + +from code_to_optimize.bubble_sort_picklepatch_test_unused_socket import bubble_sort_with_unused_socket +from code_to_optimize.bubble_sort_picklepatch_test_used_socket import bubble_sort_with_used_socket + +def test_socket_picklepatch(benchmark): + s1, s2 = socket.socketpair() + data = { + "numbers": list(reversed(range(500))), + "socket": s1 + } + benchmark(bubble_sort_with_unused_socket, data) + +def test_used_socket_picklepatch(benchmark): + s1, s2 = socket.socketpair() + data = { + "numbers": list(reversed(range(500))), + "socket": s1 + } + benchmark(bubble_sort_with_used_socket, data) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py b/code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py deleted file mode 100644 index 9f3e0f9af..000000000 --- a/code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py +++ /dev/null @@ -1,34 +0,0 @@ -import socket -from unittest.mock import Mock - -import pytest - -from code_to_optimize.bubble_sort_picklepatch import bubble_sort_with_unused_socket, bubble_sort_with_used_socket - - -def test_bubble_sort_with_unused_socket(): - mock_socket = Mock() - # Test case 1: Regular unsorted list - data_container = { - 'numbers': [5, 2, 9, 1, 5, 6], - 'socket': mock_socket - } - - result = bubble_sort_with_unused_socket(data_container) - - # Check that the result is correctly sorted - assert result == [1, 2, 5, 5, 6, 9] - -def test_bubble_sort_with_used_socket(): - mock_socket = Mock() - # Test case 1: Regular unsorted list - data_container = { - 'numbers': [5, 2, 9, 1, 5, 6], - 'socket': mock_socket - } - - result = bubble_sort_with_used_socket(data_container) - - # Check that the result is correctly sorted - assert result == [1, 2, 5, 5, 6, 9] - diff --git a/codeflash/benchmarking/codeflash_trace.py b/codeflash/benchmarking/codeflash_trace.py index 95318a38a..2694532f3 100644 --- a/codeflash/benchmarking/codeflash_trace.py +++ b/codeflash/benchmarking/codeflash_trace.py @@ -2,12 +2,11 @@ import functools import os import pickle import sqlite3 -import sys import threading import time from typing import Callable -import dill +from codeflash.picklepatch.pickle_patcher import PicklePatcher class CodeflashTrace: @@ -147,34 +146,20 @@ class CodeflashTrace: return result try: - original_recursion_limit = sys.getrecursionlimit() - sys.setrecursionlimit(10000) - # args = dict(args.items()) - # if class_name and func.__name__ == "__init__" and "self" in args: - # del args["self"] # Pickle the arguments - pickled_args = pickle.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) - pickled_kwargs = pickle.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) - sys.setrecursionlimit(original_recursion_limit) - except (TypeError, pickle.PicklingError, AttributeError, RecursionError, OSError): - # Retry with dill if pickle fails. It's slower but more comprehensive - try: - pickled_args = dill.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) - pickled_kwargs = dill.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) - sys.setrecursionlimit(original_recursion_limit) - - except (TypeError, dill.PicklingError, AttributeError, RecursionError, OSError) as e: - print(f"Error pickling arguments for function {func.__name__}: {e}") - # Add to the list of function calls without pickled args. Used for timing info only - self._thread_local.active_functions.remove(func_id) - overhead_time = time.thread_time_ns() - end_time - self.function_calls_data.append( - (func.__name__, class_name, func.__module__, func.__code__.co_filename, - benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, - overhead_time, None, None) - ) - return result - + pickled_args = PicklePatcher.dumps(args, protocol=pickle.HIGHEST_PROTOCOL) + pickled_kwargs = PicklePatcher.dumps(kwargs, protocol=pickle.HIGHEST_PROTOCOL) + except Exception as e: + print(f"Error pickling arguments for function {func.__name__}: {e}") + # Add to the list of function calls without pickled args. Used for timing info only + self._thread_local.active_functions.remove(func_id) + overhead_time = time.thread_time_ns() - end_time + self.function_calls_data.append( + (func.__name__, class_name, func.__module__, func.__code__.co_filename, + benchmark_function_name, benchmark_module_path, benchmark_line_number, execution_time, + overhead_time, None, None) + ) + return result # Flush to database every 1000 calls if len(self.function_calls_data) > 1000: self.write_function_timings() diff --git a/codeflash/benchmarking/plugin/plugin.py b/codeflash/benchmarking/plugin/plugin.py index f1614b5c8..313817041 100644 --- a/codeflash/benchmarking/plugin/plugin.py +++ b/codeflash/benchmarking/plugin/plugin.py @@ -175,7 +175,6 @@ class CodeFlashBenchmarkPlugin: benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func) # Subtract overhead from total time overhead = overhead_by_benchmark.get(benchmark_key, 0) - print("benchmark_func:", benchmark_func, "Total time:", time_ns, "Overhead:", overhead, "Result:", time_ns - overhead) result[benchmark_key] = time_ns - overhead finally: @@ -267,9 +266,9 @@ class CodeFlashBenchmarkPlugin: os.environ["CODEFLASH_BENCHMARK_LINE_NUMBER"] = str(line_number) os.environ["CODEFLASH_BENCHMARKING"] = "True" # Run the function - start = time.thread_time_ns() + start = time.time_ns() result = func(*args, **kwargs) - end = time.thread_time_ns() + end = time.time_ns() # Reset the environment variable os.environ["CODEFLASH_BENCHMARKING"] = "False" diff --git a/codeflash/benchmarking/replay_test.py b/codeflash/benchmarking/replay_test.py index 445957505..ee1107241 100644 --- a/codeflash/benchmarking/replay_test.py +++ b/codeflash/benchmarking/replay_test.py @@ -62,7 +62,7 @@ def create_trace_replay_test_code( assert test_framework in ["pytest", "unittest"] # Create Imports - imports = f"""import dill as pickle + imports = f"""from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle {"import unittest" if test_framework == "unittest" else ""} from codeflash.benchmarking.replay_test import get_next_arg_and_return """ diff --git a/codeflash/models/models.py b/codeflash/models/models.py index aede322a1..791912b8a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -16,7 +16,7 @@ from collections.abc import Collection from enum import Enum, IntEnum from pathlib import Path from re import Pattern -from typing import Annotated, Any, Optional, Union, cast +from typing import Annotated, Optional, cast from jedi.api.classes import Name from pydantic import AfterValidator, BaseModel, ConfigDict, Field @@ -362,6 +362,7 @@ class FunctionCoverage: class TestingMode(enum.Enum): BEHAVIOR = "behavior" PERFORMANCE = "performance" + LINE_PROFILE = "line_profile" class VerificationType(str, Enum): @@ -533,7 +534,7 @@ class TestResults(BaseModel): tree.add( f"{test_type.to_name()} - Passed: {report[test_type]['passed']}, Failed: {report[test_type]['failed']}" ) - return + return tree def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: @@ -606,4 +607,4 @@ class TestResults(BaseModel): sys.setrecursionlimit(original_recursion_limit) return False sys.setrecursionlimit(original_recursion_limit) - return True \ No newline at end of file + return True diff --git a/codeflash/picklepatch/pickle_placeholder.py b/codeflash/picklepatch/pickle_placeholder.py index cddb6535a..a422abb45 100644 --- a/codeflash/picklepatch/pickle_placeholder.py +++ b/codeflash/picklepatch/pickle_placeholder.py @@ -1,3 +1,8 @@ +class PicklePlaceholderAccessError(Exception): + """Custom exception raised when attempting to access an unpicklable object.""" + + + class PicklePlaceholder: """A placeholder for an object that couldn't be pickled. @@ -22,22 +27,22 @@ class PicklePlaceholder: self.__dict__["path"] = path if path is not None else [] def __getattr__(self, name): - """Raise an error when any attribute is accessed.""" + """Raise a custom error when any attribute is accessed.""" path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" - raise AttributeError( - f"Cannot access attribute '{name}' on unpicklable object at {path_str}. " + raise PicklePlaceholderAccessError( + f"Attempt to access unpickleable object: Cannot access attribute '{name}' on unpicklable object at {path_str}. " f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" ) def __setattr__(self, name, value): """Prevent setting attributes.""" - self.__getattr__(name) # This will raise an AttributeError + self.__getattr__(name) # This will raise our custom error def __call__(self, *args, **kwargs): - """Raise an error when the object is called.""" + """Raise a custom error when the object is called.""" path_str = ".".join(self.__dict__["path"]) if self.__dict__["path"] else "root object" - raise TypeError( - f"Cannot call unpicklable object at {path_str}. " + raise PicklePlaceholderAccessError( + f"Attempt to access unpickleable object: Cannot call unpicklable object at {path_str}. " f"Original type: {self.__dict__['obj_type']}. Error: {self.__dict__['error_msg']}" ) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index f047d5b3c..8a7048c57 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -10,6 +10,7 @@ from typing import Any import sentry_sdk from codeflash.cli_cmds.console import logger +from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError try: import numpy as np @@ -64,7 +65,11 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: if len(orig) != len(new): return False return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) - + if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): + # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. + # The test results should be rejected as the behavior of the unpickleable object is unknown. + logger.debug("Unable to verify behavior of unpickleable object in replay test") + return False if isinstance( orig, ( diff --git a/codeflash/verification/parse_test_output.py b/codeflash/verification/parse_test_output.py index 80d711894..2228559f9 100644 --- a/codeflash/verification/parse_test_output.py +++ b/codeflash/verification/parse_test_output.py @@ -8,6 +8,7 @@ from collections import defaultdict from pathlib import Path from typing import TYPE_CHECKING +import dill as pickle from junitparser.xunit2 import JUnitXml from lxml.etree import XMLParser, parse @@ -20,7 +21,6 @@ from codeflash.code_utils.code_utils import ( ) from codeflash.discovery.discover_unit_tests import discover_parameters_unittest from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType, VerificationType -from codeflash.picklepatch.pickle_patcher import PicklePatcher from codeflash.verification.coverage_utils import CoverageUtils if TYPE_CHECKING: @@ -75,7 +75,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes test_type = test_files.get_test_type_by_instrumented_file_path(test_file_path) try: - test_pickle = PicklePatcher.loads(test_pickle_bin) if loop_index == 1 else None + test_pickle = pickle.loads(test_pickle_bin) if loop_index == 1 else None except Exception as e: if DEBUG_MODE: logger.exception(f"Failed to load pickle file for {encoded_test_name} Exception: {e}") @@ -133,7 +133,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes # TODO : this is because sqlite writes original file module path. Should make it consistent test_type = test_files.get_test_type_by_original_file_path(test_file_path) try: - ret_val = (PicklePatcher.loads(val[7]) if loop_index == 1 else None,) + ret_val = (pickle.loads(val[7]) if loop_index == 1 else None,) except Exception: continue test_results.add( diff --git a/tests/test_pickle_patcher.py b/tests/test_pickle_patcher.py index 05bd06f15..3d2f21b66 100644 --- a/tests/test_pickle_patcher.py +++ b/tests/test_pickle_patcher.py @@ -1,34 +1,40 @@ import os import pickle +import shutil import socket +import sqlite3 from argparse import Namespace from pathlib import Path import dill import pytest -import requests -import sqlite3 +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin +from codeflash.benchmarking.replay_test import generate_replay_test +from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest +from codeflash.benchmarking.utils import validate_and_format_benchmark_table from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodePosition, TestingMode, TestType, TestFiles, TestFile +from codeflash.models.models import CodePosition, TestFile, TestFiles, TestingMode, TestsInFile, TestType from codeflash.optimization.optimizer import Optimizer +from codeflash.verification.equivalence import compare_test_results try: import sqlalchemy - from sqlalchemy.orm import Session - from sqlalchemy import create_engine, Column, Integer, String + from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.ext.declarative import declarative_base + from sqlalchemy.orm import Session HAS_SQLALCHEMY = True except ImportError: HAS_SQLALCHEMY = False from codeflash.picklepatch.pickle_patcher import PicklePatcher -from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder +from codeflash.picklepatch.pickle_placeholder import PicklePlaceholder, PicklePlaceholderAccessError + + def test_picklepatch_simple_nested(): - """ - Test that a simple nested data structure pickles and unpickles correctly. + """Test that a simple nested data structure pickles and unpickles correctly. """ original_data = { "numbers": [1, 2, 3], @@ -41,17 +47,24 @@ def test_picklepatch_simple_nested(): assert reloaded == original_data # Everything was pickleable, so no placeholders should appear. + def test_picklepatch_with_socket(): - """ - Test that a data structure containing a raw socket is replaced by + """Test that a data structure containing a raw socket is replaced by PicklePlaceholder rather than raising an error. """ - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Create a pair of connected sockets instead of a single socket + sock1, sock2 = socket.socketpair() + data_with_socket = { "safe_value": 123, - "raw_socket": s, + "raw_socket": sock1, } + # Send a message through sock1, which can be received by sock2 + sock1.send(b"Hello, world!") + received = sock2.recv(1024) + assert received == b"Hello, world!" + # Pickle the data structure containing the socket dumped = PicklePatcher.dumps(data_with_socket) reloaded = PicklePatcher.loads(dumped) @@ -60,15 +73,18 @@ def test_picklepatch_with_socket(): assert reloaded["safe_value"] == 123 assert isinstance(reloaded["raw_socket"], PicklePlaceholder) - # Attempting to use or access attributes => AttributeError + # Attempting to use or access attributes => AttributeError # (not RuntimeError as in original tests, our implementation uses AttributeError) - with pytest.raises(AttributeError) : + with pytest.raises(PicklePlaceholderAccessError): reloaded["raw_socket"].recv(1024) + # Clean up by closing both sockets + sock1.close() + sock2.close() + def test_picklepatch_deeply_nested(): - """ - Test that deep nesting with unpicklable objects works correctly. + """Test that deep nesting with unpicklable objects works correctly. """ # Create a deeply nested structure with an unpicklable object deep_nested = { @@ -92,8 +108,7 @@ def test_picklepatch_deeply_nested(): assert isinstance(reloaded["level1"]["level2"]["level3"]["socket"], PicklePlaceholder) def test_picklepatch_class_with_unpicklable_attr(): - """ - Test that a class with an unpicklable attribute works correctly. + """Test that a class with an unpicklable attribute works correctly. """ class TestClass: def __init__(self): @@ -115,12 +130,11 @@ def test_picklepatch_class_with_unpicklable_attr(): def test_picklepatch_with_database_connection(): - """ - Test that a data structure containing a database connection is replaced + """Test that a data structure containing a database connection is replaced by PicklePlaceholder rather than raising an error. """ # SQLite connection - not pickleable - conn = sqlite3.connect(':memory:') + conn = sqlite3.connect(":memory:") cursor = conn.cursor() data_with_db = { @@ -139,13 +153,12 @@ def test_picklepatch_with_database_connection(): assert isinstance(reloaded["cursor"], PicklePlaceholder) # Attempting to use attributes => AttributeError - with pytest.raises(AttributeError): + with pytest.raises(PicklePlaceholderAccessError): reloaded["connection"].execute("SELECT 1") def test_picklepatch_with_generator(): - """ - Test that a data structure containing a generator is replaced by + """Test that a data structure containing a generator is replaced by PicklePlaceholder rather than raising an error. """ @@ -178,13 +191,12 @@ def test_picklepatch_with_generator(): next(reloaded["generator"]) # Attempting to call methods on the generator => AttributeError - with pytest.raises(AttributeError): + with pytest.raises(PicklePlaceholderAccessError): reloaded["generator"].send(None) def test_picklepatch_loads_standard_pickle(): - """ - Test that PicklePatcher.loads can correctly load data that was pickled + """Test that PicklePatcher.loads can correctly load data that was pickled using the standard pickle module. """ # Create a simple data structure @@ -209,12 +221,10 @@ def test_picklepatch_loads_standard_pickle(): def test_picklepatch_loads_dill_pickle(): - """ - Test that PicklePatcher.loads can correctly load data that was pickled + """Test that PicklePatcher.loads can correctly load data that was pickled using the dill module, which can pickle more complex objects than the standard pickle module. """ - # Create a more complex data structure that includes a lambda function # which dill can handle but standard pickle cannot original_data = { @@ -240,80 +250,264 @@ def test_picklepatch_loads_dill_pickle(): assert reloaded["nested"]["another_function"](4) == 16 def test_run_and_parse_picklepatch() -> None: + """Test the end to end functionality of picklepatch, from tracing benchmarks to running the replay tests. - test_path = ( - Path(__file__).parent.resolve() - / "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch.py" - ).resolve() - test_path_perf = ( - Path(__file__).parent.resolve() - / "../code_to_optimize/tests/pytest/test_bubble_sort_picklepatch_perf.py" - ).resolve() - fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_picklepatch.py").resolve() - original_test =test_path.read_text("utf-8") + The first example has an argument (an object containing a socket) that is not pickleable However, the socket attributs is not used, so we are able to compare the test results with the optimized test results. + Here, we are simply 'ignoring' the unused unpickleable object. + + The second example also has an argument (an object containing socket) that is not pickleable. The socket attribute is used, which results in an error thrown by the PicklePlaceholder object. + Both the original and optimized results should error out in this case, but this should be flagged as incorrect behavior when comparing test results, + since we were not able to reuse the unpickleable object in the replay test. + """ + # Init paths + project_root = Path(__file__).parent.parent.resolve() + tests_root = project_root / "code_to_optimize" / "tests" / "pytest" + benchmarks_root = project_root / "code_to_optimize" / "tests" / "pytest" / "benchmarks_socket_test" + replay_tests_dir = benchmarks_root / "codeflash_replay_tests" + output_file = (benchmarks_root / Path("test_trace_benchmarks.trace")).resolve() + fto_unused_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_unused_socket.py").resolve() + fto_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").resolve() + original_fto_unused_socket_code = fto_unused_socket_path.read_text("utf-8") + original_fto_used_socket_code = fto_used_socket_path.read_text("utf-8") + # Trace benchmarks + trace_benchmarks_pytest(benchmarks_root, tests_root, project_root, output_file) + assert output_file.exists() try: - tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve() - project_root_path = (Path(__file__).parent / "..").resolve() + # Check contents + conn = sqlite3.connect(output_file.as_posix()) + cursor = conn.cursor() + + cursor.execute( + "SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name") + function_calls = cursor.fetchall() + + # Assert the length of function calls + assert len(function_calls) == 2, f"Expected 2 function calls, but got {len(function_calls)}" + function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file) + total_benchmark_timings = codeflash_benchmark_plugin.get_benchmark_timings(output_file) + function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings) + assert "code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket" in function_to_results + + test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + test_name, total_time, function_time, percent = \ + function_to_results["code_to_optimize.bubble_sort_picklepatch_test_unused_socket.bubble_sort_with_unused_socket"][0] + assert total_time > 0.0 + assert function_time > 0.0 + assert percent > 0.0 + + bubble_sort_unused_socket_path = (project_root / "code_to_optimize"/ "bubble_sort_picklepatch_test_unused_socket.py").as_posix() + bubble_sort_used_socket_path = (project_root / "code_to_optimize" / "bubble_sort_picklepatch_test_used_socket.py").as_posix() + # Expected function calls + expected_calls = [ + ("bubble_sort_with_unused_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_unused_socket", + f"{bubble_sort_unused_socket_path}", + "test_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 12), + ("bubble_sort_with_used_socket", "", "code_to_optimize.bubble_sort_picklepatch_test_used_socket", + f"{bubble_sort_used_socket_path}", + "test_used_socket_picklepatch", "code_to_optimize.tests.pytest.benchmarks_socket_test.test_socket", 20) + ] + for idx, (actual, expected) in enumerate(zip(function_calls, expected_calls)): + assert actual[0] == expected[0], f"Mismatch at index {idx} for function_name" + assert actual[1] == expected[1], f"Mismatch at index {idx} for class_name" + assert actual[2] == expected[2], f"Mismatch at index {idx} for module_name" + assert Path(actual[3]).name == Path(expected[3]).name, f"Mismatch at index {idx} for file_path" + assert actual[4] == expected[4], f"Mismatch at index {idx} for benchmark_function_name" + assert actual[5] == expected[5], f"Mismatch at index {idx} for benchmark_module_path" + assert actual[6] == expected[6], f"Mismatch at index {idx} for benchmark_line_number" + conn.close() + + # Generate replay test + generate_replay_test(output_file, replay_tests_dir) + replay_test_path = replay_tests_dir / Path( + "test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0.py") + replay_test_perf_path = replay_tests_dir / Path( + "test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0_perf.py") + assert replay_test_path.exists() + original_replay_test_code = replay_test_path.read_text("utf-8") + + # Instrument the replay test + func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_unused_socket_path)) original_cwd = Path.cwd() - run_cwd = Path(__file__).parent.parent.resolve() - func = FunctionToOptimize(function_name="bubble_sort_with_unused_socket", parents=[], file_path=Path(fto_path)) + run_cwd = project_root os.chdir(run_cwd) success, new_test = inject_profiling_into_existing_test( - test_path, - [CodePosition(13,14), CodePosition(31,14)], + replay_test_path, + [CodePosition(17, 15)], func, - project_root_path, + project_root, "pytest", mode=TestingMode.BEHAVIOR, ) os.chdir(original_cwd) assert success assert new_test is not None - - with test_path.open("w") as f: - f.write(new_test) + replay_test_path.write_text(new_test) opt = Optimizer( Namespace( - project_root=project_root_path, + project_root=project_root, disable_telemetry=True, tests_root=tests_root, test_framework="pytest", pytest_cmd="pytest", experiment_id=None, - test_project_root=project_root_path, + test_project_root=project_root, ) ) + + # Run the replay test for the original code that does not use the socket test_env = os.environ.copy() test_env["CODEFLASH_TEST_ITERATION"] = "0" test_env["CODEFLASH_LOOP_INDEX"] = "1" - test_type = TestType.EXISTING_UNIT_TEST - + test_type = TestType.REPLAY_TEST + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" func_optimizer = opt.create_function_optimizer(func) func_optimizer.test_files = TestFiles( test_files=[ TestFile( - instrumented_behavior_file_path=test_path, + instrumented_behavior_file_path=replay_test_path, test_type=test_type, - original_file_path=test_path, - benchmarking_file_path=test_path_perf, + original_file_path=replay_test_path, + benchmarking_file_path=replay_test_perf_path, + tests_in_file=[TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, test_type=test_type)], ) ] ) - test_results, coverage_data = func_optimizer.run_and_parse_tests( + test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests( testing_type=TestingMode.BEHAVIOR, test_env=test_env, test_files=func_optimizer.test_files, optimization_iteration=0, pytest_min_loops=1, pytest_max_loops=1, - testing_time=0.1, + testing_time=1.0, ) - assert test_results.test_results[0].id.test_function_name =="test_bubble_sort_with_unused_socket" - assert test_results.test_results[0].did_pass ==True - assert test_results.test_results[1].id.test_function_name =="test_bubble_sort_with_used_socket" - assert test_results.test_results[1].did_pass ==False - # assert pickle placeholder problem - print(test_results) + assert len(test_results_unused_socket) == 1 + assert test_results_unused_socket.test_results[0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_unused_socket.test_results[0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_unused_socket_bubble_sort_with_unused_socket" + assert test_results_unused_socket.test_results[0].did_pass == True + + # Replace with optimized candidate + fto_unused_socket_path.write_text(""" +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +@codeflash_trace +def bubble_sort_with_unused_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + return sorted(numbers) +""") + # Run optimized code for unused socket + optimized_test_results_unused_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(optimized_test_results_unused_socket) == 1 + verification_result = compare_test_results(test_results_unused_socket, optimized_test_results_unused_socket) + assert verification_result is True + + # Remove the previous instrumentation + replay_test_path.write_text(original_replay_test_code) + # Instrument the replay test + func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], file_path=Path(fto_used_socket_path)) + success, new_test = inject_profiling_into_existing_test( + replay_test_path, + [CodePosition(23,15)], + func, + project_root, + "pytest", + mode=TestingMode.BEHAVIOR, + ) + os.chdir(original_cwd) + assert success + assert new_test is not None + replay_test_path.write_text(new_test) + + # Run test for original function code that uses the socket. This should fail, as the PicklePlaceholder is accessed. + test_env = os.environ.copy() + test_env["CODEFLASH_TEST_ITERATION"] = "0" + test_env["CODEFLASH_LOOP_INDEX"] = "1" + test_type = TestType.REPLAY_TEST + func = FunctionToOptimize(function_name="bubble_sort_with_used_socket", parents=[], + file_path=Path(fto_used_socket_path)) + replay_test_function = "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + func_optimizer = opt.create_function_optimizer(func) + func_optimizer.test_files = TestFiles( + test_files=[ + TestFile( + instrumented_behavior_file_path=replay_test_path, + test_type=test_type, + original_file_path=replay_test_path, + benchmarking_file_path=replay_test_perf_path, + tests_in_file=[ + TestsInFile(test_file=replay_test_path, test_class=None, test_function=replay_test_function, + test_type=test_type)], + ) + ] + ) + test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(test_results_used_socket) == 1 + assert test_results_used_socket.test_results[ + 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_used_socket.test_results[ + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + assert test_results_used_socket.test_results[0].did_pass is False + print("test results used socket") + print(test_results_used_socket) + # Replace with optimized candidate + fto_used_socket_path.write_text(""" +from codeflash.benchmarking.codeflash_trace import codeflash_trace + +@codeflash_trace +def bubble_sort_with_used_socket(data_container): + # Extract the list to sort, leaving the socket untouched + numbers = data_container.get('numbers', []).copy() + socket = data_container.get('socket') + socket.send("Hello from the optimized function!") + return sorted(numbers) + """) + + # Run test for optimized function code that uses the socket. This should fail, as the PicklePlaceholder is accessed. + optimized_test_results_used_socket, coverage_data = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.BEHAVIOR, + test_env=test_env, + test_files=func_optimizer.test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=1.0, + ) + assert len(test_results_used_socket) == 1 + assert test_results_used_socket.test_results[ + 0].id.test_module_path == "code_to_optimize.tests.pytest.benchmarks_socket_test.codeflash_replay_tests.test_code_to_optimize_tests_pytest_benchmarks_socket_test_test_socket__replay_test_0" + assert test_results_used_socket.test_results[ + 0].id.test_function_name == "test_code_to_optimize_bubble_sort_picklepatch_test_used_socket_bubble_sort_with_used_socket" + assert test_results_used_socket.test_results[0].did_pass is False + + # Even though tests threw the same error, we reject this as the behavior of the unpickleable object could not be determined. + assert compare_test_results(test_results_used_socket, optimized_test_results_used_socket) is False + finally: - test_path.write_text(original_test) \ No newline at end of file + # cleanup + output_file.unlink(missing_ok=True) + shutil.rmtree(replay_tests_dir, ignore_errors=True) + fto_unused_socket_path.write_text(original_fto_unused_socket_code) + fto_used_socket_path.write_text(original_fto_used_socket_code) + diff --git a/tests/test_trace_benchmarks.py b/tests/test_trace_benchmarks.py index 715955063..af9a1e3f3 100644 --- a/tests/test_trace_benchmarks.py +++ b/tests/test_trace_benchmarks.py @@ -1,15 +1,14 @@ +import shutil import sqlite3 - -from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin -from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest -from codeflash.benchmarking.replay_test import generate_replay_test from pathlib import Path -from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table -import shutil +from codeflash.benchmarking.plugin.plugin import codeflash_benchmark_plugin +from codeflash.benchmarking.replay_test import generate_replay_test +from codeflash.benchmarking.trace_benchmarks import trace_benchmarks_pytest +from codeflash.benchmarking.utils import validate_and_format_benchmark_table -def test_trace_benchmarks(): +def test_trace_benchmarks() -> None: # Test the trace_benchmarks function project_root = Path(__file__).parent.parent / "code_to_optimize" benchmarks_root = project_root / "tests" / "pytest" / "benchmarks_test" @@ -83,13 +82,12 @@ def test_trace_benchmarks(): test_class_sort_path = replay_tests_dir/ Path("test_tests_pytest_benchmarks_test_test_benchmark_bubble_sort_example__replay_test_0.py") assert test_class_sort_path.exists() test_class_sort_code = f""" -import dill as pickle - from code_to_optimize.bubble_sort_codeflash_trace import \\ Sorter as code_to_optimize_bubble_sort_codeflash_trace_Sorter from code_to_optimize.bubble_sort_codeflash_trace import \\ sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter from codeflash.benchmarking.replay_test import get_next_arg_and_return +from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle functions = ['sort_class', 'sort_static', 'sorter'] trace_file_path = r"{output_file.as_posix()}" @@ -146,14 +144,13 @@ def test_code_to_optimize_bubble_sort_codeflash_trace_Sorter___init__(): test_sort_path = replay_tests_dir / Path("test_tests_pytest_benchmarks_test_test_process_and_sort_example__replay_test_0.py") assert test_sort_path.exists() test_sort_code = f""" -import dill as pickle - from code_to_optimize.bubble_sort_codeflash_trace import \\ sorter as code_to_optimize_bubble_sort_codeflash_trace_sorter from code_to_optimize.process_and_bubble_sort_codeflash_trace import \\ compute_and_sort as \\ code_to_optimize_process_and_bubble_sort_codeflash_trace_compute_and_sort from codeflash.benchmarking.replay_test import get_next_arg_and_return +from codeflash.picklepatch.pickle_patcher import PicklePatcher as pickle functions = ['compute_and_sort', 'sorter'] trace_file_path = r"{output_file}" @@ -284,4 +281,4 @@ def test_trace_benchmark_decorator() -> None: finally: # cleanup - output_file.unlink(missing_ok=True) \ No newline at end of file + output_file.unlink(missing_ok=True)