perf: cache module scan in _clear_lru_caches and expand test coverage

Cache inspect.getmembers() results per module so repeated loop
iterations skip the expensive rescan. Add tests for get_runtime_from_stdout,
should_stop, _set_nodeid, _get_total_time, _timed_out, logreport, and
setup/teardown hooks.
This commit is contained in:
Kevin Turcios 2026-02-22 01:17:05 -05:00
parent 5a37f9a3ca
commit 1689a7bbb5
2 changed files with 451 additions and 64 deletions

View file

@ -367,6 +367,7 @@ class PytestLoops:
self.enable_stability_check: bool = (
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
)
self._module_clearables: dict[str, list[Callable]] = {}
@pytest.hookimpl
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
@ -438,44 +439,55 @@ class PytestLoops:
return True
def _clear_lru_caches(self, item: pytest.Item) -> None:
processed_functions: set[Callable] = set()
func = item.function # type: ignore[attr-defined]
def _clear_cache_for_object(obj: obj) -> None:
if obj in processed_functions:
return
processed_functions.add(obj)
# Always clear the test function itself
if hasattr(func, "cache_clear") and callable(func.cache_clear):
with contextlib.suppress(Exception):
func.cache_clear()
module_name = getattr(func, "__module__", None)
if not module_name:
return
try:
clearables = self._module_clearables.get(module_name)
if clearables is None:
clearables = self._scan_module_clearables(module_name)
self._module_clearables[module_name] = clearables
for obj in clearables:
with contextlib.suppress(Exception):
obj.cache_clear()
except Exception:
pass
def _scan_module_clearables(self, module_name: str) -> list[Callable]:
module = sys.modules.get(module_name)
if not module:
return []
clearables: list[Callable] = []
for _, obj in inspect.getmembers(module):
if not callable(obj):
continue
if hasattr(obj, "__wrapped__"):
module_name = obj.__wrapped__.__module__
top_module = obj.__wrapped__.__module__
else:
try:
obj_module = inspect.getmodule(obj)
module_name = obj_module.__name__.split(".")[0] if obj_module is not None else None
top_module = obj_module.__name__.split(".")[0] if obj_module is not None else None
except Exception:
module_name = None
top_module = None
if module_name in _PROTECTED_MODULES:
return
if top_module in _PROTECTED_MODULES:
continue
if hasattr(obj, "cache_clear") and callable(obj.cache_clear):
with contextlib.suppress(Exception):
obj.cache_clear()
clearables.append(obj)
_clear_cache_for_object(item.function) # type: ignore[attr-defined]
try:
if hasattr(item.function, "__module__"): # type: ignore[attr-defined]
module_name = item.function.__module__ # type: ignore[attr-defined]
try:
module = sys.modules.get(module_name)
if module:
for _, obj in inspect.getmembers(module):
if callable(obj):
_clear_cache_for_object(obj)
except Exception:
pass
except Exception:
pass
return clearables
def _set_nodeid(self, nodeid: str, count: int) -> str:
"""Set loop count when using duration.

View file

@ -1,10 +1,18 @@
import os
import sys
import types
from typing import NoReturn
from unittest.mock import patch
import pytest
from _pytest.config import Config
from codeflash.verification.pytest_plugin import PytestLoops
from codeflash.verification.pytest_plugin import (
InvalidTimeParameterError,
PytestLoops,
get_runtime_from_stdout,
should_stop,
)
@pytest.fixture
@ -15,39 +23,301 @@ def pytest_loops_instance(pytestconfig: Config) -> PytestLoops:
@pytest.fixture
def mock_item() -> type:
class MockItem:
def __init__(self, function: types.FunctionType) -> None:
def __init__(self, function: types.FunctionType, name: str = "test_func", cls: type = None, module: types.ModuleType = None) -> None:
self.function = function
self.name = name
self.cls = cls
self.module = module
return MockItem
def create_mock_module(module_name: str, source_code: str) -> types.ModuleType:
def create_mock_module(module_name: str, source_code: str, register: bool = False) -> types.ModuleType:
module = types.ModuleType(module_name)
exec(source_code, module.__dict__) # noqa: S102
if register:
sys.modules[module_name] = module
return module
def test_clear_lru_caches_function(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
def mock_session(**kwargs):
"""Create a mock session with config options."""
defaults = {
"codeflash_hours": 0,
"codeflash_minutes": 0,
"codeflash_seconds": 10,
"codeflash_delay": 0.0,
"codeflash_loops": 1,
"codeflash_min_loops": 1,
"codeflash_max_loops": 100_000,
}
defaults.update(kwargs)
class Option:
pass
option = Option()
for k, v in defaults.items():
setattr(option, k, v)
class MockConfig:
pass
config = MockConfig()
config.option = option
class MockSession:
pass
session = MockSession()
session.config = config
return session
# --- get_runtime_from_stdout ---
class TestGetRuntimeFromStdout:
def test_valid_payload(self) -> None:
assert get_runtime_from_stdout("!######test_func:12345######!") == 12345
def test_valid_payload_with_surrounding_text(self) -> None:
assert get_runtime_from_stdout("some output\n!######mod.func:99999######!\nmore output") == 99999
def test_empty_string(self) -> None:
assert get_runtime_from_stdout("") is None
def test_no_markers(self) -> None:
assert get_runtime_from_stdout("just some output") is None
def test_missing_end_marker(self) -> None:
assert get_runtime_from_stdout("!######test:123") is None
def test_missing_start_marker(self) -> None:
assert get_runtime_from_stdout("test:123######!") is None
def test_no_colon_in_payload(self) -> None:
assert get_runtime_from_stdout("!######nocolon######!") is None
def test_non_integer_value(self) -> None:
assert get_runtime_from_stdout("!######test:notanumber######!") is None
def test_multiple_markers_uses_last(self) -> None:
stdout = "!######first:111######! middle !######second:222######!"
assert get_runtime_from_stdout(stdout) == 222
# --- should_stop ---
class TestShouldStop:
def test_not_enough_data_for_window(self) -> None:
assert should_stop([100, 100], window=5, min_window_size=3) is False
def test_below_min_window_size(self) -> None:
assert should_stop([100, 100], window=2, min_window_size=5) is False
def test_stable_runtimes_stops(self) -> None:
runtimes = [1000000] * 10
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is True
def test_unstable_runtimes_continues(self) -> None:
runtimes = [100, 200, 100, 200, 100]
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.01, spread_rel_tol=0.01) is False
def test_zero_runtimes_raises(self) -> None:
# All-zero runtimes cause ZeroDivisionError in median check.
# In practice the caller guards with best_runtime_until_now > 0.
runtimes = [0, 0, 0, 0, 0]
with pytest.raises(ZeroDivisionError):
should_stop(runtimes, window=5, min_window_size=3)
def test_even_window_median(self) -> None:
# Even window: median is average of two middle values
runtimes = [1000, 1000, 1001, 1001]
assert should_stop(runtimes, window=4, min_window_size=2, center_rel_tol=0.01, spread_rel_tol=0.01) is True
def test_centered_but_spread_too_large(self) -> None:
# All close to median but spread exceeds tolerance
runtimes = [1000, 1050, 1000, 1050, 1000]
assert should_stop(runtimes, window=5, min_window_size=3, center_rel_tol=0.1, spread_rel_tol=0.001) is False
# --- _set_nodeid ---
class TestSetNodeid:
def test_appends_count_to_plain_nodeid(self, pytest_loops_instance: PytestLoops) -> None:
result = pytest_loops_instance._set_nodeid("test_module.py::test_func", 3) # noqa: SLF001
assert result == "test_module.py::test_func[ 3 ]"
assert os.environ["CODEFLASH_LOOP_INDEX"] == "3"
def test_replaces_existing_count(self, pytest_loops_instance: PytestLoops) -> None:
result = pytest_loops_instance._set_nodeid("test_module.py::test_func[ 1 ]", 5) # noqa: SLF001
assert result == "test_module.py::test_func[ 5 ]"
def test_replaces_only_loop_pattern(self, pytest_loops_instance: PytestLoops) -> None:
# Parametrize brackets like [param0] should not be replaced
result = pytest_loops_instance._set_nodeid("test_mod.py::test_func[param0]", 2) # noqa: SLF001
assert result == "test_mod.py::test_func[param0][ 2 ]"
# --- _get_total_time ---
class TestGetTotalTime:
def test_seconds_only(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_seconds=30)
assert pytest_loops_instance._get_total_time(session) == 30 # noqa: SLF001
def test_mixed_units(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=1, codeflash_minutes=30, codeflash_seconds=45)
assert pytest_loops_instance._get_total_time(session) == 3600 + 1800 + 45 # noqa: SLF001
def test_zero_time_is_valid(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=0)
assert pytest_loops_instance._get_total_time(session) == 0 # noqa: SLF001
def test_negative_time_raises(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_hours=0, codeflash_minutes=0, codeflash_seconds=-1)
with pytest.raises(InvalidTimeParameterError):
pytest_loops_instance._get_total_time(session) # noqa: SLF001
# --- _timed_out ---
class TestTimedOut:
def test_exceeds_max_loops(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=10, codeflash_min_loops=1, codeflash_seconds=9999)
assert pytest_loops_instance._timed_out(session, start_time=0, count=10) is True # noqa: SLF001
def test_below_min_loops_never_times_out(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=50, codeflash_seconds=0)
# Even with 0 seconds budget, count < min_loops means not timed out
assert pytest_loops_instance._timed_out(session, start_time=0, count=5) is False # noqa: SLF001
def test_above_min_loops_and_time_exceeded(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_max_loops=100_000, codeflash_min_loops=1, codeflash_seconds=1)
# start_time far in the past → time exceeded
assert pytest_loops_instance._timed_out(session, start_time=0, count=2) is True # noqa: SLF001
# --- _get_delay_time ---
class TestGetDelayTime:
def test_returns_configured_delay(self, pytest_loops_instance: PytestLoops) -> None:
session = mock_session(codeflash_delay=0.5)
assert pytest_loops_instance._get_delay_time(session) == 0.5 # noqa: SLF001
# --- pytest_runtest_logreport ---
class TestRunTestLogReport:
def test_skipped_when_stability_check_disabled(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = False
class MockReport:
when = "call"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func"
instance.pytest_runtest_logreport(MockReport())
assert instance.runtime_data_by_test_case == {}
def test_records_runtime_on_passed_call(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = True
class MockReport:
when = "call"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func [ 1 ]"
instance.pytest_runtest_logreport(MockReport())
assert "test::func" in instance.runtime_data_by_test_case
assert instance.runtime_data_by_test_case["test::func"] == [12345]
def test_ignores_non_call_phase(self, pytestconfig: Config) -> None:
instance = PytestLoops(pytestconfig)
instance.enable_stability_check = True
class MockReport:
when = "setup"
passed = True
capstdout = "!######func:12345######!"
nodeid = "test::func"
instance.pytest_runtest_logreport(MockReport())
assert instance.runtime_data_by_test_case == {}
# --- pytest_runtest_setup / teardown ---
class TestRunTestSetupTeardown:
def test_setup_sets_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module = types.ModuleType("my_test_module")
class MyTestClass:
pass
item = mock_item(lambda: None, name="test_something[param1]", cls=MyTestClass, module=module)
pytest_loops_instance.pytest_runtest_setup(item)
assert os.environ["CODEFLASH_TEST_MODULE"] == "my_test_module"
assert os.environ["CODEFLASH_TEST_CLASS"] == "MyTestClass"
assert os.environ["CODEFLASH_TEST_FUNCTION"] == "test_something"
def test_setup_no_class(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module = types.ModuleType("my_test_module")
item = mock_item(lambda: None, name="test_plain", cls=None, module=module)
pytest_loops_instance.pytest_runtest_setup(item)
assert os.environ["CODEFLASH_TEST_CLASS"] == ""
def test_teardown_clears_env_vars(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
os.environ["CODEFLASH_TEST_MODULE"] = "leftover"
os.environ["CODEFLASH_TEST_CLASS"] = "leftover"
os.environ["CODEFLASH_TEST_FUNCTION"] = "leftover"
item = mock_item(lambda: None)
pytest_loops_instance.pytest_runtest_teardown(item)
assert "CODEFLASH_TEST_MODULE" not in os.environ
assert "CODEFLASH_TEST_CLASS" not in os.environ
assert "CODEFLASH_TEST_FUNCTION" not in os.environ
# --- _clear_lru_caches ---
class TestClearLruCaches:
def test_clears_lru_cached_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def my_func(x):
return x * 2
my_func(10) # miss the cache
my_func(10) # hit the cache
my_func(10)
my_func(10)
"""
mock_module = create_mock_module("test_module_func", source_code)
item = mock_item(mock_module.my_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.my_func.cache_info().hits == 0
assert mock_module.my_func.cache_info().misses == 0
assert mock_module.my_func.cache_info().currsize == 0
mock_module = create_mock_module("test_module_func", source_code)
item = mock_item(mock_module.my_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.my_func.cache_info().hits == 0
assert mock_module.my_func.cache_info().misses == 0
assert mock_module.my_func.cache_info().currsize == 0
def test_clear_lru_caches_class_method(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
def test_clears_class_method_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
source_code = """
import functools
class MyClass:
@ -56,32 +326,137 @@ class MyClass:
return x * 3
obj = MyClass()
obj.my_method(5) # Pre-populate the cache
obj.my_method(5) # Hit the cache
obj.my_method(5)
obj.my_method(5)
# """
mock_module = create_mock_module("test_module_class", source_code)
item = mock_item(mock_module.MyClass.my_method)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.MyClass.my_method.cache_info().hits == 0
assert mock_module.MyClass.my_method.cache_info().misses == 0
assert mock_module.MyClass.my_method.cache_info().currsize == 0
mock_module = create_mock_module("test_module_class", source_code)
item = mock_item(mock_module.MyClass.my_method)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.MyClass.my_method.cache_info().hits == 0
assert mock_module.MyClass.my_method.cache_info().misses == 0
assert mock_module.MyClass.my_method.cache_info().currsize == 0
def test_handles_exception_in_cache_clear(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
class BrokenCache:
def cache_clear(self) -> NoReturn:
msg = "Cache clearing failed!"
raise ValueError(msg)
def test_clear_lru_caches_exception_handling(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
"""Test that exceptions during clearing are handled."""
item = mock_item(BrokenCache())
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
class BrokenCache:
def cache_clear(self) -> NoReturn:
msg = "Cache clearing failed!"
raise ValueError(msg)
def test_handles_no_cache(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def no_cache_func(x: int) -> int:
return x
item = mock_item(BrokenCache())
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
item = mock_item(no_cache_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
def test_clears_module_level_caches_via_sys_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_module_scan"
source_code = """
import functools
def test_clear_lru_caches_no_cache(pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def no_cache_func(x: int) -> int:
return x
@functools.lru_cache(maxsize=None)
def cached_a(x):
return x + 1
item = mock_item(no_cache_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
@functools.lru_cache(maxsize=None)
def cached_b(x):
return x + 2
def plain_func(x):
return x
cached_a(1)
cached_a(1)
cached_b(2)
cached_b(2)
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.plain_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.cached_a.cache_info().currsize == 0
assert mock_module.cached_b.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_skips_protected_modules(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_protected"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def user_func(x):
return x
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
mock_module.os_exists = os.path.exists
item = mock_item(mock_module.user_func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
finally:
sys.modules.pop(module_name, None)
def test_caches_scan_result(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_cache_reuse"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def cached_fn(x):
return x
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.cached_fn)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert module_name in pytest_loops_instance._module_clearables # noqa: SLF001
mock_module.cached_fn(42)
assert mock_module.cached_fn.cache_info().currsize == 1
with patch("codeflash.verification.pytest_plugin.inspect.getmembers") as mock_getmembers:
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
mock_getmembers.assert_not_called()
assert mock_module.cached_fn.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_handles_wrapped_function(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
module_name = "_cf_test_wrapped"
source_code = """
import functools
@functools.lru_cache(maxsize=None)
def inner(x):
return x
def wrapper(x):
return inner(x)
wrapper.__wrapped__ = inner
wrapper.__module__ = __name__
inner(1)
inner(1)
"""
mock_module = create_mock_module(module_name, source_code, register=True)
try:
item = mock_item(mock_module.wrapper)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001
assert mock_module.inner.cache_info().currsize == 0
finally:
sys.modules.pop(module_name, None)
def test_handles_function_without_module(self, pytest_loops_instance: PytestLoops, mock_item: type) -> None:
def func() -> None:
pass
func.__module__ = None # type: ignore[assignment]
item = mock_item(func)
pytest_loops_instance._clear_lru_caches(item) # noqa: SLF001