import array # Add import for array import ast import copy import dataclasses import datetime import decimal import re import sys import uuid import weakref from collections import ChainMap, Counter, OrderedDict, UserDict, UserList, UserString, defaultdict, deque, namedtuple from enum import Enum, Flag, IntFlag, auto from pathlib import Path import pydantic import pytest from codeflash.either import Failure, Success from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType from codeflash.verification.comparator import ( PYTEST_TEMP_PATH_PATTERN, PYTHON_TEMPFILE_PATTERN, _extract_exception_from_message, _get_wrapped_exception, _is_temp_path, _normalize_temp_path, comparator, ) from codeflash.verification.equivalence import compare_test_results def test_basic_python_objects() -> None: a = 5 b = 5 c = 6 d = None assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) a = 5.0 b = 5.0 c = 6.0 d = None e = None assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) assert not comparator(d, a) assert comparator(d, e) a = "Hello" b = "Hello" c = "World" assert comparator(a, b) assert not comparator(a, c) a = [1, 2, 3] b = [1, 2, 3] c = [1, 2, 4] assert comparator(a, b) assert not comparator(a, c) a = {"a": 1, "b": 2} b = {"a": 1, "b": 2} c = {"a": 1, "b": 3} d = {"c": 1, "b": 2} e = {"a": 1, "b": 2, "c": 3} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) assert not comparator(a, e) a = (1, 2, "str") b = (1, 2, "str") c = (1, 2, "str2") d = [1, 2, "str"] assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) a = {1, 2, 3} b = {2, 3, 1} c = {1, 2, 4} d = {1, 2, 3, 4} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) a = (65).to_bytes(1, byteorder="big") b = (65).to_bytes(1, byteorder="big") c = (66).to_bytes(1, byteorder="big") assert comparator(a, b) assert not comparator(a, c) a = (65).to_bytes(2, byteorder="little") b = (65).to_bytes(2, byteorder="big") assert not comparator(a, b) a = bytearray([65, 64, 63]) b = bytearray([65, 64, 63]) c = bytearray([65, 64, 62]) assert comparator(a, b) assert not comparator(a, c) memoryview_a = memoryview(bytearray([65, 64, 63])) memoryview_b = memoryview(bytearray([65, 64, 63])) memoryview_c = memoryview(bytearray([65, 64, 62])) assert comparator(memoryview_a, memoryview_b) assert not comparator(memoryview_a, memoryview_c) a = frozenset([1, 2, 3]) b = frozenset([2, 3, 1]) c = frozenset([1, 2, 4]) d = frozenset([1, 2, 3, 4]) assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) a = map b = pow c = pow d = abs assert comparator(b, c) assert not comparator(a, b) assert not comparator(c, d) a = object() b = object() c = abs assert comparator(a, b) assert not comparator(a, c) a = type([]) b = type([]) c = type({}) assert comparator(a, b) assert not comparator(a, c) def test_weakref() -> None: """Test comparator for weakref.ref objects.""" # Helper class that supports weak references and has comparable __dict__ class Holder: def __init__(self, value): self.value = value # Test weak references to the same object obj = Holder([1, 2, 3]) ref1 = weakref.ref(obj) ref2 = weakref.ref(obj) assert comparator(ref1, ref2) # Test weak references to equivalent but different objects obj1 = Holder({"key": "value"}) obj2 = Holder({"key": "value"}) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert comparator(ref1, ref2) # Test weak references to different objects obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 4]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert not comparator(ref1, ref2) # Test weak references with different data obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 3, 4]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert not comparator(ref1, ref2) # Test dead weak references (both dead) obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 3]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) del obj1 del obj2 # Both refs are now dead, should be equal assert comparator(ref1, ref2) # Test one dead, one alive weak reference obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 3]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) del obj1 # ref1 is dead, ref2 is alive, should not be equal assert not comparator(ref1, ref2) assert not comparator(ref2, ref1) # Test weak references to nested structures obj1 = Holder({"nested": [1, 2, {"inner": "value"}]}) obj2 = Holder({"nested": [1, 2, {"inner": "value"}]}) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert comparator(ref1, ref2) # Test weak references to nested structures with differences obj1 = Holder({"nested": [1, 2, {"inner": "value1"}]}) obj2 = Holder({"nested": [1, 2, {"inner": "value2"}]}) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert not comparator(ref1, ref2) # Test weak references in a dictionary (simulating __dict__ with weakrefs) obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 3]) dict1 = {"data": 42, "ref": weakref.ref(obj1)} dict2 = {"data": 42, "ref": weakref.ref(obj2)} assert comparator(dict1, dict2) # Test weak references in a dictionary with different referents obj1 = Holder([1, 2, 3]) obj2 = Holder([4, 5, 6]) dict1 = {"data": 42, "ref": weakref.ref(obj1)} dict2 = {"data": 42, "ref": weakref.ref(obj2)} assert not comparator(dict1, dict2) # Test weak references in a list obj1 = Holder({"a": 1}) obj2 = Holder({"a": 1}) list1 = [weakref.ref(obj1), "other"] list2 = [weakref.ref(obj2), "other"] assert comparator(list1, list2) def test_weakref_to_custom_objects() -> None: """Test comparator for weakref.ref to custom class instances.""" class MyClass: def __init__(self, value): self.value = value # Test weak references to equivalent custom objects obj1 = MyClass(42) obj2 = MyClass(42) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert comparator(ref1, ref2) # Test weak references to different custom objects obj1 = MyClass(42) obj2 = MyClass(99) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert not comparator(ref1, ref2) # Test weak references to custom objects with nested data class Container: def __init__(self, items): self.items = items obj1 = Container([1, 2, 3]) obj2 = Container([1, 2, 3]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert comparator(ref1, ref2) obj1 = Container([1, 2, 3]) obj2 = Container([1, 2, 4]) ref1 = weakref.ref(obj1) ref2 = weakref.ref(obj2) assert not comparator(ref1, ref2) def test_weakref_with_callbacks() -> None: """Test that weakrefs with callbacks are compared correctly.""" class Holder: def __init__(self, value): self.value = value callback_called = [] def callback(ref): callback_called.append(ref) obj1 = Holder([1, 2, 3]) obj2 = Holder([1, 2, 3]) # Weakrefs with callbacks should still compare based on referents ref1 = weakref.ref(obj1, callback) ref2 = weakref.ref(obj2, callback) assert comparator(ref1, ref2) obj1 = Holder([1, 2, 3]) obj2 = Holder([4, 5, 6]) ref1 = weakref.ref(obj1, callback) ref2 = weakref.ref(obj2, callback) assert not comparator(ref1, ref2) @pytest.mark.parametrize( "r1, r2, expected", [ (range(1, 10), range(1, 10), True), # equal (range(10), range(1, 10), False), # different start (range(2, 10), range(1, 10), False), (range(1, 5), range(1, 10), False), # different stop (range(1, 20), range(1, 10), False), (range(1, 10, 1), range(1, 10, 2), False), # different step (range(1, 10, 3), range(1, 10, 2), False), (range(-5, 0), range(-5, 0), True), # negative ranges (range(-10, 0), range(-5, 0), False), (range(5, 1), range(10, 5), True), # empty ranges (range(5, 1), range(5, 1), True), (range(7), range(7), True), (range(7), range(0, 7, 1), True), (range(7), range(0, 7, 1), True), ], ) def test_ranges(r1, r2, expected): assert comparator(r1, r2) == expected def test_standard_python_library_objects() -> None: a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore assert comparator(a, b) assert not comparator(a, c) a = datetime.date(2020, 2, 2) # type: ignore b = datetime.date(2020, 2, 2) # type: ignore c = datetime.date(2020, 2, 3) # type: ignore assert comparator(a, b) assert not comparator(a, c) a = datetime.timedelta(days=1) # type: ignore b = datetime.timedelta(days=1) # type: ignore c = datetime.timedelta(days=2) # type: ignore assert comparator(a, b) assert not comparator(a, c) a = datetime.time(2, 2, 2) # type: ignore b = datetime.time(2, 2, 2) # type: ignore c = datetime.time(2, 2, 3) # type: ignore assert comparator(a, b) assert not comparator(a, c) a = datetime.timezone.utc # type: ignore b = datetime.timezone.utc # type: ignore c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore assert comparator(a, b) assert not comparator(a, c) a = decimal.Decimal(3.14) # type: ignore b = decimal.Decimal(3.14) # type: ignore c = decimal.Decimal(3.15) # type: ignore assert comparator(a, b) assert not comparator(a, c) class Color(Flag): RED = auto() GREEN = auto() BLUE = auto() class Color2(Enum): RED = auto() GREEN = auto() BLUE = auto() a = Color.RED # type: ignore b = Color.RED # type: ignore c = Color.GREEN # type: ignore assert comparator(a, b) assert not comparator(a, c) a = Color2.RED # type: ignore b = Color2.RED # type: ignore c = Color2.GREEN # type: ignore assert comparator(a, b) assert not comparator(a, c) class Color4(IntFlag): RED = auto() GREEN = auto() BLUE = auto() a = Color4.RED # type: ignore b = Color4.RED # type: ignore c = Color4.GREEN # type: ignore assert comparator(a, b) assert not comparator(a, c) a: re.Pattern = re.compile("a") b: re.Pattern = re.compile("a") c: re.Pattern = re.compile("b") d: re.Pattern = re.compile("a", re.IGNORECASE) assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) arr1 = array.array("i", [1, 2, 3]) arr2 = array.array("i", [1, 2, 3]) arr3 = array.array("i", [4, 5, 6]) arr4 = array.array("f", [1.0, 2.0, 3.0]) assert comparator(arr1, arr2) assert not comparator(arr1, arr3) assert not comparator(arr1, arr4) assert not comparator(arr1, [1, 2, 3]) empty_arr_i1 = array.array("i") empty_arr_i2 = array.array("i") empty_arr_f = array.array("f") assert comparator(empty_arr_i1, empty_arr_i2) assert not comparator(empty_arr_i1, empty_arr_f) assert not comparator(empty_arr_i1, arr1) id1 = uuid.uuid4() id3 = uuid.uuid4() assert comparator(id1, id1) assert not comparator(id1, id3) def test_itertools_count() -> None: import itertools # Equal: same start and step (default step=1) assert comparator(itertools.count(0), itertools.count(0)) assert comparator(itertools.count(5), itertools.count(5)) assert comparator(itertools.count(0, 1), itertools.count(0, 1)) assert comparator(itertools.count(10, 3), itertools.count(10, 3)) # Equal: negative start and step assert comparator(itertools.count(-5, -2), itertools.count(-5, -2)) # Equal: float start and step assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1)) # Not equal: different start assert not comparator(itertools.count(0), itertools.count(1)) assert not comparator(itertools.count(5), itertools.count(10)) # Not equal: different step assert not comparator(itertools.count(0, 1), itertools.count(0, 2)) assert not comparator(itertools.count(0, 1), itertools.count(0, -1)) # Not equal: different type assert not comparator(itertools.count(0), 0) assert not comparator(itertools.count(0), [0, 1, 2]) # Equal after partial consumption (both advanced to the same state) a = itertools.count(0) b = itertools.count(0) next(a) next(b) assert comparator(a, b) # Not equal after different consumption a = itertools.count(0) b = itertools.count(0) next(a) assert not comparator(a, b) # Works inside containers assert comparator([itertools.count(0)], [itertools.count(0)]) assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)}) assert not comparator([itertools.count(0)], [itertools.count(1)]) def test_itertools_repeat() -> None: import itertools # Equal: infinite repeat assert comparator(itertools.repeat(5), itertools.repeat(5)) assert comparator(itertools.repeat("hello"), itertools.repeat("hello")) # Equal: bounded repeat assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3)) assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10)) # Not equal: different value assert not comparator(itertools.repeat(5), itertools.repeat(6)) assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3)) # Not equal: different count assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4)) # Not equal: bounded vs infinite assert not comparator(itertools.repeat(5), itertools.repeat(5, 3)) # Not equal: different type assert not comparator(itertools.repeat(5), 5) assert not comparator(itertools.repeat(5), [5]) # Equal after partial consumption a = itertools.repeat(5, 5) b = itertools.repeat(5, 5) next(a) next(b) assert comparator(a, b) # Not equal after different consumption a = itertools.repeat(5, 5) b = itertools.repeat(5, 5) next(a) assert not comparator(a, b) # Works inside containers assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)]) assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)]) def test_itertools_cycle() -> None: import itertools # Equal: same sequence assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3])) assert comparator(itertools.cycle("abc"), itertools.cycle("abc")) # Not equal: different sequence assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4])) assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2])) # Not equal: different type assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3]) # Equal after same partial consumption a = itertools.cycle([1, 2, 3]) b = itertools.cycle([1, 2, 3]) next(a) next(b) assert comparator(a, b) # Not equal after different consumption a = itertools.cycle([1, 2, 3]) b = itertools.cycle([1, 2, 3]) next(a) assert not comparator(a, b) # Equal after consuming a full cycle a = itertools.cycle([1, 2, 3]) b = itertools.cycle([1, 2, 3]) for _ in range(3): next(a) next(b) assert comparator(a, b) # Equal at same position across different full-cycle counts a = itertools.cycle([1, 2, 3]) b = itertools.cycle([1, 2, 3]) for _ in range(4): next(a) for _ in range(7): next(b) # Both at position 1 within the cycle (4%3 == 7%3 == 1) assert comparator(a, b) # Works inside containers assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])]) assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])]) def test_itertools_chain() -> None: import itertools assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4])) assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5])) assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]])) assert comparator(itertools.chain(), itertools.chain()) assert not comparator(itertools.chain([1]), itertools.chain([1, 2])) def test_itertools_islice() -> None: import itertools assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5)) assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6)) assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5)) assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6)) def test_itertools_product() -> None: import itertools assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2)) assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2)) assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4])) assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5])) def test_itertools_permutations_combinations() -> None: import itertools assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2)) assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2)) assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2)) assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3)) assert comparator( itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABC", 2) ) assert not comparator( itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABD", 2) ) def test_itertools_accumulate() -> None: import itertools assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4])) assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5])) assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10)) assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0)) def test_itertools_filtering() -> None: import itertools # compress assert comparator( itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]) ) assert not comparator( itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]) ) # dropwhile assert comparator( itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]) ) assert not comparator( itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]) ) # takewhile assert comparator( itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]) ) assert not comparator( itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]) ) # filterfalse assert comparator( itertools.filterfalse(lambda x: x % 2, range(10)), itertools.filterfalse(lambda x: x % 2, range(10)) ) def test_itertools_starmap() -> None: import itertools assert comparator( itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]) ) assert not comparator(itertools.starmap(pow, [(2, 3), (3, 2)]), itertools.starmap(pow, [(2, 3), (3, 3)])) def test_itertools_zip_longest() -> None: import itertools assert comparator( itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="-") ) assert not comparator( itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="*") ) def test_itertools_groupby() -> None: import itertools assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC")) assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC")) assert comparator(itertools.groupby([]), itertools.groupby([])) # With key function assert comparator( itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x) ) @pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+") def test_itertools_pairwise() -> None: import itertools assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4])) assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5])) @pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+") def test_itertools_batched() -> None: import itertools assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3)) assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2)) def test_itertools_in_containers() -> None: import itertools # Itertools objects nested in dicts/lists assert comparator( {"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)}, {"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)}, ) assert not comparator([itertools.product("AB", repeat=2)], [itertools.product("AC", repeat=2)]) # Different itertools types should not match assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2)) def test_numpy(): try: import numpy as np except ImportError: pytest.skip() a = np.array([1, 2, 3]) b = np.array([1, 2, 3]) c = np.array([1, 2, 4]) assert comparator(a, b) assert not comparator(a, c) d = np.array([[1, 2], [3, 4]]) e = np.array([[1, 2], [3, 4]]) f = np.array([[1, 2], [3, 5]]) assert comparator(d, e) assert not comparator(d, f) assert not comparator(a, d) g = np.array([1.0, 2.0, 3.0]) assert not comparator(a, g) h = np.float32(1.0) i = np.float32(1.0) assert comparator(h, i) j = np.float64(1.0) k = np.float64(1.0) assert not comparator(h, j) assert comparator(j, k) l = np.int32(1) m = np.int32(1) assert comparator(l, m) assert not comparator(l, h) assert not comparator(l, j) n = np.int64(1) o = np.int64(1) assert not comparator(n, l) assert comparator(n, o) p = np.uint32(1) q = np.uint32(1) assert comparator(p, q) assert not comparator(p, l) r = np.uint64(1) s = np.uint64(1) assert not comparator(r, p) assert comparator(r, s) t = np.bool_(True) u = np.bool_(True) assert comparator(t, u) assert not comparator(t, r) v = np.complex64(1.0 + 1.0j) w = np.complex64(1.0 + 1.0j) assert comparator(v, w) assert not comparator(v, t) x = np.complex128(1.0 + 1.0j) y = np.complex128(1.0 + 1.0j) assert not comparator(x, v) assert comparator(x, y) # Create numpy array with mixed type object z = np.array([1, 2, "str"], dtype=np.object_) aa = np.array([1, 2, "str"], dtype=np.object_) ab = np.array([1, 2, "str2"], dtype=np.object_) assert comparator(z, aa) assert not comparator(z, ab) ac = np.array([1, 2, "str2"]) ad = np.array([1, 2, "str2"]) assert comparator(ac, ad) # Test for numpy array with nan and inf ae = np.array([1, 2, np.nan]) af = np.array([1, 2, np.nan]) ag = np.array([1, 2, np.inf]) ah = np.array([1, 2, np.inf]) ai = np.inf aj = np.inf ak = np.nan al = np.nan assert comparator(ae, af) assert comparator(ag, ah) assert not comparator(ae, ag) assert not comparator(af, ah) assert comparator(ai, aj) assert comparator(ak, al) assert not comparator(ai, ak) dt = np.dtype([("name", "S10"), ("age", np.int32)]) a_struct = np.array([("Alice", 25)], dtype=dt) b_struct = np.array([("Alice", 25)], dtype=dt) c_struct = np.array([("Bob", 30)], dtype=dt) a_void = a_struct[0] b_void = b_struct[0] c_void = c_struct[0] assert isinstance(a_void, np.void) assert comparator(a_void, b_void) assert not comparator(a_void, c_void) def test_numpy_random_generator(): try: import numpy as np except ImportError: pytest.skip() # Test numpy.random.Generator (modern API) # Same seed should produce equal generators rng1 = np.random.default_rng(seed=42) rng2 = np.random.default_rng(seed=42) assert comparator(rng1, rng2) # Different seeds should produce non-equal generators rng3 = np.random.default_rng(seed=123) assert not comparator(rng1, rng3) # After generating numbers, state changes rng4 = np.random.default_rng(seed=42) rng5 = np.random.default_rng(seed=42) rng4.random() # Advance state assert not comparator(rng4, rng5) # Both advanced by same amount should be equal rng5.random() assert comparator(rng4, rng5) # Test with different bit generators from numpy.random import MT19937, PCG64 rng_pcg1 = np.random.Generator(PCG64(seed=42)) rng_pcg2 = np.random.Generator(PCG64(seed=42)) assert comparator(rng_pcg1, rng_pcg2) rng_mt1 = np.random.Generator(MT19937(seed=42)) rng_mt2 = np.random.Generator(MT19937(seed=42)) assert comparator(rng_mt1, rng_mt2) # Different bit generator types should not be equal assert not comparator(rng_pcg1, rng_mt1) def test_numpy_random_state(): try: import numpy as np except ImportError: pytest.skip() # Test numpy.random.RandomState (legacy API) # Same seed should produce equal states rs1 = np.random.RandomState(seed=42) rs2 = np.random.RandomState(seed=42) assert comparator(rs1, rs2) # Different seeds should produce non-equal states rs3 = np.random.RandomState(seed=123) assert not comparator(rs1, rs3) # After generating numbers, state changes rs4 = np.random.RandomState(seed=42) rs5 = np.random.RandomState(seed=42) rs4.random() # Advance state assert not comparator(rs4, rs5) # Both advanced by same amount should be equal rs5.random() assert comparator(rs4, rs5) # Test state restoration rs6 = np.random.RandomState(seed=42) state = rs6.get_state() rs6.random() # Advance state rs7 = np.random.RandomState(seed=42) rs7.set_state(state) # rs6 advanced, rs7 restored to original state assert not comparator(rs6, rs7) def test_scipy(): try: import scipy as sp # type: ignore except ImportError: pytest.skip() a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) b = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) c = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) ca = sp.sparse.csr_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]]) assert comparator(a, b) assert not comparator(a, c) assert not comparator(c, ca) d = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) e = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) f = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) fa = sp.sparse.csc_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]]) assert comparator(d, e) assert not comparator(d, f) assert not comparator(a, d) assert not comparator(c, f) assert not comparator(f, fa) g = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) h = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) i = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) assert comparator(g, h) assert not comparator(g, i) assert not comparator(a, g) j = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) k = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) l = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) assert comparator(j, k) assert not comparator(j, l) assert not comparator(a, j) m = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) n = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) o = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) assert comparator(m, n) assert not comparator(m, o) assert not comparator(a, m) p = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) q = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) r = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) assert comparator(p, q) assert not comparator(p, r) assert not comparator(a, p) s = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) t = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]]) u = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]]) assert comparator(s, t) assert not comparator(s, u) assert not comparator(a, s) try: import numpy as np row = np.array([0, 3, 1, 0]) col = np.array([0, 3, 1, 2]) data = np.array([4, 5, 7, 9]) v = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray() w = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray() assert comparator(v, w) except ImportError: print("Should run tests with numpy installed to test more thoroughly") def test_pandas(): try: import pandas as pd except ImportError: pytest.skip() a = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) b = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) c = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 7]}) ca = pd.DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]}) assert comparator(a, b) assert not comparator(a, c) assert not comparator(c, ca) ak = pd.DataFrame( {"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]} ) al = pd.DataFrame( {"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]} ) am = pd.DataFrame( {"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]} ) assert comparator(ak, al) assert not comparator(ak, am) d = pd.Series([1, 2, 3]) e = pd.Series([1, 2, 3]) f = pd.Series([1, 2, 4]) assert comparator(d, e) assert not comparator(d, f) g = pd.Index([1, 2, 3]) h = pd.Index([1, 2, 3]) i = pd.Index([1, 2, 4]) assert comparator(g, h) assert not comparator(g, i) j = pd.MultiIndex.from_tuples([(1, 2), (3, 4)]) k = pd.MultiIndex.from_tuples([(1, 2), (3, 4)]) l = pd.MultiIndex.from_tuples([(1, 2), (3, 5)]) assert comparator(j, k) assert not comparator(j, l) m = pd.Categorical([1, 2, 3]) n = pd.Categorical([1, 2, 3]) o = pd.Categorical([1, 2, 4]) assert comparator(m, n) assert not comparator(m, o) p = pd.Interval(1, 2) q = pd.Interval(1, 2) r = pd.Interval(1, 3) assert comparator(p, q) assert not comparator(p, r) s = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)]) t = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)]) u = pd.IntervalIndex.from_tuples([(1, 2), (3, 5)]) assert comparator(s, t) assert not comparator(s, u) v = pd.Period("2021-01") w = pd.Period("2021-01") x = pd.Period("2021-02") assert comparator(v, w) assert not comparator(v, x) y = pd.period_range(start="2021-01", periods=3, freq="M") z = pd.period_range(start="2021-01", periods=3, freq="M") aa = pd.period_range(start="2021-01", periods=4, freq="M") assert comparator(y, z) assert not comparator(y, aa) ab = pd.Timedelta("1 days") ac = pd.Timedelta("1 days") ad = pd.Timedelta("2 days") assert comparator(ab, ac) assert not comparator(ab, ad) ae = pd.TimedeltaIndex(["1 days", "2 days"]) af = pd.TimedeltaIndex(["1 days", "2 days"]) ag = pd.TimedeltaIndex(["1 days", "3 days"]) assert comparator(ae, af) assert not comparator(ae, ag) ah = pd.Timestamp("2021-01-01") ai = pd.Timestamp("2021-01-01") aj = pd.Timestamp("2021-01-02") assert comparator(ah, ai) assert not comparator(ah, aj) # test cases for sparse pandas arrays an = pd.arrays.SparseArray([1, 2, 3]) ao = pd.arrays.SparseArray([1, 2, 3]) ap = pd.arrays.SparseArray([1, 2, 4]) assert comparator(an, ao) assert not comparator(an, ap) assert comparator(pd.NA, pd.NA) assert not comparator(pd.NA, None) assert not comparator(None, pd.NA) s1 = pd.Series([1, 2, pd.NA, 4]) s2 = pd.Series([1, 2, pd.NA, 4]) s3 = pd.Series([1, 2, None, 4]) assert comparator(s1, s2) assert not comparator(s1, s3) df1 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]}) df2 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]}) df3 = pd.DataFrame({"a": [1, 2, None], "b": [4, None, 6]}) assert comparator(df1, df2) assert not comparator(df1, df3) d1 = {"a": pd.NA, "b": [1, pd.NA, 3]} d2 = {"a": pd.NA, "b": [1, pd.NA, 3]} d3 = {"a": None, "b": [1, None, 3]} assert comparator(d1, d2) assert not comparator(d1, d3) s1 = pd.Series([1, 2, pd.NA, 4]) s2 = pd.Series([1, 2, pd.NA, 4]) filtered1 = s1[s1 > 1] filtered2 = s2[s2 > 1] assert comparator(filtered1, filtered2) def test_pyarrow(): try: import pyarrow as pa except ImportError: pytest.skip() # Test PyArrow Table table1 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) table2 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) table3 = pa.table({"a": [1, 2, 3], "b": [4, 5, 7]}) table4 = pa.table({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]}) table5 = pa.table({"a": [1, 2, 3], "c": [4, 5, 6]}) # different column name assert comparator(table1, table2) assert not comparator(table1, table3) assert not comparator(table1, table4) assert not comparator(table1, table5) # Test PyArrow RecordBatch batch1 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]}) batch2 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]}) batch3 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 5.0]}) batch4 = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [3.0, 4.0, 5.0]}) assert comparator(batch1, batch2) assert not comparator(batch1, batch3) assert not comparator(batch1, batch4) # Test PyArrow Array arr1 = pa.array([1, 2, 3]) arr2 = pa.array([1, 2, 3]) arr3 = pa.array([1, 2, 4]) arr4 = pa.array([1, 2, 3, 4]) arr5 = pa.array([1.0, 2.0, 3.0]) # different type assert comparator(arr1, arr2) assert not comparator(arr1, arr3) assert not comparator(arr1, arr4) assert not comparator(arr1, arr5) # Test PyArrow Array with nulls arr_null1 = pa.array([1, None, 3]) arr_null2 = pa.array([1, None, 3]) arr_null3 = pa.array([1, 2, 3]) assert comparator(arr_null1, arr_null2) assert not comparator(arr_null1, arr_null3) # Test PyArrow ChunkedArray chunked1 = pa.chunked_array([[1, 2], [3, 4]]) chunked2 = pa.chunked_array([[1, 2], [3, 4]]) chunked3 = pa.chunked_array([[1, 2], [3, 5]]) chunked4 = pa.chunked_array([[1, 2, 3], [4, 5]]) assert comparator(chunked1, chunked2) assert not comparator(chunked1, chunked3) assert not comparator(chunked1, chunked4) # Test PyArrow Scalar scalar1 = pa.scalar(42) scalar2 = pa.scalar(42) scalar3 = pa.scalar(43) scalar4 = pa.scalar(42.0) # different type assert comparator(scalar1, scalar2) assert not comparator(scalar1, scalar3) assert not comparator(scalar1, scalar4) # Test null scalars null_scalar1 = pa.scalar(None, type=pa.int64()) null_scalar2 = pa.scalar(None, type=pa.int64()) null_scalar3 = pa.scalar(None, type=pa.float64()) assert comparator(null_scalar1, null_scalar2) assert not comparator(null_scalar1, null_scalar3) # Test PyArrow Schema schema1 = pa.schema([("a", pa.int64()), ("b", pa.float64())]) schema2 = pa.schema([("a", pa.int64()), ("b", pa.float64())]) schema3 = pa.schema([("a", pa.int64()), ("c", pa.float64())]) schema4 = pa.schema([("a", pa.int32()), ("b", pa.float64())]) assert comparator(schema1, schema2) assert not comparator(schema1, schema3) assert not comparator(schema1, schema4) # Test PyArrow Field field1 = pa.field("name", pa.int64()) field2 = pa.field("name", pa.int64()) field3 = pa.field("other", pa.int64()) field4 = pa.field("name", pa.float64()) assert comparator(field1, field2) assert not comparator(field1, field3) assert not comparator(field1, field4) # Test PyArrow DataType type1 = pa.int64() type2 = pa.int64() type3 = pa.int32() type4 = pa.float64() assert comparator(type1, type2) assert not comparator(type1, type3) assert not comparator(type1, type4) # Test string arrays str_arr1 = pa.array(["hello", "world"]) str_arr2 = pa.array(["hello", "world"]) str_arr3 = pa.array(["hello", "there"]) assert comparator(str_arr1, str_arr2) assert not comparator(str_arr1, str_arr3) # Test nested types (struct) struct_arr1 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}]) struct_arr2 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}]) struct_arr3 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 5}]) assert comparator(struct_arr1, struct_arr2) assert not comparator(struct_arr1, struct_arr3) # Test list arrays list_arr1 = pa.array([[1, 2], [3, 4, 5]]) list_arr2 = pa.array([[1, 2], [3, 4, 5]]) list_arr3 = pa.array([[1, 2], [3, 4, 6]]) assert comparator(list_arr1, list_arr2) assert not comparator(list_arr1, list_arr3) def test_pyrsistent(): try: from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore except ImportError: pytest.skip() a = pmap({"a": 1, "b": 2}) b = pmap({"a": 1, "b": 2}) c = pmap({"a": 1, "b": 3}) assert comparator(a, b) assert not comparator(a, c) d = pvector([1, 2, 3]) e = pvector([1, 2, 3]) f = pvector([1, 2, 4]) assert comparator(d, e) assert not comparator(d, f) g = pset([1, 2, 3]) h = pset([2, 3, 1]) i = pset([1, 2, 4]) assert comparator(g, h) assert not comparator(g, i) class TestRecord(PRecord): a = field() b = field() j = TestRecord() k = TestRecord() l = TestRecord(a=2, b=3) assert comparator(j, k) assert not comparator(j, l) class TestClass(PClass): a = field() b = field() m = TestClass() n = TestClass() o = TestClass(a=1, b=3) assert comparator(m, n) assert not comparator(m, o) p = pdeque([1, 2, 3], 3) q = pdeque([1, 2, 3], 3) r = pdeque([1, 2, 4], 3) assert comparator(p, q) assert not comparator(p, r) s = PBag([1, 2, 3]) t = PBag([1, 2, 3]) u = PBag([1, 2, 4]) assert comparator(s, t) assert not comparator(s, u) v = pvector([1, 2, 3]) w = pvector([1, 2, 3]) x = pvector([1, 2, 4]) assert comparator(v, w) assert not comparator(v, x) def test_torch_dtype(): try: import torch # type: ignore except ImportError: pytest.skip() # Test torch.dtype comparisons a = torch.float32 b = torch.float32 c = torch.float64 d = torch.int32 assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # Test different dtype categories e = torch.int64 f = torch.int64 g = torch.int32 assert comparator(e, f) assert not comparator(e, g) # Test complex dtypes h = torch.complex64 i = torch.complex64 j = torch.complex128 assert comparator(h, i) assert not comparator(h, j) # Test bool dtype k = torch.bool l = torch.bool m = torch.int8 assert comparator(k, l) assert not comparator(k, m) def test_torch(): try: import torch # type: ignore except ImportError: pytest.skip() a = torch.tensor([1, 2, 3]) b = torch.tensor([1, 2, 3]) c = torch.tensor([1, 2, 4]) assert comparator(a, b) assert not comparator(a, c) d = torch.tensor([[1, 2, 3], [4, 5, 6]]) e = torch.tensor([[1, 2, 3], [4, 5, 6]]) f = torch.tensor([[1, 2, 3], [4, 5, 7]]) assert comparator(d, e) assert not comparator(d, f) # Test tensors with different data types g = torch.tensor([1, 2, 3], dtype=torch.float32) h = torch.tensor([1, 2, 3], dtype=torch.float32) i = torch.tensor([1, 2, 3], dtype=torch.int64) assert comparator(g, h) assert not comparator(g, i) # Test 3D tensors j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) assert comparator(j, k) assert not comparator(j, l) # Test tensors with different shapes m = torch.tensor([1, 2, 3]) n = torch.tensor([[1, 2, 3]]) assert not comparator(m, n) # Test empty tensors o = torch.tensor([]) p = torch.tensor([]) q = torch.tensor([1]) assert comparator(o, p) assert not comparator(o, q) # Test tensors with NaN values r = torch.tensor([1.0, float("nan"), 3.0]) s = torch.tensor([1.0, float("nan"), 3.0]) t = torch.tensor([1.0, 2.0, 3.0]) assert comparator(r, s) # NaN == NaN assert not comparator(r, t) # Test tensors with infinity values u = torch.tensor([1.0, float("inf"), 3.0]) v = torch.tensor([1.0, float("inf"), 3.0]) w = torch.tensor([1.0, float("-inf"), 3.0]) assert comparator(u, v) assert not comparator(u, w) # Test tensors with different devices (if CUDA is available) if torch.cuda.is_available(): x = torch.tensor([1, 2, 3]).cuda() y = torch.tensor([1, 2, 3]).cuda() z = torch.tensor([1, 2, 3]) assert comparator(x, y) assert not comparator(x, z) # Test tensors with requires_grad aa = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) bb = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) cc = torch.tensor([1.0, 2.0, 3.0], requires_grad=False) assert comparator(aa, bb) assert not comparator(aa, cc) # Test complex tensors dd = torch.tensor([1 + 2j, 3 + 4j]) ee = torch.tensor([1 + 2j, 3 + 4j]) ff = torch.tensor([1 + 2j, 3 + 5j]) assert comparator(dd, ee) assert not comparator(dd, ff) # Test boolean tensors gg = torch.tensor([True, False, True]) hh = torch.tensor([True, False, True]) ii = torch.tensor([True, True, True]) assert comparator(gg, hh) assert not comparator(gg, ii) def test_torch_device(): try: import torch # type: ignore except ImportError: pytest.skip() # Test torch.device comparisons - same device type a = torch.device("cpu") b = torch.device("cpu") assert comparator(a, b) # Test different device types c = torch.device("cpu") d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") if torch.cuda.is_available(): assert not comparator(c, d) # Test device with index e = torch.device("cpu") f = torch.device("cpu") assert comparator(e, f) # Test cuda devices with different indices (if multiple GPUs available) if torch.cuda.is_available() and torch.cuda.device_count() > 1: g = torch.device("cuda:0") h = torch.device("cuda:0") i = torch.device("cuda:1") assert comparator(g, h) assert not comparator(g, i) # Test cuda device with and without explicit index if torch.cuda.is_available(): j = torch.device("cuda:0") k = torch.device("cuda", 0) assert comparator(j, k) # Test meta device l = torch.device("meta") m = torch.device("meta") n = torch.device("cpu") assert comparator(l, m) assert not comparator(l, n) def test_torch_nn_linear(): """Test comparator for torch.nn.Linear modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Linear layers torch.manual_seed(42) a = nn.Linear(10, 5) torch.manual_seed(42) b = nn.Linear(10, 5) assert comparator(a, b) # Test Linear layers with different weights (different seeds) torch.manual_seed(42) c = nn.Linear(10, 5) torch.manual_seed(123) d = nn.Linear(10, 5) assert not comparator(c, d) # Test Linear layers with different in_features torch.manual_seed(42) e = nn.Linear(10, 5) torch.manual_seed(42) f = nn.Linear(20, 5) assert not comparator(e, f) # Test Linear layers with different out_features torch.manual_seed(42) g = nn.Linear(10, 5) torch.manual_seed(42) h = nn.Linear(10, 10) assert not comparator(g, h) # Test Linear with and without bias torch.manual_seed(42) i = nn.Linear(10, 5, bias=True) torch.manual_seed(42) j = nn.Linear(10, 5, bias=False) assert not comparator(i, j) # Test Linear layers in train vs eval mode torch.manual_seed(42) k = nn.Linear(10, 5) k.train() torch.manual_seed(42) l = nn.Linear(10, 5) l.eval() assert not comparator(k, l) def test_torch_nn_conv2d(): """Test comparator for torch.nn.Conv2d modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Conv2d layers torch.manual_seed(42) a = nn.Conv2d(3, 16, kernel_size=3) torch.manual_seed(42) b = nn.Conv2d(3, 16, kernel_size=3) assert comparator(a, b) # Test Conv2d with different weights torch.manual_seed(42) c = nn.Conv2d(3, 16, kernel_size=3) torch.manual_seed(123) d = nn.Conv2d(3, 16, kernel_size=3) assert not comparator(c, d) # Test Conv2d with different in_channels torch.manual_seed(42) e = nn.Conv2d(3, 16, kernel_size=3) torch.manual_seed(42) f = nn.Conv2d(1, 16, kernel_size=3) assert not comparator(e, f) # Test Conv2d with different out_channels torch.manual_seed(42) g = nn.Conv2d(3, 16, kernel_size=3) torch.manual_seed(42) h = nn.Conv2d(3, 32, kernel_size=3) assert not comparator(g, h) # Test Conv2d with different kernel_size torch.manual_seed(42) i = nn.Conv2d(3, 16, kernel_size=3) torch.manual_seed(42) j = nn.Conv2d(3, 16, kernel_size=5) assert not comparator(i, j) # Test Conv2d with different stride torch.manual_seed(42) k = nn.Conv2d(3, 16, kernel_size=3, stride=1) torch.manual_seed(42) l = nn.Conv2d(3, 16, kernel_size=3, stride=2) assert not comparator(k, l) # Test Conv2d with different padding torch.manual_seed(42) m = nn.Conv2d(3, 16, kernel_size=3, padding=0) torch.manual_seed(42) n = nn.Conv2d(3, 16, kernel_size=3, padding=1) assert not comparator(m, n) def test_torch_nn_batchnorm(): """Test comparator for torch.nn.BatchNorm modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical BatchNorm2d layers torch.manual_seed(42) a = nn.BatchNorm2d(16) torch.manual_seed(42) b = nn.BatchNorm2d(16) assert comparator(a, b) # Test BatchNorm2d with different num_features torch.manual_seed(42) c = nn.BatchNorm2d(16) torch.manual_seed(42) d = nn.BatchNorm2d(32) assert not comparator(c, d) # Test BatchNorm2d with different eps torch.manual_seed(42) e = nn.BatchNorm2d(16, eps=1e-5) torch.manual_seed(42) f = nn.BatchNorm2d(16, eps=1e-3) assert not comparator(e, f) # Test BatchNorm2d with different momentum torch.manual_seed(42) g = nn.BatchNorm2d(16, momentum=0.1) torch.manual_seed(42) h = nn.BatchNorm2d(16, momentum=0.01) assert not comparator(g, h) # Test BatchNorm2d with and without affine torch.manual_seed(42) i = nn.BatchNorm2d(16, affine=True) torch.manual_seed(42) j = nn.BatchNorm2d(16, affine=False) assert not comparator(i, j) # Test BatchNorm2d running stats after forward passes torch.manual_seed(42) k = nn.BatchNorm2d(16) k.train() input_k = torch.randn(4, 16, 8, 8) _ = k(input_k) torch.manual_seed(42) l = nn.BatchNorm2d(16) l.train() input_l = torch.randn(4, 16, 8, 8) _ = l(input_l) # Same seed means same running stats assert comparator(k, l) # Test BatchNorm2d with different running stats torch.manual_seed(42) m = nn.BatchNorm2d(16) m.train() torch.manual_seed(42) _ = m(torch.randn(4, 16, 8, 8)) torch.manual_seed(42) n = nn.BatchNorm2d(16) n.train() torch.manual_seed(123) _ = n(torch.randn(4, 16, 8, 8)) assert not comparator(m, n) # Test BatchNorm1d torch.manual_seed(42) o = nn.BatchNorm1d(16) torch.manual_seed(42) p = nn.BatchNorm1d(16) assert comparator(o, p) def test_torch_nn_dropout(): """Test comparator for torch.nn.Dropout modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Dropout layers a = nn.Dropout(p=0.5) b = nn.Dropout(p=0.5) assert comparator(a, b) # Test Dropout with different p values c = nn.Dropout(p=0.5) d = nn.Dropout(p=0.3) assert not comparator(c, d) # Test Dropout with different inplace values e = nn.Dropout(p=0.5, inplace=False) f = nn.Dropout(p=0.5, inplace=True) assert not comparator(e, f) # Test Dropout2d g = nn.Dropout2d(p=0.5) h = nn.Dropout2d(p=0.5) assert comparator(g, h) # Test Dropout vs Dropout2d (different types) i = nn.Dropout(p=0.5) j = nn.Dropout2d(p=0.5) assert not comparator(i, j) def test_torch_nn_activation(): """Test comparator for torch.nn activation modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test ReLU a = nn.ReLU() b = nn.ReLU() assert comparator(a, b) # Test ReLU with different inplace c = nn.ReLU(inplace=False) d = nn.ReLU(inplace=True) assert not comparator(c, d) # Test LeakyReLU e = nn.LeakyReLU(negative_slope=0.01) f = nn.LeakyReLU(negative_slope=0.01) assert comparator(e, f) # Test LeakyReLU with different negative_slope g = nn.LeakyReLU(negative_slope=0.01) h = nn.LeakyReLU(negative_slope=0.1) assert not comparator(g, h) # Test Sigmoid vs ReLU (different types) i = nn.Sigmoid() j = nn.ReLU() assert not comparator(i, j) # Test GELU k = nn.GELU() l = nn.GELU() assert comparator(k, l) # Test Softmax m = nn.Softmax(dim=1) n = nn.Softmax(dim=1) assert comparator(m, n) # Test Softmax with different dim o = nn.Softmax(dim=1) p = nn.Softmax(dim=0) assert not comparator(o, p) def test_torch_nn_pooling(): """Test comparator for torch.nn pooling modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test MaxPool2d a = nn.MaxPool2d(kernel_size=2) b = nn.MaxPool2d(kernel_size=2) assert comparator(a, b) # Test MaxPool2d with different kernel_size c = nn.MaxPool2d(kernel_size=2) d = nn.MaxPool2d(kernel_size=3) assert not comparator(c, d) # Test MaxPool2d with different stride e = nn.MaxPool2d(kernel_size=2, stride=2) f = nn.MaxPool2d(kernel_size=2, stride=1) assert not comparator(e, f) # Test AvgPool2d g = nn.AvgPool2d(kernel_size=2) h = nn.AvgPool2d(kernel_size=2) assert comparator(g, h) # Test MaxPool2d vs AvgPool2d (different types) i = nn.MaxPool2d(kernel_size=2) j = nn.AvgPool2d(kernel_size=2) assert not comparator(i, j) # Test AdaptiveAvgPool2d k = nn.AdaptiveAvgPool2d(output_size=(1, 1)) l = nn.AdaptiveAvgPool2d(output_size=(1, 1)) assert comparator(k, l) # Test AdaptiveAvgPool2d with different output_size m = nn.AdaptiveAvgPool2d(output_size=(1, 1)) n = nn.AdaptiveAvgPool2d(output_size=(2, 2)) assert not comparator(m, n) def test_torch_nn_embedding(): """Test comparator for torch.nn.Embedding modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Embedding layers torch.manual_seed(42) a = nn.Embedding(1000, 128) torch.manual_seed(42) b = nn.Embedding(1000, 128) assert comparator(a, b) # Test Embedding with different weights torch.manual_seed(42) c = nn.Embedding(1000, 128) torch.manual_seed(123) d = nn.Embedding(1000, 128) assert not comparator(c, d) # Test Embedding with different num_embeddings torch.manual_seed(42) e = nn.Embedding(1000, 128) torch.manual_seed(42) f = nn.Embedding(2000, 128) assert not comparator(e, f) # Test Embedding with different embedding_dim torch.manual_seed(42) g = nn.Embedding(1000, 128) torch.manual_seed(42) h = nn.Embedding(1000, 256) assert not comparator(g, h) # Test Embedding with different padding_idx torch.manual_seed(42) i = nn.Embedding(1000, 128, padding_idx=0) torch.manual_seed(42) j = nn.Embedding(1000, 128, padding_idx=1) assert not comparator(i, j) # Test Embedding with and without padding_idx torch.manual_seed(42) k = nn.Embedding(1000, 128) torch.manual_seed(42) l = nn.Embedding(1000, 128, padding_idx=0) assert not comparator(k, l) def test_torch_nn_lstm(): """Test comparator for torch.nn.LSTM modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical LSTM layers torch.manual_seed(42) a = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) torch.manual_seed(42) b = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) assert comparator(a, b) # Test LSTM with different weights torch.manual_seed(42) c = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) torch.manual_seed(123) d = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) assert not comparator(c, d) # Test LSTM with different input_size torch.manual_seed(42) e = nn.LSTM(input_size=10, hidden_size=20) torch.manual_seed(42) f = nn.LSTM(input_size=20, hidden_size=20) assert not comparator(e, f) # Test LSTM with different hidden_size torch.manual_seed(42) g = nn.LSTM(input_size=10, hidden_size=20) torch.manual_seed(42) h = nn.LSTM(input_size=10, hidden_size=40) assert not comparator(g, h) # Test LSTM with different num_layers torch.manual_seed(42) i = nn.LSTM(input_size=10, hidden_size=20, num_layers=1) torch.manual_seed(42) j = nn.LSTM(input_size=10, hidden_size=20, num_layers=2) assert not comparator(i, j) # Test LSTM with different bidirectional torch.manual_seed(42) k = nn.LSTM(input_size=10, hidden_size=20, bidirectional=False) torch.manual_seed(42) l = nn.LSTM(input_size=10, hidden_size=20, bidirectional=True) assert not comparator(k, l) # Test LSTM with different batch_first torch.manual_seed(42) m = nn.LSTM(input_size=10, hidden_size=20, batch_first=False) torch.manual_seed(42) n = nn.LSTM(input_size=10, hidden_size=20, batch_first=True) assert not comparator(m, n) def test_torch_nn_gru(): """Test comparator for torch.nn.GRU modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical GRU layers torch.manual_seed(42) a = nn.GRU(input_size=10, hidden_size=20, num_layers=2) torch.manual_seed(42) b = nn.GRU(input_size=10, hidden_size=20, num_layers=2) assert comparator(a, b) # Test GRU with different hidden_size torch.manual_seed(42) c = nn.GRU(input_size=10, hidden_size=20) torch.manual_seed(42) d = nn.GRU(input_size=10, hidden_size=40) assert not comparator(c, d) # Test GRU vs LSTM (different types) torch.manual_seed(42) e = nn.GRU(input_size=10, hidden_size=20) torch.manual_seed(42) f = nn.LSTM(input_size=10, hidden_size=20) assert not comparator(e, f) def test_torch_nn_layernorm(): """Test comparator for torch.nn.LayerNorm modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical LayerNorm layers torch.manual_seed(42) a = nn.LayerNorm(normalized_shape=[10]) torch.manual_seed(42) b = nn.LayerNorm(normalized_shape=[10]) assert comparator(a, b) # Test LayerNorm with different normalized_shape torch.manual_seed(42) c = nn.LayerNorm(normalized_shape=[10]) torch.manual_seed(42) d = nn.LayerNorm(normalized_shape=[20]) assert not comparator(c, d) # Test LayerNorm with different eps torch.manual_seed(42) e = nn.LayerNorm(normalized_shape=[10], eps=1e-5) torch.manual_seed(42) f = nn.LayerNorm(normalized_shape=[10], eps=1e-3) assert not comparator(e, f) # Test LayerNorm with and without elementwise_affine torch.manual_seed(42) g = nn.LayerNorm(normalized_shape=[10], elementwise_affine=True) torch.manual_seed(42) h = nn.LayerNorm(normalized_shape=[10], elementwise_affine=False) assert not comparator(g, h) def test_torch_nn_multihead_attention(): """Test comparator for torch.nn.MultiheadAttention modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical MultiheadAttention layers torch.manual_seed(42) a = nn.MultiheadAttention(embed_dim=64, num_heads=8) torch.manual_seed(42) b = nn.MultiheadAttention(embed_dim=64, num_heads=8) assert comparator(a, b) # Test MultiheadAttention with different weights torch.manual_seed(42) c = nn.MultiheadAttention(embed_dim=64, num_heads=8) torch.manual_seed(123) d = nn.MultiheadAttention(embed_dim=64, num_heads=8) assert not comparator(c, d) # Test MultiheadAttention with different embed_dim torch.manual_seed(42) e = nn.MultiheadAttention(embed_dim=64, num_heads=8) torch.manual_seed(42) f = nn.MultiheadAttention(embed_dim=128, num_heads=8) assert not comparator(e, f) # Test MultiheadAttention with different num_heads torch.manual_seed(42) g = nn.MultiheadAttention(embed_dim=64, num_heads=8) torch.manual_seed(42) h = nn.MultiheadAttention(embed_dim=64, num_heads=4) assert not comparator(g, h) # Test MultiheadAttention with different dropout torch.manual_seed(42) i = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.0) torch.manual_seed(42) j = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.1) assert not comparator(i, j) def test_torch_nn_sequential(): """Test comparator for torch.nn.Sequential modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Sequential modules torch.manual_seed(42) a = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)) torch.manual_seed(42) b = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)) assert comparator(a, b) # Test Sequential with different weights torch.manual_seed(42) c = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)) torch.manual_seed(123) d = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)) assert not comparator(c, d) # Test Sequential with different number of layers torch.manual_seed(42) e = nn.Sequential(nn.Linear(10, 20), nn.ReLU()) torch.manual_seed(42) f = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)) assert not comparator(e, f) # Test Sequential with different layer types torch.manual_seed(42) g = nn.Sequential(nn.Linear(10, 20), nn.ReLU()) torch.manual_seed(42) h = nn.Sequential(nn.Linear(10, 20), nn.Sigmoid()) assert not comparator(g, h) def test_torch_nn_modulelist(): """Test comparator for torch.nn.ModuleList modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical ModuleList torch.manual_seed(42) a = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) torch.manual_seed(42) b = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) assert comparator(a, b) # Test ModuleList with different number of modules torch.manual_seed(42) c = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)]) torch.manual_seed(42) d = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)]) assert not comparator(c, d) def test_torch_nn_moduledict(): """Test comparator for torch.nn.ModuleDict modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical ModuleDict torch.manual_seed(42) a = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)}) torch.manual_seed(42) b = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)}) assert comparator(a, b) # Test ModuleDict with different keys torch.manual_seed(42) c = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)}) torch.manual_seed(42) d = nn.ModuleDict({"layer1": nn.Linear(10, 20), "layer2": nn.Linear(20, 5)}) assert not comparator(c, d) def test_torch_nn_custom_module(): """Test comparator for custom torch.nn.Module subclasses.""" try: import torch from torch import nn except ImportError: pytest.skip() class SimpleNet(nn.Module): def __init__(self, hidden_size): super().__init__() self.fc1 = nn.Linear(10, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, 5) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # Test identical custom modules torch.manual_seed(42) a = SimpleNet(hidden_size=20) torch.manual_seed(42) b = SimpleNet(hidden_size=20) assert comparator(a, b) # Test custom modules with different weights torch.manual_seed(42) c = SimpleNet(hidden_size=20) torch.manual_seed(123) d = SimpleNet(hidden_size=20) assert not comparator(c, d) # Test custom modules with different architecture torch.manual_seed(42) e = SimpleNet(hidden_size=20) torch.manual_seed(42) f = SimpleNet(hidden_size=40) assert not comparator(e, f) def test_torch_nn_nested_modules(): """Test comparator for nested torch.nn.Module structures.""" try: import torch from torch import nn except ImportError: pytest.skip() class EncoderBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() def forward(self, x): return self.relu(self.bn(self.conv(x))) class Encoder(nn.Module): def __init__(self): super().__init__() self.block1 = EncoderBlock(3, 16) self.block2 = EncoderBlock(16, 32) self.pool = nn.MaxPool2d(2) def forward(self, x): x = self.block1(x) x = self.pool(x) x = self.block2(x) x = self.pool(x) return x # Test identical nested modules torch.manual_seed(42) a = Encoder() torch.manual_seed(42) b = Encoder() assert comparator(a, b) # Test nested modules with different weights torch.manual_seed(42) c = Encoder() torch.manual_seed(123) d = Encoder() assert not comparator(c, d) def test_torch_nn_transformer(): """Test comparator for torch.nn.Transformer modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test identical Transformer torch.manual_seed(42) a = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2) torch.manual_seed(42) b = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2) assert comparator(a, b) # Test Transformer with different d_model torch.manual_seed(42) c = nn.Transformer(d_model=64, nhead=4) torch.manual_seed(42) d = nn.Transformer(d_model=128, nhead=4) assert not comparator(c, d) # Test Transformer with different nhead torch.manual_seed(42) e = nn.Transformer(d_model=64, nhead=4) torch.manual_seed(42) f = nn.Transformer(d_model=64, nhead=8) assert not comparator(e, f) # Test TransformerEncoder torch.manual_seed(42) encoder_layer_a = nn.TransformerEncoderLayer(d_model=64, nhead=4) g = nn.TransformerEncoder(encoder_layer_a, num_layers=2) torch.manual_seed(42) encoder_layer_b = nn.TransformerEncoderLayer(d_model=64, nhead=4) h = nn.TransformerEncoder(encoder_layer_b, num_layers=2) assert comparator(g, h) def test_torch_nn_parameter_buffer_modification(): """Test comparator detects parameter and buffer modifications.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test that modifying a parameter is detected torch.manual_seed(42) a = nn.Linear(10, 5) torch.manual_seed(42) b = nn.Linear(10, 5) assert comparator(a, b) # Modify a parameter with torch.no_grad(): b.weight[0, 0] = 999.0 assert not comparator(a, b) # Test that modifying a buffer is detected (BatchNorm running_mean) torch.manual_seed(42) c = nn.BatchNorm2d(16) torch.manual_seed(42) d = nn.BatchNorm2d(16) assert comparator(c, d) # Modify a buffer d.running_mean[0] = 999.0 assert not comparator(c, d) def test_torch_nn_device_placement(): """Test comparator handles modules on different devices.""" try: import torch from torch import nn except ImportError: pytest.skip() # Create modules on CPU torch.manual_seed(42) cpu_module = nn.Linear(10, 5) torch.manual_seed(42) cpu_module2 = nn.Linear(10, 5) assert comparator(cpu_module, cpu_module2) # If CUDA is available, test device mismatch if torch.cuda.is_available(): torch.manual_seed(42) cpu_mod = nn.Linear(10, 5) torch.manual_seed(42) cuda_mod = nn.Linear(10, 5).cuda() assert not comparator(cpu_mod, cuda_mod) def test_torch_nn_conv1d_conv3d(): """Test comparator for Conv1d and Conv3d modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test Conv1d torch.manual_seed(42) a = nn.Conv1d(3, 16, kernel_size=3) torch.manual_seed(42) b = nn.Conv1d(3, 16, kernel_size=3) assert comparator(a, b) # Test Conv1d with different out_channels torch.manual_seed(42) c = nn.Conv1d(3, 16, kernel_size=3) torch.manual_seed(42) d = nn.Conv1d(3, 32, kernel_size=3) assert not comparator(c, d) # Test Conv3d torch.manual_seed(42) e = nn.Conv3d(3, 16, kernel_size=3) torch.manual_seed(42) f = nn.Conv3d(3, 16, kernel_size=3) assert comparator(e, f) # Test Conv1d vs Conv2d (different types) torch.manual_seed(42) g = nn.Conv1d(3, 16, kernel_size=3) torch.manual_seed(42) h = nn.Conv2d(3, 16, kernel_size=3) assert not comparator(g, h) def test_torch_nn_flatten_unflatten(): """Test comparator for Flatten and Unflatten modules.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test Flatten a = nn.Flatten() b = nn.Flatten() assert comparator(a, b) # Test Flatten with different start_dim c = nn.Flatten(start_dim=1) d = nn.Flatten(start_dim=0) assert not comparator(c, d) # Test Unflatten e = nn.Unflatten(dim=1, unflattened_size=(2, 5)) f = nn.Unflatten(dim=1, unflattened_size=(2, 5)) assert comparator(e, f) def test_torch_nn_identity(): """Test comparator for Identity module.""" try: import torch from torch import nn except ImportError: pytest.skip() # Test Identity a = nn.Identity() b = nn.Identity() assert comparator(a, b) # Test Identity vs Linear (different types) torch.manual_seed(42) c = nn.Identity() d = nn.Linear(10, 10) assert not comparator(c, d) def test_torch_nn_with_superset(): """Test comparator superset_obj mode with nn.Module.""" try: import torch from torch import nn except ImportError: pytest.skip() # For nn.Module, superset_obj should still work torch.manual_seed(42) a = nn.Linear(10, 5) torch.manual_seed(42) b = nn.Linear(10, 5) # superset_obj=True should pass for identical modules assert comparator(a, b, superset_obj=True) # Different modules should still fail torch.manual_seed(42) c = nn.Linear(10, 5) torch.manual_seed(123) d = nn.Linear(10, 5) assert not comparator(c, d, superset_obj=True) def test_jax(): try: import jax.numpy as jnp except ImportError: pytest.skip() # Test basic arrays a = jnp.array([1, 2, 3]) b = jnp.array([1, 2, 3]) c = jnp.array([1, 2, 4]) assert comparator(a, b) assert not comparator(a, c) # Test 2D arrays d = jnp.array([[1, 2, 3], [4, 5, 6]]) e = jnp.array([[1, 2, 3], [4, 5, 6]]) f = jnp.array([[1, 2, 3], [4, 5, 7]]) assert comparator(d, e) assert not comparator(d, f) # Test arrays with different data types g = jnp.array([1, 2, 3], dtype=jnp.float32) h = jnp.array([1, 2, 3], dtype=jnp.float32) i = jnp.array([1, 2, 3], dtype=jnp.int32) assert comparator(g, h) assert not comparator(g, i) # Test 3D arrays j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) assert comparator(j, k) assert not comparator(j, l) # Test arrays with different shapes m = jnp.array([1, 2, 3]) n = jnp.array([[1, 2, 3]]) assert not comparator(m, n) # Test empty arrays o = jnp.array([]) p = jnp.array([]) q = jnp.array([1]) assert comparator(o, p) assert not comparator(o, q) # Test arrays with NaN values r = jnp.array([1.0, jnp.nan, 3.0]) s = jnp.array([1.0, jnp.nan, 3.0]) t = jnp.array([1.0, 2.0, 3.0]) assert comparator(r, s) # NaN == NaN assert not comparator(r, t) # Test arrays with infinity values u = jnp.array([1.0, jnp.inf, 3.0]) v = jnp.array([1.0, jnp.inf, 3.0]) w = jnp.array([1.0, -jnp.inf, 3.0]) assert comparator(u, v) assert not comparator(u, w) # Test complex arrays x = jnp.array([1 + 2j, 3 + 4j]) y = jnp.array([1 + 2j, 3 + 4j]) z = jnp.array([1 + 2j, 3 + 5j]) assert comparator(x, y) assert not comparator(x, z) # Test boolean arrays aa = jnp.array([True, False, True]) bb = jnp.array([True, False, True]) cc = jnp.array([True, True, True]) assert comparator(aa, bb) assert not comparator(aa, cc) def test_xarray(): try: import numpy as np import xarray as xr except ImportError: pytest.skip() # Test basic DataArray a = xr.DataArray([1, 2, 3], dims=["x"]) b = xr.DataArray([1, 2, 3], dims=["x"]) c = xr.DataArray([1, 2, 4], dims=["x"]) assert comparator(a, b) assert not comparator(a, c) # Test DataArray with coordinates d = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"]) e = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"]) f = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 3]}, dims=["x"]) assert comparator(d, e) assert not comparator(d, f) # Test DataArray with attributes g = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"}) h = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"}) i = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "feet"}) assert comparator(g, h) assert not comparator(g, i) # Test 2D DataArray j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"]) k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"]) l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=["x", "y"]) assert comparator(j, k) assert not comparator(j, l) # Test DataArray with different dimensions m = xr.DataArray([1, 2, 3], dims=["x"]) n = xr.DataArray([1, 2, 3], dims=["y"]) assert not comparator(m, n) # Test DataArray with NaN values o = xr.DataArray([1.0, np.nan, 3.0], dims=["x"]) p = xr.DataArray([1.0, np.nan, 3.0], dims=["x"]) q = xr.DataArray([1.0, 2.0, 3.0], dims=["x"]) assert comparator(o, p) assert not comparator(o, q) # Test Dataset r = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])}) s = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])}) t = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 9]])}) assert comparator(r, s) assert not comparator(r, t) # Test Dataset with coordinates u = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]}) v = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]}) w = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 2], "y": [0, 1]}) assert comparator(u, v) assert not comparator(u, w) # Test Dataset with attributes x = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"}) y = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"}) z = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "model"}) assert comparator(x, y) assert not comparator(x, z) # Test Dataset with different variables aa = xr.Dataset({"temp": (["x"], [1, 2, 3])}) bb = xr.Dataset({"temp": (["x"], [1, 2, 3])}) cc = xr.Dataset({"pressure": (["x"], [1, 2, 3])}) assert comparator(aa, bb) assert not comparator(aa, cc) # Test empty Dataset dd = xr.Dataset() ee = xr.Dataset() assert comparator(dd, ee) # Test DataArray with different shapes ff = xr.DataArray([1, 2, 3], dims=["x"]) gg = xr.DataArray([[1, 2, 3]], dims=["x", "y"]) assert not comparator(ff, gg) # Test DataArray with different data types # Note: xarray.identical() considers int and float arrays with same values as identical hh = xr.DataArray(np.array([1, 2, 3], dtype="int32"), dims=["x"]) ii = xr.DataArray(np.array([1, 2, 3], dtype="int64"), dims=["x"]) # xarray is permissive with dtype comparisons, treats these as identical assert comparator(hh, ii) # Test DataArray with infinity jj = xr.DataArray([1.0, np.inf, 3.0], dims=["x"]) kk = xr.DataArray([1.0, np.inf, 3.0], dims=["x"]) ll = xr.DataArray([1.0, -np.inf, 3.0], dims=["x"]) assert comparator(jj, kk) assert not comparator(jj, ll) # Test Dataset vs DataArray (different types) mm = xr.DataArray([1, 2, 3], dims=["x"]) nn = xr.Dataset({"data": (["x"], [1, 2, 3])}) assert not comparator(mm, nn) def test_returns(): a = Success(5) b = Success(5) c = Success(6) d = Failure(5) e = Success((5, 5)) f = Success((5, 6)) assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) assert not comparator(a, e) assert not comparator(e, f) g = Success((5, 5)) h = Success((5, 5)) i = Success((5, 6)) assert comparator(g, h) assert not comparator(g, i) def test_custom_object(): class TestClass: def __init__(self, value): self.value = value def __eq__(self, other): return self.value == other.value a = TestClass(5) b = TestClass(5) c = TestClass(6) assert comparator(a, b) assert not comparator(a, c) class TestClass2: def __init__(self, value1, value2=6): self.value1 = value1 self.value2 = value2 a = TestClass(5) b = TestClass2(5, 6) c = TestClass2(5, 7) d = TestClass2(5, 6) assert not comparator(a, b) assert not comparator(b, c) assert comparator(b, d) class TestClass3(TestClass): def print(self): print(self.value) a = TestClass2(5) b = TestClass3(5) c = TestClass3(5) assert not comparator(a, b) assert comparator(b, c) @dataclasses.dataclass class InventoryItem: """Class for keeping track of an item in inventory.""" name: str unit_price: float quantity_on_hand: int = 0 def total_cost(self) -> float: return self.unit_price * self.quantity_on_hand a = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10) b = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10) c = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=11) assert comparator(a, b) assert not comparator(a, c) @pydantic.dataclasses.dataclass class InventoryItemPydantic: """Class for keeping track of an item in inventory.""" name: str unit_price: float quantity_on_hand: int = 0 def total_cost(self) -> float: return self.unit_price * self.quantity_on_hand a = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10) b = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10) c = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=11) assert comparator(a, b) assert not comparator(a, c) class InventoryItemBasePydantic(pydantic.BaseModel): name: str unit_price: float quantity_on_hand: int = 0 def total_cost(self) -> float: return self.unit_price * self.quantity_on_hand a = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10) b = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10) c = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=11) assert comparator(a, b) assert not comparator(a, c) class A: items = [1, 2, 3] val = 5 class B: items = [1, 2, 4] val = 5 assert comparator(A, A) assert not comparator(A, B) class C: items = [1, 2, 3] val = 5 def __init__(self): self.itemm2 = [1, 2, 3] self.val2 = 5 class D: items = [1, 2, 3] val = 5 def __init__(self): self.itemm2 = [1, 2, 4] self.val2 = 5 assert comparator(C, C) assert not comparator(C, D) E = C assert comparator(C, E) def test_custom_object_2(): fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve() original_code = fto_path.read_text("utf-8") from code_to_optimize.bubble_sort_method import BubbleSorter a = BubbleSorter() assert a.x == 0 try: # Remove the module from sys.modules, to get the updated class sys.modules.pop("code_to_optimize.bubble_sort_method", None) from code_to_optimize.bubble_sort_method import BubbleSorter b = BubbleSorter() assert comparator( a, b ) # Note that type(a) != type(b) as the class type objects are different, even if the code is the same. optimized_code_mutated_attr = """ class BubbleSorter: z = 0 def __init__(self, x=1): self.x = x def sorter(self, arr): for i in range(len(arr)): for j in range(len(arr) - 1): if arr[j] > arr[j + 1]: temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp return arr """ fto_path.write_text(optimized_code_mutated_attr, "utf-8") sys.modules.pop("code_to_optimize.bubble_sort_method", None) from code_to_optimize.bubble_sort_method import BubbleSorter c = BubbleSorter() assert c.x == 1 assert not comparator(a, c) optimized_code_new_attr = """ class BubbleSorter: z = 5 def __init__(self, x=0): self.x = x def sorter(self, arr): for i in range(len(arr)): for j in range(len(arr) - 1): if arr[j] > arr[j + 1]: temp = arr[j] arr[j] = arr[j + 1] arr[j + 1] = temp return arr """ fto_path.write_text(optimized_code_new_attr, "utf-8") sys.modules.pop("code_to_optimize.bubble_sort_method", None) from code_to_optimize.bubble_sort_method import BubbleSorter d = BubbleSorter() assert d.x == 0 # Currently, we do not check if class variables are different, since the code replacer does not allow this. # In the future, if this functionality is allowed, this assert should be false. assert comparator(a, d) finally: fto_path.write_text(original_code, "utf-8") def test_superset(): class A: def __init__(self): self.a = 1 obj = A() obj.x = 3 assert comparator(A(), obj, superset_obj=True) assert not comparator(obj, A(), superset_obj=True) assert not comparator(A(), obj) assert not comparator(obj, A()) assert comparator(obj, obj, superset_obj=True) assert comparator(obj, obj) def test_compare_results_fn(): original_results = TestResults() original_results.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=5, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=5, timed_out=False, loop_index=1, ) ) new_results_1 = TestResults() new_results_1.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=10, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=5, timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(original_results, new_results_1) assert match new_results_2 = TestResults() new_results_2.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=10, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=[5], timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(original_results, new_results_2) assert not match new_results_3 = TestResults() new_results_3.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=10, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=5, timed_out=False, loop_index=1, ) ) new_results_3.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="2", ), file_name=Path("file_name"), did_pass=True, runtime=10, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=5, timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(original_results, new_results_3) assert match new_results_4 = TestResults() new_results_4.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=False, runtime=5, test_framework="unittest", test_type=TestType.EXISTING_UNIT_TEST, return_value=5, timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(original_results, new_results_4) assert not match new_results_5_baseline = TestResults() new_results_5_baseline.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=5, test_framework="unittest", test_type=TestType.GENERATED_REGRESSION, return_value=5, timed_out=False, loop_index=1, ) ) new_results_5_opt = TestResults() new_results_5_opt.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=False, runtime=5, test_framework="unittest", test_type=TestType.GENERATED_REGRESSION, return_value=5, timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt) assert not match new_results_6_baseline = TestResults() new_results_6_baseline.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=True, runtime=5, test_framework="unittest", test_type=TestType.REPLAY_TEST, return_value=5, timed_out=False, loop_index=1, ) ) new_results_6_opt = TestResults() new_results_6_opt.add( FunctionTestInvocation( id=InvocationId( test_module_path="test_module_path", test_class_name="test_class_name", test_function_name="test_function_name", function_getting_tested="function_getting_tested", iteration_id="0", ), file_name=Path("file_name"), did_pass=False, runtime=5, test_framework="unittest", test_type=TestType.REPLAY_TEST, return_value=5, timed_out=False, loop_index=1, ) ) match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt) assert not match match, _ = compare_test_results(TestResults(), TestResults()) assert not match def test_exceptions(): type_error = TypeError("This is a type error") type_error_2 = TypeError("This is a type error") assert comparator(type_error, type_error_2) def raise_exception(): raise Exception("This is an exception") def test_exceptions_comparator(): # Currently we are only comparing the exception types and the attributes that don't start with "_" # there are complications with comparing the exception messages try: raise_exception() except Exception as e: exception = e try: raise_exception() except Exception as b: exception_2 = b assert comparator(exception, exception_2) exc1 = ValueError("same message") exc2 = ValueError("same message") assert comparator(exc1, exc2) exc_msg1 = ValueError("message one") exc_msg2 = ValueError("message two") # Different messages but same types assert comparator(exc_msg1, exc_msg2) exc1 = ValueError("common message") exc2 = TypeError("common message") assert not comparator(exc1, exc2) exc_file_1 = FileNotFoundError(2, "No such file or directory") exc_file2 = FileNotFoundError(2, "No such file or directory") exc_file4 = FileNotFoundError(2, "File not found") exc_file3 = FileNotFoundError(3, "No such file or directory") assert not comparator(exc1, exc2) assert comparator(exc_file_1, exc_file2) assert comparator(exc_file_1, exc_file3) assert comparator(exc_file_1, exc_file4) assert comparator(exception, exception) assert not comparator(exception, None) assert not comparator(None, exception) assert comparator(None, None) # Different exception types exc_type1 = TypeError("Type error") exc_type2 = TypeError("Another type error") assert comparator(exc_type1, exc_type2) exc_type3 = KeyError("Missing key") exc_type4 = KeyError("Missing key") assert comparator(exc_type3, exc_type4) assert not comparator(exc_type1, exc_type3) # compare the attributes of the exception as well class CustomError(Exception): def __init__(self, message, code): super().__init__(message) self.code = code custom_exc1 = CustomError("Something went wrong", 101) custom_exc2 = CustomError("Something went wrong", 101) assert comparator(custom_exc1, custom_exc2) custom_exc4 = CustomError("Something went wrong", 102) assert not comparator(custom_exc1, custom_exc4) class CustomErrorNoArgs(Exception): pass custom_no_args1 = CustomErrorNoArgs() custom_no_args2 = CustomErrorNoArgs() assert comparator(custom_no_args1, custom_no_args2) exc_empty1 = ValueError("") exc_empty2 = ValueError("") assert comparator(exc_empty1, exc_empty2) exc_not_empty = ValueError("Not empty") assert comparator(exc_empty1, exc_not_empty) class CustomValueError(ValueError): pass custom_value_error1 = CustomValueError("A custom value error") value_error1 = ValueError("A custom value error") assert not comparator(custom_value_error1, value_error1) custom_value_error2 = CustomValueError("Another custom value error") assert comparator(custom_value_error1, custom_value_error2) class CustomExceptionWithArgs(Exception): def __init__(self, arg1, arg2): self.args = (arg1, arg2) custom_args_exc1 = CustomExceptionWithArgs(1, "test") custom_args_exc2 = CustomExceptionWithArgs(1, "test") assert comparator(custom_args_exc1, custom_args_exc2) custom_args_exc3 = CustomExceptionWithArgs(1, "different") assert comparator(custom_args_exc1, custom_args_exc3) def raise_specific_exception(): raise ZeroDivisionError("Cannot divide by zero") try: raise_specific_exception() except ZeroDivisionError as z1: zero_division_exc1 = z1 try: raise_specific_exception() except ZeroDivisionError as z2: zero_division_exc2 = z2 assert comparator(zero_division_exc1, zero_division_exc2) zero_division_exc3 = ZeroDivisionError("Different message") assert comparator(zero_division_exc1, zero_division_exc3) assert comparator(..., ...) assert comparator(Ellipsis, Ellipsis) assert not comparator(..., None) assert not comparator(Ellipsis, None) code7 = "a = 1 + 2" module7 = ast.parse(code7) for node in ast.walk(module7): for child in ast.iter_child_nodes(node): child.parent = node # type: ignore module8 = copy.deepcopy(module7) assert comparator(module7, module8) code2 = "a = 1 + 3" module2 = ast.parse(code2) assert not comparator(module7, module2) def test_torch_runtime_error_wrapping(): """Test that TorchRuntimeError wrapping is handled correctly. When torch.compile is used, exceptions are wrapped in TorchRuntimeError. The comparator should consider an IndexError equivalent to a TorchRuntimeError that wraps an IndexError. """ # Create a mock TorchRuntimeError class that mimics torch._dynamo.exc.TorchRuntimeError class TorchRuntimeError(Exception): """Mock TorchRuntimeError for testing.""" # Monkey-patch the __module__ to match torch._dynamo.exc TorchRuntimeError.__module__ = "torch._dynamo.exc" # Test 1: TorchRuntimeError with __cause__ set to the same exception type index_error = IndexError("index 0 is out of bounds for dimension 0 with size 0") torch_error = TorchRuntimeError( "Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')" ) torch_error.__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0") # These should be considered equivalent since TorchRuntimeError wraps IndexError assert comparator(index_error, torch_error) assert comparator(torch_error, index_error) # Test 2: TorchRuntimeError without __cause__ but with matching error type in message torch_error_no_cause = TorchRuntimeError( "Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')" ) assert comparator(index_error, torch_error_no_cause) assert comparator(torch_error_no_cause, index_error) # Test 3: Different exception types should not be equivalent value_error = ValueError("some value error") torch_error_index = TorchRuntimeError("got IndexError('some error')") torch_error_index.__cause__ = IndexError("some error") assert not comparator(value_error, torch_error_index) assert not comparator(torch_error_index, value_error) # Test 4: TorchRuntimeError wrapping a different type should not match type_error = TypeError("some type error") torch_error_with_index = TorchRuntimeError("got IndexError('index error')") torch_error_with_index.__cause__ = IndexError("index error") assert not comparator(type_error, torch_error_with_index) # Test 5: Two TorchRuntimeErrors wrapping the same exception type torch_error1 = TorchRuntimeError("got IndexError('error 1')") torch_error1.__cause__ = IndexError("error 1") torch_error2 = TorchRuntimeError("got IndexError('error 2')") torch_error2.__cause__ = IndexError("error 2") assert comparator(torch_error1, torch_error2) # Test 6: Regular exception comparison still works error1 = IndexError("same error") error2 = IndexError("same error") assert comparator(error1, error2) # Test 7: Exception wrapped in tuple (return value scenario from debug output) orig_return = (("tensor1", "tensor2"), {}, IndexError("index 0 is out of bounds for dimension 0 with size 0")) torch_wrapped_return = ( ("tensor1", "tensor2"), {}, TorchRuntimeError("Dynamo failed: got IndexError('index 0 is out of bounds for dimension 0 with size 0')"), ) torch_wrapped_return[2].__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0") assert comparator(orig_return, torch_wrapped_return) def test_extract_exception_from_message(): """Test the _extract_exception_from_message helper function.""" # Test with single-quoted message result = _extract_exception_from_message("got IndexError('some error message')") assert result is not None assert isinstance(result, IndexError) # Test with double-quoted message result = _extract_exception_from_message('got ValueError("another error")') assert result is not None assert isinstance(result, ValueError) # Test with various builtin exception types for exc_name, exc_class in [ ("TypeError", TypeError), ("KeyError", KeyError), ("RuntimeError", RuntimeError), ("AttributeError", AttributeError), ("ZeroDivisionError", ZeroDivisionError), ]: result = _extract_exception_from_message(f"got {exc_name}('test')") assert result is not None assert isinstance(result, exc_class) # Test with no matching pattern result = _extract_exception_from_message("This is a normal error message") assert result is None # Test with non-exception class name result = _extract_exception_from_message("got SomeRandomClass('not an exception')") assert result is None # Test with partial match (no opening quote) result = _extract_exception_from_message("got IndexError without quotes") assert result is None # Test with empty string result = _extract_exception_from_message("") assert result is None # Test with torch-like error message format result = _extract_exception_from_message( "Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds for dimension 0 with size 0')" ) assert result is not None assert isinstance(result, IndexError) def test_get_wrapped_exception(): """Test the _get_wrapped_exception helper function.""" # Test with __cause__ (explicit chaining) inner_error = ValueError("inner error") outer_error = RuntimeError("outer error") outer_error.__cause__ = inner_error result = _get_wrapped_exception(outer_error) assert result is inner_error # Test with no wrapping plain_error = ValueError("plain error") result = _get_wrapped_exception(plain_error) assert result is None # Test with message pattern error_with_pattern = RuntimeError("got TypeError('some type error')") result = _get_wrapped_exception(error_with_pattern) assert result is not None assert isinstance(result, TypeError) # Test that __cause__ takes precedence over message pattern actual_cause = IndexError("actual cause") error_with_both = RuntimeError("got TypeError('different error in message')") error_with_both.__cause__ = actual_cause result = _get_wrapped_exception(error_with_both) assert result is actual_cause assert isinstance(result, IndexError) @pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+") def test_get_wrapped_exception_exception_group(): """Test _get_wrapped_exception with ExceptionGroup (Python 3.11+).""" # ExceptionGroup with single exception inner_error = ValueError("single inner error") group = ExceptionGroup("group", [inner_error]) result = _get_wrapped_exception(group) assert result is inner_error # ExceptionGroup with multiple exceptions - should return None error1 = ValueError("error 1") error2 = TypeError("error 2") multi_group = ExceptionGroup("multi group", [error1, error2]) result = _get_wrapped_exception(multi_group) assert result is None @pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+") def test_comparator_with_exception_group(): """Test comparator with ExceptionGroup wrapping (Python 3.11+).""" # ExceptionGroup wrapping a single ValueError should match a plain ValueError inner_value_error = ValueError("some value error") group = ExceptionGroup("group", [inner_value_error]) plain_value_error = ValueError("different message but same type") assert comparator(group, plain_value_error) assert comparator(plain_value_error, group) # ExceptionGroup with different exception type should not match inner_type_error = TypeError("type error") type_group = ExceptionGroup("group", [inner_type_error]) assert not comparator(type_group, plain_value_error) # Two ExceptionGroups with same wrapped type should match group1 = ExceptionGroup("group1", [ValueError("error 1")]) group2 = ExceptionGroup("group2", [ValueError("error 2")]) assert comparator(group1, group2) def test_comparator_with_cause_chaining(): """Test comparator with __cause__ exception chaining.""" # Create an exception chain using 'raise from' inner = IndexError("inner index error") outer = RuntimeError("outer runtime error") outer.__cause__ = inner # Outer exception should match the inner exception type plain_index_error = IndexError("different index error") assert comparator(outer, plain_index_error) assert comparator(plain_index_error, outer) # Should not match a different type plain_type_error = TypeError("type error") assert not comparator(outer, plain_type_error) # Two chained exceptions with same wrapper type match (regardless of inner type) # because same-type exceptions compare non-private attributes only (__cause__ is ignored) outer1 = RuntimeError("outer 1") outer1.__cause__ = ValueError("inner 1") outer2 = RuntimeError("outer 2") outer2.__cause__ = ValueError("inner 2") assert comparator(outer1, outer2) # Different wrapper types with same inner type - unwrapping makes them match class WrapperA(Exception): pass class WrapperB(Exception): pass wrapper_a = WrapperA("wrapper a") wrapper_a.__cause__ = KeyError("same inner type") wrapper_b = WrapperB("wrapper b") wrapper_b.__cause__ = KeyError("same inner type") # Both unwrap to KeyError, so they should match assert comparator(wrapper_a, wrapper_b) # Different wrapper types with different inner types should not match wrapper_c = WrapperA("wrapper c") wrapper_c.__cause__ = ValueError("value error") wrapper_d = WrapperB("wrapper d") wrapper_d.__cause__ = TypeError("type error") assert not comparator(wrapper_c, wrapper_d) def test_comparator_with_message_pattern(): """Test comparator with exception type extracted from message pattern.""" # Exception with wrapped type in message (no __cause__) wrapper = RuntimeError("Operation failed: got IndexError('list index out of range')") plain_index = IndexError("some index error") assert comparator(wrapper, plain_index) assert comparator(plain_index, wrapper) # Should not match different types plain_key = KeyError("some key error") assert not comparator(wrapper, plain_key) def test_comparator_wrapped_exceptions_bidirectional(): """Test that wrapped exception comparison works in both directions.""" class CustomWrapper(Exception): pass # Create wrapper with __cause__ inner = AttributeError("attr error") wrapper = CustomWrapper("wrapper message") wrapper.__cause__ = inner plain_attr = AttributeError("plain attr error") # Test both directions assert comparator(wrapper, plain_attr) assert comparator(plain_attr, wrapper) # Test with superset_obj flag assert comparator(wrapper, plain_attr, superset_obj=True) assert comparator(plain_attr, wrapper, superset_obj=True) def test_comparator_same_type_exceptions_still_work(): """Ensure that same-type exception comparison still works correctly.""" exc1 = ValueError("message 1") exc2 = ValueError("message 2") assert comparator(exc1, exc2) # With custom attributes class CustomError(Exception): def __init__(self, msg, code): super().__init__(msg) self.code = code custom1 = CustomError("msg1", 100) custom2 = CustomError("msg2", 100) assert comparator(custom1, custom2) custom3 = CustomError("msg3", 200) assert not comparator(custom1, custom3) def test_comparator_no_false_positives_for_wrapped_exceptions(): """Test that unrelated exception types don't match due to wrapping logic.""" # Two completely different exception types should never match val_err = ValueError("value error") type_err = TypeError("type error") assert not comparator(val_err, type_err) # Wrapper with different inner type should not match wrapper = RuntimeError("some error") wrapper.__cause__ = KeyError("key error") assert not comparator(wrapper, val_err) assert not comparator(val_err, wrapper) def test_collections() -> None: # Deque a = deque([1, 2, 3]) b = deque([1, 2, 3]) c = deque([1, 2, 4]) d = deque([1, 2]) e = [1, 2, 3] f = deque([1, 2, 3], maxlen=5) assert comparator(a, b) assert comparator(a, f) # same elements, different maxlen is ok assert not comparator(a, c) assert not comparator(a, d) assert not comparator(a, e) g = deque([{"a": 1}, {"b": 2}]) h = deque([{"a": 1}, {"b": 2}]) i = deque([{"a": 1}, {"b": 3}]) assert comparator(g, h) assert not comparator(g, i) empty_deque1 = deque() empty_deque2 = deque() assert comparator(empty_deque1, empty_deque2) assert not comparator(empty_deque1, a) # namedtuple Point = namedtuple("Point", ["x", "y"]) a = Point(x=1, y=2) b = Point(x=1, y=2) c = Point(x=1, y=3) assert comparator(a, b) assert not comparator(a, c) Point2 = namedtuple("Point2", ["x", "y"]) d = Point2(x=1, y=2) assert not comparator(a, d) e = (1, 2) assert not comparator(a, e) # ChainMap map1 = {"a": 1, "b": 2} map2 = {"c": 3, "d": 4} a = ChainMap(map1, map2) b = ChainMap(map1, map2) c = ChainMap(map2, map1) d = {"a": 1, "b": 2, "c": 3, "d": 4} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # Counter a = Counter(["a", "b", "a", "c", "b", "a"]) b = Counter({"a": 3, "b": 2, "c": 1}) c = Counter({"a": 3, "b": 2, "c": 2}) d = {"a": 3, "b": 2, "c": 1} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # OrderedDict a = OrderedDict([("a", 1), ("b", 2)]) b = OrderedDict([("a", 1), ("b", 2)]) c = OrderedDict([("b", 2), ("a", 1)]) d = {"a": 1, "b": 2} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # defaultdict a = defaultdict(int, {"a": 1, "b": 2}) b = defaultdict(int, {"a": 1, "b": 2}) c = defaultdict(list, {"a": 1, "b": 2}) d = {"a": 1, "b": 2} e = defaultdict(int, {"a": 1, "b": 3}) assert comparator(a, b) assert comparator(a, c) assert not comparator(a, d) assert not comparator(a, e) # UserDict a = UserDict({"a": 1, "b": 2}) b = UserDict({"a": 1, "b": 2}) c = UserDict({"a": 1, "b": 3}) d = {"a": 1, "b": 2} assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # UserList a = UserList([1, 2, 3]) b = UserList([1, 2, 3]) c = UserList([1, 2, 4]) d = [1, 2, 3] assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) # UserString a = UserString("hello") b = UserString("hello") c = UserString("world") d = "hello" assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) def test_attrs(): try: import attrs # type: ignore except ImportError: pytest.skip() @attrs.define class Person: name: str age: int = 10 a = Person("Alice", 25) b = Person("Alice", 25) c = Person("Bob", 25) d = Person("Alice", 30) assert comparator(a, b) assert not comparator(a, c) assert not comparator(a, d) @attrs.frozen class Point: x: int y: int p1 = Point(1, 2) p2 = Point(1, 2) p3 = Point(2, 3) assert comparator(p1, p2) assert not comparator(p1, p3) @attrs.define(slots=True) class Vehicle: brand: str model: str year: int = 2020 v1 = Vehicle("Toyota", "Camry", 2021) v2 = Vehicle("Toyota", "Camry", 2021) v3 = Vehicle("Honda", "Civic", 2021) assert comparator(v1, v2) assert not comparator(v1, v3) @attrs.define class ComplexClass: public_field: str private_field: str = attrs.field(repr=False) non_eq_field: int = attrs.field(eq=False, default=0) computed: str = attrs.field(init=False, eq=True) def __attrs_post_init__(self): self.computed = f"{self.public_field}_{self.private_field}" c1 = ComplexClass("test", "secret") c2 = ComplexClass("test", "secret") c3 = ComplexClass("different", "secret") c1.non_eq_field = 100 c2.non_eq_field = 200 assert comparator(c1, c2) assert not comparator(c1, c3) @attrs.define class Address: street: str city: str @attrs.define class PersonWithAddress: name: str address: Address addr1 = Address("123 Main St", "Anytown") addr2 = Address("123 Main St", "Anytown") addr3 = Address("456 Oak Ave", "Anytown") person1 = PersonWithAddress("John", addr1) person2 = PersonWithAddress("John", addr2) person3 = PersonWithAddress("John", addr3) assert comparator(person1, person2) assert not comparator(person1, person3) @attrs.define class Container: items: list metadata: dict cont1 = Container([1, 2, 3], {"type": "numbers"}) cont2 = Container([1, 2, 3], {"type": "numbers"}) cont3 = Container([1, 2, 4], {"type": "numbers"}) assert comparator(cont1, cont2) assert not comparator(cont1, cont3) @attrs.define class BaseClass: name: str value: int @attrs.define class ExtendedClass: name: str value: int extra_field: str = "default" base = BaseClass("test", 42) extended = ExtendedClass("test", 42, "extra") assert not comparator(base, extended) @attrs.define class WithNonEqFields: name: str timestamp: float = attrs.field(eq=False) # Should be ignored debug_info: str = attrs.field(eq=False, default="debug") obj1 = WithNonEqFields("test", 1000.0, "info1") obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields obj3 = WithNonEqFields("different", 1000.0, "info1") assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info assert not comparator(obj1, obj3) # Should be different due to name @attrs.define class MinimalClass: name: str value: int @attrs.define class ExtendedClass: name: str value: int extra_field: str = "default" metadata: dict = attrs.field(factory=dict) timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored minimal = MinimalClass("test", 42) extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0) assert not comparator(minimal, extended) def test_dict_views() -> None: """Test comparator support for dict_keys, dict_values, and dict_items.""" # Test dict_keys d1 = {"a": 1, "b": 2, "c": 3} d2 = {"a": 1, "b": 2, "c": 3} d3 = {"a": 1, "b": 2, "d": 3} d4 = {"a": 1, "b": 2} # dict_keys - same keys assert comparator(d1.keys(), d2.keys()) # dict_keys - different keys assert not comparator(d1.keys(), d3.keys()) # dict_keys - different length assert not comparator(d1.keys(), d4.keys()) # Test dict_values v1 = {"a": 1, "b": 2, "c": 3} v2 = {"x": 1, "y": 2, "z": 3} # same values, different keys v3 = {"a": 1, "b": 2, "c": 4} # different value v4 = {"a": 1, "b": 2} # different length # dict_values - same values (order matters for values since they're iterable) assert comparator(v1.values(), v2.values()) # dict_values - different values assert not comparator(v1.values(), v3.values()) # dict_values - different length assert not comparator(v1.values(), v4.values()) # Test dict_items i1 = {"a": 1, "b": 2, "c": 3} i2 = {"a": 1, "b": 2, "c": 3} i3 = {"a": 1, "b": 2, "c": 4} # different value i4 = {"a": 1, "b": 2, "d": 3} # different key i5 = {"a": 1, "b": 2} # different length i6 = {"b": 2, "c": 3, "a": 1} # different order # dict_items - same items assert comparator(i1.items(), i2.items()) # dict_items - different value assert not comparator(i1.items(), i3.items()) # dict_items - different key assert not comparator(i1.items(), i4.items()) # dict_items - different length assert not comparator(i1.items(), i5.items()) assert comparator(i1.items(), i6.items()) # Test empty dicts empty1 = {} empty2 = {} assert comparator(empty1.keys(), empty2.keys()) assert comparator(empty1.values(), empty2.values()) assert comparator(empty1.items(), empty2.items()) # Test with nested values nested1 = {"a": [1, 2, 3], "b": {"x": 1}} nested2 = {"a": [1, 2, 3], "b": {"x": 1}} nested3 = {"a": [1, 2, 4], "b": {"x": 1}} assert comparator(nested1.values(), nested2.values()) assert not comparator(nested1.values(), nested3.values()) assert comparator(nested1.items(), nested2.items()) assert not comparator(nested1.items(), nested3.items()) # Test that dict views are not equal to lists/sets d = {"a": 1, "b": 2} assert not comparator(d.keys(), ["a", "b"]) assert not comparator(d.keys(), {"a", "b"}) assert not comparator(d.values(), [1, 2]) assert not comparator(d.items(), [("a", 1), ("b", 2)]) def test_mappingproxy() -> None: """Test comparator support for types.MappingProxyType (read-only dict view).""" import types # Basic equality mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) assert comparator(mp1, mp2) # Different values mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4}) assert not comparator(mp1, mp3) # Different keys mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3}) assert not comparator(mp1, mp4) # Different length mp5 = types.MappingProxyType({"a": 1, "b": 2}) assert not comparator(mp1, mp5) # Order doesn't matter (like dict) mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2}) assert comparator(mp1, mp6) # Empty mappingproxy empty1 = types.MappingProxyType({}) empty2 = types.MappingProxyType({}) assert comparator(empty1, empty2) # Nested values nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}}) nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}}) nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}}) assert comparator(nested1, nested2) assert not comparator(nested1, nested3) # mappingproxy is not equal to dict (different types) d = {"a": 1, "b": 2} mp = types.MappingProxyType({"a": 1, "b": 2}) assert not comparator(mp, d) assert not comparator(d, mp) # Verify class __dict__ is indeed a mappingproxy class MyClass: x = 1 y = 2 assert isinstance(MyClass.__dict__, types.MappingProxyType) def test_mappingproxy_superset() -> None: """Test comparator superset_obj support for mappingproxy.""" import types mp1 = types.MappingProxyType({"a": 1, "b": 2}) mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3}) # mp2 is a superset of mp1 assert comparator(mp1, mp2, superset_obj=True) # mp1 is not a superset of mp2 assert not comparator(mp2, mp1, superset_obj=True) # Same mappingproxy with superset_obj=True assert comparator(mp1, mp1, superset_obj=True) # Different values even with superset mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3}) assert not comparator(mp1, mp3, superset_obj=True) def test_tensorflow_tensor() -> None: """Test comparator support for TensorFlow Tensor objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test basic 1D tensors a = tf.constant([1, 2, 3]) b = tf.constant([1, 2, 3]) c = tf.constant([1, 2, 4]) assert comparator(a, b) assert not comparator(a, c) # Test 2D tensors d = tf.constant([[1, 2, 3], [4, 5, 6]]) e = tf.constant([[1, 2, 3], [4, 5, 6]]) f = tf.constant([[1, 2, 3], [4, 5, 7]]) assert comparator(d, e) assert not comparator(d, f) # Test tensors with different shapes g = tf.constant([1, 2, 3]) h = tf.constant([[1, 2, 3]]) assert not comparator(g, h) # Test tensors with different dtypes i = tf.constant([1, 2, 3], dtype=tf.float32) j = tf.constant([1, 2, 3], dtype=tf.float32) k = tf.constant([1, 2, 3], dtype=tf.int32) assert comparator(i, j) assert not comparator(i, k) # Test 3D tensors l = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) m = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) n = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 9]]]) assert comparator(l, m) assert not comparator(l, n) # Test empty tensors o = tf.constant([]) p = tf.constant([]) q = tf.constant([1.0]) assert comparator(o, p) assert not comparator(o, q) # Test tensors with NaN values r = tf.constant([1.0, float("nan"), 3.0]) s = tf.constant([1.0, float("nan"), 3.0]) t = tf.constant([1.0, 2.0, 3.0]) assert comparator(r, s) # NaN == NaN should be True assert not comparator(r, t) # Test tensors with infinity values u = tf.constant([1.0, float("inf"), 3.0]) v = tf.constant([1.0, float("inf"), 3.0]) w = tf.constant([1.0, float("-inf"), 3.0]) assert comparator(u, v) assert not comparator(u, w) # Test complex tensors x = tf.constant([1 + 2j, 3 + 4j]) y = tf.constant([1 + 2j, 3 + 4j]) z = tf.constant([1 + 2j, 3 + 5j]) assert comparator(x, y) assert not comparator(x, z) # Test boolean tensors aa = tf.constant([True, False, True]) bb = tf.constant([True, False, True]) cc = tf.constant([True, True, True]) assert comparator(aa, bb) assert not comparator(aa, cc) # Test string tensors dd = tf.constant(["hello", "world"]) ee = tf.constant(["hello", "world"]) ff = tf.constant(["hello", "there"]) assert comparator(dd, ee) assert not comparator(dd, ff) def test_tensorflow_dtype() -> None: """Test comparator support for TensorFlow DType objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test float dtypes a = tf.float32 b = tf.float32 c = tf.float64 assert comparator(a, b) assert not comparator(a, c) # Test integer dtypes d = tf.int32 e = tf.int32 f = tf.int64 assert comparator(d, e) assert not comparator(d, f) # Test unsigned integer dtypes g = tf.uint8 h = tf.uint8 i = tf.uint16 assert comparator(g, h) assert not comparator(g, i) # Test complex dtypes j = tf.complex64 k = tf.complex64 l = tf.complex128 assert comparator(j, k) assert not comparator(j, l) # Test bool dtype m = tf.bool n = tf.bool o = tf.int8 assert comparator(m, n) assert not comparator(m, o) # Test string dtype p = tf.string q = tf.string r = tf.int32 assert comparator(p, q) assert not comparator(p, r) def test_tensorflow_variable() -> None: """Test comparator support for TensorFlow Variable objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test basic variables a = tf.Variable([1, 2, 3], dtype=tf.float32) b = tf.Variable([1, 2, 3], dtype=tf.float32) c = tf.Variable([1, 2, 4], dtype=tf.float32) assert comparator(a, b) assert not comparator(a, c) # Test variables with different dtypes d = tf.Variable([1, 2, 3], dtype=tf.float32) e = tf.Variable([1, 2, 3], dtype=tf.float64) assert not comparator(d, e) # Test 2D variables f = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32) g = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32) h = tf.Variable([[1, 2], [3, 5]], dtype=tf.float32) assert comparator(f, g) assert not comparator(f, h) # Test variables with different shapes i = tf.Variable([1, 2, 3], dtype=tf.float32) j = tf.Variable([[1, 2, 3]], dtype=tf.float32) assert not comparator(i, j) def test_tensorflow_tensor_shape() -> None: """Test comparator support for TensorFlow TensorShape objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test equal shapes a = tf.TensorShape([2, 3, 4]) b = tf.TensorShape([2, 3, 4]) c = tf.TensorShape([2, 3, 5]) assert comparator(a, b) assert not comparator(a, c) # Test different ranks d = tf.TensorShape([2, 3]) e = tf.TensorShape([2, 3, 4]) assert not comparator(d, e) # Test scalar shapes f = tf.TensorShape([]) g = tf.TensorShape([]) h = tf.TensorShape([1]) assert comparator(f, g) assert not comparator(f, h) # Test shapes with None dimensions (unknown dimensions) i = tf.TensorShape([None, 3, 4]) j = tf.TensorShape([None, 3, 4]) k = tf.TensorShape([2, 3, 4]) assert comparator(i, j) assert not comparator(i, k) # Test fully unknown shapes l = tf.TensorShape(None) m = tf.TensorShape(None) n = tf.TensorShape([1, 2]) assert comparator(l, m) assert not comparator(l, n) def test_tensorflow_sparse_tensor() -> None: """Test comparator support for TensorFlow SparseTensor objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test equal sparse tensors a = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4]) b = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4]) c = tf.SparseTensor( indices=[[0, 0], [1, 2]], values=[1.0, 3.0], # Different value dense_shape=[3, 4], ) assert comparator(a, b) assert not comparator(a, c) # Test sparse tensors with different indices d = tf.SparseTensor( indices=[[0, 0], [1, 3]], # Different index values=[1.0, 2.0], dense_shape=[3, 4], ) assert not comparator(a, d) # Test sparse tensors with different shapes e = tf.SparseTensor( indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[4, 5], # Different shape ) assert not comparator(a, e) # Test empty sparse tensors f = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4]) g = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4]) assert comparator(f, g) def test_tensorflow_ragged_tensor() -> None: """Test comparator support for TensorFlow RaggedTensor objects.""" try: import tensorflow as tf except ImportError: pytest.skip("tensorflow required for this test") # Test equal ragged tensors a = tf.ragged.constant([[1, 2], [3, 4, 5], [6]]) b = tf.ragged.constant([[1, 2], [3, 4, 5], [6]]) c = tf.ragged.constant([[1, 2], [3, 4, 6], [6]]) # Different value assert comparator(a, b) assert not comparator(a, c) # Test ragged tensors with different row lengths d = tf.ragged.constant([[1, 2, 3], [4, 5], [6]]) # Different structure assert not comparator(a, d) # Test ragged tensors with different dtypes e = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]]) f = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]]) assert comparator(e, f) assert not comparator(a, e) # int vs float # Test nested ragged tensors g = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) h = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]]) i = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 7]]]) assert comparator(g, h) assert not comparator(g, i) # Test empty ragged tensors j = tf.ragged.constant([[], [], []]) k = tf.ragged.constant([[], [], []]) assert comparator(j, k) def test_slice() -> None: """Test comparator support for slice objects.""" # Test equal slices a = slice(1, 10, 2) b = slice(1, 10, 2) assert comparator(a, b) # Test slices with different start c = slice(2, 10, 2) assert not comparator(a, c) # Test slices with different stop d = slice(1, 11, 2) assert not comparator(a, d) # Test slices with different step e = slice(1, 10, 3) assert not comparator(a, e) # Test slices with None values f = slice(None, 10, 2) g = slice(None, 10, 2) h = slice(1, 10, 2) assert comparator(f, g) assert not comparator(f, h) # Test slices with all None (equivalent to [:]) i = slice(None, None, None) j = slice(None, None, None) k = slice(None, None, 1) assert comparator(i, j) assert not comparator(i, k) # Test slices with only stop l = slice(5) m = slice(5) n = slice(6) assert comparator(l, m) assert not comparator(l, n) # Test slices with negative values o = slice(-5, -1, 1) p = slice(-5, -1, 1) q = slice(-5, -2, 1) assert comparator(o, p) assert not comparator(o, q) # Test slice is not equal to other types r = slice(1, 10) s = (1, 10) assert not comparator(r, s) def test_numpy_datetime64() -> None: """Test comparator support for numpy datetime64 and timedelta64 types.""" try: import numpy as np except ImportError: pytest.skip("numpy required for this test") # Test datetime64 equality a = np.datetime64("2021-01-01") b = np.datetime64("2021-01-01") c = np.datetime64("2021-01-02") assert comparator(a, b) assert not comparator(a, c) # Test datetime64 with different units d = np.datetime64("2021-01-01", "D") e = np.datetime64("2021-01-01", "D") f = np.datetime64("2021-01-01", "s") # Different unit (seconds) assert comparator(d, e) # Note: datetime64 with different units but same moment may or may not be equal # depending on numpy version behavior # Test datetime64 with time g = np.datetime64("2021-01-01T12:00:00") h = np.datetime64("2021-01-01T12:00:00") i = np.datetime64("2021-01-01T12:00:01") assert comparator(g, h) assert not comparator(g, i) # Test timedelta64 equality j = np.timedelta64(1, "D") k = np.timedelta64(1, "D") l = np.timedelta64(2, "D") assert comparator(j, k) assert not comparator(j, l) # Test timedelta64 with different units m = np.timedelta64(1, "h") n = np.timedelta64(1, "h") o = np.timedelta64(60, "m") # Same duration, different unit assert comparator(m, n) # 1 hour == 60 minutes, but they have different units # numpy may treat them as equal or not depending on comparison # Test NaT (Not a Time) - numpy's equivalent of NaN for datetime p = np.datetime64("NaT") q = np.datetime64("NaT") r = np.datetime64("2021-01-01") assert comparator(p, q) # NaT == NaT should be True assert not comparator(p, r) # Test timedelta64 NaT s = np.timedelta64("NaT") t = np.timedelta64("NaT") u = np.timedelta64(1, "D") assert comparator(s, t) # NaT == NaT should be True assert not comparator(s, u) # Test datetime64 is not equal to other types v = np.datetime64("2021-01-01") w = "2021-01-01" assert not comparator(v, w) # Test arrays of datetime64 x = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64") y = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64") z = np.array(["2021-01-01", "2021-01-03"], dtype="datetime64") assert comparator(x, y) assert not comparator(x, z) def test_numpy_0d_array() -> None: """Test comparator handles 0-d numpy arrays without 'iteration over 0-d array' error.""" try: import numpy as np except ImportError: pytest.skip("numpy required for this test") # Test 0-d integer array a = np.array(5) b = np.array(5) c = np.array(6) assert comparator(a, b) assert not comparator(a, c) # Test 0-d float array d = np.array(3.14) e = np.array(3.14) f = np.array(2.71) assert comparator(d, e) assert not comparator(d, f) # Test 0-d complex array g = np.array(1 + 2j) h = np.array(1 + 2j) i = np.array(1 + 3j) assert comparator(g, h) assert not comparator(g, i) # Test 0-d string array j = np.array("hello") k = np.array("hello") l = np.array("world") assert comparator(j, k) assert not comparator(j, l) # Test 0-d boolean array m = np.array(True) n = np.array(True) o = np.array(False) assert comparator(m, n) assert not comparator(m, o) # Test 0-d array with NaN p = np.array(np.nan) q = np.array(np.nan) r = np.array(1.0) assert comparator(p, q) # NaN == NaN should be True assert not comparator(p, r) # Test 0-d datetime64 array s = np.array(np.datetime64("2021-01-01")) t = np.array(np.datetime64("2021-01-01")) u = np.array(np.datetime64("2021-01-02")) assert comparator(s, t) assert not comparator(s, u) # Test 0-d array vs scalar v = np.array(5) w = 5 # 0-d array and scalar are different types assert not comparator(v, w) # Test 0-d array vs 1-d array with one element x = np.array(5) y = np.array([5]) # Different shapes assert not comparator(x, y) def test_numpy_dtypes() -> None: """Test comparator for numpy.dtypes types like Float64DType, Int64DType, etc.""" try: import numpy as np from numpy import dtypes except ImportError: pytest.skip("numpy not available") # Test Float64DType a = dtypes.Float64DType() b = dtypes.Float64DType() assert comparator(a, b) # Test Int64DType c = dtypes.Int64DType() d = dtypes.Int64DType() assert comparator(c, d) # Test different DType classes should not be equal assert not comparator(a, c) # Float64DType vs Int64DType # Test various numeric DType classes assert comparator(dtypes.Int8DType(), dtypes.Int8DType()) assert comparator(dtypes.Int16DType(), dtypes.Int16DType()) assert comparator(dtypes.Int32DType(), dtypes.Int32DType()) assert comparator(dtypes.UInt8DType(), dtypes.UInt8DType()) assert comparator(dtypes.UInt16DType(), dtypes.UInt16DType()) assert comparator(dtypes.UInt32DType(), dtypes.UInt32DType()) assert comparator(dtypes.UInt64DType(), dtypes.UInt64DType()) assert comparator(dtypes.Float32DType(), dtypes.Float32DType()) assert comparator(dtypes.Complex64DType(), dtypes.Complex64DType()) assert comparator(dtypes.Complex128DType(), dtypes.Complex128DType()) assert comparator(dtypes.BoolDType(), dtypes.BoolDType()) # Test cross-type comparisons should be False assert not comparator(dtypes.Int32DType(), dtypes.Int64DType()) assert not comparator(dtypes.Float32DType(), dtypes.Float64DType()) assert not comparator(dtypes.UInt32DType(), dtypes.Int32DType()) # Test regular np.dtype instances e = np.dtype("float64") f = np.dtype("float64") assert comparator(e, f) g = np.dtype("int64") h = np.dtype("int64") assert comparator(g, h) assert not comparator(e, g) # float64 vs int64 # Test DType class instances vs regular np.dtype (they should be equal if same underlying type) assert comparator(dtypes.Float64DType(), np.dtype("float64")) assert comparator(dtypes.Int64DType(), np.dtype("int64")) assert comparator(dtypes.Int32DType(), np.dtype("int32")) assert comparator(dtypes.BoolDType(), np.dtype("bool")) # Test that DType and np.dtype of different types are not equal assert not comparator(dtypes.Float64DType(), np.dtype("int64")) assert not comparator(dtypes.Int32DType(), np.dtype("float32")) def test_numpy_extended_precision_types() -> None: """Test comparator for numpy extended precision types like clongdouble.""" try: import numpy as np except ImportError: pytest.skip("numpy not available") # Test np.clongdouble (extended precision complex) c1 = np.clongdouble(1 + 2j) c2 = np.clongdouble(1 + 2j) c3 = np.clongdouble(1 + 3j) assert comparator(c1, c2) assert not comparator(c1, c3) # Test np.longdouble (extended precision float) l1 = np.longdouble(1.5) l2 = np.longdouble(1.5) l3 = np.longdouble(2.5) assert comparator(l1, l2) assert not comparator(l1, l3) # Test NaN handling for extended precision complex nan_c1 = np.clongdouble(complex(np.nan, 2)) nan_c2 = np.clongdouble(complex(np.nan, 2)) assert comparator(nan_c1, nan_c2) # Test NaN handling for extended precision float nan_l1 = np.longdouble(np.nan) nan_l2 = np.longdouble(np.nan) assert comparator(nan_l1, nan_l2) def test_numpy_typing_types() -> None: """Test comparator for numpy.typing types like NDArray type aliases.""" try: import numpy as np import numpy.typing as npt except ImportError: pytest.skip("numpy or numpy.typing not available") # Test NDArray type alias comparisons arr_type1 = npt.NDArray[np.float64] arr_type2 = npt.NDArray[np.float64] arr_type3 = npt.NDArray[np.int64] assert comparator(arr_type1, arr_type2) assert not comparator(arr_type1, arr_type3) # Test NBitBase (if it can be instantiated) try: nbit1 = npt.NBitBase() nbit2 = npt.NBitBase() # NBitBase instances with empty __dict__ should compare as equal assert comparator(nbit1, nbit2) # Also test with superset_obj=True assert comparator(nbit1, nbit2, superset_obj=True) except TypeError: # NBitBase may not be instantiable in all numpy versions pass def test_numpy_typing_superset_obj() -> None: """Test comparator with superset_obj=True for numpy types.""" try: import numpy as np import numpy.typing as npt except ImportError: pytest.skip("numpy or numpy.typing not available") # Test numpy arrays with object dtype containing dicts (superset scenario) a1 = np.array([{"a": 1}], dtype=object) a2 = np.array([{"a": 1, "b": 2}], dtype=object) # superset assert comparator(a1, a2, superset_obj=True) assert not comparator(a1, a2, superset_obj=False) # Test extended precision types with superset_obj=True c1 = np.clongdouble(1 + 2j) c2 = np.clongdouble(1 + 2j) assert comparator(c1, c2, superset_obj=True) l1 = np.longdouble(1.5) l2 = np.longdouble(1.5) assert comparator(l1, l2, superset_obj=True) # Test NDArray type alias with superset_obj=True arr_type1 = npt.NDArray[np.float64] arr_type2 = npt.NDArray[np.float64] assert comparator(arr_type1, arr_type2, superset_obj=True) # Test numpy structured arrays (np.void) with superset_obj=True dt = np.dtype([("name", "S10"), ("age", np.int32)]) a_struct = np.array([("Alice", 25)], dtype=dt) b_struct = np.array([("Alice", 25)], dtype=dt) assert comparator(a_struct[0], b_struct[0], superset_obj=True) # Test numpy random generators with superset_obj=True rng1 = np.random.default_rng(seed=42) rng2 = np.random.default_rng(seed=42) assert comparator(rng1, rng2, superset_obj=True) rs1 = np.random.RandomState(seed=42) rs2 = np.random.RandomState(seed=42) assert comparator(rs1, rs2, superset_obj=True) def test_numba_typed_list() -> None: """Test comparator for numba.typed.List.""" try: import numba from numba.typed import List as NumbaList except ImportError: pytest.skip("numba not available") # Test equal lists a = NumbaList([1, 2, 3]) b = NumbaList([1, 2, 3]) assert comparator(a, b) # Test different values c = NumbaList([1, 2, 4]) assert not comparator(a, c) # Test different lengths d = NumbaList([1, 2, 3, 4]) assert not comparator(a, d) # Test empty lists e = NumbaList.empty_list(item_type=numba.int64) f = NumbaList.empty_list(item_type=numba.int64) assert comparator(e, f) # Test nested values (floats) g = NumbaList([1.0, 2.0, 3.0]) h = NumbaList([1.0, 2.0, 3.0]) assert comparator(g, h) i = NumbaList([1.0, 2.0, 4.0]) assert not comparator(g, i) def test_numba_typed_dict() -> None: """Test comparator for numba.typed.Dict.""" try: import numba from numba.typed import Dict as NumbaDict except ImportError: pytest.skip("numba not available") # Test equal dicts a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) a["x"] = 1 a["y"] = 2 b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) b["x"] = 1 b["y"] = 2 assert comparator(a, b) # Test different values c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) c["x"] = 1 c["y"] = 3 assert not comparator(a, c) # Test different keys d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) d["x"] = 1 d["z"] = 2 assert not comparator(a, d) # Test different lengths e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) e["x"] = 1 assert not comparator(a, e) # Test empty dicts f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) assert comparator(f, g) def test_numba_types() -> None: """Test comparator for numba type objects.""" try: import numba from numba import types except ImportError: pytest.skip("numba not available") # Test basic numeric types from numba module assert comparator(numba.int64, numba.int64) assert comparator(numba.float64, numba.float64) assert comparator(numba.int32, numba.int32) assert comparator(numba.float32, numba.float32) # Test basic numeric types from numba.types module assert comparator(types.int64, types.int64) assert comparator(types.float64, types.float64) assert comparator(types.int8, types.int8) assert comparator(types.int16, types.int16) assert comparator(types.uint8, types.uint8) assert comparator(types.uint16, types.uint16) assert comparator(types.uint32, types.uint32) assert comparator(types.uint64, types.uint64) assert comparator(types.complex64, types.complex64) assert comparator(types.complex128, types.complex128) # Test different types assert not comparator(numba.int64, numba.float64) assert not comparator(numba.int32, numba.int64) assert not comparator(numba.float32, numba.float64) assert not comparator(types.int8, types.int16) assert not comparator(types.uint32, types.int32) assert not comparator(types.complex64, types.complex128) # Test boolean type assert comparator(numba.boolean, numba.boolean) assert comparator(types.boolean, types.boolean) assert not comparator(numba.boolean, numba.int64) # Test special types assert comparator(types.none, types.none) assert comparator(types.void, types.void) assert comparator(types.pyobject, types.pyobject) assert comparator(types.unicode_type, types.unicode_type) # Note: types.none and types.void are the same object in numba assert comparator(types.none, types.void) assert not comparator(types.unicode_type, types.pyobject) assert not comparator(types.none, types.int64) # Test array types arr_type1 = types.Array(numba.float64, 1, "C") arr_type2 = types.Array(numba.float64, 1, "C") arr_type3 = types.Array(numba.float64, 2, "C") arr_type4 = types.Array(numba.int64, 1, "C") arr_type5 = types.Array(numba.float64, 1, "F") # Fortran order assert comparator(arr_type1, arr_type2) assert not comparator(arr_type1, arr_type3) # different ndim assert not comparator(arr_type1, arr_type4) # different dtype assert not comparator(arr_type1, arr_type5) # different layout # Test tuple types tuple_type1 = types.UniTuple(types.int64, 3) tuple_type2 = types.UniTuple(types.int64, 3) tuple_type3 = types.UniTuple(types.int64, 4) tuple_type4 = types.UniTuple(types.float64, 3) assert comparator(tuple_type1, tuple_type2) assert not comparator(tuple_type1, tuple_type3) # different count assert not comparator(tuple_type1, tuple_type4) # different dtype # Test heterogeneous tuple types hetero_tuple1 = types.Tuple([types.int64, types.float64]) hetero_tuple2 = types.Tuple([types.int64, types.float64]) hetero_tuple3 = types.Tuple([types.int64, types.int64]) assert comparator(hetero_tuple1, hetero_tuple2) assert not comparator(hetero_tuple1, hetero_tuple3) # Test ListType and DictType list_type1 = types.ListType(types.int64) list_type2 = types.ListType(types.int64) list_type3 = types.ListType(types.float64) assert comparator(list_type1, list_type2) assert not comparator(list_type1, list_type3) dict_type1 = types.DictType(types.unicode_type, types.int64) dict_type2 = types.DictType(types.unicode_type, types.int64) dict_type3 = types.DictType(types.unicode_type, types.float64) dict_type4 = types.DictType(types.int64, types.int64) assert comparator(dict_type1, dict_type2) assert not comparator(dict_type1, dict_type3) # different value type assert not comparator(dict_type1, dict_type4) # different key type def test_numba_jit_functions() -> None: """Test comparator for numba JIT-compiled functions.""" try: from numba import jit except ImportError: pytest.skip("numba not available") @jit(nopython=True) def add(x, y): return x + y @jit(nopython=True) def add2(x, y): return x + y @jit(nopython=True) def multiply(x, y): return x * y # Compile the functions by calling them add(1, 2) add2(1, 2) multiply(1, 2) # Same function should compare equal to itself assert comparator(add, add) # Different functions (even with same code) should not compare equal # since they are distinct function objects assert not comparator(add, add2) # Different functions with different code should not compare equal assert not comparator(add, multiply) def test_numba_superset_obj() -> None: """Test comparator for numba types with superset_obj=True.""" try: import numba from numba.typed import Dict as NumbaDict from numba.typed import List as NumbaList except ImportError: pytest.skip("numba not available") # Test NumbaDict with superset_obj=True orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) orig_dict["x"] = 1 orig_dict["y"] = 2 # New dict with same keys - should pass new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) new_dict_same["x"] = 1 new_dict_same["y"] = 2 assert comparator(orig_dict, new_dict_same, superset_obj=True) # New dict with extra keys - should pass with superset_obj=True new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) new_dict_superset["x"] = 1 new_dict_superset["y"] = 2 new_dict_superset["z"] = 3 assert comparator(orig_dict, new_dict_superset, superset_obj=True) # But should fail with superset_obj=False assert not comparator(orig_dict, new_dict_superset, superset_obj=False) # New dict missing keys - should fail even with superset_obj=True new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) new_dict_subset["x"] = 1 assert not comparator(orig_dict, new_dict_subset, superset_obj=True) # New dict with different values - should fail new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) new_dict_diff["x"] = 1 new_dict_diff["y"] = 99 assert not comparator(orig_dict, new_dict_diff, superset_obj=True) # Test NumbaList with superset_obj=True (lists don't support superset semantics) orig_list = NumbaList([1, 2, 3]) new_list_same = NumbaList([1, 2, 3]) new_list_longer = NumbaList([1, 2, 3, 4]) assert comparator(orig_list, new_list_same, superset_obj=True) # Lists must have same length regardless of superset_obj assert not comparator(orig_list, new_list_longer, superset_obj=True) # Test empty dict with superset_obj=True empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64) non_empty_new["a"] = 1 # Empty orig should match any superset assert comparator(empty_orig, non_empty_new, superset_obj=True) assert not comparator(empty_orig, non_empty_new, superset_obj=False) # ============================================================================= # Tests for pytest temp path normalization (lines 28-69 in comparator.py) # ============================================================================= class TestIsTempPath: """Tests for the _is_temp_path() function.""" def test_standard_pytest_temp_path(self): """Test detection of standard pytest temp paths.""" assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_something") assert _is_temp_path("/tmp/pytest-of-user/pytest-123/") assert _is_temp_path("/tmp/pytest-of-admin/pytest-999/subdir/file.txt") def test_different_usernames(self): """Test temp paths with various usernames.""" assert _is_temp_path("/tmp/pytest-of-root/pytest-1/") assert _is_temp_path("/tmp/pytest-of-john_doe/pytest-42/") assert _is_temp_path("/tmp/pytest-of-user123/pytest-0/test") assert _is_temp_path("/tmp/pytest-of-test-user/pytest-5/data") def test_different_session_numbers(self): """Test temp paths with various session numbers.""" assert _is_temp_path("/tmp/pytest-of-user/pytest-0/") assert _is_temp_path("/tmp/pytest-of-user/pytest-1/") assert _is_temp_path("/tmp/pytest-of-user/pytest-99/") assert _is_temp_path("/tmp/pytest-of-user/pytest-12345/") def test_paths_with_subdirectories(self): """Test temp paths with nested subdirectories.""" assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_func/subdir") assert _is_temp_path("/tmp/pytest-of-user/pytest-0/a/b/c/d/file.txt") assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_module0/test_file.py") def test_paths_with_filenames(self): """Test temp paths ending with filenames.""" assert _is_temp_path("/tmp/pytest-of-user/pytest-0/output.json") assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test.log") assert _is_temp_path("/tmp/pytest-of-user/pytest-0/data.csv") def test_non_temp_paths(self): """Test that non-temp paths are correctly identified.""" assert not _is_temp_path("/home/user/project/test.py") assert not _is_temp_path("/tmp/other/directory") assert not _is_temp_path("/var/log/test.log") assert not _is_temp_path("./relative/path") assert not _is_temp_path("test_file.py") def test_similar_but_not_temp_paths(self): """Test paths that look similar but don't match the pattern.""" assert not _is_temp_path("/tmp/pytest-user/pytest-0/") # missing "of-" assert not _is_temp_path("/tmp/pytest-of-user/pytest-/") # no number assert not _is_temp_path("/tmp/pytest-of-/pytest-0/") # empty username assert not _is_temp_path("/tmp/pytest-of-user/pytest-abc/") # non-numeric session def test_edge_cases(self): """Test edge cases for _is_temp_path.""" assert not _is_temp_path("") assert not _is_temp_path("/") assert not _is_temp_path("/tmp/") assert not _is_temp_path("/tmp/pytest-of-") def test_path_embedded_in_string(self): """Test that temp paths are detected when embedded in longer strings.""" assert _is_temp_path("Error in /tmp/pytest-of-user/pytest-0/test.py: failed") assert _is_temp_path("File: /tmp/pytest-of-user/pytest-123/output.txt") def test_windows_style_paths(self): """Test that Windows-style paths are not detected as temp paths.""" assert not _is_temp_path("C:\\Users\\test\\pytest") assert not _is_temp_path("D:\\tmp\\pytest-of-user\\pytest-0\\") class TestNormalizeTempPath: """Tests for the _normalize_temp_path() function.""" def test_basic_normalization(self): """Test basic temp path normalization.""" assert _normalize_temp_path("/tmp/pytest-of-user/pytest-0/test") == "/tmp/pytest-temp/test" assert _normalize_temp_path("/tmp/pytest-of-user/pytest-123/test") == "/tmp/pytest-temp/test" def test_different_session_numbers_normalize_same(self): """Test that different session numbers normalize to the same result.""" path1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/file.txt") path2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-99/file.txt") path3 = _normalize_temp_path("/tmp/pytest-of-user/pytest-12345/file.txt") assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt" def test_different_usernames_normalize_same(self): """Test that different usernames normalize to the same result.""" path1 = _normalize_temp_path("/tmp/pytest-of-alice/pytest-0/file.txt") path2 = _normalize_temp_path("/tmp/pytest-of-bob/pytest-0/file.txt") path3 = _normalize_temp_path("/tmp/pytest-of-root/pytest-0/file.txt") assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt" def test_complex_subdirectories(self): """Test normalization with complex subdirectory structures.""" result = _normalize_temp_path("/tmp/pytest-of-user/pytest-42/test_module/subdir/file.py") assert result == "/tmp/pytest-temp/test_module/subdir/file.py" def test_non_temp_path_unchanged(self): """Test that non-temp paths are returned unchanged.""" path = "/home/user/project/test.py" assert _normalize_temp_path(path) == path def test_empty_string(self): """Test normalization of empty string.""" assert _normalize_temp_path("") == "" def test_path_with_multiple_occurrences(self): """Test paths with multiple temp path patterns (unusual but possible in error messages).""" path = "/tmp/pytest-of-user/pytest-0/ref to /tmp/pytest-of-user/pytest-1/other" result = _normalize_temp_path(path) assert result == "/tmp/pytest-temp/ref to /tmp/pytest-temp/other" def test_trailing_slash_handling(self): """Test normalization preserves or removes trailing slashes correctly.""" result1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/") result2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/subdir/") assert result1 == "/tmp/pytest-temp/" assert result2 == "/tmp/pytest-temp/subdir/" class TestComparatorTempPaths: """Tests for comparator() with temp path strings.""" def test_identical_temp_paths(self): """Test that identical temp paths compare as equal.""" path = "/tmp/pytest-of-user/pytest-0/test.txt" assert comparator(path, path) def test_different_session_numbers(self): """Test that paths differing only in session number are equal.""" path1 = "/tmp/pytest-of-user/pytest-0/output.txt" path2 = "/tmp/pytest-of-user/pytest-99/output.txt" assert comparator(path1, path2) def test_different_usernames(self): """Test that paths differing in username are equal.""" path1 = "/tmp/pytest-of-alice/pytest-0/result.json" path2 = "/tmp/pytest-of-bob/pytest-0/result.json" assert comparator(path1, path2) def test_different_usernames_and_sessions(self): """Test that paths differing in both username and session are equal.""" path1 = "/tmp/pytest-of-alice/pytest-10/data/file.csv" path2 = "/tmp/pytest-of-bob/pytest-999/data/file.csv" assert comparator(path1, path2) def test_different_subdirectories_not_equal(self): """Test that paths with different subdirectories are not equal.""" path1 = "/tmp/pytest-of-user/pytest-0/subdir1/file.txt" path2 = "/tmp/pytest-of-user/pytest-0/subdir2/file.txt" assert not comparator(path1, path2) def test_different_filenames_not_equal(self): """Test that paths with different filenames are not equal.""" path1 = "/tmp/pytest-of-user/pytest-0/file1.txt" path2 = "/tmp/pytest-of-user/pytest-0/file2.txt" assert not comparator(path1, path2) def test_temp_path_vs_non_temp_path(self): """Test that temp paths don't match non-temp paths.""" temp_path = "/tmp/pytest-of-user/pytest-0/file.txt" non_temp_path = "/home/user/file.txt" assert not comparator(temp_path, non_temp_path) def test_regular_strings_still_work(self): """Test that regular string comparison still works.""" assert comparator("hello", "hello") assert not comparator("hello", "world") assert comparator("", "") assert not comparator("test", "") def test_non_temp_paths_must_be_exact(self): """Test that non-temp paths require exact equality.""" path1 = "/home/user/project/file.txt" path2 = "/home/user/project/file.txt" path3 = "/home/user/project/other.txt" assert comparator(path1, path2) assert not comparator(path1, path3) class TestComparatorTempPathsInNestedStructures: """Tests for comparator() with temp paths in nested data structures.""" def test_temp_paths_in_list(self): """Test temp paths inside lists.""" list1 = ["/tmp/pytest-of-alice/pytest-0/file.txt", "other"] list2 = ["/tmp/pytest-of-bob/pytest-99/file.txt", "other"] assert comparator(list1, list2) def test_temp_paths_in_tuple(self): """Test temp paths inside tuples.""" tuple1 = ("/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt") tuple2 = ("/tmp/pytest-of-user/pytest-123/a.txt", "/tmp/pytest-of-user/pytest-123/b.txt") assert comparator(tuple1, tuple2) def test_temp_paths_in_dict_values(self): """Test temp paths as dictionary values.""" dict1 = {"path": "/tmp/pytest-of-user/pytest-0/output.json", "name": "test"} dict2 = {"path": "/tmp/pytest-of-user/pytest-999/output.json", "name": "test"} assert comparator(dict1, dict2) def test_temp_paths_in_dict_keys_not_supported(self): """Test that temp paths as dictionary keys must match exactly (keys are not normalized).""" # Dict keys use direct comparison, so temp paths as keys won't be normalized # This tests the expected behavior dict1 = {"/tmp/pytest-of-user/pytest-0/key": "value"} dict2 = {"/tmp/pytest-of-user/pytest-0/key": "value"} assert comparator(dict1, dict2) def test_temp_paths_in_nested_dict(self): """Test temp paths in nested dictionaries.""" nested1 = { "config": { "output_path": "/tmp/pytest-of-alice/pytest-5/results", "log_path": "/tmp/pytest-of-alice/pytest-5/logs", } } nested2 = { "config": { "output_path": "/tmp/pytest-of-bob/pytest-10/results", "log_path": "/tmp/pytest-of-bob/pytest-10/logs", } } assert comparator(nested1, nested2) def test_temp_paths_in_deeply_nested_structure(self): """Test temp paths in deeply nested structures.""" deep1 = {"a": {"b": {"c": ["/tmp/pytest-of-user/pytest-0/file.txt"]}}} deep2 = {"a": {"b": {"c": ["/tmp/pytest-of-other/pytest-99/file.txt"]}}} assert comparator(deep1, deep2) def test_mixed_temp_and_regular_paths(self): """Test structures with both temp and regular paths.""" data1 = {"temp": "/tmp/pytest-of-user/pytest-0/temp.txt", "regular": "/home/user/file.txt"} data2 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/file.txt"} assert comparator(data1, data2) data3 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/different.txt"} assert not comparator(data1, data3) def test_temp_paths_in_deque(self): """Test temp paths inside deque.""" from collections import deque d1 = deque(["/tmp/pytest-of-user/pytest-0/file.txt"]) d2 = deque(["/tmp/pytest-of-user/pytest-123/file.txt"]) assert comparator(d1, d2) def test_temp_paths_in_chainmap(self): """Test temp paths inside ChainMap.""" from collections import ChainMap cm1 = ChainMap({"path": "/tmp/pytest-of-user/pytest-0/file.txt"}) cm2 = ChainMap({"path": "/tmp/pytest-of-user/pytest-99/file.txt"}) assert comparator(cm1, cm2) class TestComparatorTempPathsEdgeCases: """Edge case tests for temp path handling in comparator.""" def test_empty_string_vs_temp_path(self): """Test empty string comparison with temp path.""" assert not comparator("", "/tmp/pytest-of-user/pytest-0/file.txt") assert not comparator("/tmp/pytest-of-user/pytest-0/file.txt", "") def test_path_with_special_characters(self): """Test temp paths containing special characters in filenames.""" path1 = "/tmp/pytest-of-user/pytest-0/file with spaces.txt" path2 = "/tmp/pytest-of-user/pytest-99/file with spaces.txt" assert comparator(path1, path2) path3 = "/tmp/pytest-of-user/pytest-0/file-with-dashes.txt" path4 = "/tmp/pytest-of-user/pytest-99/file-with-dashes.txt" assert comparator(path3, path4) def test_path_with_unicode_characters(self): """Test temp paths with unicode characters.""" path1 = "/tmp/pytest-of-user/pytest-0/файл.txt" path2 = "/tmp/pytest-of-user/pytest-99/файл.txt" assert comparator(path1, path2) def test_very_long_session_number(self): """Test temp paths with very long session numbers.""" path1 = "/tmp/pytest-of-user/pytest-9999999999/file.txt" path2 = "/tmp/pytest-of-user/pytest-0/file.txt" assert comparator(path1, path2) def test_username_with_special_characters(self): """Test temp paths with special characters in username.""" path1 = "/tmp/pytest-of-user-name/pytest-0/file.txt" path2 = "/tmp/pytest-of-other-user/pytest-99/file.txt" assert comparator(path1, path2) def test_path_only_differs_in_temp_portion(self): """Test that only the temp portion is normalized, rest must match.""" path1 = "/tmp/pytest-of-user/pytest-0/subdir/nested/file.txt" path2 = "/tmp/pytest-of-user/pytest-99/subdir/nested/file.txt" assert comparator(path1, path2) path3 = "/tmp/pytest-of-user/pytest-0/subdir/nested/other.txt" assert not comparator(path1, path3) def test_multiple_slashes(self): """Test temp paths with multiple consecutive slashes (should still work).""" # Note: The regex handles the standard format, extra slashes may not be normalized path1 = "/tmp/pytest-of-user/pytest-0/file.txt" path2 = "/tmp/pytest-of-user/pytest-99/file.txt" assert comparator(path1, path2) def test_temp_path_at_start_middle_end(self): """Test that temp paths are detected regardless of position in string.""" # Path at start assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test") # Path in middle (embedded in error message) assert _is_temp_path("Error: /tmp/pytest-of-user/pytest-0/test failed") # Path at end assert _is_temp_path("Output saved to /tmp/pytest-of-user/pytest-0/") def test_partial_temp_path_patterns(self): """Test strings that partially match temp path pattern.""" # Missing components assert not _is_temp_path("/tmp/pytest-of-user/") assert not _is_temp_path("/tmp/pytest-0/") assert not _is_temp_path("pytest-of-user/pytest-0/") class TestPytestTempPathPatternRegex: """Tests for the PYTEST_TEMP_PATH_PATTERN regex directly.""" def test_pattern_matches_standard_format(self): """Test regex matches standard pytest temp path format.""" assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-0/") assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-123/file") def test_pattern_captures_correctly(self): """Test that the pattern substitution works correctly.""" result = PYTEST_TEMP_PATH_PATTERN.sub("REPLACED", "/tmp/pytest-of-user/pytest-0/file.txt") assert result == "REPLACEDfile.txt" def test_pattern_handles_multiple_matches(self): """Test pattern with multiple temp paths in same string.""" text = "/tmp/pytest-of-a/pytest-1/ and /tmp/pytest-of-b/pytest-2/" result = PYTEST_TEMP_PATH_PATTERN.sub("X", text) assert result == "X and X" def test_pattern_greedy_behavior(self): """Test that the pattern doesn't over-match.""" # The pattern should stop at the trailing slash of the session number path = "/tmp/pytest-of-user/pytest-0/subdir/pytest-1/file.txt" result = PYTEST_TEMP_PATH_PATTERN.sub("X", path) # The first temp path should be replaced, but "pytest-1" in subdir shouldn't trigger assert "subdir" in result class TestComparatorTempPathsWithSuperset: """Tests for temp path comparison with superset_obj=True.""" def test_superset_with_temp_paths_in_dict(self): """Test superset comparison with temp paths in dictionaries.""" orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"} new = {"path": "/tmp/pytest-of-user/pytest-99/file.txt", "extra": "data"} assert comparator(orig, new, superset_obj=True) def test_superset_temp_paths_must_still_match(self): """Test that temp paths must still be equivalent in superset mode.""" orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"} new = {"path": "/tmp/pytest-of-user/pytest-99/other.txt", "extra": "data"} assert not comparator(orig, new, superset_obj=True) def test_superset_nested_dict_with_temp_paths(self): """Test superset comparison with temp paths in nested dictionaries.""" orig = {"config": {"output": "/tmp/pytest-of-alice/pytest-5/results.json"}} new = { "config": {"output": "/tmp/pytest-of-bob/pytest-100/results.json", "debug": True}, "metadata": {"version": "1.0"}, } assert comparator(orig, new, superset_obj=True) def test_superset_multiple_temp_paths_in_dict(self): """Test superset with multiple temp paths in dictionary values.""" orig = {"input": "/tmp/pytest-of-user/pytest-0/input.txt", "output": "/tmp/pytest-of-user/pytest-0/output.txt"} new = { "input": "/tmp/pytest-of-user/pytest-99/input.txt", "output": "/tmp/pytest-of-user/pytest-99/output.txt", "log": "/tmp/pytest-of-user/pytest-99/debug.log", } assert comparator(orig, new, superset_obj=True) def test_superset_temp_path_in_list_inside_dict(self): """Test superset with temp paths in lists inside dictionaries.""" orig = {"files": ["/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt"]} new = {"files": ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"], "count": 2} assert comparator(orig, new, superset_obj=True) def test_superset_false_when_temp_path_missing(self): """Test superset fails when temp path key is missing in new.""" orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"} new = {"other": "data"} assert not comparator(orig, new, superset_obj=True) def test_superset_temp_path_with_different_filenames_fails(self): """Test superset fails when normalized temp paths have different filenames.""" orig = {"result": "/tmp/pytest-of-user/pytest-0/output_v1.json"} new = {"result": "/tmp/pytest-of-user/pytest-99/output_v2.json", "extra": "data"} assert not comparator(orig, new, superset_obj=True) def test_superset_mixed_temp_and_regular_paths(self): """Test superset with mix of temp paths and regular paths.""" orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"} new = { "temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt", "config_file": "/etc/app/config.yaml", "extra_key": "extra_value", } assert comparator(orig, new, superset_obj=True) def test_superset_regular_path_must_match_exactly(self): """Test that regular paths must match exactly even in superset mode.""" orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"} new = { "temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt", "config_file": "/etc/app/other.yaml", "extra_key": "extra_value", } assert not comparator(orig, new, superset_obj=True) def test_superset_deeply_nested_temp_paths(self): """Test superset with deeply nested structures containing temp paths.""" orig = {"level1": {"level2": {"level3": {"path": "/tmp/pytest-of-user/pytest-0/deep.txt"}}}} new = { "level1": { "level2": { "level3": {"path": "/tmp/pytest-of-other/pytest-999/deep.txt", "extra": True}, "sibling": "value", } }, "top_level_extra": 123, } assert comparator(orig, new, superset_obj=True) def test_superset_with_attrs_class_containing_temp_paths(self): """Test superset with attrs classes containing temp paths.""" try: import attr except ImportError: pytest.skip("attrs not installed") @attr.s class Config: path = attr.ib() name = attr.ib(default="default") # Test that temp paths are normalized in attrs classes orig = Config(path="/tmp/pytest-of-user/pytest-0/config.json") new = Config(path="/tmp/pytest-of-user/pytest-99/config.json") assert comparator(orig, new, superset_obj=True) # Test that different non-temp values still fail orig2 = Config(path="/tmp/pytest-of-user/pytest-0/config.json", name="name1") new2 = Config(path="/tmp/pytest-of-user/pytest-99/config.json", name="name2") assert not comparator(orig2, new2, superset_obj=True) def test_superset_with_class_dict_containing_temp_paths(self): """Test superset with regular class objects containing temp paths.""" class Result: def __init__(self, output_path): self.output_path = output_path class ResultExtended: def __init__(self, output_path, extra=None): self.output_path = output_path self.extra = extra # Note: These are different classes, so type check will fail first # Let's use the same class orig = Result("/tmp/pytest-of-user/pytest-0/result.json") new = Result("/tmp/pytest-of-user/pytest-99/result.json") # Add extra attribute to new new.extra_field = "extra_data" assert comparator(orig, new, superset_obj=True) def test_superset_list_temp_paths_must_have_same_length(self): """Test that lists with temp paths must have same length even in superset mode.""" # superset_obj doesn't apply to list lengths - they must match orig = ["/tmp/pytest-of-user/pytest-0/a.txt"] new = ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"] assert not comparator(orig, new, superset_obj=True) def test_superset_tuple_temp_paths_must_have_same_length(self): """Test that tuples with temp paths must have same length even in superset mode.""" orig = ("/tmp/pytest-of-user/pytest-0/a.txt",) new = ("/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt") assert not comparator(orig, new, superset_obj=True) def test_superset_with_exception_containing_temp_path(self): """Test superset with exception objects containing temp paths in attributes.""" class CustomError(Exception): def __init__(self, message, path): super().__init__(message) self.path = path orig = CustomError("File error", "/tmp/pytest-of-user/pytest-0/file.txt") new = CustomError("File error", "/tmp/pytest-of-user/pytest-99/file.txt") new.extra_info = "additional data" assert comparator(orig, new, superset_obj=True) class TestComparatorTempPathsRealisticScenarios: """Tests simulating realistic scenarios where temp path comparison matters.""" def test_test_output_comparison(self): """Simulate comparing test outputs that contain temp paths.""" original_result = { "status": "success", "output_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/results.json", "log_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/debug.log", } replay_result = { "status": "success", "output_file": "/tmp/pytest-of-local-user/pytest-0/test_output/results.json", "log_file": "/tmp/pytest-of-local-user/pytest-0/test_output/debug.log", } assert comparator(original_result, replay_result) def test_exception_message_with_temp_path(self): """Test comparing exception-like structures with temp paths.""" exc1 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-0/missing.txt"} exc2 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-99/missing.txt"} assert comparator(exc1, exc2) def test_function_return_with_temp_path(self): """Test comparing function returns that include temp paths.""" # Simulating a function that returns a created file path return1 = "/tmp/pytest-of-user/pytest-5/generated_file_abc123.txt" return2 = "/tmp/pytest-of-user/pytest-10/generated_file_abc123.txt" assert comparator(return1, return2) def test_list_of_created_files(self): """Test comparing lists of created file paths.""" files1 = [ "/tmp/pytest-of-user/pytest-0/output/file1.txt", "/tmp/pytest-of-user/pytest-0/output/file2.txt", "/tmp/pytest-of-user/pytest-0/output/file3.txt", ] files2 = [ "/tmp/pytest-of-user/pytest-99/output/file1.txt", "/tmp/pytest-of-user/pytest-99/output/file2.txt", "/tmp/pytest-of-user/pytest-99/output/file3.txt", ] assert comparator(files1, files2) def test_config_object_with_paths(self): """Test comparing config-like objects with multiple paths.""" config1 = { "temp_dir": "/tmp/pytest-of-user/pytest-0/", "cache_dir": "/tmp/pytest-of-user/pytest-0/cache/", "output_dir": "/tmp/pytest-of-user/pytest-0/output/", "permanent_dir": "/home/user/data/", } config2 = { "temp_dir": "/tmp/pytest-of-other/pytest-100/", "cache_dir": "/tmp/pytest-of-other/pytest-100/cache/", "output_dir": "/tmp/pytest-of-other/pytest-100/output/", "permanent_dir": "/home/user/data/", } assert comparator(config1, config2) class TestPythonTempfilePaths: """Tests for Python tempfile paths (from tempfile.mkdtemp() or TemporaryDirectory()).""" def test_is_temp_path_detects_python_tempfile(self): """Test that _is_temp_path detects Python tempfile paths.""" assert _is_temp_path("/tmp/tmpqtwy7hpf/special.txt") assert _is_temp_path("/tmp/tmpp6wx3tz3/special.txt") assert _is_temp_path("/tmp/tmpabcdef12/") assert _is_temp_path("/tmp/tmp_underscore/file.txt") def test_is_temp_path_various_tempfile_names(self): """Test various tempfile naming patterns.""" assert _is_temp_path("/tmp/tmpABCDEF/file.txt") # uppercase assert _is_temp_path("/tmp/tmp123456/file.txt") # numeric assert _is_temp_path("/tmp/tmpaBc123/file.txt") # mixed assert _is_temp_path("/tmp/tmp_test_dir/subdir/file.txt") # with underscore def test_is_temp_path_non_tempfile(self): """Test that non-tempfile paths are not detected.""" assert not _is_temp_path("/tmp/mydir/file.txt") # doesn't start with tmp assert not _is_temp_path("/tmp/temp/file.txt") # temp, not tmp assert not _is_temp_path("/home/user/tmp123/file.txt") # not in /tmp/ def test_normalize_temp_path_python_tempfile(self): """Test normalization of Python tempfile paths.""" path1 = _normalize_temp_path("/tmp/tmpqtwy7hpf/special.txt") path2 = _normalize_temp_path("/tmp/tmpp6wx3tz3/special.txt") assert path1 == path2 == "/tmp/python-temp/special.txt" def test_normalize_temp_path_preserves_subdirs(self): """Test that subdirectories are preserved during normalization.""" result = _normalize_temp_path("/tmp/tmpabcdef12/subdir/nested/file.txt") assert result == "/tmp/python-temp/subdir/nested/file.txt" def test_comparator_python_tempfile_paths_equal(self): """Test that different tempfile paths with same content are equal.""" path1 = "/tmp/tmpqtwy7hpf/special.txt" path2 = "/tmp/tmpp6wx3tz3/special.txt" assert comparator(path1, path2) def test_comparator_python_tempfile_different_filenames_not_equal(self): """Test that different filenames in tempfile paths are not equal.""" path1 = "/tmp/tmpqtwy7hpf/special.txt" path2 = "/tmp/tmpp6wx3tz3/different.txt" assert not comparator(path1, path2) def test_comparator_python_tempfile_in_tuple(self): """Test tempfile paths in tuples.""" orig = ("/tmp/tmpqtwy7hpf/special.txt",) new = ("/tmp/tmpp6wx3tz3/special.txt",) assert comparator(orig, new) def test_comparator_python_tempfile_in_list(self): """Test tempfile paths in lists.""" orig = ["/tmp/tmpabcdef12/file1.txt", "/tmp/tmpabcdef12/file2.txt"] new = ["/tmp/tmpxyz78901/file1.txt", "/tmp/tmpxyz78901/file2.txt"] assert comparator(orig, new) def test_comparator_python_tempfile_in_dict(self): """Test tempfile paths in dictionaries.""" orig = {"output": "/tmp/tmpabcdef12/result.json"} new = {"output": "/tmp/tmpxyz78901/result.json"} assert comparator(orig, new) def test_comparator_mixed_pytest_and_python_tempfile(self): """Test that pytest and Python tempfile paths don't match each other.""" pytest_path = "/tmp/pytest-of-user/pytest-0/file.txt" python_path = "/tmp/tmpabcdef12/file.txt" # These should not be equal - they're different temp path types assert not comparator(pytest_path, python_path) def test_python_tempfile_pattern_regex(self): """Test the PYTHON_TEMPFILE_PATTERN regex directly.""" assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmpabcdef/file.txt") assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/") assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt") assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt") @pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+") class TestUnionType: def test_union_type_equal(self): assert comparator(int | str, int | str) def test_union_type_not_equal(self): assert not comparator(int | str, int | float) def test_union_type_order_independent(self): assert comparator(int | str, str | int) def test_union_type_multiple_args(self): assert comparator(int | str | float, int | str | float) def test_union_type_in_list(self): assert comparator([int | str, 1], [int | str, 1]) def test_union_type_in_dict(self): assert comparator({"key": int | str}, {"key": int | str}) def test_union_type_vs_none(self): assert not comparator(int | str, None) class SlotsOnly: __slots__ = ("x", "y") def __init__(self, x, y): self.x = x self.y = y class SlotsInherited(SlotsOnly): __slots__ = ("z",) def __init__(self, x, y, z): super().__init__(x, y) self.z = z class TestSlotsObjects: def test_slots_equal(self): assert comparator(SlotsOnly(1, 2), SlotsOnly(1, 2)) def test_slots_not_equal(self): assert not comparator(SlotsOnly(1, 2), SlotsOnly(1, 3)) def test_slots_inherited_equal(self): assert comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 3)) def test_slots_inherited_not_equal(self): assert not comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 4)) def test_slots_nested(self): a = SlotsOnly(SlotsOnly(1, 2), [3, 4]) b = SlotsOnly(SlotsOnly(1, 2), [3, 4]) assert comparator(a, b) def test_slots_nested_not_equal(self): a = SlotsOnly(SlotsOnly(1, 2), [3, 4]) b = SlotsOnly(SlotsOnly(1, 9), [3, 4]) assert not comparator(a, b)