perf: cache module scan in _clear_lru_caches and expand test coverage
Cache inspect.getmembers() results per module so repeated loop iterations skip the expensive rescan. Add tests for get_runtime_from_stdout, should_stop, _set_nodeid, _get_total_time, _timed_out, logreport, and setup/teardown hooks.
This commit is contained in:
parent
5a37f9a3ca
commit
1689a7bbb5
2 changed files with 451 additions and 64 deletions
|
|
@ -367,6 +367,7 @@ class PytestLoops:
|
|||
self.enable_stability_check: bool = (
|
||||
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
|
||||
)
|
||||
self._module_clearables: dict[str, list[Callable]] = {}
|
||||
|
||||
@pytest.hookimpl
|
||||
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
|
||||
|
|
@ -438,44 +439,55 @@ class PytestLoops:
|
|||
return True
|
||||
|
||||
def _clear_lru_caches(self, item: pytest.Item) -> None:
|
||||
processed_functions: set[Callable] = set()
|
||||
func = item.function # type: ignore[attr-defined]
|
||||
|
||||
def _clear_cache_for_object(obj: obj) -> None:
|
||||
if obj in processed_functions:
|
||||
return
|
||||
processed_functions.add(obj)
|
||||
# Always clear the test function itself
|
||||
if hasattr(func, "cache_clear") and callable(func.cache_clear):
|
||||
with contextlib.suppress(Exception):
|
||||
func.cache_clear()
|
||||
|
||||
module_name = getattr(func, "__module__", None)
|
||||
if not module_name:
|
||||
return
|
||||
|
||||
try:
|
||||
clearables = self._module_clearables.get(module_name)
|
||||
if clearables is None:
|
||||
clearables = self._scan_module_clearables(module_name)
|
||||
self._module_clearables[module_name] = clearables
|
||||
|
||||
for obj in clearables:
|
||||
with contextlib.suppress(Exception):
|
||||
obj.cache_clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _scan_module_clearables(self, module_name: str) -> list[Callable]:
|
||||
module = sys.modules.get(module_name)
|
||||
if not module:
|
||||
return []
|
||||
|
||||
clearables: list[Callable] = []
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if not callable(obj):
|
||||
continue
|
||||
|
||||
if hasattr(obj, "__wrapped__"):
|
||||
module_name = obj.__wrapped__.__module__
|
||||
top_module = obj.__wrapped__.__module__
|
||||
else:
|
||||
try:
|
||||
obj_module = inspect.getmodule(obj)
|
||||
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
|
||||
top_module = obj_module.__name__.split(".")[0] if obj_module is not None else None
|
||||
except Exception:
|
||||
module_name = None
|
||||
top_module = None
|
||||
|
||||
if module_name in _PROTECTED_MODULES:
|
||||
return
|
||||
if top_module in _PROTECTED_MODULES:
|
||||
continue
|
||||
|
||||
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
|
||||
with contextlib.suppress(Exception):
|
||||
obj.cache_clear()
|
||||
clearables.append(obj)
|
||||
|
||||
_clear_cache_for_object(item.function) # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
if hasattr(item.function, "__module__"): # type: ignore[attr-defined]
|
||||
module_name = item.function.__module__ # type: ignore[attr-defined]
|
||||
try:
|
||||
module = sys.modules.get(module_name)
|
||||
if module:
|
||||
for _, obj in inspect.getmembers(module):
|
||||
if callable(obj):
|
||||
_clear_cache_for_object(obj)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return clearables
|
||||
|
||||
def _set_nodeid(self, nodeid: str, count: int) -> str:
|
||||
"""Set loop count when using duration.
|
||||
|
|
|
|||
|
|
@ -1,10 +1,18 @@
|
|||
import os
|
||||
import sys
|
||||
import types
|
||||
from typing import NoReturn
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from _pytest.config import Config
|
||||
|
||||
from codeflash.verification.pytest_plugin import PytestLoops
|
||||
from codeflash.verification.pytest_plugin import (
|
||||
InvalidTimeParameterError,
|
||||
PytestLoops,
|
||||
get_runtime_from_stdout,
|
||||
should_stop,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -15,39 +23,301 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
|
|||
@pytest.fixture
|
||||
def mock_item() -> type:
|
||||
class MockItem:
|
||||
def __init__(self, function: types.FunctionType) -> None:
|
||||
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
|
||||
self.function = function
|
||||
self.name = name
|
||||
self.cls = cls
|
||||
self.module = module
|
||||
|
||||
return MockItem
|
||||
|
||||
|
||||
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
|
||||
def create_mock_module(module_name: str, source_code: str, register: bool = False) -> types.ModuleType:
|
||||
module = types.ModuleType(module_name)
|
||||
exec(source_code, module.__dict__) # noqa: S102
|
||||
if register:
|
||||
sys.modules[module_name] = module
|
||||
return module
|
||||
|
||||
|
||||
def test_clear_lru_caches_function(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
def mock_session(**kwargs):
|
||||
"""Create a mock session with config options."""
|
||||
defaults = {
|
||||
"codeflash_hours": 0,
|
||||
"codeflash_minutes": 0,
|
||||
"codeflash_seconds": 10,
|
||||
"codeflash_delay": 0.0,
|
||||
"codeflash_loops": 1,
|
||||
"codeflash_min_loops": 1,
|
||||
"codeflash_max_loops": 100_000,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
|
||||
class Option:
|
||||
pass
|
||||
|
||||
option = Option()
|
||||
for k, v in defaults.items():
|
||||
setattr(option, k, v)
|
||||
|
||||
class MockConfig:
|
||||
pass
|
||||
|
||||
config = MockConfig()
|
||||
config.option = option
|
||||
|
||||
class MockSession:
|
||||
pass
|
||||
|
||||
session = MockSession()
|
||||
session.config = config
|
||||
return session
|
||||
|
||||
|
||||
# --- get_runtime_from_stdout ---
|
||||
|
||||
|
||||
class TestGetRuntimeFromStdout:
|
||||
def test_valid_payload(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test_func:12345######!") == 12345
|
||||
|
||||
def test_valid_payload_with_surrounding_text(self) -> None:
|
||||
assert get_runtime_from_stdout("some output\n!######mod.func:99999######!\nmore output") == 99999
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
assert get_runtime_from_stdout("") is None
|
||||
|
||||
def test_no_markers(self) -> None:
|
||||
assert get_runtime_from_stdout("just some output") is None
|
||||
|
||||
def test_missing_end_marker(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test:123") is None
|
||||
|
||||
def test_missing_start_marker(self) -> None:
|
||||
assert get_runtime_from_stdout("test:123######!") is None
|
||||
|
||||
def test_no_colon_in_payload(self) -> None:
|
||||
assert get_runtime_from_stdout("!######nocolon######!") is None
|
||||
|
||||
def test_non_integer_value(self) -> None:
|
||||
assert get_runtime_from_stdout("!######test:notanumber######!") is None
|
||||
|
||||
def test_multiple_markers_uses_last(self) -> None:
|
||||
stdout = "!######first:111######! middle !######second:222######!"
|
||||
assert get_runtime_from_stdout(stdout) == 222
|
||||
|
||||
|
||||
# --- should_stop ---
|
||||
|
||||
|
||||
class TestShouldStop:
|
||||
def test_not_enough_data_for_window(self) -> None:
|
||||
assert should_stop([100, 100], window=5, min_window_size=3) is False
|
||||
|
||||
def test_below_min_window_size(self) -> None:
|
||||
assert should_stop([100, 100], window=2, min_window_size=5) is False
|
||||
|
||||
def test_stable_runtimes_stops(self) -> None:
|
||||
runtimes = [1000000] * 10
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is True
|
||||
|
||||
def test_unstable_runtimes_continues(self) -> None:
|
||||
runtimes = [100, 200, 100, 200, 100]
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is False
|
||||
|
||||
def test_zero_runtimes_raises(self) -> None:
|
||||
# All-zero runtimes cause ZeroDivisionError in median check.
|
||||
# In practice the caller guards with best_runtime_until_now > 0.
|
||||
runtimes = [0, 0, 0, 0, 0]
|
||||
with pytest.raises(ZeroDivisionError):
|
||||
should_stop(runtimes, window=5, min_window_size=3)
|
||||
|
||||
def test_even_window_median(self) -> None:
|
||||
# Even window: median is average of two middle values
|
||||
runtimes = [1000, 1000, 1001, 1001]
|
||||
assert should_stop(runtimes, window=4, min_window_size=2, center_rel_tol=0.01, spread_rel_tol=0.01) is True
|
||||
|
||||
def test_centered_but_spread_too_large(self) -> None:
|
||||
# All close to median but spread exceeds tolerance
|
||||
runtimes = [1000, 1050, 1000, 1050, 1000]
|
||||
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.1, spread_rel_tol=0.001) is False
|
||||
|
||||
|
||||
# --- _set_nodeid ---
|
||||
|
||||
|
||||
class TestSetNodeid:
|
||||
def test_appends_count_to_plain_nodeid(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
result = pytest_loops_instance._set_nodeid("test_module.py::test_func", 3) # noqa: SLF001
|
||||
assert result == "test_module.py::test_func[ 3 ]"
|
||||
assert os.environ["CODEFLASH_LOOP_INDEX"] == "3"
|
||||
|
||||
def test_replaces_existing_count(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
result = pytest_loops_instance._set_nodeid("test_module.py::test_func[ 1 ]", 5) # noqa: SLF001
|
||||
assert result == "test_module.py::test_func[ 5 ]"
|
||||
|
||||
def test_replaces_only_loop_pattern(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
# Parametrize brackets like [param0] should not be replaced
|
||||
result = pytest_loops_instance._set_nodeid("test_mod.py::test_func[param0]", 2) # noqa: SLF001
|
||||
assert result == "test_mod.py::test_func[param0][ 2 ]"
|
||||
|
||||
|
||||
# --- _get_total_time ---
|
||||
|
||||
|
||||
class TestGetTotalTime:
|
||||
def test_seconds_only(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_seconds=30)
|
||||
assert pytest_loops_instance._get_total_time(session) == 30 # noqa: SLF001
|
||||
|
||||
def test_mixed_units(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=1, codeflash_minutes=30, codeflash_seconds=45)
|
||||
assert pytest_loops_instance._get_total_time(session) == 3600 + 1800 + 45 # noqa: SLF001
|
||||
|
||||
def test_zero_time_is_valid(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=0)
|
||||
assert pytest_loops_instance._get_total_time(session) == 0 # noqa: SLF001
|
||||
|
||||
def test_negative_time_raises(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=-1)
|
||||
with pytest.raises(InvalidTimeParameterError):
|
||||
pytest_loops_instance._get_total_time(session) # noqa: SLF001
|
||||
|
||||
|
||||
# --- _timed_out ---
|
||||
|
||||
|
||||
class TestTimedOut:
|
||||
def test_exceeds_max_loops(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=10, codeflash_min_loops=1, codeflash_seconds=9999)
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=10) is True # noqa: SLF001
|
||||
|
||||
def test_below_min_loops_never_times_out(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=50, codeflash_seconds=0)
|
||||
# Even with 0 seconds budget, count < min_loops means not timed out
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=5) is False # noqa: SLF001
|
||||
|
||||
def test_above_min_loops_and_time_exceeded(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=1, codeflash_seconds=1)
|
||||
# start_time far in the past → time exceeded
|
||||
assert pytest_loops_instance._timed_out(session, start_time=0, count=2) is True # noqa: SLF001
|
||||
|
||||
|
||||
# --- _get_delay_time ---
|
||||
|
||||
|
||||
class TestGetDelayTime:
|
||||
def test_returns_configured_delay(self, pytest_loops_instance: PytestLoops) -> None:
|
||||
session = mock_session(codeflash_delay=0.5)
|
||||
assert pytest_loops_instance._get_delay_time(session) == 0.5 # noqa: SLF001
|
||||
|
||||
|
||||
# --- pytest_runtest_logreport ---
|
||||
|
||||
|
||||
class TestRunTestLogReport:
|
||||
def test_skipped_when_stability_check_disabled(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = False
|
||||
|
||||
class MockReport:
|
||||
when = "call"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert instance.runtime_data_by_test_case == {}
|
||||
|
||||
def test_records_runtime_on_passed_call(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = True
|
||||
|
||||
class MockReport:
|
||||
when = "call"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func [ 1 ]"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert "test::func" in instance.runtime_data_by_test_case
|
||||
assert instance.runtime_data_by_test_case["test::func"] == [12345]
|
||||
|
||||
def test_ignores_non_call_phase(self, pytestconfig: Config) -> None:
|
||||
instance = PytestLoops(pytestconfig)
|
||||
instance.enable_stability_check = True
|
||||
|
||||
class MockReport:
|
||||
when = "setup"
|
||||
passed = True
|
||||
capstdout = "!######func:12345######!"
|
||||
nodeid = "test::func"
|
||||
|
||||
instance.pytest_runtest_logreport(MockReport())
|
||||
assert instance.runtime_data_by_test_case == {}
|
||||
|
||||
|
||||
# --- pytest_runtest_setup / teardown ---
|
||||
|
||||
|
||||
class TestRunTestSetupTeardown:
|
||||
def test_setup_sets_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module = types.ModuleType("my_test_module")
|
||||
|
||||
class MyTestClass:
|
||||
pass
|
||||
|
||||
item = mock_item(lambda: None, name="test_something[param1]", cls=MyTestClass, module=module)
|
||||
pytest_loops_instance.pytest_runtest_setup(item)
|
||||
|
||||
assert os.environ["CODEFLASH_TEST_MODULE"] == "my_test_module"
|
||||
assert os.environ["CODEFLASH_TEST_CLASS"] == "MyTestClass"
|
||||
assert os.environ["CODEFLASH_TEST_FUNCTION"] == "test_something"
|
||||
|
||||
def test_setup_no_class(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module = types.ModuleType("my_test_module")
|
||||
item = mock_item(lambda: None, name="test_plain", cls=None, module=module)
|
||||
pytest_loops_instance.pytest_runtest_setup(item)
|
||||
|
||||
assert os.environ["CODEFLASH_TEST_CLASS"] == ""
|
||||
|
||||
def test_teardown_clears_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
os.environ["CODEFLASH_TEST_MODULE"] = "leftover"
|
||||
os.environ["CODEFLASH_TEST_CLASS"] = "leftover"
|
||||
os.environ["CODEFLASH_TEST_FUNCTION"] = "leftover"
|
||||
|
||||
item = mock_item(lambda: None)
|
||||
pytest_loops_instance.pytest_runtest_teardown(item)
|
||||
|
||||
assert "CODEFLASH_TEST_MODULE" not in os.environ
|
||||
assert "CODEFLASH_TEST_CLASS" not in os.environ
|
||||
assert "CODEFLASH_TEST_FUNCTION" not in os.environ
|
||||
|
||||
|
||||
# --- _clear_lru_caches ---
|
||||
|
||||
|
||||
class TestClearLruCaches:
|
||||
def test_clears_lru_cached_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def my_func(x):
|
||||
return x * 2
|
||||
|
||||
my_func(10) # miss the cache
|
||||
my_func(10) # hit the cache
|
||||
my_func(10)
|
||||
my_func(10)
|
||||
"""
|
||||
mock_module = create_mock_module("test_module_func", source_code)
|
||||
item = mock_item(mock_module.my_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.my_func.cache_info().hits == 0
|
||||
assert mock_module.my_func.cache_info().misses == 0
|
||||
assert mock_module.my_func.cache_info().currsize == 0
|
||||
mock_module = create_mock_module("test_module_func", source_code)
|
||||
item = mock_item(mock_module.my_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.my_func.cache_info().hits == 0
|
||||
assert mock_module.my_func.cache_info().misses == 0
|
||||
assert mock_module.my_func.cache_info().currsize == 0
|
||||
|
||||
|
||||
def test_clear_lru_caches_class_method(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
def test_clears_class_method_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
class MyClass:
|
||||
|
|
@ -56,32 +326,137 @@ class MyClass:
|
|||
return x * 3
|
||||
|
||||
obj = MyClass()
|
||||
obj.my_method(5) # Pre-populate the cache
|
||||
obj.my_method(5) # Hit the cache
|
||||
obj.my_method(5)
|
||||
obj.my_method(5)
|
||||
# """
|
||||
mock_module = create_mock_module("test_module_class", source_code)
|
||||
item = mock_item(mock_module.MyClass.my_method)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.MyClass.my_method.cache_info().hits == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().misses == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().currsize == 0
|
||||
mock_module = create_mock_module("test_module_class", source_code)
|
||||
item = mock_item(mock_module.MyClass.my_method)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.MyClass.my_method.cache_info().hits == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().misses == 0
|
||||
assert mock_module.MyClass.my_method.cache_info().currsize == 0
|
||||
|
||||
def test_handles_exception_in_cache_clear(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
class BrokenCache:
|
||||
def cache_clear(self) -> NoReturn:
|
||||
msg = "Cache clearing failed!"
|
||||
raise ValueError(msg)
|
||||
|
||||
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
"""Test that exceptions during clearing are handled."""
|
||||
item = mock_item(BrokenCache())
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
class BrokenCache:
|
||||
def cache_clear(self) -> NoReturn:
|
||||
msg = "Cache clearing failed!"
|
||||
raise ValueError(msg)
|
||||
def test_handles_no_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def no_cache_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
item = mock_item(BrokenCache())
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
item = mock_item(no_cache_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_module_scan"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
def test_clear_lru_caches_no_cache(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def no_cache_func(x: int) -> int:
|
||||
return x
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_a(x):
|
||||
return x + 1
|
||||
|
||||
item = mock_item(no_cache_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_b(x):
|
||||
return x + 2
|
||||
|
||||
def plain_func(x):
|
||||
return x
|
||||
|
||||
cached_a(1)
|
||||
cached_a(1)
|
||||
cached_b(2)
|
||||
cached_b(2)
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.plain_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
||||
assert mock_module.cached_a.cache_info().currsize == 0
|
||||
assert mock_module.cached_b.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_skips_protected_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_protected"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def user_func(x):
|
||||
return x
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
mock_module.os_exists = os.path.exists
|
||||
item = mock_item(mock_module.user_func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_caches_scan_result(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_cache_reuse"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def cached_fn(x):
|
||||
return x
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.cached_fn)
|
||||
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert module_name in pytest_loops_instance._module_clearables # noqa: SLF001
|
||||
|
||||
mock_module.cached_fn(42)
|
||||
assert mock_module.cached_fn.cache_info().currsize == 1
|
||||
|
||||
with patch("codeflash.verification.pytest_plugin.inspect.getmembers") as mock_getmembers:
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
mock_getmembers.assert_not_called()
|
||||
|
||||
assert mock_module.cached_fn.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_handles_wrapped_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
module_name = "_cf_test_wrapped"
|
||||
source_code = """
|
||||
import functools
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def inner(x):
|
||||
return x
|
||||
|
||||
def wrapper(x):
|
||||
return inner(x)
|
||||
|
||||
wrapper.__wrapped__ = inner
|
||||
wrapper.__module__ = __name__
|
||||
|
||||
inner(1)
|
||||
inner(1)
|
||||
"""
|
||||
mock_module = create_mock_module(module_name, source_code, register=True)
|
||||
try:
|
||||
item = mock_item(mock_module.wrapper)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
assert mock_module.inner.cache_info().currsize == 0
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
def test_handles_function_without_module(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
|
||||
def func() -> None:
|
||||
pass
|
||||
|
||||
func.__module__ = None # type: ignore[assignment]
|
||||
item = mock_item(func)
|
||||
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
|
||||
|
|
|
|||
Loading…
Reference in a new issue