perf: add frozenset fast-path for comparator type dispatch

Use O(1) frozenset membership test with type identity before falling
through to isinstance MRO traversal. Backported from codeflash-python.
This commit is contained in:
Kevin Turcios 2026-04-10 00:53:55 -05:00
parent accbab4a16
commit 4c3c6ea167

View file

@ -74,6 +74,27 @@ _DICT_KEYS_TYPE = type({}.keys())
_DICT_VALUES_TYPE = type({}.values())
_DICT_ITEMS_TYPE = type({}.items())
_IDENTITY_EQ_TYPES: frozenset[type[Any]] = frozenset(
{
int,
bool,
complex,
type(None),
type(Ellipsis),
decimal.Decimal,
set,
bytes,
bytearray,
memoryview,
frozenset,
type,
range,
slice,
OrderedDict,
types.GenericAlias,
}
)
_EQUALITY_TYPES = (
int,
bool,
@ -184,12 +205,18 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return False
if type(orig) is not type(new):
type_obj = type(orig)
new_type_obj = type(new)
orig_type = type(orig)
if orig_type is not type(new):
# distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names
if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__:
if orig_type.__name__ != type(new).__name__ or orig_type.__qualname__ != type(new).__qualname__:
return False
# Fast-path: O(1) frozenset lookup for common types (avoids isinstance MRO traversal)
if orig_type in _IDENTITY_EQ_TYPES:
return orig == new
if orig_type is float:
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
if isinstance(orig, (list, tuple, deque, ChainMap)):
if len(orig) != len(new):
return False
@ -204,12 +231,9 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool:
return _normalize_temp_path(orig) == _normalize_temp_path(new)
return False
# enum.Enum subclasses and UnionType fall through from the frozenset fast-path
if isinstance(orig, _EQUALITY_TYPES):
return orig == new
if isinstance(orig, float):
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
# Handle weak references (e.g., found in torch.nn.LSTM/GRU modules)
if isinstance(orig, weakref.ref):