codeflash/codeflash/verification/pytest_plugin.py
2026-01-29 12:07:47 -05:00

587 lines
20 KiB
Python

from __future__ import annotations
import contextlib
import inspect
# System Imports
import logging
import os
import platform
import re
import sys
import time as _time_module
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional
from unittest import TestCase
# PyTest Imports
import pytest
from pluggy import HookspecMarker
from codeflash.code_utils.config_consts import (
STABILITY_CENTER_TOLERANCE,
STABILITY_SPREAD_TOLERANCE,
STABILITY_WINDOW_SIZE,
)
if TYPE_CHECKING:
from _pytest.config import Config, Parser
from _pytest.main import Session
from _pytest.python import Metafunc
SECONDS_IN_HOUR: float = 3600
SECONDS_IN_MINUTE: float = 60
SHORTEST_AMOUNT_OF_TIME: float = 0
hookspec = HookspecMarker("pytest")
class InvalidTimeParameterError(Exception):
pass
class UnexpectedError(Exception):
pass
if platform.system() == "Linux":
import resource
# We set the memory limit to 85% of total system memory + swap when swap exists
swap_file_path = Path("/proc/swaps")
swap_exists = swap_file_path.is_file()
swap_size = 0
if swap_exists:
with swap_file_path.open("r") as f:
swap_lines = f.readlines()
swap_exists = len(swap_lines) > 1 # First line is header
if swap_exists:
# Parse swap size from lines after header
for line in swap_lines[1:]:
parts = line.split()
if len(parts) >= 3:
# Swap size is in KB in the 3rd column
with contextlib.suppress(ValueError, IndexError):
swap_size += int(parts[2]) * 1024 # Convert KB to bytes
# Get total system memory
total_memory = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
# Add swap to total available memory if swap exists
if swap_exists:
total_memory += swap_size
# Set the memory limit to 85% of total memory (RAM plus swap)
memory_limit = int(total_memory * 0.85)
# Set both soft and hard limits
resource.setrlimit(resource.RLIMIT_AS, (memory_limit, memory_limit))
# Store references to original functions before any patching
_ORIGINAL_TIME_TIME = _time_module.time
_ORIGINAL_PERF_COUNTER = _time_module.perf_counter
_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns
_ORIGINAL_TIME_SLEEP = _time_module.sleep
# Apply deterministic patches for reproducible test execution
def _apply_deterministic_patches() -> None:
"""Apply patches to make all sources of randomness deterministic."""
import datetime
import random
import time
import uuid
# Store original functions (these are already saved globally above)
_original_time = time.time
_original_perf_counter = time.perf_counter
_original_datetime_now = datetime.datetime.now
_original_datetime_utcnow = datetime.datetime.utcnow
_original_uuid4 = uuid.uuid4
_original_uuid1 = uuid.uuid1
_original_random = random.random
# Fixed deterministic values
fixed_timestamp = 1761717605.108106
fixed_datetime = datetime.datetime(2021, 1, 1, 2, 5, 10, tzinfo=datetime.timezone.utc)
fixed_uuid = uuid.UUID("12345678-1234-5678-9abc-123456789012")
# Counter for perf_counter to maintain relative timing
_perf_counter_start = fixed_timestamp
_perf_counter_calls = 0
def mock_time_time() -> float:
"""Return fixed timestamp while preserving performance characteristics."""
_original_time() # Maintain performance characteristics
return fixed_timestamp
def mock_perf_counter() -> float:
"""Return incrementing counter for relative timing."""
nonlocal _perf_counter_calls
_original_perf_counter() # Maintain performance characteristics
_perf_counter_calls += 1
return _perf_counter_start + (_perf_counter_calls * 0.001) # Increment by 1ms each call
def mock_datetime_now(tz: datetime.timezone | None = None) -> datetime.datetime:
"""Return fixed datetime while preserving performance characteristics."""
_original_datetime_now(tz) # Maintain performance characteristics
if tz is None:
return fixed_datetime
return fixed_datetime.replace(tzinfo=tz)
def mock_datetime_utcnow() -> datetime.datetime:
"""Return fixed UTC datetime while preserving performance characteristics."""
_original_datetime_utcnow() # Maintain performance characteristics
return fixed_datetime
def mock_uuid4() -> uuid.UUID:
"""Return fixed UUID4 while preserving performance characteristics."""
_original_uuid4() # Maintain performance characteristics
return fixed_uuid
def mock_uuid1(node: int | None = None, clock_seq: int | None = None) -> uuid.UUID:
"""Return fixed UUID1 while preserving performance characteristics."""
_original_uuid1(node, clock_seq) # Maintain performance characteristics
return fixed_uuid
def mock_random() -> float:
"""Return deterministic random value while preserving performance characteristics."""
_original_random() # Maintain performance characteristics
return 0.123456789 # Fixed random value
# Apply patches
time.time = mock_time_time
time.perf_counter = mock_perf_counter
uuid.uuid4 = mock_uuid4
uuid.uuid1 = mock_uuid1
# Seed random module for other random functions
random.seed(42)
random.random = mock_random
# For datetime, we need to use a different approach since we can't patch class methods
# Store original methods for potential later use
import builtins
builtins._original_datetime_now = _original_datetime_now # noqa: SLF001
builtins._original_datetime_utcnow = _original_datetime_utcnow # noqa: SLF001
builtins._mock_datetime_now = mock_datetime_now # noqa: SLF001
builtins._mock_datetime_utcnow = mock_datetime_utcnow # noqa: SLF001
# Patch numpy.random if available
try:
import numpy as np
# Use modern numpy random generator approach
np.random.default_rng(42)
np.random.seed(42) # Keep legacy seed for compatibility # noqa: NPY002
except ImportError:
pass
# Patch os.urandom if needed
try:
import os
_original_urandom = os.urandom
def mock_urandom(n: int) -> bytes:
_original_urandom(n) # Maintain performance characteristics
return b"\x42" * n # Fixed bytes
os.urandom = mock_urandom
except (ImportError, AttributeError):
pass
# Note: Deterministic patches are applied conditionally, not globally
# They should only be applied when running CodeFlash optimization tests
def pytest_addoption(parser: Parser) -> None:
"""Add command line options."""
pytest_loops = parser.getgroup("loops")
pytest_loops.addoption(
"--codeflash_delay",
action="store",
default=0,
type=float,
help="The amount of time to wait between each test loop.",
)
pytest_loops.addoption(
"--codeflash_hours", action="store", default=0, type=float, help="The number of hours to loop the tests for."
)
pytest_loops.addoption(
"--codeflash_minutes",
action="store",
default=0,
type=float,
help="The number of minutes to loop the tests for.",
)
pytest_loops.addoption(
"--codeflash_seconds",
action="store",
default=0,
type=float,
help="The number of seconds to loop the tests for.",
)
pytest_loops.addoption(
"--codeflash_loops", action="store", default=1, type=int, help="The number of times to loop each test"
)
pytest_loops.addoption(
"--codeflash_min_loops",
action="store",
default=1,
type=int,
help="The minimum number of times to loop each test",
)
pytest_loops.addoption(
"--codeflash_max_loops",
action="store",
default=100_000,
type=int,
help="The maximum number of times to loop each test",
)
pytest_loops.addoption(
"--codeflash_loops_scope",
action="store",
default="function",
type=str,
choices=("function", "class", "module", "session"),
help="Scope for looping tests",
)
pytest_loops.addoption(
"--codeflash_stability_check",
action="store",
default="false",
type=str,
choices=("true", "false"),
help="Enable stability checks for the loops",
)
@pytest.hookimpl(trylast=True)
def pytest_configure(config: Config) -> None:
config.addinivalue_line("markers", "loops(n): run the given test function `n` times.")
config.pluginmanager.register(PytestLoops(config), PytestLoops.name)
# Apply deterministic patches when the plugin is configured
_apply_deterministic_patches()
def get_runtime_from_stdout(stdout: str) -> Optional[int]:
marker_start = "!######"
marker_end = "######!"
if not stdout:
return None
end = stdout.rfind(marker_end)
if end == -1:
return None
start = stdout.rfind(marker_start, 0, end)
if start == -1:
return None
payload = stdout[start + len(marker_start) : end]
last_colon = payload.rfind(":")
if last_colon == -1:
return None
try:
return int(payload[last_colon + 1 :])
except ValueError:
return None
_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$")
def should_stop(
runtimes: list[int],
window: int,
min_window_size: int,
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
) -> bool:
if len(runtimes) < window:
return False
if len(runtimes) < min_window_size:
return False
recent = runtimes[-window:]
# Use sorted array for faster median and min/max operations
recent_sorted = sorted(recent)
mid = window // 2
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2
# 1) All recent points close to the median
centered = True
for r in recent:
if abs(r - m) / m > center_rel_tol:
centered = False
break
# 2) Window spread is small
r_min, r_max = recent_sorted[0], recent_sorted[-1]
if r_min == 0:
return False
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
return centered and spread_ok
class PytestLoops:
name: str = "pytest-loops"
def __init__(self, config: Config) -> None:
# Turn debug prints on only if "-vv" or more passed
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
logging.basicConfig(level=level)
self.logger = logging.getLogger(self.name)
self.runtime_data_by_test_case: dict[str, list[int]] = {}
self.enable_stability_check: bool = (
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
)
@pytest.hookimpl
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
if not self.enable_stability_check:
return
if report.when == "call" and report.passed:
duration_ns = get_runtime_from_stdout(report.capstdout)
if duration_ns:
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
@hookspec(firstresult=True)
def pytest_runtestloop(self, session: Session) -> bool:
"""Reimplement the test loop but loop for the user defined amount of time."""
if session.testsfailed and not session.config.option.continue_on_collection_errors:
msg = "{} error{} during collection".format(session.testsfailed, "s" if session.testsfailed != 1 else "")
raise session.Interrupted(msg)
if session.config.option.collectonly:
return True
start_time: float = _ORIGINAL_TIME_TIME()
total_time: float = self._get_total_time(session)
count: int = 0
runtimes = []
elapsed_ns = 0
while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests
count += 1
loop_start = _ORIGINAL_PERF_COUNTER_NS()
for index, item in enumerate(session.items):
item: pytest.Item = item # noqa: PLW0127, PLW2901
item._report_sections.clear() # clear reports for new test # noqa: SLF001
if total_time > SHORTEST_AMOUNT_OF_TIME:
item._nodeid = self._set_nodeid(item._nodeid, count) # noqa: SLF001
next_item: pytest.Item = session.items[index + 1] if index + 1 < len(session.items) else None
self._clear_lru_caches(item)
item.config.hook.pytest_runtest_protocol(item=item, nextitem=next_item)
if session.shouldfail:
raise session.Failed(session.shouldfail)
if session.shouldstop:
raise session.Interrupted(session.shouldstop)
if self.enable_stability_check:
elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start
best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()])
if best_runtime_until_now > 0:
runtimes.append(best_runtime_until_now)
estimated_total_loops = 0
if elapsed_ns > 0:
rate = count / elapsed_ns
total_time_ns = total_time * 1e9
estimated_total_loops = int(rate * total_time_ns)
window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5)
if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops):
break
if self._timed_out(session, start_time, count):
break
_ORIGINAL_TIME_SLEEP(self._get_delay_time(session))
return True
def _clear_lru_caches(self, item: pytest.Item) -> None:
processed_functions: set[Callable] = set()
protected_modules = {
"gc",
"inspect",
"os",
"sys",
"time",
"functools",
"pathlib",
"typing",
"dill",
"pytest",
"importlib",
}
def _clear_cache_for_object(obj: obj) -> None:
if obj in processed_functions:
return
processed_functions.add(obj)
if hasattr(obj, "__wrapped__"):
module_name = obj.__wrapped__.__module__
else:
try:
obj_module = inspect.getmodule(obj)
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
except Exception:
module_name = None
if module_name in protected_modules:
return
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
with contextlib.suppress(Exception):
obj.cache_clear()
_clear_cache_for_object(item.function) # type: ignore[attr-defined]
try:
if hasattr(item.function, "__module__"): # type: ignore[attr-defined]
module_name = item.function.__module__ # type: ignore[attr-defined]
try:
module = sys.modules.get(module_name)
if module:
for _, obj in inspect.getmembers(module):
if callable(obj):
_clear_cache_for_object(obj)
except Exception:
pass
except Exception:
pass
def _set_nodeid(self, nodeid: str, count: int) -> str:
"""Set loop count when using duration.
:param nodeid: Name of test function.
:param count: Current loop count.
:return: Formatted string for test name.
"""
pattern = r"\[ \d+ \]"
run_str = f"[ {count} ]"
os.environ["CODEFLASH_LOOP_INDEX"] = str(count)
return re.sub(pattern, run_str, nodeid) if re.search(pattern, nodeid) else nodeid + run_str
def _get_delay_time(self, session: Session) -> float:
"""Extract delay time from session.
:param session: Pytest session object.
:return: Returns the delay time for each test loop.
"""
return session.config.option.codeflash_delay
def _get_total_time(self, session: Session) -> float:
"""Take all the user available time options, add them and return it in seconds.
:param session: Pytest session object.
:return: Returns total amount of time in seconds.
"""
hours_in_seconds = session.config.option.codeflash_hours * SECONDS_IN_HOUR
minutes_in_seconds = session.config.option.codeflash_minutes * SECONDS_IN_MINUTE
seconds = session.config.option.codeflash_seconds
total_time = hours_in_seconds + minutes_in_seconds + seconds
if total_time < SHORTEST_AMOUNT_OF_TIME:
msg = f"Total time cannot be less than: {SHORTEST_AMOUNT_OF_TIME}!"
raise InvalidTimeParameterError(msg)
return total_time
def _timed_out(self, session: Session, start_time: float, count: int) -> bool:
"""Check if the user specified amount of time has lapsed.
:param session: Pytest session object.
:return: Returns True if the timeout has expired, False otherwise.
"""
return count >= session.config.option.codeflash_max_loops or (
count >= session.config.option.codeflash_min_loops
and _ORIGINAL_TIME_TIME() - start_time > self._get_total_time(session)
)
@pytest.fixture
def __pytest_loop_step_number(self, request: pytest.FixtureRequest) -> int:
"""Set step number for loop.
:param request: The number to print.
:return: request.param.
"""
marker = request.node.get_closest_marker("loops")
count = (marker and marker.args[0]) or request.config.option.codeflash_loops
if count > 1:
try:
return request.param
except AttributeError:
if issubclass(request.cls, TestCase):
warnings.warn("Repeating unittest class tests not supported", stacklevel=2)
else:
msg = "This call couldn't work with pytest-loops. Please consider raising an issue with your usage."
raise UnexpectedError(msg) from None
return count
@pytest.hookimpl(trylast=True)
def pytest_generate_tests(self, metafunc: Metafunc) -> None:
"""Create tests based on loop value.
:param metafunc: pytest metafunction
:return: None.
"""
count = metafunc.config.option.codeflash_loops
m = metafunc.definition.get_closest_marker("loops")
if m is not None:
count = int(m.args[0])
if count > 1:
metafunc.fixturenames.append("__pytest_loop_step_number")
def make_progress_id(i: int, n: int = count) -> str:
return f"{n}/{i + 1}"
scope = metafunc.config.option.codeflash_loops_scope
metafunc.parametrize(
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
)
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_setup(self, item: pytest.Item) -> None:
"""Set test context environment variables before each test."""
test_module_name = item.module.__name__ if item.module else "unknown_module"
test_class_name = None
if item.cls:
test_class_name = item.cls.__name__
test_function_name = item.name
if "[" in test_function_name:
test_function_name = test_function_name.split("[", 1)[0]
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
@pytest.hookimpl(trylast=True)
def pytest_runtest_teardown(self, item: pytest.Item) -> None:
"""Clean up test context environment variables after each test."""
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
os.environ.pop(var, None)