codeflash/tests/test_comparator.py
Kevin Turcios 19bd6e4bad test: sync test files from main (safe, main-only changes)
34 test files updated with main's refactored tests for new language
support protocol, JS/TS improvements, and code context extraction.
2026-03-02 15:25:50 -05:00

5588 lines
178 KiB
Python

import array # Add import for array
import ast
import copy
import dataclasses
import datetime
import decimal
import re
import sys
import uuid
import weakref
from collections import ChainMap, Counter, OrderedDict, UserDict, UserList, UserString, defaultdict, deque, namedtuple
from enum import Enum, Flag, IntFlag, auto
from pathlib import Path
import pydantic
import pytest
from codeflash.either import Failure, Success
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
from codeflash.verification.comparator import (
PYTEST_TEMP_PATH_PATTERN,
PYTHON_TEMPFILE_PATTERN,
_extract_exception_from_message,
_get_wrapped_exception,
_is_temp_path,
_normalize_temp_path,
comparator,
)
from codeflash.verification.equivalence import compare_test_results
def test_basic_python_objects() -> None:
a = 5
b = 5
c = 6
d = None
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = 5.0
b = 5.0
c = 6.0
d = None
e = None
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(d, a)
assert comparator(d, e)
a = "Hello"
b = "Hello"
c = "World"
assert comparator(a, b)
assert not comparator(a, c)
a = [1, 2, 3]
b = [1, 2, 3]
c = [1, 2, 4]
assert comparator(a, b)
assert not comparator(a, c)
a = {"a": 1, "b": 2}
b = {"a": 1, "b": 2}
c = {"a": 1, "b": 3}
d = {"c": 1, "b": 2}
e = {"a": 1, "b": 2, "c": 3}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
a = (1, 2, "str")
b = (1, 2, "str")
c = (1, 2, "str2")
d = [1, 2, "str"]
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = {1, 2, 3}
b = {2, 3, 1}
c = {1, 2, 4}
d = {1, 2, 3, 4}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = (65).to_bytes(1, byteorder="big")
b = (65).to_bytes(1, byteorder="big")
c = (66).to_bytes(1, byteorder="big")
assert comparator(a, b)
assert not comparator(a, c)
a = (65).to_bytes(2, byteorder="little")
b = (65).to_bytes(2, byteorder="big")
assert not comparator(a, b)
a = bytearray([65, 64, 63])
b = bytearray([65, 64, 63])
c = bytearray([65, 64, 62])
assert comparator(a, b)
assert not comparator(a, c)
memoryview_a = memoryview(bytearray([65, 64, 63]))
memoryview_b = memoryview(bytearray([65, 64, 63]))
memoryview_c = memoryview(bytearray([65, 64, 62]))
assert comparator(memoryview_a, memoryview_b)
assert not comparator(memoryview_a, memoryview_c)
a = frozenset([1, 2, 3])
b = frozenset([2, 3, 1])
c = frozenset([1, 2, 4])
d = frozenset([1, 2, 3, 4])
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = map
b = pow
c = pow
d = abs
assert comparator(b, c)
assert not comparator(a, b)
assert not comparator(c, d)
a = object()
b = object()
c = abs
assert comparator(a, b)
assert not comparator(a, c)
a = type([])
b = type([])
c = type({})
assert comparator(a, b)
assert not comparator(a, c)
def test_weakref() -> None:
"""Test comparator for weakref.ref objects."""
# Helper class that supports weak references and has comparable __dict__
class Holder:
def __init__(self, value):
self.value = value
# Test weak references to the same object
obj = Holder([1, 2, 3])
ref1 = weakref.ref(obj)
ref2 = weakref.ref(obj)
assert comparator(ref1, ref2)
# Test weak references to equivalent but different objects
obj1 = Holder({"key": "value"})
obj2 = Holder({"key": "value"})
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert comparator(ref1, ref2)
# Test weak references to different objects
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 4])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert not comparator(ref1, ref2)
# Test weak references with different data
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 3, 4])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert not comparator(ref1, ref2)
# Test dead weak references (both dead)
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 3])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
del obj1
del obj2
# Both refs are now dead, should be equal
assert comparator(ref1, ref2)
# Test one dead, one alive weak reference
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 3])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
del obj1
# ref1 is dead, ref2 is alive, should not be equal
assert not comparator(ref1, ref2)
assert not comparator(ref2, ref1)
# Test weak references to nested structures
obj1 = Holder({"nested": [1, 2, {"inner": "value"}]})
obj2 = Holder({"nested": [1, 2, {"inner": "value"}]})
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert comparator(ref1, ref2)
# Test weak references to nested structures with differences
obj1 = Holder({"nested": [1, 2, {"inner": "value1"}]})
obj2 = Holder({"nested": [1, 2, {"inner": "value2"}]})
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert not comparator(ref1, ref2)
# Test weak references in a dictionary (simulating __dict__ with weakrefs)
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 3])
dict1 = {"data": 42, "ref": weakref.ref(obj1)}
dict2 = {"data": 42, "ref": weakref.ref(obj2)}
assert comparator(dict1, dict2)
# Test weak references in a dictionary with different referents
obj1 = Holder([1, 2, 3])
obj2 = Holder([4, 5, 6])
dict1 = {"data": 42, "ref": weakref.ref(obj1)}
dict2 = {"data": 42, "ref": weakref.ref(obj2)}
assert not comparator(dict1, dict2)
# Test weak references in a list
obj1 = Holder({"a": 1})
obj2 = Holder({"a": 1})
list1 = [weakref.ref(obj1), "other"]
list2 = [weakref.ref(obj2), "other"]
assert comparator(list1, list2)
def test_weakref_to_custom_objects() -> None:
"""Test comparator for weakref.ref to custom class instances."""
class MyClass:
def __init__(self, value):
self.value = value
# Test weak references to equivalent custom objects
obj1 = MyClass(42)
obj2 = MyClass(42)
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert comparator(ref1, ref2)
# Test weak references to different custom objects
obj1 = MyClass(42)
obj2 = MyClass(99)
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert not comparator(ref1, ref2)
# Test weak references to custom objects with nested data
class Container:
def __init__(self, items):
self.items = items
obj1 = Container([1, 2, 3])
obj2 = Container([1, 2, 3])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert comparator(ref1, ref2)
obj1 = Container([1, 2, 3])
obj2 = Container([1, 2, 4])
ref1 = weakref.ref(obj1)
ref2 = weakref.ref(obj2)
assert not comparator(ref1, ref2)
def test_weakref_with_callbacks() -> None:
"""Test that weakrefs with callbacks are compared correctly."""
class Holder:
def __init__(self, value):
self.value = value
callback_called = []
def callback(ref):
callback_called.append(ref)
obj1 = Holder([1, 2, 3])
obj2 = Holder([1, 2, 3])
# Weakrefs with callbacks should still compare based on referents
ref1 = weakref.ref(obj1, callback)
ref2 = weakref.ref(obj2, callback)
assert comparator(ref1, ref2)
obj1 = Holder([1, 2, 3])
obj2 = Holder([4, 5, 6])
ref1 = weakref.ref(obj1, callback)
ref2 = weakref.ref(obj2, callback)
assert not comparator(ref1, ref2)
@pytest.mark.parametrize(
"r1, r2, expected",
[
(range(1, 10), range(1, 10), True), # equal
(range(10), range(1, 10), False), # different start
(range(2, 10), range(1, 10), False),
(range(1, 5), range(1, 10), False), # different stop
(range(1, 20), range(1, 10), False),
(range(1, 10, 1), range(1, 10, 2), False), # different step
(range(1, 10, 3), range(1, 10, 2), False),
(range(-5, 0), range(-5, 0), True), # negative ranges
(range(-10, 0), range(-5, 0), False),
(range(5, 1), range(10, 5), True), # empty ranges
(range(5, 1), range(5, 1), True),
(range(7), range(7), True),
(range(7), range(0, 7, 1), True),
(range(7), range(0, 7, 1), True),
],
)
def test_ranges(r1, r2, expected):
assert comparator(r1, r2) == expected
def test_standard_python_library_objects() -> None:
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.date(2020, 2, 2) # type: ignore
b = datetime.date(2020, 2, 2) # type: ignore
c = datetime.date(2020, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timedelta(days=1) # type: ignore
b = datetime.timedelta(days=1) # type: ignore
c = datetime.timedelta(days=2) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.time(2, 2, 2) # type: ignore
b = datetime.time(2, 2, 2) # type: ignore
c = datetime.time(2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timezone.utc # type: ignore
b = datetime.timezone.utc # type: ignore
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = decimal.Decimal(3.14) # type: ignore
b = decimal.Decimal(3.14) # type: ignore
c = decimal.Decimal(3.15) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
class Color(Flag):
RED = auto()
GREEN = auto()
BLUE = auto()
class Color2(Enum):
RED = auto()
GREEN = auto()
BLUE = auto()
a = Color.RED # type: ignore
b = Color.RED # type: ignore
c = Color.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = Color2.RED # type: ignore
b = Color2.RED # type: ignore
c = Color2.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
class Color4(IntFlag):
RED = auto()
GREEN = auto()
BLUE = auto()
a = Color4.RED # type: ignore
b = Color4.RED # type: ignore
c = Color4.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a: re.Pattern = re.compile("a")
b: re.Pattern = re.compile("a")
c: re.Pattern = re.compile("b")
d: re.Pattern = re.compile("a", re.IGNORECASE)
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
arr1 = array.array("i", [1, 2, 3])
arr2 = array.array("i", [1, 2, 3])
arr3 = array.array("i", [4, 5, 6])
arr4 = array.array("f", [1.0, 2.0, 3.0])
assert comparator(arr1, arr2)
assert not comparator(arr1, arr3)
assert not comparator(arr1, arr4)
assert not comparator(arr1, [1, 2, 3])
empty_arr_i1 = array.array("i")
empty_arr_i2 = array.array("i")
empty_arr_f = array.array("f")
assert comparator(empty_arr_i1, empty_arr_i2)
assert not comparator(empty_arr_i1, empty_arr_f)
assert not comparator(empty_arr_i1, arr1)
id1 = uuid.uuid4()
id3 = uuid.uuid4()
assert comparator(id1, id1)
assert not comparator(id1, id3)
def test_itertools_count() -> None:
import itertools
# Equal: same start and step (default step=1)
assert comparator(itertools.count(0), itertools.count(0))
assert comparator(itertools.count(5), itertools.count(5))
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
assert comparator(itertools.count(10, 3), itertools.count(10, 3))
# Equal: negative start and step
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))
# Equal: float start and step
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))
# Not equal: different start
assert not comparator(itertools.count(0), itertools.count(1))
assert not comparator(itertools.count(5), itertools.count(10))
# Not equal: different step
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))
# Not equal: different type
assert not comparator(itertools.count(0), 0)
assert not comparator(itertools.count(0), [0, 1, 2])
# Equal after partial consumption (both advanced to the same state)
a = itertools.count(0)
b = itertools.count(0)
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.count(0)
b = itertools.count(0)
next(a)
assert not comparator(a, b)
# Works inside containers
assert comparator([itertools.count(0)], [itertools.count(0)])
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
assert not comparator([itertools.count(0)], [itertools.count(1)])
def test_itertools_repeat() -> None:
import itertools
# Equal: infinite repeat
assert comparator(itertools.repeat(5), itertools.repeat(5))
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))
# Equal: bounded repeat
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))
# Not equal: different value
assert not comparator(itertools.repeat(5), itertools.repeat(6))
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))
# Not equal: different count
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))
# Not equal: bounded vs infinite
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))
# Not equal: different type
assert not comparator(itertools.repeat(5), 5)
assert not comparator(itertools.repeat(5), [5])
# Equal after partial consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.repeat(5, 5)
b = itertools.repeat(5, 5)
next(a)
assert not comparator(a, b)
# Works inside containers
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])
def test_itertools_cycle() -> None:
import itertools
# Equal: same sequence
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))
# Not equal: different sequence
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))
# Not equal: different type
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])
# Equal after same partial consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
next(b)
assert comparator(a, b)
# Not equal after different consumption
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
next(a)
assert not comparator(a, b)
# Equal after consuming a full cycle
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(3):
next(a)
next(b)
assert comparator(a, b)
# Equal at same position across different full-cycle counts
a = itertools.cycle([1, 2, 3])
b = itertools.cycle([1, 2, 3])
for _ in range(4):
next(a)
for _ in range(7):
next(b)
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
assert comparator(a, b)
# Works inside containers
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
def test_itertools_chain() -> None:
import itertools
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
assert comparator(itertools.chain(), itertools.chain())
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
def test_itertools_islice() -> None:
import itertools
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
def test_itertools_product() -> None:
import itertools
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
def test_itertools_permutations_combinations() -> None:
import itertools
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
assert comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABC", 2),
)
assert not comparator(
itertools.combinations_with_replacement("ABC", 2),
itertools.combinations_with_replacement("ABD", 2),
)
def test_itertools_accumulate() -> None:
import itertools
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
def test_itertools_filtering() -> None:
import itertools
# compress
assert comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
)
assert not comparator(
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]),
itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1]),
)
# dropwhile
assert comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1]),
)
# takewhile
assert comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
)
assert not comparator(
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]),
itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1]),
)
# filterfalse
assert comparator(
itertools.filterfalse(lambda x: x % 2, range(10)),
itertools.filterfalse(lambda x: x % 2, range(10)),
)
def test_itertools_starmap() -> None:
import itertools
assert comparator(
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]),
)
assert not comparator(
itertools.starmap(pow, [(2, 3), (3, 2)]),
itertools.starmap(pow, [(2, 3), (3, 3)]),
)
def test_itertools_zip_longest() -> None:
import itertools
assert comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="-"),
)
assert not comparator(
itertools.zip_longest("AB", "xyz", fillvalue="-"),
itertools.zip_longest("AB", "xyz", fillvalue="*"),
)
def test_itertools_groupby() -> None:
import itertools
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
assert comparator(itertools.groupby([]), itertools.groupby([]))
# With key function
assert comparator(
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x),
)
@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
def test_itertools_pairwise() -> None:
import itertools
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
def test_itertools_batched() -> None:
import itertools
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
def test_itertools_in_containers() -> None:
import itertools
# Itertools objects nested in dicts/lists
assert comparator(
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
)
assert not comparator(
[itertools.product("AB", repeat=2)],
[itertools.product("AC", repeat=2)],
)
# Different itertools types should not match
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
def test_numpy():
try:
import numpy as np
except ImportError:
pytest.skip()
a = np.array([1, 2, 3])
b = np.array([1, 2, 3])
c = np.array([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
d = np.array([[1, 2], [3, 4]])
e = np.array([[1, 2], [3, 4]])
f = np.array([[1, 2], [3, 5]])
assert comparator(d, e)
assert not comparator(d, f)
assert not comparator(a, d)
g = np.array([1.0, 2.0, 3.0])
assert not comparator(a, g)
h = np.float32(1.0)
i = np.float32(1.0)
assert comparator(h, i)
j = np.float64(1.0)
k = np.float64(1.0)
assert not comparator(h, j)
assert comparator(j, k)
l = np.int32(1)
m = np.int32(1)
assert comparator(l, m)
assert not comparator(l, h)
assert not comparator(l, j)
n = np.int64(1)
o = np.int64(1)
assert not comparator(n, l)
assert comparator(n, o)
p = np.uint32(1)
q = np.uint32(1)
assert comparator(p, q)
assert not comparator(p, l)
r = np.uint64(1)
s = np.uint64(1)
assert not comparator(r, p)
assert comparator(r, s)
t = np.bool_(True)
u = np.bool_(True)
assert comparator(t, u)
assert not comparator(t, r)
v = np.complex64(1.0 + 1.0j)
w = np.complex64(1.0 + 1.0j)
assert comparator(v, w)
assert not comparator(v, t)
x = np.complex128(1.0 + 1.0j)
y = np.complex128(1.0 + 1.0j)
assert not comparator(x, v)
assert comparator(x, y)
# Create numpy array with mixed type object
z = np.array([1, 2, "str"], dtype=np.object_)
aa = np.array([1, 2, "str"], dtype=np.object_)
ab = np.array([1, 2, "str2"], dtype=np.object_)
assert comparator(z, aa)
assert not comparator(z, ab)
ac = np.array([1, 2, "str2"])
ad = np.array([1, 2, "str2"])
assert comparator(ac, ad)
# Test for numpy array with nan and inf
ae = np.array([1, 2, np.nan])
af = np.array([1, 2, np.nan])
ag = np.array([1, 2, np.inf])
ah = np.array([1, 2, np.inf])
ai = np.inf
aj = np.inf
ak = np.nan
al = np.nan
assert comparator(ae, af)
assert comparator(ag, ah)
assert not comparator(ae, ag)
assert not comparator(af, ah)
assert comparator(ai, aj)
assert comparator(ak, al)
assert not comparator(ai, ak)
dt = np.dtype([("name", "S10"), ("age", np.int32)])
a_struct = np.array([("Alice", 25)], dtype=dt)
b_struct = np.array([("Alice", 25)], dtype=dt)
c_struct = np.array([("Bob", 30)], dtype=dt)
a_void = a_struct[0]
b_void = b_struct[0]
c_void = c_struct[0]
assert isinstance(a_void, np.void)
assert comparator(a_void, b_void)
assert not comparator(a_void, c_void)
def test_numpy_random_generator():
try:
import numpy as np
except ImportError:
pytest.skip()
# Test numpy.random.Generator (modern API)
# Same seed should produce equal generators
rng1 = np.random.default_rng(seed=42)
rng2 = np.random.default_rng(seed=42)
assert comparator(rng1, rng2)
# Different seeds should produce non-equal generators
rng3 = np.random.default_rng(seed=123)
assert not comparator(rng1, rng3)
# After generating numbers, state changes
rng4 = np.random.default_rng(seed=42)
rng5 = np.random.default_rng(seed=42)
rng4.random() # Advance state
assert not comparator(rng4, rng5)
# Both advanced by same amount should be equal
rng5.random()
assert comparator(rng4, rng5)
# Test with different bit generators
from numpy.random import MT19937, PCG64
rng_pcg1 = np.random.Generator(PCG64(seed=42))
rng_pcg2 = np.random.Generator(PCG64(seed=42))
assert comparator(rng_pcg1, rng_pcg2)
rng_mt1 = np.random.Generator(MT19937(seed=42))
rng_mt2 = np.random.Generator(MT19937(seed=42))
assert comparator(rng_mt1, rng_mt2)
# Different bit generator types should not be equal
assert not comparator(rng_pcg1, rng_mt1)
def test_numpy_random_state():
try:
import numpy as np
except ImportError:
pytest.skip()
# Test numpy.random.RandomState (legacy API)
# Same seed should produce equal states
rs1 = np.random.RandomState(seed=42)
rs2 = np.random.RandomState(seed=42)
assert comparator(rs1, rs2)
# Different seeds should produce non-equal states
rs3 = np.random.RandomState(seed=123)
assert not comparator(rs1, rs3)
# After generating numbers, state changes
rs4 = np.random.RandomState(seed=42)
rs5 = np.random.RandomState(seed=42)
rs4.random() # Advance state
assert not comparator(rs4, rs5)
# Both advanced by same amount should be equal
rs5.random()
assert comparator(rs4, rs5)
# Test state restoration
rs6 = np.random.RandomState(seed=42)
state = rs6.get_state()
rs6.random() # Advance state
rs7 = np.random.RandomState(seed=42)
rs7.set_state(state)
# rs6 advanced, rs7 restored to original state
assert not comparator(rs6, rs7)
def test_scipy():
try:
import scipy as sp # type: ignore
except ImportError:
pytest.skip()
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
b = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
c = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
ca = sp.sparse.csr_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(c, ca)
d = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
e = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
f = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
fa = sp.sparse.csc_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
assert comparator(d, e)
assert not comparator(d, f)
assert not comparator(a, d)
assert not comparator(c, f)
assert not comparator(f, fa)
g = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
h = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
i = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(g, h)
assert not comparator(g, i)
assert not comparator(a, g)
j = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
k = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
l = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(j, k)
assert not comparator(j, l)
assert not comparator(a, j)
m = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
n = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
o = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(m, n)
assert not comparator(m, o)
assert not comparator(a, m)
p = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
q = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
r = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(p, q)
assert not comparator(p, r)
assert not comparator(a, p)
s = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
t = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
u = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(s, t)
assert not comparator(s, u)
assert not comparator(a, s)
try:
import numpy as np
row = np.array([0, 3, 1, 0])
col = np.array([0, 3, 1, 2])
data = np.array([4, 5, 7, 9])
v = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
w = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
assert comparator(v, w)
except ImportError:
print("Should run tests with numpy installed to test more thoroughly")
def test_pandas():
try:
import pandas as pd
except ImportError:
pytest.skip()
a = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
b = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
c = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 7]})
ca = pd.DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(c, ca)
ak = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
al = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
am = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]}
)
assert comparator(ak, al)
assert not comparator(ak, am)
d = pd.Series([1, 2, 3])
e = pd.Series([1, 2, 3])
f = pd.Series([1, 2, 4])
assert comparator(d, e)
assert not comparator(d, f)
g = pd.Index([1, 2, 3])
h = pd.Index([1, 2, 3])
i = pd.Index([1, 2, 4])
assert comparator(g, h)
assert not comparator(g, i)
j = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
k = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
l = pd.MultiIndex.from_tuples([(1, 2), (3, 5)])
assert comparator(j, k)
assert not comparator(j, l)
m = pd.Categorical([1, 2, 3])
n = pd.Categorical([1, 2, 3])
o = pd.Categorical([1, 2, 4])
assert comparator(m, n)
assert not comparator(m, o)
p = pd.Interval(1, 2)
q = pd.Interval(1, 2)
r = pd.Interval(1, 3)
assert comparator(p, q)
assert not comparator(p, r)
s = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
t = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
u = pd.IntervalIndex.from_tuples([(1, 2), (3, 5)])
assert comparator(s, t)
assert not comparator(s, u)
v = pd.Period("2021-01")
w = pd.Period("2021-01")
x = pd.Period("2021-02")
assert comparator(v, w)
assert not comparator(v, x)
y = pd.period_range(start="2021-01", periods=3, freq="M")
z = pd.period_range(start="2021-01", periods=3, freq="M")
aa = pd.period_range(start="2021-01", periods=4, freq="M")
assert comparator(y, z)
assert not comparator(y, aa)
ab = pd.Timedelta("1 days")
ac = pd.Timedelta("1 days")
ad = pd.Timedelta("2 days")
assert comparator(ab, ac)
assert not comparator(ab, ad)
ae = pd.TimedeltaIndex(["1 days", "2 days"])
af = pd.TimedeltaIndex(["1 days", "2 days"])
ag = pd.TimedeltaIndex(["1 days", "3 days"])
assert comparator(ae, af)
assert not comparator(ae, ag)
ah = pd.Timestamp("2021-01-01")
ai = pd.Timestamp("2021-01-01")
aj = pd.Timestamp("2021-01-02")
assert comparator(ah, ai)
assert not comparator(ah, aj)
# test cases for sparse pandas arrays
an = pd.arrays.SparseArray([1, 2, 3])
ao = pd.arrays.SparseArray([1, 2, 3])
ap = pd.arrays.SparseArray([1, 2, 4])
assert comparator(an, ao)
assert not comparator(an, ap)
assert comparator(pd.NA, pd.NA)
assert not comparator(pd.NA, None)
assert not comparator(None, pd.NA)
s1 = pd.Series([1, 2, pd.NA, 4])
s2 = pd.Series([1, 2, pd.NA, 4])
s3 = pd.Series([1, 2, None, 4])
assert comparator(s1, s2)
assert not comparator(s1, s3)
df1 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]})
df2 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]})
df3 = pd.DataFrame({"a": [1, 2, None], "b": [4, None, 6]})
assert comparator(df1, df2)
assert not comparator(df1, df3)
d1 = {"a": pd.NA, "b": [1, pd.NA, 3]}
d2 = {"a": pd.NA, "b": [1, pd.NA, 3]}
d3 = {"a": None, "b": [1, None, 3]}
assert comparator(d1, d2)
assert not comparator(d1, d3)
s1 = pd.Series([1, 2, pd.NA, 4])
s2 = pd.Series([1, 2, pd.NA, 4])
filtered1 = s1[s1 > 1]
filtered2 = s2[s2 > 1]
assert comparator(filtered1, filtered2)
def test_pyarrow():
try:
import pyarrow as pa
except ImportError:
pytest.skip()
# Test PyArrow Table
table1 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
table2 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
table3 = pa.table({"a": [1, 2, 3], "b": [4, 5, 7]})
table4 = pa.table({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
table5 = pa.table({"a": [1, 2, 3], "c": [4, 5, 6]}) # different column name
assert comparator(table1, table2)
assert not comparator(table1, table3)
assert not comparator(table1, table4)
assert not comparator(table1, table5)
# Test PyArrow RecordBatch
batch1 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
batch2 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
batch3 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 5.0]})
batch4 = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [3.0, 4.0, 5.0]})
assert comparator(batch1, batch2)
assert not comparator(batch1, batch3)
assert not comparator(batch1, batch4)
# Test PyArrow Array
arr1 = pa.array([1, 2, 3])
arr2 = pa.array([1, 2, 3])
arr3 = pa.array([1, 2, 4])
arr4 = pa.array([1, 2, 3, 4])
arr5 = pa.array([1.0, 2.0, 3.0]) # different type
assert comparator(arr1, arr2)
assert not comparator(arr1, arr3)
assert not comparator(arr1, arr4)
assert not comparator(arr1, arr5)
# Test PyArrow Array with nulls
arr_null1 = pa.array([1, None, 3])
arr_null2 = pa.array([1, None, 3])
arr_null3 = pa.array([1, 2, 3])
assert comparator(arr_null1, arr_null2)
assert not comparator(arr_null1, arr_null3)
# Test PyArrow ChunkedArray
chunked1 = pa.chunked_array([[1, 2], [3, 4]])
chunked2 = pa.chunked_array([[1, 2], [3, 4]])
chunked3 = pa.chunked_array([[1, 2], [3, 5]])
chunked4 = pa.chunked_array([[1, 2, 3], [4, 5]])
assert comparator(chunked1, chunked2)
assert not comparator(chunked1, chunked3)
assert not comparator(chunked1, chunked4)
# Test PyArrow Scalar
scalar1 = pa.scalar(42)
scalar2 = pa.scalar(42)
scalar3 = pa.scalar(43)
scalar4 = pa.scalar(42.0) # different type
assert comparator(scalar1, scalar2)
assert not comparator(scalar1, scalar3)
assert not comparator(scalar1, scalar4)
# Test null scalars
null_scalar1 = pa.scalar(None, type=pa.int64())
null_scalar2 = pa.scalar(None, type=pa.int64())
null_scalar3 = pa.scalar(None, type=pa.float64())
assert comparator(null_scalar1, null_scalar2)
assert not comparator(null_scalar1, null_scalar3)
# Test PyArrow Schema
schema1 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
schema2 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
schema3 = pa.schema([("a", pa.int64()), ("c", pa.float64())])
schema4 = pa.schema([("a", pa.int32()), ("b", pa.float64())])
assert comparator(schema1, schema2)
assert not comparator(schema1, schema3)
assert not comparator(schema1, schema4)
# Test PyArrow Field
field1 = pa.field("name", pa.int64())
field2 = pa.field("name", pa.int64())
field3 = pa.field("other", pa.int64())
field4 = pa.field("name", pa.float64())
assert comparator(field1, field2)
assert not comparator(field1, field3)
assert not comparator(field1, field4)
# Test PyArrow DataType
type1 = pa.int64()
type2 = pa.int64()
type3 = pa.int32()
type4 = pa.float64()
assert comparator(type1, type2)
assert not comparator(type1, type3)
assert not comparator(type1, type4)
# Test string arrays
str_arr1 = pa.array(["hello", "world"])
str_arr2 = pa.array(["hello", "world"])
str_arr3 = pa.array(["hello", "there"])
assert comparator(str_arr1, str_arr2)
assert not comparator(str_arr1, str_arr3)
# Test nested types (struct)
struct_arr1 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
struct_arr2 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
struct_arr3 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 5}])
assert comparator(struct_arr1, struct_arr2)
assert not comparator(struct_arr1, struct_arr3)
# Test list arrays
list_arr1 = pa.array([[1, 2], [3, 4, 5]])
list_arr2 = pa.array([[1, 2], [3, 4, 5]])
list_arr3 = pa.array([[1, 2], [3, 4, 6]])
assert comparator(list_arr1, list_arr2)
assert not comparator(list_arr1, list_arr3)
def test_pyrsistent():
try:
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
except ImportError:
pytest.skip()
a = pmap({"a": 1, "b": 2})
b = pmap({"a": 1, "b": 2})
c = pmap({"a": 1, "b": 3})
assert comparator(a, b)
assert not comparator(a, c)
d = pvector([1, 2, 3])
e = pvector([1, 2, 3])
f = pvector([1, 2, 4])
assert comparator(d, e)
assert not comparator(d, f)
g = pset([1, 2, 3])
h = pset([2, 3, 1])
i = pset([1, 2, 4])
assert comparator(g, h)
assert not comparator(g, i)
class TestRecord(PRecord):
a = field()
b = field()
j = TestRecord()
k = TestRecord()
l = TestRecord(a=2, b=3)
assert comparator(j, k)
assert not comparator(j, l)
class TestClass(PClass):
a = field()
b = field()
m = TestClass()
n = TestClass()
o = TestClass(a=1, b=3)
assert comparator(m, n)
assert not comparator(m, o)
p = pdeque([1, 2, 3], 3)
q = pdeque([1, 2, 3], 3)
r = pdeque([1, 2, 4], 3)
assert comparator(p, q)
assert not comparator(p, r)
s = PBag([1, 2, 3])
t = PBag([1, 2, 3])
u = PBag([1, 2, 4])
assert comparator(s, t)
assert not comparator(s, u)
v = pvector([1, 2, 3])
w = pvector([1, 2, 3])
x = pvector([1, 2, 4])
assert comparator(v, w)
assert not comparator(v, x)
def test_torch_dtype():
try:
import torch # type: ignore
except ImportError:
pytest.skip()
# Test torch.dtype comparisons
a = torch.float32
b = torch.float32
c = torch.float64
d = torch.int32
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# Test different dtype categories
e = torch.int64
f = torch.int64
g = torch.int32
assert comparator(e, f)
assert not comparator(e, g)
# Test complex dtypes
h = torch.complex64
i = torch.complex64
j = torch.complex128
assert comparator(h, i)
assert not comparator(h, j)
# Test bool dtype
k = torch.bool
l = torch.bool
m = torch.int8
assert comparator(k, l)
assert not comparator(k, m)
def test_torch():
try:
import torch # type: ignore
except ImportError:
pytest.skip()
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
assert comparator(d, e)
assert not comparator(d, f)
# Test tensors with different data types
g = torch.tensor([1, 2, 3], dtype=torch.float32)
h = torch.tensor([1, 2, 3], dtype=torch.float32)
i = torch.tensor([1, 2, 3], dtype=torch.int64)
assert comparator(g, h)
assert not comparator(g, i)
# Test 3D tensors
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
assert comparator(j, k)
assert not comparator(j, l)
# Test tensors with different shapes
m = torch.tensor([1, 2, 3])
n = torch.tensor([[1, 2, 3]])
assert not comparator(m, n)
# Test empty tensors
o = torch.tensor([])
p = torch.tensor([])
q = torch.tensor([1])
assert comparator(o, p)
assert not comparator(o, q)
# Test tensors with NaN values
r = torch.tensor([1.0, float("nan"), 3.0])
s = torch.tensor([1.0, float("nan"), 3.0])
t = torch.tensor([1.0, 2.0, 3.0])
assert comparator(r, s) # NaN == NaN
assert not comparator(r, t)
# Test tensors with infinity values
u = torch.tensor([1.0, float("inf"), 3.0])
v = torch.tensor([1.0, float("inf"), 3.0])
w = torch.tensor([1.0, float("-inf"), 3.0])
assert comparator(u, v)
assert not comparator(u, w)
# Test tensors with different devices (if CUDA is available)
if torch.cuda.is_available():
x = torch.tensor([1, 2, 3]).cuda()
y = torch.tensor([1, 2, 3]).cuda()
z = torch.tensor([1, 2, 3])
assert comparator(x, y)
assert not comparator(x, z)
# Test tensors with requires_grad
aa = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
bb = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
cc = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
assert comparator(aa, bb)
assert not comparator(aa, cc)
# Test complex tensors
dd = torch.tensor([1 + 2j, 3 + 4j])
ee = torch.tensor([1 + 2j, 3 + 4j])
ff = torch.tensor([1 + 2j, 3 + 5j])
assert comparator(dd, ee)
assert not comparator(dd, ff)
# Test boolean tensors
gg = torch.tensor([True, False, True])
hh = torch.tensor([True, False, True])
ii = torch.tensor([True, True, True])
assert comparator(gg, hh)
assert not comparator(gg, ii)
def test_torch_device():
try:
import torch # type: ignore
except ImportError:
pytest.skip()
# Test torch.device comparisons - same device type
a = torch.device("cpu")
b = torch.device("cpu")
assert comparator(a, b)
# Test different device types
c = torch.device("cpu")
d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
assert not comparator(c, d)
# Test device with index
e = torch.device("cpu")
f = torch.device("cpu")
assert comparator(e, f)
# Test cuda devices with different indices (if multiple GPUs available)
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
g = torch.device("cuda:0")
h = torch.device("cuda:0")
i = torch.device("cuda:1")
assert comparator(g, h)
assert not comparator(g, i)
# Test cuda device with and without explicit index
if torch.cuda.is_available():
j = torch.device("cuda:0")
k = torch.device("cuda", 0)
assert comparator(j, k)
# Test meta device
l = torch.device("meta")
m = torch.device("meta")
n = torch.device("cpu")
assert comparator(l, m)
assert not comparator(l, n)
def test_torch_nn_linear():
"""Test comparator for torch.nn.Linear modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Linear layers
torch.manual_seed(42)
a = nn.Linear(10, 5)
torch.manual_seed(42)
b = nn.Linear(10, 5)
assert comparator(a, b)
# Test Linear layers with different weights (different seeds)
torch.manual_seed(42)
c = nn.Linear(10, 5)
torch.manual_seed(123)
d = nn.Linear(10, 5)
assert not comparator(c, d)
# Test Linear layers with different in_features
torch.manual_seed(42)
e = nn.Linear(10, 5)
torch.manual_seed(42)
f = nn.Linear(20, 5)
assert not comparator(e, f)
# Test Linear layers with different out_features
torch.manual_seed(42)
g = nn.Linear(10, 5)
torch.manual_seed(42)
h = nn.Linear(10, 10)
assert not comparator(g, h)
# Test Linear with and without bias
torch.manual_seed(42)
i = nn.Linear(10, 5, bias=True)
torch.manual_seed(42)
j = nn.Linear(10, 5, bias=False)
assert not comparator(i, j)
# Test Linear layers in train vs eval mode
torch.manual_seed(42)
k = nn.Linear(10, 5)
k.train()
torch.manual_seed(42)
l = nn.Linear(10, 5)
l.eval()
assert not comparator(k, l)
def test_torch_nn_conv2d():
"""Test comparator for torch.nn.Conv2d modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Conv2d layers
torch.manual_seed(42)
a = nn.Conv2d(3, 16, kernel_size=3)
torch.manual_seed(42)
b = nn.Conv2d(3, 16, kernel_size=3)
assert comparator(a, b)
# Test Conv2d with different weights
torch.manual_seed(42)
c = nn.Conv2d(3, 16, kernel_size=3)
torch.manual_seed(123)
d = nn.Conv2d(3, 16, kernel_size=3)
assert not comparator(c, d)
# Test Conv2d with different in_channels
torch.manual_seed(42)
e = nn.Conv2d(3, 16, kernel_size=3)
torch.manual_seed(42)
f = nn.Conv2d(1, 16, kernel_size=3)
assert not comparator(e, f)
# Test Conv2d with different out_channels
torch.manual_seed(42)
g = nn.Conv2d(3, 16, kernel_size=3)
torch.manual_seed(42)
h = nn.Conv2d(3, 32, kernel_size=3)
assert not comparator(g, h)
# Test Conv2d with different kernel_size
torch.manual_seed(42)
i = nn.Conv2d(3, 16, kernel_size=3)
torch.manual_seed(42)
j = nn.Conv2d(3, 16, kernel_size=5)
assert not comparator(i, j)
# Test Conv2d with different stride
torch.manual_seed(42)
k = nn.Conv2d(3, 16, kernel_size=3, stride=1)
torch.manual_seed(42)
l = nn.Conv2d(3, 16, kernel_size=3, stride=2)
assert not comparator(k, l)
# Test Conv2d with different padding
torch.manual_seed(42)
m = nn.Conv2d(3, 16, kernel_size=3, padding=0)
torch.manual_seed(42)
n = nn.Conv2d(3, 16, kernel_size=3, padding=1)
assert not comparator(m, n)
def test_torch_nn_batchnorm():
"""Test comparator for torch.nn.BatchNorm modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical BatchNorm2d layers
torch.manual_seed(42)
a = nn.BatchNorm2d(16)
torch.manual_seed(42)
b = nn.BatchNorm2d(16)
assert comparator(a, b)
# Test BatchNorm2d with different num_features
torch.manual_seed(42)
c = nn.BatchNorm2d(16)
torch.manual_seed(42)
d = nn.BatchNorm2d(32)
assert not comparator(c, d)
# Test BatchNorm2d with different eps
torch.manual_seed(42)
e = nn.BatchNorm2d(16, eps=1e-5)
torch.manual_seed(42)
f = nn.BatchNorm2d(16, eps=1e-3)
assert not comparator(e, f)
# Test BatchNorm2d with different momentum
torch.manual_seed(42)
g = nn.BatchNorm2d(16, momentum=0.1)
torch.manual_seed(42)
h = nn.BatchNorm2d(16, momentum=0.01)
assert not comparator(g, h)
# Test BatchNorm2d with and without affine
torch.manual_seed(42)
i = nn.BatchNorm2d(16, affine=True)
torch.manual_seed(42)
j = nn.BatchNorm2d(16, affine=False)
assert not comparator(i, j)
# Test BatchNorm2d running stats after forward passes
torch.manual_seed(42)
k = nn.BatchNorm2d(16)
k.train()
input_k = torch.randn(4, 16, 8, 8)
_ = k(input_k)
torch.manual_seed(42)
l = nn.BatchNorm2d(16)
l.train()
input_l = torch.randn(4, 16, 8, 8)
_ = l(input_l)
# Same seed means same running stats
assert comparator(k, l)
# Test BatchNorm2d with different running stats
torch.manual_seed(42)
m = nn.BatchNorm2d(16)
m.train()
torch.manual_seed(42)
_ = m(torch.randn(4, 16, 8, 8))
torch.manual_seed(42)
n = nn.BatchNorm2d(16)
n.train()
torch.manual_seed(123)
_ = n(torch.randn(4, 16, 8, 8))
assert not comparator(m, n)
# Test BatchNorm1d
torch.manual_seed(42)
o = nn.BatchNorm1d(16)
torch.manual_seed(42)
p = nn.BatchNorm1d(16)
assert comparator(o, p)
def test_torch_nn_dropout():
"""Test comparator for torch.nn.Dropout modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Dropout layers
a = nn.Dropout(p=0.5)
b = nn.Dropout(p=0.5)
assert comparator(a, b)
# Test Dropout with different p values
c = nn.Dropout(p=0.5)
d = nn.Dropout(p=0.3)
assert not comparator(c, d)
# Test Dropout with different inplace values
e = nn.Dropout(p=0.5, inplace=False)
f = nn.Dropout(p=0.5, inplace=True)
assert not comparator(e, f)
# Test Dropout2d
g = nn.Dropout2d(p=0.5)
h = nn.Dropout2d(p=0.5)
assert comparator(g, h)
# Test Dropout vs Dropout2d (different types)
i = nn.Dropout(p=0.5)
j = nn.Dropout2d(p=0.5)
assert not comparator(i, j)
def test_torch_nn_activation():
"""Test comparator for torch.nn activation modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test ReLU
a = nn.ReLU()
b = nn.ReLU()
assert comparator(a, b)
# Test ReLU with different inplace
c = nn.ReLU(inplace=False)
d = nn.ReLU(inplace=True)
assert not comparator(c, d)
# Test LeakyReLU
e = nn.LeakyReLU(negative_slope=0.01)
f = nn.LeakyReLU(negative_slope=0.01)
assert comparator(e, f)
# Test LeakyReLU with different negative_slope
g = nn.LeakyReLU(negative_slope=0.01)
h = nn.LeakyReLU(negative_slope=0.1)
assert not comparator(g, h)
# Test Sigmoid vs ReLU (different types)
i = nn.Sigmoid()
j = nn.ReLU()
assert not comparator(i, j)
# Test GELU
k = nn.GELU()
l = nn.GELU()
assert comparator(k, l)
# Test Softmax
m = nn.Softmax(dim=1)
n = nn.Softmax(dim=1)
assert comparator(m, n)
# Test Softmax with different dim
o = nn.Softmax(dim=1)
p = nn.Softmax(dim=0)
assert not comparator(o, p)
def test_torch_nn_pooling():
"""Test comparator for torch.nn pooling modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test MaxPool2d
a = nn.MaxPool2d(kernel_size=2)
b = nn.MaxPool2d(kernel_size=2)
assert comparator(a, b)
# Test MaxPool2d with different kernel_size
c = nn.MaxPool2d(kernel_size=2)
d = nn.MaxPool2d(kernel_size=3)
assert not comparator(c, d)
# Test MaxPool2d with different stride
e = nn.MaxPool2d(kernel_size=2, stride=2)
f = nn.MaxPool2d(kernel_size=2, stride=1)
assert not comparator(e, f)
# Test AvgPool2d
g = nn.AvgPool2d(kernel_size=2)
h = nn.AvgPool2d(kernel_size=2)
assert comparator(g, h)
# Test MaxPool2d vs AvgPool2d (different types)
i = nn.MaxPool2d(kernel_size=2)
j = nn.AvgPool2d(kernel_size=2)
assert not comparator(i, j)
# Test AdaptiveAvgPool2d
k = nn.AdaptiveAvgPool2d(output_size=(1, 1))
l = nn.AdaptiveAvgPool2d(output_size=(1, 1))
assert comparator(k, l)
# Test AdaptiveAvgPool2d with different output_size
m = nn.AdaptiveAvgPool2d(output_size=(1, 1))
n = nn.AdaptiveAvgPool2d(output_size=(2, 2))
assert not comparator(m, n)
def test_torch_nn_embedding():
"""Test comparator for torch.nn.Embedding modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Embedding layers
torch.manual_seed(42)
a = nn.Embedding(1000, 128)
torch.manual_seed(42)
b = nn.Embedding(1000, 128)
assert comparator(a, b)
# Test Embedding with different weights
torch.manual_seed(42)
c = nn.Embedding(1000, 128)
torch.manual_seed(123)
d = nn.Embedding(1000, 128)
assert not comparator(c, d)
# Test Embedding with different num_embeddings
torch.manual_seed(42)
e = nn.Embedding(1000, 128)
torch.manual_seed(42)
f = nn.Embedding(2000, 128)
assert not comparator(e, f)
# Test Embedding with different embedding_dim
torch.manual_seed(42)
g = nn.Embedding(1000, 128)
torch.manual_seed(42)
h = nn.Embedding(1000, 256)
assert not comparator(g, h)
# Test Embedding with different padding_idx
torch.manual_seed(42)
i = nn.Embedding(1000, 128, padding_idx=0)
torch.manual_seed(42)
j = nn.Embedding(1000, 128, padding_idx=1)
assert not comparator(i, j)
# Test Embedding with and without padding_idx
torch.manual_seed(42)
k = nn.Embedding(1000, 128)
torch.manual_seed(42)
l = nn.Embedding(1000, 128, padding_idx=0)
assert not comparator(k, l)
def test_torch_nn_lstm():
"""Test comparator for torch.nn.LSTM modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical LSTM layers
torch.manual_seed(42)
a = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
torch.manual_seed(42)
b = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
assert comparator(a, b)
# Test LSTM with different weights
torch.manual_seed(42)
c = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
torch.manual_seed(123)
d = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
assert not comparator(c, d)
# Test LSTM with different input_size
torch.manual_seed(42)
e = nn.LSTM(input_size=10, hidden_size=20)
torch.manual_seed(42)
f = nn.LSTM(input_size=20, hidden_size=20)
assert not comparator(e, f)
# Test LSTM with different hidden_size
torch.manual_seed(42)
g = nn.LSTM(input_size=10, hidden_size=20)
torch.manual_seed(42)
h = nn.LSTM(input_size=10, hidden_size=40)
assert not comparator(g, h)
# Test LSTM with different num_layers
torch.manual_seed(42)
i = nn.LSTM(input_size=10, hidden_size=20, num_layers=1)
torch.manual_seed(42)
j = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
assert not comparator(i, j)
# Test LSTM with different bidirectional
torch.manual_seed(42)
k = nn.LSTM(input_size=10, hidden_size=20, bidirectional=False)
torch.manual_seed(42)
l = nn.LSTM(input_size=10, hidden_size=20, bidirectional=True)
assert not comparator(k, l)
# Test LSTM with different batch_first
torch.manual_seed(42)
m = nn.LSTM(input_size=10, hidden_size=20, batch_first=False)
torch.manual_seed(42)
n = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
assert not comparator(m, n)
def test_torch_nn_gru():
"""Test comparator for torch.nn.GRU modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical GRU layers
torch.manual_seed(42)
a = nn.GRU(input_size=10, hidden_size=20, num_layers=2)
torch.manual_seed(42)
b = nn.GRU(input_size=10, hidden_size=20, num_layers=2)
assert comparator(a, b)
# Test GRU with different hidden_size
torch.manual_seed(42)
c = nn.GRU(input_size=10, hidden_size=20)
torch.manual_seed(42)
d = nn.GRU(input_size=10, hidden_size=40)
assert not comparator(c, d)
# Test GRU vs LSTM (different types)
torch.manual_seed(42)
e = nn.GRU(input_size=10, hidden_size=20)
torch.manual_seed(42)
f = nn.LSTM(input_size=10, hidden_size=20)
assert not comparator(e, f)
def test_torch_nn_layernorm():
"""Test comparator for torch.nn.LayerNorm modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical LayerNorm layers
torch.manual_seed(42)
a = nn.LayerNorm(normalized_shape=[10])
torch.manual_seed(42)
b = nn.LayerNorm(normalized_shape=[10])
assert comparator(a, b)
# Test LayerNorm with different normalized_shape
torch.manual_seed(42)
c = nn.LayerNorm(normalized_shape=[10])
torch.manual_seed(42)
d = nn.LayerNorm(normalized_shape=[20])
assert not comparator(c, d)
# Test LayerNorm with different eps
torch.manual_seed(42)
e = nn.LayerNorm(normalized_shape=[10], eps=1e-5)
torch.manual_seed(42)
f = nn.LayerNorm(normalized_shape=[10], eps=1e-3)
assert not comparator(e, f)
# Test LayerNorm with and without elementwise_affine
torch.manual_seed(42)
g = nn.LayerNorm(normalized_shape=[10], elementwise_affine=True)
torch.manual_seed(42)
h = nn.LayerNorm(normalized_shape=[10], elementwise_affine=False)
assert not comparator(g, h)
def test_torch_nn_multihead_attention():
"""Test comparator for torch.nn.MultiheadAttention modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical MultiheadAttention layers
torch.manual_seed(42)
a = nn.MultiheadAttention(embed_dim=64, num_heads=8)
torch.manual_seed(42)
b = nn.MultiheadAttention(embed_dim=64, num_heads=8)
assert comparator(a, b)
# Test MultiheadAttention with different weights
torch.manual_seed(42)
c = nn.MultiheadAttention(embed_dim=64, num_heads=8)
torch.manual_seed(123)
d = nn.MultiheadAttention(embed_dim=64, num_heads=8)
assert not comparator(c, d)
# Test MultiheadAttention with different embed_dim
torch.manual_seed(42)
e = nn.MultiheadAttention(embed_dim=64, num_heads=8)
torch.manual_seed(42)
f = nn.MultiheadAttention(embed_dim=128, num_heads=8)
assert not comparator(e, f)
# Test MultiheadAttention with different num_heads
torch.manual_seed(42)
g = nn.MultiheadAttention(embed_dim=64, num_heads=8)
torch.manual_seed(42)
h = nn.MultiheadAttention(embed_dim=64, num_heads=4)
assert not comparator(g, h)
# Test MultiheadAttention with different dropout
torch.manual_seed(42)
i = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.0)
torch.manual_seed(42)
j = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.1)
assert not comparator(i, j)
def test_torch_nn_sequential():
"""Test comparator for torch.nn.Sequential modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Sequential modules
torch.manual_seed(42)
a = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
torch.manual_seed(42)
b = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
assert comparator(a, b)
# Test Sequential with different weights
torch.manual_seed(42)
c = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
torch.manual_seed(123)
d = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
assert not comparator(c, d)
# Test Sequential with different number of layers
torch.manual_seed(42)
e = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
torch.manual_seed(42)
f = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU(),
nn.Linear(20, 5)
)
assert not comparator(e, f)
# Test Sequential with different layer types
torch.manual_seed(42)
g = nn.Sequential(
nn.Linear(10, 20),
nn.ReLU()
)
torch.manual_seed(42)
h = nn.Sequential(
nn.Linear(10, 20),
nn.Sigmoid()
)
assert not comparator(g, h)
def test_torch_nn_modulelist():
"""Test comparator for torch.nn.ModuleList modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical ModuleList
torch.manual_seed(42)
a = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
torch.manual_seed(42)
b = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
assert comparator(a, b)
# Test ModuleList with different number of modules
torch.manual_seed(42)
c = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
torch.manual_seed(42)
d = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)])
assert not comparator(c, d)
def test_torch_nn_moduledict():
"""Test comparator for torch.nn.ModuleDict modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical ModuleDict
torch.manual_seed(42)
a = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
torch.manual_seed(42)
b = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
assert comparator(a, b)
# Test ModuleDict with different keys
torch.manual_seed(42)
c = nn.ModuleDict({
"fc1": nn.Linear(10, 20),
"fc2": nn.Linear(20, 5)
})
torch.manual_seed(42)
d = nn.ModuleDict({
"layer1": nn.Linear(10, 20),
"layer2": nn.Linear(20, 5)
})
assert not comparator(c, d)
def test_torch_nn_custom_module():
"""Test comparator for custom torch.nn.Module subclasses."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
class SimpleNet(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.fc1 = nn.Linear(10, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, 5)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# Test identical custom modules
torch.manual_seed(42)
a = SimpleNet(hidden_size=20)
torch.manual_seed(42)
b = SimpleNet(hidden_size=20)
assert comparator(a, b)
# Test custom modules with different weights
torch.manual_seed(42)
c = SimpleNet(hidden_size=20)
torch.manual_seed(123)
d = SimpleNet(hidden_size=20)
assert not comparator(c, d)
# Test custom modules with different architecture
torch.manual_seed(42)
e = SimpleNet(hidden_size=20)
torch.manual_seed(42)
f = SimpleNet(hidden_size=40)
assert not comparator(e, f)
def test_torch_nn_nested_modules():
"""Test comparator for nested torch.nn.Module structures."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.block1 = EncoderBlock(3, 16)
self.block2 = EncoderBlock(16, 32)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x = self.block1(x)
x = self.pool(x)
x = self.block2(x)
x = self.pool(x)
return x
# Test identical nested modules
torch.manual_seed(42)
a = Encoder()
torch.manual_seed(42)
b = Encoder()
assert comparator(a, b)
# Test nested modules with different weights
torch.manual_seed(42)
c = Encoder()
torch.manual_seed(123)
d = Encoder()
assert not comparator(c, d)
def test_torch_nn_transformer():
"""Test comparator for torch.nn.Transformer modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test identical Transformer
torch.manual_seed(42)
a = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
torch.manual_seed(42)
b = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
assert comparator(a, b)
# Test Transformer with different d_model
torch.manual_seed(42)
c = nn.Transformer(d_model=64, nhead=4)
torch.manual_seed(42)
d = nn.Transformer(d_model=128, nhead=4)
assert not comparator(c, d)
# Test Transformer with different nhead
torch.manual_seed(42)
e = nn.Transformer(d_model=64, nhead=4)
torch.manual_seed(42)
f = nn.Transformer(d_model=64, nhead=8)
assert not comparator(e, f)
# Test TransformerEncoder
torch.manual_seed(42)
encoder_layer_a = nn.TransformerEncoderLayer(d_model=64, nhead=4)
g = nn.TransformerEncoder(encoder_layer_a, num_layers=2)
torch.manual_seed(42)
encoder_layer_b = nn.TransformerEncoderLayer(d_model=64, nhead=4)
h = nn.TransformerEncoder(encoder_layer_b, num_layers=2)
assert comparator(g, h)
def test_torch_nn_parameter_buffer_modification():
"""Test comparator detects parameter and buffer modifications."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test that modifying a parameter is detected
torch.manual_seed(42)
a = nn.Linear(10, 5)
torch.manual_seed(42)
b = nn.Linear(10, 5)
assert comparator(a, b)
# Modify a parameter
with torch.no_grad():
b.weight[0, 0] = 999.0
assert not comparator(a, b)
# Test that modifying a buffer is detected (BatchNorm running_mean)
torch.manual_seed(42)
c = nn.BatchNorm2d(16)
torch.manual_seed(42)
d = nn.BatchNorm2d(16)
assert comparator(c, d)
# Modify a buffer
d.running_mean[0] = 999.0
assert not comparator(c, d)
def test_torch_nn_device_placement():
"""Test comparator handles modules on different devices."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Create modules on CPU
torch.manual_seed(42)
cpu_module = nn.Linear(10, 5)
torch.manual_seed(42)
cpu_module2 = nn.Linear(10, 5)
assert comparator(cpu_module, cpu_module2)
# If CUDA is available, test device mismatch
if torch.cuda.is_available():
torch.manual_seed(42)
cpu_mod = nn.Linear(10, 5)
torch.manual_seed(42)
cuda_mod = nn.Linear(10, 5).cuda()
assert not comparator(cpu_mod, cuda_mod)
def test_torch_nn_conv1d_conv3d():
"""Test comparator for Conv1d and Conv3d modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test Conv1d
torch.manual_seed(42)
a = nn.Conv1d(3, 16, kernel_size=3)
torch.manual_seed(42)
b = nn.Conv1d(3, 16, kernel_size=3)
assert comparator(a, b)
# Test Conv1d with different out_channels
torch.manual_seed(42)
c = nn.Conv1d(3, 16, kernel_size=3)
torch.manual_seed(42)
d = nn.Conv1d(3, 32, kernel_size=3)
assert not comparator(c, d)
# Test Conv3d
torch.manual_seed(42)
e = nn.Conv3d(3, 16, kernel_size=3)
torch.manual_seed(42)
f = nn.Conv3d(3, 16, kernel_size=3)
assert comparator(e, f)
# Test Conv1d vs Conv2d (different types)
torch.manual_seed(42)
g = nn.Conv1d(3, 16, kernel_size=3)
torch.manual_seed(42)
h = nn.Conv2d(3, 16, kernel_size=3)
assert not comparator(g, h)
def test_torch_nn_flatten_unflatten():
"""Test comparator for Flatten and Unflatten modules."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test Flatten
a = nn.Flatten()
b = nn.Flatten()
assert comparator(a, b)
# Test Flatten with different start_dim
c = nn.Flatten(start_dim=1)
d = nn.Flatten(start_dim=0)
assert not comparator(c, d)
# Test Unflatten
e = nn.Unflatten(dim=1, unflattened_size=(2, 5))
f = nn.Unflatten(dim=1, unflattened_size=(2, 5))
assert comparator(e, f)
def test_torch_nn_identity():
"""Test comparator for Identity module."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# Test Identity
a = nn.Identity()
b = nn.Identity()
assert comparator(a, b)
# Test Identity vs Linear (different types)
torch.manual_seed(42)
c = nn.Identity()
d = nn.Linear(10, 10)
assert not comparator(c, d)
def test_torch_nn_with_superset():
"""Test comparator superset_obj mode with nn.Module."""
try:
import torch
from torch import nn
except ImportError:
pytest.skip()
# For nn.Module, superset_obj should still work
torch.manual_seed(42)
a = nn.Linear(10, 5)
torch.manual_seed(42)
b = nn.Linear(10, 5)
# superset_obj=True should pass for identical modules
assert comparator(a, b, superset_obj=True)
# Different modules should still fail
torch.manual_seed(42)
c = nn.Linear(10, 5)
torch.manual_seed(123)
d = nn.Linear(10, 5)
assert not comparator(c, d, superset_obj=True)
def test_jax():
try:
import jax.numpy as jnp
except ImportError:
pytest.skip()
# Test basic arrays
a = jnp.array([1, 2, 3])
b = jnp.array([1, 2, 3])
c = jnp.array([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
# Test 2D arrays
d = jnp.array([[1, 2, 3], [4, 5, 6]])
e = jnp.array([[1, 2, 3], [4, 5, 6]])
f = jnp.array([[1, 2, 3], [4, 5, 7]])
assert comparator(d, e)
assert not comparator(d, f)
# Test arrays with different data types
g = jnp.array([1, 2, 3], dtype=jnp.float32)
h = jnp.array([1, 2, 3], dtype=jnp.float32)
i = jnp.array([1, 2, 3], dtype=jnp.int32)
assert comparator(g, h)
assert not comparator(g, i)
# Test 3D arrays
j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
assert comparator(j, k)
assert not comparator(j, l)
# Test arrays with different shapes
m = jnp.array([1, 2, 3])
n = jnp.array([[1, 2, 3]])
assert not comparator(m, n)
# Test empty arrays
o = jnp.array([])
p = jnp.array([])
q = jnp.array([1])
assert comparator(o, p)
assert not comparator(o, q)
# Test arrays with NaN values
r = jnp.array([1.0, jnp.nan, 3.0])
s = jnp.array([1.0, jnp.nan, 3.0])
t = jnp.array([1.0, 2.0, 3.0])
assert comparator(r, s) # NaN == NaN
assert not comparator(r, t)
# Test arrays with infinity values
u = jnp.array([1.0, jnp.inf, 3.0])
v = jnp.array([1.0, jnp.inf, 3.0])
w = jnp.array([1.0, -jnp.inf, 3.0])
assert comparator(u, v)
assert not comparator(u, w)
# Test complex arrays
x = jnp.array([1 + 2j, 3 + 4j])
y = jnp.array([1 + 2j, 3 + 4j])
z = jnp.array([1 + 2j, 3 + 5j])
assert comparator(x, y)
assert not comparator(x, z)
# Test boolean arrays
aa = jnp.array([True, False, True])
bb = jnp.array([True, False, True])
cc = jnp.array([True, True, True])
assert comparator(aa, bb)
assert not comparator(aa, cc)
def test_xarray():
try:
import numpy as np
import xarray as xr
except ImportError:
pytest.skip()
# Test basic DataArray
a = xr.DataArray([1, 2, 3], dims=["x"])
b = xr.DataArray([1, 2, 3], dims=["x"])
c = xr.DataArray([1, 2, 4], dims=["x"])
assert comparator(a, b)
assert not comparator(a, c)
# Test DataArray with coordinates
d = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"])
e = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"])
f = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 3]}, dims=["x"])
assert comparator(d, e)
assert not comparator(d, f)
# Test DataArray with attributes
g = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"})
h = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"})
i = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "feet"})
assert comparator(g, h)
assert not comparator(g, i)
# Test 2D DataArray
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"])
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"])
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=["x", "y"])
assert comparator(j, k)
assert not comparator(j, l)
# Test DataArray with different dimensions
m = xr.DataArray([1, 2, 3], dims=["x"])
n = xr.DataArray([1, 2, 3], dims=["y"])
assert not comparator(m, n)
# Test DataArray with NaN values
o = xr.DataArray([1.0, np.nan, 3.0], dims=["x"])
p = xr.DataArray([1.0, np.nan, 3.0], dims=["x"])
q = xr.DataArray([1.0, 2.0, 3.0], dims=["x"])
assert comparator(o, p)
assert not comparator(o, q)
# Test Dataset
r = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])})
s = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])})
t = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 9]])})
assert comparator(r, s)
assert not comparator(r, t)
# Test Dataset with coordinates
u = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]})
v = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]})
w = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 2], "y": [0, 1]})
assert comparator(u, v)
assert not comparator(u, w)
# Test Dataset with attributes
x = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"})
y = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"})
z = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "model"})
assert comparator(x, y)
assert not comparator(x, z)
# Test Dataset with different variables
aa = xr.Dataset({"temp": (["x"], [1, 2, 3])})
bb = xr.Dataset({"temp": (["x"], [1, 2, 3])})
cc = xr.Dataset({"pressure": (["x"], [1, 2, 3])})
assert comparator(aa, bb)
assert not comparator(aa, cc)
# Test empty Dataset
dd = xr.Dataset()
ee = xr.Dataset()
assert comparator(dd, ee)
# Test DataArray with different shapes
ff = xr.DataArray([1, 2, 3], dims=["x"])
gg = xr.DataArray([[1, 2, 3]], dims=["x", "y"])
assert not comparator(ff, gg)
# Test DataArray with different data types
# Note: xarray.identical() considers int and float arrays with same values as identical
hh = xr.DataArray(np.array([1, 2, 3], dtype="int32"), dims=["x"])
ii = xr.DataArray(np.array([1, 2, 3], dtype="int64"), dims=["x"])
# xarray is permissive with dtype comparisons, treats these as identical
assert comparator(hh, ii)
# Test DataArray with infinity
jj = xr.DataArray([1.0, np.inf, 3.0], dims=["x"])
kk = xr.DataArray([1.0, np.inf, 3.0], dims=["x"])
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=["x"])
assert comparator(jj, kk)
assert not comparator(jj, ll)
# Test Dataset vs DataArray (different types)
mm = xr.DataArray([1, 2, 3], dims=["x"])
nn = xr.Dataset({"data": (["x"], [1, 2, 3])})
assert not comparator(mm, nn)
def test_returns():
a = Success(5)
b = Success(5)
c = Success(6)
d = Failure(5)
e = Success((5, 5))
f = Success((5, 6))
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
assert not comparator(e, f)
g = Success((5, 5))
h = Success((5, 5))
i = Success((5, 6))
assert comparator(g, h)
assert not comparator(g, i)
def test_custom_object():
class TestClass:
def __init__(self, value):
self.value = value
def __eq__(self, other):
return self.value == other.value
a = TestClass(5)
b = TestClass(5)
c = TestClass(6)
assert comparator(a, b)
assert not comparator(a, c)
class TestClass2:
def __init__(self, value1, value2=6):
self.value1 = value1
self.value2 = value2
a = TestClass(5)
b = TestClass2(5, 6)
c = TestClass2(5, 7)
d = TestClass2(5, 6)
assert not comparator(a, b)
assert not comparator(b, c)
assert comparator(b, d)
class TestClass3(TestClass):
def print(self):
print(self.value)
a = TestClass2(5)
b = TestClass3(5)
c = TestClass3(5)
assert not comparator(a, b)
assert comparator(b, c)
@dataclasses.dataclass
class InventoryItem:
"""Class for keeping track of an item in inventory."""
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
@pydantic.dataclasses.dataclass
class InventoryItemPydantic:
"""Class for keeping track of an item in inventory."""
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
class InventoryItemBasePydantic(pydantic.BaseModel):
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
class A:
items = [1, 2, 3]
val = 5
class B:
items = [1, 2, 4]
val = 5
assert comparator(A, A)
assert not comparator(A, B)
class C:
items = [1, 2, 3]
val = 5
def __init__(self):
self.itemm2 = [1, 2, 3]
self.val2 = 5
class D:
items = [1, 2, 3]
val = 5
def __init__(self):
self.itemm2 = [1, 2, 4]
self.val2 = 5
assert comparator(C, C)
assert not comparator(C, D)
E = C
assert comparator(C, E)
def test_custom_object_2():
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
original_code = fto_path.read_text("utf-8")
from code_to_optimize.bubble_sort_method import BubbleSorter
a = BubbleSorter()
assert a.x == 0
try:
# Remove the module from sys.modules, to get the updated class
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
b = BubbleSorter()
assert comparator(
a, b
) # Note that type(a) != type(b) as the class type objects are different, even if the code is the same.
optimized_code_mutated_attr = """
class BubbleSorter:
z = 0
def __init__(self, x=1):
self.x = x
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
c = BubbleSorter()
assert c.x == 1
assert not comparator(a, c)
optimized_code_new_attr = """
class BubbleSorter:
z = 5
def __init__(self, x=0):
self.x = x
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
fto_path.write_text(optimized_code_new_attr, "utf-8")
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
d = BubbleSorter()
assert d.x == 0
# Currently, we do not check if class variables are different, since the code replacer does not allow this.
# In the future, if this functionality is allowed, this assert should be false.
assert comparator(a, d)
finally:
fto_path.write_text(original_code, "utf-8")
def test_superset():
class A:
def __init__(self):
self.a = 1
obj = A()
obj.x = 3
assert comparator(A(), obj, superset_obj=True)
assert not comparator(obj, A(), superset_obj=True)
assert not comparator(A(), obj)
assert not comparator(obj, A())
assert comparator(obj, obj, superset_obj=True)
assert comparator(obj, obj)
def test_compare_results_fn():
original_results = TestResults()
original_results.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=5,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_1 = TestResults()
new_results_1.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(original_results, new_results_1)
assert match
new_results_2 = TestResults()
new_results_2.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=[5],
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(original_results, new_results_2)
assert not match
new_results_3 = TestResults()
new_results_3.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_3.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="2",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(original_results, new_results_3)
assert match
new_results_4 = TestResults()
new_results_4.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=False,
runtime=5,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(original_results, new_results_4)
assert not match
new_results_5_baseline = TestResults()
new_results_5_baseline.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=5,
test_framework="unittest",
test_type=TestType.GENERATED_REGRESSION,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_5_opt = TestResults()
new_results_5_opt.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=False,
runtime=5,
test_framework="unittest",
test_type=TestType.GENERATED_REGRESSION,
return_value=5,
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt)
assert not match
new_results_6_baseline = TestResults()
new_results_6_baseline.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=5,
test_framework="unittest",
test_type=TestType.REPLAY_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_6_opt = TestResults()
new_results_6_opt.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=False,
runtime=5,
test_framework="unittest",
test_type=TestType.REPLAY_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt)
assert not match
match, _ = compare_test_results(TestResults(), TestResults())
assert not match
def test_exceptions():
type_error = TypeError("This is a type error")
type_error_2 = TypeError("This is a type error")
assert comparator(type_error, type_error_2)
def raise_exception():
raise Exception("This is an exception")
def test_exceptions_comparator():
# Currently we are only comparing the exception types and the attributes that don't start with "_"
# there are complications with comparing the exception messages
try:
raise_exception()
except Exception as e:
exception = e
try:
raise_exception()
except Exception as b:
exception_2 = b
assert comparator(exception, exception_2)
exc1 = ValueError("same message")
exc2 = ValueError("same message")
assert comparator(exc1, exc2)
exc_msg1 = ValueError("message one")
exc_msg2 = ValueError("message two")
# Different messages but same types
assert comparator(exc_msg1, exc_msg2)
exc1 = ValueError("common message")
exc2 = TypeError("common message")
assert not comparator(exc1, exc2)
exc_file_1 = FileNotFoundError(2, "No such file or directory")
exc_file2 = FileNotFoundError(2, "No such file or directory")
exc_file4 = FileNotFoundError(2, "File not found")
exc_file3 = FileNotFoundError(3, "No such file or directory")
assert not comparator(exc1, exc2)
assert comparator(exc_file_1, exc_file2)
assert comparator(exc_file_1, exc_file3)
assert comparator(exc_file_1, exc_file4)
assert comparator(exception, exception)
assert not comparator(exception, None)
assert not comparator(None, exception)
assert comparator(None, None)
# Different exception types
exc_type1 = TypeError("Type error")
exc_type2 = TypeError("Another type error")
assert comparator(exc_type1, exc_type2)
exc_type3 = KeyError("Missing key")
exc_type4 = KeyError("Missing key")
assert comparator(exc_type3, exc_type4)
assert not comparator(exc_type1, exc_type3)
# compare the attributes of the exception as well
class CustomError(Exception):
def __init__(self, message, code):
super().__init__(message)
self.code = code
custom_exc1 = CustomError("Something went wrong", 101)
custom_exc2 = CustomError("Something went wrong", 101)
assert comparator(custom_exc1, custom_exc2)
custom_exc4 = CustomError("Something went wrong", 102)
assert not comparator(custom_exc1, custom_exc4)
class CustomErrorNoArgs(Exception):
pass
custom_no_args1 = CustomErrorNoArgs()
custom_no_args2 = CustomErrorNoArgs()
assert comparator(custom_no_args1, custom_no_args2)
exc_empty1 = ValueError("")
exc_empty2 = ValueError("")
assert comparator(exc_empty1, exc_empty2)
exc_not_empty = ValueError("Not empty")
assert comparator(exc_empty1, exc_not_empty)
class CustomValueError(ValueError):
pass
custom_value_error1 = CustomValueError("A custom value error")
value_error1 = ValueError("A custom value error")
assert not comparator(custom_value_error1, value_error1)
custom_value_error2 = CustomValueError("Another custom value error")
assert comparator(custom_value_error1, custom_value_error2)
class CustomExceptionWithArgs(Exception):
def __init__(self, arg1, arg2):
self.args = (arg1, arg2)
custom_args_exc1 = CustomExceptionWithArgs(1, "test")
custom_args_exc2 = CustomExceptionWithArgs(1, "test")
assert comparator(custom_args_exc1, custom_args_exc2)
custom_args_exc3 = CustomExceptionWithArgs(1, "different")
assert comparator(custom_args_exc1, custom_args_exc3)
def raise_specific_exception():
raise ZeroDivisionError("Cannot divide by zero")
try:
raise_specific_exception()
except ZeroDivisionError as z1:
zero_division_exc1 = z1
try:
raise_specific_exception()
except ZeroDivisionError as z2:
zero_division_exc2 = z2
assert comparator(zero_division_exc1, zero_division_exc2)
zero_division_exc3 = ZeroDivisionError("Different message")
assert comparator(zero_division_exc1, zero_division_exc3)
assert comparator(..., ...)
assert comparator(Ellipsis, Ellipsis)
assert not comparator(..., None)
assert not comparator(Ellipsis, None)
code7 = "a = 1 + 2"
module7 = ast.parse(code7)
for node in ast.walk(module7):
for child in ast.iter_child_nodes(node):
child.parent = node # type: ignore
module8 = copy.deepcopy(module7)
assert comparator(module7, module8)
code2 = "a = 1 + 3"
module2 = ast.parse(code2)
assert not comparator(module7, module2)
def test_torch_runtime_error_wrapping():
"""Test that TorchRuntimeError wrapping is handled correctly.
When torch.compile is used, exceptions are wrapped in TorchRuntimeError.
The comparator should consider an IndexError equivalent to a TorchRuntimeError
that wraps an IndexError.
"""
# Create a mock TorchRuntimeError class that mimics torch._dynamo.exc.TorchRuntimeError
class TorchRuntimeError(Exception):
"""Mock TorchRuntimeError for testing."""
# Monkey-patch the __module__ to match torch._dynamo.exc
TorchRuntimeError.__module__ = "torch._dynamo.exc"
# Test 1: TorchRuntimeError with __cause__ set to the same exception type
index_error = IndexError("index 0 is out of bounds for dimension 0 with size 0")
torch_error = TorchRuntimeError(
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')"
)
torch_error.__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0")
# These should be considered equivalent since TorchRuntimeError wraps IndexError
assert comparator(index_error, torch_error)
assert comparator(torch_error, index_error)
# Test 2: TorchRuntimeError without __cause__ but with matching error type in message
torch_error_no_cause = TorchRuntimeError(
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')"
)
assert comparator(index_error, torch_error_no_cause)
assert comparator(torch_error_no_cause, index_error)
# Test 3: Different exception types should not be equivalent
value_error = ValueError("some value error")
torch_error_index = TorchRuntimeError("got IndexError('some error')")
torch_error_index.__cause__ = IndexError("some error")
assert not comparator(value_error, torch_error_index)
assert not comparator(torch_error_index, value_error)
# Test 4: TorchRuntimeError wrapping a different type should not match
type_error = TypeError("some type error")
torch_error_with_index = TorchRuntimeError("got IndexError('index error')")
torch_error_with_index.__cause__ = IndexError("index error")
assert not comparator(type_error, torch_error_with_index)
# Test 5: Two TorchRuntimeErrors wrapping the same exception type
torch_error1 = TorchRuntimeError("got IndexError('error 1')")
torch_error1.__cause__ = IndexError("error 1")
torch_error2 = TorchRuntimeError("got IndexError('error 2')")
torch_error2.__cause__ = IndexError("error 2")
assert comparator(torch_error1, torch_error2)
# Test 6: Regular exception comparison still works
error1 = IndexError("same error")
error2 = IndexError("same error")
assert comparator(error1, error2)
# Test 7: Exception wrapped in tuple (return value scenario from debug output)
orig_return = (("tensor1", "tensor2"), {}, IndexError("index 0 is out of bounds for dimension 0 with size 0"))
torch_wrapped_return = (
("tensor1", "tensor2"),
{},
TorchRuntimeError("Dynamo failed: got IndexError('index 0 is out of bounds for dimension 0 with size 0')"),
)
torch_wrapped_return[2].__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0")
assert comparator(orig_return, torch_wrapped_return)
def test_extract_exception_from_message():
"""Test the _extract_exception_from_message helper function."""
# Test with single-quoted message
result = _extract_exception_from_message("got IndexError('some error message')")
assert result is not None
assert isinstance(result, IndexError)
# Test with double-quoted message
result = _extract_exception_from_message('got ValueError("another error")')
assert result is not None
assert isinstance(result, ValueError)
# Test with various builtin exception types
for exc_name, exc_class in [
("TypeError", TypeError),
("KeyError", KeyError),
("RuntimeError", RuntimeError),
("AttributeError", AttributeError),
("ZeroDivisionError", ZeroDivisionError),
]:
result = _extract_exception_from_message(f"got {exc_name}('test')")
assert result is not None
assert isinstance(result, exc_class)
# Test with no matching pattern
result = _extract_exception_from_message("This is a normal error message")
assert result is None
# Test with non-exception class name
result = _extract_exception_from_message("got SomeRandomClass('not an exception')")
assert result is None
# Test with partial match (no opening quote)
result = _extract_exception_from_message("got IndexError without quotes")
assert result is None
# Test with empty string
result = _extract_exception_from_message("")
assert result is None
# Test with torch-like error message format
result = _extract_exception_from_message(
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds for dimension 0 with size 0')"
)
assert result is not None
assert isinstance(result, IndexError)
def test_get_wrapped_exception():
"""Test the _get_wrapped_exception helper function."""
# Test with __cause__ (explicit chaining)
inner_error = ValueError("inner error")
outer_error = RuntimeError("outer error")
outer_error.__cause__ = inner_error
result = _get_wrapped_exception(outer_error)
assert result is inner_error
# Test with no wrapping
plain_error = ValueError("plain error")
result = _get_wrapped_exception(plain_error)
assert result is None
# Test with message pattern
error_with_pattern = RuntimeError("got TypeError('some type error')")
result = _get_wrapped_exception(error_with_pattern)
assert result is not None
assert isinstance(result, TypeError)
# Test that __cause__ takes precedence over message pattern
actual_cause = IndexError("actual cause")
error_with_both = RuntimeError("got TypeError('different error in message')")
error_with_both.__cause__ = actual_cause
result = _get_wrapped_exception(error_with_both)
assert result is actual_cause
assert isinstance(result, IndexError)
@pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+")
def test_get_wrapped_exception_exception_group():
"""Test _get_wrapped_exception with ExceptionGroup (Python 3.11+)."""
# ExceptionGroup with single exception
inner_error = ValueError("single inner error")
group = ExceptionGroup("group", [inner_error])
result = _get_wrapped_exception(group)
assert result is inner_error
# ExceptionGroup with multiple exceptions - should return None
error1 = ValueError("error 1")
error2 = TypeError("error 2")
multi_group = ExceptionGroup("multi group", [error1, error2])
result = _get_wrapped_exception(multi_group)
assert result is None
@pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+")
def test_comparator_with_exception_group():
"""Test comparator with ExceptionGroup wrapping (Python 3.11+)."""
# ExceptionGroup wrapping a single ValueError should match a plain ValueError
inner_value_error = ValueError("some value error")
group = ExceptionGroup("group", [inner_value_error])
plain_value_error = ValueError("different message but same type")
assert comparator(group, plain_value_error)
assert comparator(plain_value_error, group)
# ExceptionGroup with different exception type should not match
inner_type_error = TypeError("type error")
type_group = ExceptionGroup("group", [inner_type_error])
assert not comparator(type_group, plain_value_error)
# Two ExceptionGroups with same wrapped type should match
group1 = ExceptionGroup("group1", [ValueError("error 1")])
group2 = ExceptionGroup("group2", [ValueError("error 2")])
assert comparator(group1, group2)
def test_comparator_with_cause_chaining():
"""Test comparator with __cause__ exception chaining."""
# Create an exception chain using 'raise from'
inner = IndexError("inner index error")
outer = RuntimeError("outer runtime error")
outer.__cause__ = inner
# Outer exception should match the inner exception type
plain_index_error = IndexError("different index error")
assert comparator(outer, plain_index_error)
assert comparator(plain_index_error, outer)
# Should not match a different type
plain_type_error = TypeError("type error")
assert not comparator(outer, plain_type_error)
# Two chained exceptions with same wrapper type match (regardless of inner type)
# because same-type exceptions compare non-private attributes only (__cause__ is ignored)
outer1 = RuntimeError("outer 1")
outer1.__cause__ = ValueError("inner 1")
outer2 = RuntimeError("outer 2")
outer2.__cause__ = ValueError("inner 2")
assert comparator(outer1, outer2)
# Different wrapper types with same inner type - unwrapping makes them match
class WrapperA(Exception):
pass
class WrapperB(Exception):
pass
wrapper_a = WrapperA("wrapper a")
wrapper_a.__cause__ = KeyError("same inner type")
wrapper_b = WrapperB("wrapper b")
wrapper_b.__cause__ = KeyError("same inner type")
# Both unwrap to KeyError, so they should match
assert comparator(wrapper_a, wrapper_b)
# Different wrapper types with different inner types should not match
wrapper_c = WrapperA("wrapper c")
wrapper_c.__cause__ = ValueError("value error")
wrapper_d = WrapperB("wrapper d")
wrapper_d.__cause__ = TypeError("type error")
assert not comparator(wrapper_c, wrapper_d)
def test_comparator_with_message_pattern():
"""Test comparator with exception type extracted from message pattern."""
# Exception with wrapped type in message (no __cause__)
wrapper = RuntimeError("Operation failed: got IndexError('list index out of range')")
plain_index = IndexError("some index error")
assert comparator(wrapper, plain_index)
assert comparator(plain_index, wrapper)
# Should not match different types
plain_key = KeyError("some key error")
assert not comparator(wrapper, plain_key)
def test_comparator_wrapped_exceptions_bidirectional():
"""Test that wrapped exception comparison works in both directions."""
class CustomWrapper(Exception):
pass
# Create wrapper with __cause__
inner = AttributeError("attr error")
wrapper = CustomWrapper("wrapper message")
wrapper.__cause__ = inner
plain_attr = AttributeError("plain attr error")
# Test both directions
assert comparator(wrapper, plain_attr)
assert comparator(plain_attr, wrapper)
# Test with superset_obj flag
assert comparator(wrapper, plain_attr, superset_obj=True)
assert comparator(plain_attr, wrapper, superset_obj=True)
def test_comparator_same_type_exceptions_still_work():
"""Ensure that same-type exception comparison still works correctly."""
exc1 = ValueError("message 1")
exc2 = ValueError("message 2")
assert comparator(exc1, exc2)
# With custom attributes
class CustomError(Exception):
def __init__(self, msg, code):
super().__init__(msg)
self.code = code
custom1 = CustomError("msg1", 100)
custom2 = CustomError("msg2", 100)
assert comparator(custom1, custom2)
custom3 = CustomError("msg3", 200)
assert not comparator(custom1, custom3)
def test_comparator_no_false_positives_for_wrapped_exceptions():
"""Test that unrelated exception types don't match due to wrapping logic."""
# Two completely different exception types should never match
val_err = ValueError("value error")
type_err = TypeError("type error")
assert not comparator(val_err, type_err)
# Wrapper with different inner type should not match
wrapper = RuntimeError("some error")
wrapper.__cause__ = KeyError("key error")
assert not comparator(wrapper, val_err)
assert not comparator(val_err, wrapper)
def test_collections() -> None:
# Deque
a = deque([1, 2, 3])
b = deque([1, 2, 3])
c = deque([1, 2, 4])
d = deque([1, 2])
e = [1, 2, 3]
f = deque([1, 2, 3], maxlen=5)
assert comparator(a, b)
assert comparator(a, f) # same elements, different maxlen is ok
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
g = deque([{"a": 1}, {"b": 2}])
h = deque([{"a": 1}, {"b": 2}])
i = deque([{"a": 1}, {"b": 3}])
assert comparator(g, h)
assert not comparator(g, i)
empty_deque1 = deque()
empty_deque2 = deque()
assert comparator(empty_deque1, empty_deque2)
assert not comparator(empty_deque1, a)
# namedtuple
Point = namedtuple("Point", ["x", "y"])
a = Point(x=1, y=2)
b = Point(x=1, y=2)
c = Point(x=1, y=3)
assert comparator(a, b)
assert not comparator(a, c)
Point2 = namedtuple("Point2", ["x", "y"])
d = Point2(x=1, y=2)
assert not comparator(a, d)
e = (1, 2)
assert not comparator(a, e)
# ChainMap
map1 = {"a": 1, "b": 2}
map2 = {"c": 3, "d": 4}
a = ChainMap(map1, map2)
b = ChainMap(map1, map2)
c = ChainMap(map2, map1)
d = {"a": 1, "b": 2, "c": 3, "d": 4}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# Counter
a = Counter(["a", "b", "a", "c", "b", "a"])
b = Counter({"a": 3, "b": 2, "c": 1})
c = Counter({"a": 3, "b": 2, "c": 2})
d = {"a": 3, "b": 2, "c": 1}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# OrderedDict
a = OrderedDict([("a", 1), ("b", 2)])
b = OrderedDict([("a", 1), ("b", 2)])
c = OrderedDict([("b", 2), ("a", 1)])
d = {"a": 1, "b": 2}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# defaultdict
a = defaultdict(int, {"a": 1, "b": 2})
b = defaultdict(int, {"a": 1, "b": 2})
c = defaultdict(list, {"a": 1, "b": 2})
d = {"a": 1, "b": 2}
e = defaultdict(int, {"a": 1, "b": 3})
assert comparator(a, b)
assert comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
# UserDict
a = UserDict({"a": 1, "b": 2})
b = UserDict({"a": 1, "b": 2})
c = UserDict({"a": 1, "b": 3})
d = {"a": 1, "b": 2}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# UserList
a = UserList([1, 2, 3])
b = UserList([1, 2, 3])
c = UserList([1, 2, 4])
d = [1, 2, 3]
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
# UserString
a = UserString("hello")
b = UserString("hello")
c = UserString("world")
d = "hello"
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
def test_attrs():
try:
import attrs # type: ignore
except ImportError:
pytest.skip()
@attrs.define
class Person:
name: str
age: int = 10
a = Person("Alice", 25)
b = Person("Alice", 25)
c = Person("Bob", 25)
d = Person("Alice", 30)
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
@attrs.frozen
class Point:
x: int
y: int
p1 = Point(1, 2)
p2 = Point(1, 2)
p3 = Point(2, 3)
assert comparator(p1, p2)
assert not comparator(p1, p3)
@attrs.define(slots=True)
class Vehicle:
brand: str
model: str
year: int = 2020
v1 = Vehicle("Toyota", "Camry", 2021)
v2 = Vehicle("Toyota", "Camry", 2021)
v3 = Vehicle("Honda", "Civic", 2021)
assert comparator(v1, v2)
assert not comparator(v1, v3)
@attrs.define
class ComplexClass:
public_field: str
private_field: str = attrs.field(repr=False)
non_eq_field: int = attrs.field(eq=False, default=0)
computed: str = attrs.field(init=False, eq=True)
def __attrs_post_init__(self):
self.computed = f"{self.public_field}_{self.private_field}"
c1 = ComplexClass("test", "secret")
c2 = ComplexClass("test", "secret")
c3 = ComplexClass("different", "secret")
c1.non_eq_field = 100
c2.non_eq_field = 200
assert comparator(c1, c2)
assert not comparator(c1, c3)
@attrs.define
class Address:
street: str
city: str
@attrs.define
class PersonWithAddress:
name: str
address: Address
addr1 = Address("123 Main St", "Anytown")
addr2 = Address("123 Main St", "Anytown")
addr3 = Address("456 Oak Ave", "Anytown")
person1 = PersonWithAddress("John", addr1)
person2 = PersonWithAddress("John", addr2)
person3 = PersonWithAddress("John", addr3)
assert comparator(person1, person2)
assert not comparator(person1, person3)
@attrs.define
class Container:
items: list
metadata: dict
cont1 = Container([1, 2, 3], {"type": "numbers"})
cont2 = Container([1, 2, 3], {"type": "numbers"})
cont3 = Container([1, 2, 4], {"type": "numbers"})
assert comparator(cont1, cont2)
assert not comparator(cont1, cont3)
@attrs.define
class BaseClass:
name: str
value: int
@attrs.define
class ExtendedClass:
name: str
value: int
extra_field: str = "default"
base = BaseClass("test", 42)
extended = ExtendedClass("test", 42, "extra")
assert not comparator(base, extended)
@attrs.define
class WithNonEqFields:
name: str
timestamp: float = attrs.field(eq=False) # Should be ignored
debug_info: str = attrs.field(eq=False, default="debug")
obj1 = WithNonEqFields("test", 1000.0, "info1")
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
obj3 = WithNonEqFields("different", 1000.0, "info1")
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
assert not comparator(obj1, obj3) # Should be different due to name
@attrs.define
class MinimalClass:
name: str
value: int
@attrs.define
class ExtendedClass:
name: str
value: int
extra_field: str = "default"
metadata: dict = attrs.field(factory=dict)
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
minimal = MinimalClass("test", 42)
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
assert not comparator(minimal, extended)
def test_dict_views() -> None:
"""Test comparator support for dict_keys, dict_values, and dict_items."""
# Test dict_keys
d1 = {"a": 1, "b": 2, "c": 3}
d2 = {"a": 1, "b": 2, "c": 3}
d3 = {"a": 1, "b": 2, "d": 3}
d4 = {"a": 1, "b": 2}
# dict_keys - same keys
assert comparator(d1.keys(), d2.keys())
# dict_keys - different keys
assert not comparator(d1.keys(), d3.keys())
# dict_keys - different length
assert not comparator(d1.keys(), d4.keys())
# Test dict_values
v1 = {"a": 1, "b": 2, "c": 3}
v2 = {"x": 1, "y": 2, "z": 3} # same values, different keys
v3 = {"a": 1, "b": 2, "c": 4} # different value
v4 = {"a": 1, "b": 2} # different length
# dict_values - same values (order matters for values since they're iterable)
assert comparator(v1.values(), v2.values())
# dict_values - different values
assert not comparator(v1.values(), v3.values())
# dict_values - different length
assert not comparator(v1.values(), v4.values())
# Test dict_items
i1 = {"a": 1, "b": 2, "c": 3}
i2 = {"a": 1, "b": 2, "c": 3}
i3 = {"a": 1, "b": 2, "c": 4} # different value
i4 = {"a": 1, "b": 2, "d": 3} # different key
i5 = {"a": 1, "b": 2} # different length
i6 = {"b": 2, "c": 3, "a": 1} # different order
# dict_items - same items
assert comparator(i1.items(), i2.items())
# dict_items - different value
assert not comparator(i1.items(), i3.items())
# dict_items - different key
assert not comparator(i1.items(), i4.items())
# dict_items - different length
assert not comparator(i1.items(), i5.items())
assert comparator(i1.items(), i6.items())
# Test empty dicts
empty1 = {}
empty2 = {}
assert comparator(empty1.keys(), empty2.keys())
assert comparator(empty1.values(), empty2.values())
assert comparator(empty1.items(), empty2.items())
# Test with nested values
nested1 = {"a": [1, 2, 3], "b": {"x": 1}}
nested2 = {"a": [1, 2, 3], "b": {"x": 1}}
nested3 = {"a": [1, 2, 4], "b": {"x": 1}}
assert comparator(nested1.values(), nested2.values())
assert not comparator(nested1.values(), nested3.values())
assert comparator(nested1.items(), nested2.items())
assert not comparator(nested1.items(), nested3.items())
# Test that dict views are not equal to lists/sets
d = {"a": 1, "b": 2}
assert not comparator(d.keys(), ["a", "b"])
assert not comparator(d.keys(), {"a", "b"})
assert not comparator(d.values(), [1, 2])
assert not comparator(d.items(), [("a", 1), ("b", 2)])
def test_mappingproxy() -> None:
"""Test comparator support for types.MappingProxyType (read-only dict view)."""
import types
# Basic equality
mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
assert comparator(mp1, mp2)
# Different values
mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4})
assert not comparator(mp1, mp3)
# Different keys
mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3})
assert not comparator(mp1, mp4)
# Different length
mp5 = types.MappingProxyType({"a": 1, "b": 2})
assert not comparator(mp1, mp5)
# Order doesn't matter (like dict)
mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2})
assert comparator(mp1, mp6)
# Empty mappingproxy
empty1 = types.MappingProxyType({})
empty2 = types.MappingProxyType({})
assert comparator(empty1, empty2)
# Nested values
nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}})
assert comparator(nested1, nested2)
assert not comparator(nested1, nested3)
# mappingproxy is not equal to dict (different types)
d = {"a": 1, "b": 2}
mp = types.MappingProxyType({"a": 1, "b": 2})
assert not comparator(mp, d)
assert not comparator(d, mp)
# Verify class __dict__ is indeed a mappingproxy
class MyClass:
x = 1
y = 2
assert isinstance(MyClass.__dict__, types.MappingProxyType)
def test_mappingproxy_superset() -> None:
"""Test comparator superset_obj support for mappingproxy."""
import types
mp1 = types.MappingProxyType({"a": 1, "b": 2})
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
# mp2 is a superset of mp1
assert comparator(mp1, mp2, superset_obj=True)
# mp1 is not a superset of mp2
assert not comparator(mp2, mp1, superset_obj=True)
# Same mappingproxy with superset_obj=True
assert comparator(mp1, mp1, superset_obj=True)
# Different values even with superset
mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3})
assert not comparator(mp1, mp3, superset_obj=True)
def test_tensorflow_tensor() -> None:
"""Test comparator support for TensorFlow Tensor objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test basic 1D tensors
a = tf.constant([1, 2, 3])
b = tf.constant([1, 2, 3])
c = tf.constant([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
# Test 2D tensors
d = tf.constant([[1, 2, 3], [4, 5, 6]])
e = tf.constant([[1, 2, 3], [4, 5, 6]])
f = tf.constant([[1, 2, 3], [4, 5, 7]])
assert comparator(d, e)
assert not comparator(d, f)
# Test tensors with different shapes
g = tf.constant([1, 2, 3])
h = tf.constant([[1, 2, 3]])
assert not comparator(g, h)
# Test tensors with different dtypes
i = tf.constant([1, 2, 3], dtype=tf.float32)
j = tf.constant([1, 2, 3], dtype=tf.float32)
k = tf.constant([1, 2, 3], dtype=tf.int32)
assert comparator(i, j)
assert not comparator(i, k)
# Test 3D tensors
l = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
m = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
n = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
assert comparator(l, m)
assert not comparator(l, n)
# Test empty tensors
o = tf.constant([])
p = tf.constant([])
q = tf.constant([1.0])
assert comparator(o, p)
assert not comparator(o, q)
# Test tensors with NaN values
r = tf.constant([1.0, float("nan"), 3.0])
s = tf.constant([1.0, float("nan"), 3.0])
t = tf.constant([1.0, 2.0, 3.0])
assert comparator(r, s) # NaN == NaN should be True
assert not comparator(r, t)
# Test tensors with infinity values
u = tf.constant([1.0, float("inf"), 3.0])
v = tf.constant([1.0, float("inf"), 3.0])
w = tf.constant([1.0, float("-inf"), 3.0])
assert comparator(u, v)
assert not comparator(u, w)
# Test complex tensors
x = tf.constant([1 + 2j, 3 + 4j])
y = tf.constant([1 + 2j, 3 + 4j])
z = tf.constant([1 + 2j, 3 + 5j])
assert comparator(x, y)
assert not comparator(x, z)
# Test boolean tensors
aa = tf.constant([True, False, True])
bb = tf.constant([True, False, True])
cc = tf.constant([True, True, True])
assert comparator(aa, bb)
assert not comparator(aa, cc)
# Test string tensors
dd = tf.constant(["hello", "world"])
ee = tf.constant(["hello", "world"])
ff = tf.constant(["hello", "there"])
assert comparator(dd, ee)
assert not comparator(dd, ff)
def test_tensorflow_dtype() -> None:
"""Test comparator support for TensorFlow DType objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test float dtypes
a = tf.float32
b = tf.float32
c = tf.float64
assert comparator(a, b)
assert not comparator(a, c)
# Test integer dtypes
d = tf.int32
e = tf.int32
f = tf.int64
assert comparator(d, e)
assert not comparator(d, f)
# Test unsigned integer dtypes
g = tf.uint8
h = tf.uint8
i = tf.uint16
assert comparator(g, h)
assert not comparator(g, i)
# Test complex dtypes
j = tf.complex64
k = tf.complex64
l = tf.complex128
assert comparator(j, k)
assert not comparator(j, l)
# Test bool dtype
m = tf.bool
n = tf.bool
o = tf.int8
assert comparator(m, n)
assert not comparator(m, o)
# Test string dtype
p = tf.string
q = tf.string
r = tf.int32
assert comparator(p, q)
assert not comparator(p, r)
def test_tensorflow_variable() -> None:
"""Test comparator support for TensorFlow Variable objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test basic variables
a = tf.Variable([1, 2, 3], dtype=tf.float32)
b = tf.Variable([1, 2, 3], dtype=tf.float32)
c = tf.Variable([1, 2, 4], dtype=tf.float32)
assert comparator(a, b)
assert not comparator(a, c)
# Test variables with different dtypes
d = tf.Variable([1, 2, 3], dtype=tf.float32)
e = tf.Variable([1, 2, 3], dtype=tf.float64)
assert not comparator(d, e)
# Test 2D variables
f = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
g = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
h = tf.Variable([[1, 2], [3, 5]], dtype=tf.float32)
assert comparator(f, g)
assert not comparator(f, h)
# Test variables with different shapes
i = tf.Variable([1, 2, 3], dtype=tf.float32)
j = tf.Variable([[1, 2, 3]], dtype=tf.float32)
assert not comparator(i, j)
def test_tensorflow_tensor_shape() -> None:
"""Test comparator support for TensorFlow TensorShape objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test equal shapes
a = tf.TensorShape([2, 3, 4])
b = tf.TensorShape([2, 3, 4])
c = tf.TensorShape([2, 3, 5])
assert comparator(a, b)
assert not comparator(a, c)
# Test different ranks
d = tf.TensorShape([2, 3])
e = tf.TensorShape([2, 3, 4])
assert not comparator(d, e)
# Test scalar shapes
f = tf.TensorShape([])
g = tf.TensorShape([])
h = tf.TensorShape([1])
assert comparator(f, g)
assert not comparator(f, h)
# Test shapes with None dimensions (unknown dimensions)
i = tf.TensorShape([None, 3, 4])
j = tf.TensorShape([None, 3, 4])
k = tf.TensorShape([2, 3, 4])
assert comparator(i, j)
assert not comparator(i, k)
# Test fully unknown shapes
l = tf.TensorShape(None)
m = tf.TensorShape(None)
n = tf.TensorShape([1, 2])
assert comparator(l, m)
assert not comparator(l, n)
def test_tensorflow_sparse_tensor() -> None:
"""Test comparator support for TensorFlow SparseTensor objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test equal sparse tensors
a = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4])
b = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4])
c = tf.SparseTensor(
indices=[[0, 0], [1, 2]],
values=[1.0, 3.0], # Different value
dense_shape=[3, 4],
)
assert comparator(a, b)
assert not comparator(a, c)
# Test sparse tensors with different indices
d = tf.SparseTensor(
indices=[[0, 0], [1, 3]], # Different index
values=[1.0, 2.0],
dense_shape=[3, 4],
)
assert not comparator(a, d)
# Test sparse tensors with different shapes
e = tf.SparseTensor(
indices=[[0, 0], [1, 2]],
values=[1.0, 2.0],
dense_shape=[4, 5], # Different shape
)
assert not comparator(a, e)
# Test empty sparse tensors
f = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4])
g = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4])
assert comparator(f, g)
def test_tensorflow_ragged_tensor() -> None:
"""Test comparator support for TensorFlow RaggedTensor objects."""
try:
import tensorflow as tf
except ImportError:
pytest.skip("tensorflow required for this test")
# Test equal ragged tensors
a = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
b = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
c = tf.ragged.constant([[1, 2], [3, 4, 6], [6]]) # Different value
assert comparator(a, b)
assert not comparator(a, c)
# Test ragged tensors with different row lengths
d = tf.ragged.constant([[1, 2, 3], [4, 5], [6]]) # Different structure
assert not comparator(a, d)
# Test ragged tensors with different dtypes
e = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
f = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
assert comparator(e, f)
assert not comparator(a, e) # int vs float
# Test nested ragged tensors
g = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
h = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
i = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 7]]])
assert comparator(g, h)
assert not comparator(g, i)
# Test empty ragged tensors
j = tf.ragged.constant([[], [], []])
k = tf.ragged.constant([[], [], []])
assert comparator(j, k)
def test_slice() -> None:
"""Test comparator support for slice objects."""
# Test equal slices
a = slice(1, 10, 2)
b = slice(1, 10, 2)
assert comparator(a, b)
# Test slices with different start
c = slice(2, 10, 2)
assert not comparator(a, c)
# Test slices with different stop
d = slice(1, 11, 2)
assert not comparator(a, d)
# Test slices with different step
e = slice(1, 10, 3)
assert not comparator(a, e)
# Test slices with None values
f = slice(None, 10, 2)
g = slice(None, 10, 2)
h = slice(1, 10, 2)
assert comparator(f, g)
assert not comparator(f, h)
# Test slices with all None (equivalent to [:])
i = slice(None, None, None)
j = slice(None, None, None)
k = slice(None, None, 1)
assert comparator(i, j)
assert not comparator(i, k)
# Test slices with only stop
l = slice(5)
m = slice(5)
n = slice(6)
assert comparator(l, m)
assert not comparator(l, n)
# Test slices with negative values
o = slice(-5, -1, 1)
p = slice(-5, -1, 1)
q = slice(-5, -2, 1)
assert comparator(o, p)
assert not comparator(o, q)
# Test slice is not equal to other types
r = slice(1, 10)
s = (1, 10)
assert not comparator(r, s)
def test_numpy_datetime64() -> None:
"""Test comparator support for numpy datetime64 and timedelta64 types."""
try:
import numpy as np
except ImportError:
pytest.skip("numpy required for this test")
# Test datetime64 equality
a = np.datetime64("2021-01-01")
b = np.datetime64("2021-01-01")
c = np.datetime64("2021-01-02")
assert comparator(a, b)
assert not comparator(a, c)
# Test datetime64 with different units
d = np.datetime64("2021-01-01", "D")
e = np.datetime64("2021-01-01", "D")
f = np.datetime64("2021-01-01", "s") # Different unit (seconds)
assert comparator(d, e)
# Note: datetime64 with different units but same moment may or may not be equal
# depending on numpy version behavior
# Test datetime64 with time
g = np.datetime64("2021-01-01T12:00:00")
h = np.datetime64("2021-01-01T12:00:00")
i = np.datetime64("2021-01-01T12:00:01")
assert comparator(g, h)
assert not comparator(g, i)
# Test timedelta64 equality
j = np.timedelta64(1, "D")
k = np.timedelta64(1, "D")
l = np.timedelta64(2, "D")
assert comparator(j, k)
assert not comparator(j, l)
# Test timedelta64 with different units
m = np.timedelta64(1, "h")
n = np.timedelta64(1, "h")
o = np.timedelta64(60, "m") # Same duration, different unit
assert comparator(m, n)
# 1 hour == 60 minutes, but they have different units
# numpy may treat them as equal or not depending on comparison
# Test NaT (Not a Time) - numpy's equivalent of NaN for datetime
p = np.datetime64("NaT")
q = np.datetime64("NaT")
r = np.datetime64("2021-01-01")
assert comparator(p, q) # NaT == NaT should be True
assert not comparator(p, r)
# Test timedelta64 NaT
s = np.timedelta64("NaT")
t = np.timedelta64("NaT")
u = np.timedelta64(1, "D")
assert comparator(s, t) # NaT == NaT should be True
assert not comparator(s, u)
# Test datetime64 is not equal to other types
v = np.datetime64("2021-01-01")
w = "2021-01-01"
assert not comparator(v, w)
# Test arrays of datetime64
x = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64")
y = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64")
z = np.array(["2021-01-01", "2021-01-03"], dtype="datetime64")
assert comparator(x, y)
assert not comparator(x, z)
def test_numpy_0d_array() -> None:
"""Test comparator handles 0-d numpy arrays without 'iteration over 0-d array' error."""
try:
import numpy as np
except ImportError:
pytest.skip("numpy required for this test")
# Test 0-d integer array
a = np.array(5)
b = np.array(5)
c = np.array(6)
assert comparator(a, b)
assert not comparator(a, c)
# Test 0-d float array
d = np.array(3.14)
e = np.array(3.14)
f = np.array(2.71)
assert comparator(d, e)
assert not comparator(d, f)
# Test 0-d complex array
g = np.array(1 + 2j)
h = np.array(1 + 2j)
i = np.array(1 + 3j)
assert comparator(g, h)
assert not comparator(g, i)
# Test 0-d string array
j = np.array("hello")
k = np.array("hello")
l = np.array("world")
assert comparator(j, k)
assert not comparator(j, l)
# Test 0-d boolean array
m = np.array(True)
n = np.array(True)
o = np.array(False)
assert comparator(m, n)
assert not comparator(m, o)
# Test 0-d array with NaN
p = np.array(np.nan)
q = np.array(np.nan)
r = np.array(1.0)
assert comparator(p, q) # NaN == NaN should be True
assert not comparator(p, r)
# Test 0-d datetime64 array
s = np.array(np.datetime64("2021-01-01"))
t = np.array(np.datetime64("2021-01-01"))
u = np.array(np.datetime64("2021-01-02"))
assert comparator(s, t)
assert not comparator(s, u)
# Test 0-d array vs scalar
v = np.array(5)
w = 5
# 0-d array and scalar are different types
assert not comparator(v, w)
# Test 0-d array vs 1-d array with one element
x = np.array(5)
y = np.array([5])
# Different shapes
assert not comparator(x, y)
def test_numpy_dtypes() -> None:
"""Test comparator for numpy.dtypes types like Float64DType, Int64DType, etc."""
try:
import numpy as np
from numpy import dtypes
except ImportError:
pytest.skip("numpy not available")
# Test Float64DType
a = dtypes.Float64DType()
b = dtypes.Float64DType()
assert comparator(a, b)
# Test Int64DType
c = dtypes.Int64DType()
d = dtypes.Int64DType()
assert comparator(c, d)
# Test different DType classes should not be equal
assert not comparator(a, c) # Float64DType vs Int64DType
# Test various numeric DType classes
assert comparator(dtypes.Int8DType(), dtypes.Int8DType())
assert comparator(dtypes.Int16DType(), dtypes.Int16DType())
assert comparator(dtypes.Int32DType(), dtypes.Int32DType())
assert comparator(dtypes.UInt8DType(), dtypes.UInt8DType())
assert comparator(dtypes.UInt16DType(), dtypes.UInt16DType())
assert comparator(dtypes.UInt32DType(), dtypes.UInt32DType())
assert comparator(dtypes.UInt64DType(), dtypes.UInt64DType())
assert comparator(dtypes.Float32DType(), dtypes.Float32DType())
assert comparator(dtypes.Complex64DType(), dtypes.Complex64DType())
assert comparator(dtypes.Complex128DType(), dtypes.Complex128DType())
assert comparator(dtypes.BoolDType(), dtypes.BoolDType())
# Test cross-type comparisons should be False
assert not comparator(dtypes.Int32DType(), dtypes.Int64DType())
assert not comparator(dtypes.Float32DType(), dtypes.Float64DType())
assert not comparator(dtypes.UInt32DType(), dtypes.Int32DType())
# Test regular np.dtype instances
e = np.dtype("float64")
f = np.dtype("float64")
assert comparator(e, f)
g = np.dtype("int64")
h = np.dtype("int64")
assert comparator(g, h)
assert not comparator(e, g) # float64 vs int64
# Test DType class instances vs regular np.dtype (they should be equal if same underlying type)
assert comparator(dtypes.Float64DType(), np.dtype("float64"))
assert comparator(dtypes.Int64DType(), np.dtype("int64"))
assert comparator(dtypes.Int32DType(), np.dtype("int32"))
assert comparator(dtypes.BoolDType(), np.dtype("bool"))
# Test that DType and np.dtype of different types are not equal
assert not comparator(dtypes.Float64DType(), np.dtype("int64"))
assert not comparator(dtypes.Int32DType(), np.dtype("float32"))
def test_numpy_extended_precision_types() -> None:
"""Test comparator for numpy extended precision types like clongdouble."""
try:
import numpy as np
except ImportError:
pytest.skip("numpy not available")
# Test np.clongdouble (extended precision complex)
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
c3 = np.clongdouble(1 + 3j)
assert comparator(c1, c2)
assert not comparator(c1, c3)
# Test np.longdouble (extended precision float)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
l3 = np.longdouble(2.5)
assert comparator(l1, l2)
assert not comparator(l1, l3)
# Test NaN handling for extended precision complex
nan_c1 = np.clongdouble(complex(np.nan, 2))
nan_c2 = np.clongdouble(complex(np.nan, 2))
assert comparator(nan_c1, nan_c2)
# Test NaN handling for extended precision float
nan_l1 = np.longdouble(np.nan)
nan_l2 = np.longdouble(np.nan)
assert comparator(nan_l1, nan_l2)
def test_numpy_typing_types() -> None:
"""Test comparator for numpy.typing types like NDArray type aliases."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test NDArray type alias comparisons
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
arr_type3 = npt.NDArray[np.int64]
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3)
# Test NBitBase (if it can be instantiated)
try:
nbit1 = npt.NBitBase()
nbit2 = npt.NBitBase()
# NBitBase instances with empty __dict__ should compare as equal
assert comparator(nbit1, nbit2)
# Also test with superset_obj=True
assert comparator(nbit1, nbit2, superset_obj=True)
except TypeError:
# NBitBase may not be instantiable in all numpy versions
pass
def test_numpy_typing_superset_obj() -> None:
"""Test comparator with superset_obj=True for numpy types."""
try:
import numpy as np
import numpy.typing as npt
except ImportError:
pytest.skip("numpy or numpy.typing not available")
# Test numpy arrays with object dtype containing dicts (superset scenario)
a1 = np.array([{"a": 1}], dtype=object)
a2 = np.array([{"a": 1, "b": 2}], dtype=object) # superset
assert comparator(a1, a2, superset_obj=True)
assert not comparator(a1, a2, superset_obj=False)
# Test extended precision types with superset_obj=True
c1 = np.clongdouble(1 + 2j)
c2 = np.clongdouble(1 + 2j)
assert comparator(c1, c2, superset_obj=True)
l1 = np.longdouble(1.5)
l2 = np.longdouble(1.5)
assert comparator(l1, l2, superset_obj=True)
# Test NDArray type alias with superset_obj=True
arr_type1 = npt.NDArray[np.float64]
arr_type2 = npt.NDArray[np.float64]
assert comparator(arr_type1, arr_type2, superset_obj=True)
# Test numpy structured arrays (np.void) with superset_obj=True
dt = np.dtype([("name", "S10"), ("age", np.int32)])
a_struct = np.array([("Alice", 25)], dtype=dt)
b_struct = np.array([("Alice", 25)], dtype=dt)
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
# Test numpy random generators with superset_obj=True
rng1 = np.random.default_rng(seed=42)
rng2 = np.random.default_rng(seed=42)
assert comparator(rng1, rng2, superset_obj=True)
rs1 = np.random.RandomState(seed=42)
rs2 = np.random.RandomState(seed=42)
assert comparator(rs1, rs2, superset_obj=True)
def test_numba_typed_list() -> None:
"""Test comparator for numba.typed.List."""
try:
import numba
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test equal lists
a = NumbaList([1, 2, 3])
b = NumbaList([1, 2, 3])
assert comparator(a, b)
# Test different values
c = NumbaList([1, 2, 4])
assert not comparator(a, c)
# Test different lengths
d = NumbaList([1, 2, 3, 4])
assert not comparator(a, d)
# Test empty lists
e = NumbaList.empty_list(item_type=numba.int64)
f = NumbaList.empty_list(item_type=numba.int64)
assert comparator(e, f)
# Test nested values (floats)
g = NumbaList([1.0, 2.0, 3.0])
h = NumbaList([1.0, 2.0, 3.0])
assert comparator(g, h)
i = NumbaList([1.0, 2.0, 4.0])
assert not comparator(g, i)
def test_numba_typed_dict() -> None:
"""Test comparator for numba.typed.Dict."""
try:
import numba
from numba.typed import Dict as NumbaDict
except ImportError:
pytest.skip("numba not available")
# Test equal dicts
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
a["x"] = 1
a["y"] = 2
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
b["x"] = 1
b["y"] = 2
assert comparator(a, b)
# Test different values
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
c["x"] = 1
c["y"] = 3
assert not comparator(a, c)
# Test different keys
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
d["x"] = 1
d["z"] = 2
assert not comparator(a, d)
# Test different lengths
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
e["x"] = 1
assert not comparator(a, e)
# Test empty dicts
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
assert comparator(f, g)
def test_numba_types() -> None:
"""Test comparator for numba type objects."""
try:
import numba
from numba import types
except ImportError:
pytest.skip("numba not available")
# Test basic numeric types from numba module
assert comparator(numba.int64, numba.int64)
assert comparator(numba.float64, numba.float64)
assert comparator(numba.int32, numba.int32)
assert comparator(numba.float32, numba.float32)
# Test basic numeric types from numba.types module
assert comparator(types.int64, types.int64)
assert comparator(types.float64, types.float64)
assert comparator(types.int8, types.int8)
assert comparator(types.int16, types.int16)
assert comparator(types.uint8, types.uint8)
assert comparator(types.uint16, types.uint16)
assert comparator(types.uint32, types.uint32)
assert comparator(types.uint64, types.uint64)
assert comparator(types.complex64, types.complex64)
assert comparator(types.complex128, types.complex128)
# Test different types
assert not comparator(numba.int64, numba.float64)
assert not comparator(numba.int32, numba.int64)
assert not comparator(numba.float32, numba.float64)
assert not comparator(types.int8, types.int16)
assert not comparator(types.uint32, types.int32)
assert not comparator(types.complex64, types.complex128)
# Test boolean type
assert comparator(numba.boolean, numba.boolean)
assert comparator(types.boolean, types.boolean)
assert not comparator(numba.boolean, numba.int64)
# Test special types
assert comparator(types.none, types.none)
assert comparator(types.void, types.void)
assert comparator(types.pyobject, types.pyobject)
assert comparator(types.unicode_type, types.unicode_type)
# Note: types.none and types.void are the same object in numba
assert comparator(types.none, types.void)
assert not comparator(types.unicode_type, types.pyobject)
assert not comparator(types.none, types.int64)
# Test array types
arr_type1 = types.Array(numba.float64, 1, "C")
arr_type2 = types.Array(numba.float64, 1, "C")
arr_type3 = types.Array(numba.float64, 2, "C")
arr_type4 = types.Array(numba.int64, 1, "C")
arr_type5 = types.Array(numba.float64, 1, "F") # Fortran order
assert comparator(arr_type1, arr_type2)
assert not comparator(arr_type1, arr_type3) # different ndim
assert not comparator(arr_type1, arr_type4) # different dtype
assert not comparator(arr_type1, arr_type5) # different layout
# Test tuple types
tuple_type1 = types.UniTuple(types.int64, 3)
tuple_type2 = types.UniTuple(types.int64, 3)
tuple_type3 = types.UniTuple(types.int64, 4)
tuple_type4 = types.UniTuple(types.float64, 3)
assert comparator(tuple_type1, tuple_type2)
assert not comparator(tuple_type1, tuple_type3) # different count
assert not comparator(tuple_type1, tuple_type4) # different dtype
# Test heterogeneous tuple types
hetero_tuple1 = types.Tuple([types.int64, types.float64])
hetero_tuple2 = types.Tuple([types.int64, types.float64])
hetero_tuple3 = types.Tuple([types.int64, types.int64])
assert comparator(hetero_tuple1, hetero_tuple2)
assert not comparator(hetero_tuple1, hetero_tuple3)
# Test ListType and DictType
list_type1 = types.ListType(types.int64)
list_type2 = types.ListType(types.int64)
list_type3 = types.ListType(types.float64)
assert comparator(list_type1, list_type2)
assert not comparator(list_type1, list_type3)
dict_type1 = types.DictType(types.unicode_type, types.int64)
dict_type2 = types.DictType(types.unicode_type, types.int64)
dict_type3 = types.DictType(types.unicode_type, types.float64)
dict_type4 = types.DictType(types.int64, types.int64)
assert comparator(dict_type1, dict_type2)
assert not comparator(dict_type1, dict_type3) # different value type
assert not comparator(dict_type1, dict_type4) # different key type
def test_numba_jit_functions() -> None:
"""Test comparator for numba JIT-compiled functions."""
try:
from numba import jit
except ImportError:
pytest.skip("numba not available")
@jit(nopython=True)
def add(x, y):
return x + y
@jit(nopython=True)
def add2(x, y):
return x + y
@jit(nopython=True)
def multiply(x, y):
return x * y
# Compile the functions by calling them
add(1, 2)
add2(1, 2)
multiply(1, 2)
# Same function should compare equal to itself
assert comparator(add, add)
# Different functions (even with same code) should not compare equal
# since they are distinct function objects
assert not comparator(add, add2)
# Different functions with different code should not compare equal
assert not comparator(add, multiply)
def test_numba_superset_obj() -> None:
"""Test comparator for numba types with superset_obj=True."""
try:
import numba
from numba.typed import Dict as NumbaDict
from numba.typed import List as NumbaList
except ImportError:
pytest.skip("numba not available")
# Test NumbaDict with superset_obj=True
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
orig_dict["x"] = 1
orig_dict["y"] = 2
# New dict with same keys - should pass
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_same["x"] = 1
new_dict_same["y"] = 2
assert comparator(orig_dict, new_dict_same, superset_obj=True)
# New dict with extra keys - should pass with superset_obj=True
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_superset["x"] = 1
new_dict_superset["y"] = 2
new_dict_superset["z"] = 3
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
# But should fail with superset_obj=False
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
# New dict missing keys - should fail even with superset_obj=True
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_subset["x"] = 1
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
# New dict with different values - should fail
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
new_dict_diff["x"] = 1
new_dict_diff["y"] = 99
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
orig_list = NumbaList([1, 2, 3])
new_list_same = NumbaList([1, 2, 3])
new_list_longer = NumbaList([1, 2, 3, 4])
assert comparator(orig_list, new_list_same, superset_obj=True)
# Lists must have same length regardless of superset_obj
assert not comparator(orig_list, new_list_longer, superset_obj=True)
# Test empty dict with superset_obj=True
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
non_empty_new["a"] = 1
# Empty orig should match any superset
assert comparator(empty_orig, non_empty_new, superset_obj=True)
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
# =============================================================================
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
# =============================================================================
class TestIsTempPath:
"""Tests for the _is_temp_path() function."""
def test_standard_pytest_temp_path(self):
"""Test detection of standard pytest temp paths."""
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_something")
assert _is_temp_path("/tmp/pytest-of-user/pytest-123/")
assert _is_temp_path("/tmp/pytest-of-admin/pytest-999/subdir/file.txt")
def test_different_usernames(self):
"""Test temp paths with various usernames."""
assert _is_temp_path("/tmp/pytest-of-root/pytest-1/")
assert _is_temp_path("/tmp/pytest-of-john_doe/pytest-42/")
assert _is_temp_path("/tmp/pytest-of-user123/pytest-0/test")
assert _is_temp_path("/tmp/pytest-of-test-user/pytest-5/data")
def test_different_session_numbers(self):
"""Test temp paths with various session numbers."""
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/")
assert _is_temp_path("/tmp/pytest-of-user/pytest-1/")
assert _is_temp_path("/tmp/pytest-of-user/pytest-99/")
assert _is_temp_path("/tmp/pytest-of-user/pytest-12345/")
def test_paths_with_subdirectories(self):
"""Test temp paths with nested subdirectories."""
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_func/subdir")
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/a/b/c/d/file.txt")
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_module0/test_file.py")
def test_paths_with_filenames(self):
"""Test temp paths ending with filenames."""
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/output.json")
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test.log")
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/data.csv")
def test_non_temp_paths(self):
"""Test that non-temp paths are correctly identified."""
assert not _is_temp_path("/home/user/project/test.py")
assert not _is_temp_path("/tmp/other/directory")
assert not _is_temp_path("/var/log/test.log")
assert not _is_temp_path("./relative/path")
assert not _is_temp_path("test_file.py")
def test_similar_but_not_temp_paths(self):
"""Test paths that look similar but don't match the pattern."""
assert not _is_temp_path("/tmp/pytest-user/pytest-0/") # missing "of-"
assert not _is_temp_path("/tmp/pytest-of-user/pytest-/") # no number
assert not _is_temp_path("/tmp/pytest-of-/pytest-0/") # empty username
assert not _is_temp_path("/tmp/pytest-of-user/pytest-abc/") # non-numeric session
def test_edge_cases(self):
"""Test edge cases for _is_temp_path."""
assert not _is_temp_path("")
assert not _is_temp_path("/")
assert not _is_temp_path("/tmp/")
assert not _is_temp_path("/tmp/pytest-of-")
def test_path_embedded_in_string(self):
"""Test that temp paths are detected when embedded in longer strings."""
assert _is_temp_path("Error in /tmp/pytest-of-user/pytest-0/test.py: failed")
assert _is_temp_path("File: /tmp/pytest-of-user/pytest-123/output.txt")
def test_windows_style_paths(self):
"""Test that Windows-style paths are not detected as temp paths."""
assert not _is_temp_path("C:\\Users\\test\\pytest")
assert not _is_temp_path("D:\\tmp\\pytest-of-user\\pytest-0\\")
class TestNormalizeTempPath:
"""Tests for the _normalize_temp_path() function."""
def test_basic_normalization(self):
"""Test basic temp path normalization."""
assert _normalize_temp_path("/tmp/pytest-of-user/pytest-0/test") == "/tmp/pytest-temp/test"
assert _normalize_temp_path("/tmp/pytest-of-user/pytest-123/test") == "/tmp/pytest-temp/test"
def test_different_session_numbers_normalize_same(self):
"""Test that different session numbers normalize to the same result."""
path1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/file.txt")
path2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-99/file.txt")
path3 = _normalize_temp_path("/tmp/pytest-of-user/pytest-12345/file.txt")
assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt"
def test_different_usernames_normalize_same(self):
"""Test that different usernames normalize to the same result."""
path1 = _normalize_temp_path("/tmp/pytest-of-alice/pytest-0/file.txt")
path2 = _normalize_temp_path("/tmp/pytest-of-bob/pytest-0/file.txt")
path3 = _normalize_temp_path("/tmp/pytest-of-root/pytest-0/file.txt")
assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt"
def test_complex_subdirectories(self):
"""Test normalization with complex subdirectory structures."""
result = _normalize_temp_path("/tmp/pytest-of-user/pytest-42/test_module/subdir/file.py")
assert result == "/tmp/pytest-temp/test_module/subdir/file.py"
def test_non_temp_path_unchanged(self):
"""Test that non-temp paths are returned unchanged."""
path = "/home/user/project/test.py"
assert _normalize_temp_path(path) == path
def test_empty_string(self):
"""Test normalization of empty string."""
assert _normalize_temp_path("") == ""
def test_path_with_multiple_occurrences(self):
"""Test paths with multiple temp path patterns (unusual but possible in error messages)."""
path = "/tmp/pytest-of-user/pytest-0/ref to /tmp/pytest-of-user/pytest-1/other"
result = _normalize_temp_path(path)
assert result == "/tmp/pytest-temp/ref to /tmp/pytest-temp/other"
def test_trailing_slash_handling(self):
"""Test normalization preserves or removes trailing slashes correctly."""
result1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/")
result2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/subdir/")
assert result1 == "/tmp/pytest-temp/"
assert result2 == "/tmp/pytest-temp/subdir/"
class TestComparatorTempPaths:
"""Tests for comparator() with temp path strings."""
def test_identical_temp_paths(self):
"""Test that identical temp paths compare as equal."""
path = "/tmp/pytest-of-user/pytest-0/test.txt"
assert comparator(path, path)
def test_different_session_numbers(self):
"""Test that paths differing only in session number are equal."""
path1 = "/tmp/pytest-of-user/pytest-0/output.txt"
path2 = "/tmp/pytest-of-user/pytest-99/output.txt"
assert comparator(path1, path2)
def test_different_usernames(self):
"""Test that paths differing in username are equal."""
path1 = "/tmp/pytest-of-alice/pytest-0/result.json"
path2 = "/tmp/pytest-of-bob/pytest-0/result.json"
assert comparator(path1, path2)
def test_different_usernames_and_sessions(self):
"""Test that paths differing in both username and session are equal."""
path1 = "/tmp/pytest-of-alice/pytest-10/data/file.csv"
path2 = "/tmp/pytest-of-bob/pytest-999/data/file.csv"
assert comparator(path1, path2)
def test_different_subdirectories_not_equal(self):
"""Test that paths with different subdirectories are not equal."""
path1 = "/tmp/pytest-of-user/pytest-0/subdir1/file.txt"
path2 = "/tmp/pytest-of-user/pytest-0/subdir2/file.txt"
assert not comparator(path1, path2)
def test_different_filenames_not_equal(self):
"""Test that paths with different filenames are not equal."""
path1 = "/tmp/pytest-of-user/pytest-0/file1.txt"
path2 = "/tmp/pytest-of-user/pytest-0/file2.txt"
assert not comparator(path1, path2)
def test_temp_path_vs_non_temp_path(self):
"""Test that temp paths don't match non-temp paths."""
temp_path = "/tmp/pytest-of-user/pytest-0/file.txt"
non_temp_path = "/home/user/file.txt"
assert not comparator(temp_path, non_temp_path)
def test_regular_strings_still_work(self):
"""Test that regular string comparison still works."""
assert comparator("hello", "hello")
assert not comparator("hello", "world")
assert comparator("", "")
assert not comparator("test", "")
def test_non_temp_paths_must_be_exact(self):
"""Test that non-temp paths require exact equality."""
path1 = "/home/user/project/file.txt"
path2 = "/home/user/project/file.txt"
path3 = "/home/user/project/other.txt"
assert comparator(path1, path2)
assert not comparator(path1, path3)
class TestComparatorTempPathsInNestedStructures:
"""Tests for comparator() with temp paths in nested data structures."""
def test_temp_paths_in_list(self):
"""Test temp paths inside lists."""
list1 = ["/tmp/pytest-of-alice/pytest-0/file.txt", "other"]
list2 = ["/tmp/pytest-of-bob/pytest-99/file.txt", "other"]
assert comparator(list1, list2)
def test_temp_paths_in_tuple(self):
"""Test temp paths inside tuples."""
tuple1 = ("/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt")
tuple2 = ("/tmp/pytest-of-user/pytest-123/a.txt", "/tmp/pytest-of-user/pytest-123/b.txt")
assert comparator(tuple1, tuple2)
def test_temp_paths_in_dict_values(self):
"""Test temp paths as dictionary values."""
dict1 = {"path": "/tmp/pytest-of-user/pytest-0/output.json", "name": "test"}
dict2 = {"path": "/tmp/pytest-of-user/pytest-999/output.json", "name": "test"}
assert comparator(dict1, dict2)
def test_temp_paths_in_dict_keys_not_supported(self):
"""Test that temp paths as dictionary keys must match exactly (keys are not normalized)."""
# Dict keys use direct comparison, so temp paths as keys won't be normalized
# This tests the expected behavior
dict1 = {"/tmp/pytest-of-user/pytest-0/key": "value"}
dict2 = {"/tmp/pytest-of-user/pytest-0/key": "value"}
assert comparator(dict1, dict2)
def test_temp_paths_in_nested_dict(self):
"""Test temp paths in nested dictionaries."""
nested1 = {
"config": {
"output_path": "/tmp/pytest-of-alice/pytest-5/results",
"log_path": "/tmp/pytest-of-alice/pytest-5/logs",
}
}
nested2 = {
"config": {
"output_path": "/tmp/pytest-of-bob/pytest-10/results",
"log_path": "/tmp/pytest-of-bob/pytest-10/logs",
}
}
assert comparator(nested1, nested2)
def test_temp_paths_in_deeply_nested_structure(self):
"""Test temp paths in deeply nested structures."""
deep1 = {"a": {"b": {"c": ["/tmp/pytest-of-user/pytest-0/file.txt"]}}}
deep2 = {"a": {"b": {"c": ["/tmp/pytest-of-other/pytest-99/file.txt"]}}}
assert comparator(deep1, deep2)
def test_mixed_temp_and_regular_paths(self):
"""Test structures with both temp and regular paths."""
data1 = {"temp": "/tmp/pytest-of-user/pytest-0/temp.txt", "regular": "/home/user/file.txt"}
data2 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/file.txt"}
assert comparator(data1, data2)
data3 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/different.txt"}
assert not comparator(data1, data3)
def test_temp_paths_in_deque(self):
"""Test temp paths inside deque."""
from collections import deque
d1 = deque(["/tmp/pytest-of-user/pytest-0/file.txt"])
d2 = deque(["/tmp/pytest-of-user/pytest-123/file.txt"])
assert comparator(d1, d2)
def test_temp_paths_in_chainmap(self):
"""Test temp paths inside ChainMap."""
from collections import ChainMap
cm1 = ChainMap({"path": "/tmp/pytest-of-user/pytest-0/file.txt"})
cm2 = ChainMap({"path": "/tmp/pytest-of-user/pytest-99/file.txt"})
assert comparator(cm1, cm2)
class TestComparatorTempPathsEdgeCases:
"""Edge case tests for temp path handling in comparator."""
def test_empty_string_vs_temp_path(self):
"""Test empty string comparison with temp path."""
assert not comparator("", "/tmp/pytest-of-user/pytest-0/file.txt")
assert not comparator("/tmp/pytest-of-user/pytest-0/file.txt", "")
def test_path_with_special_characters(self):
"""Test temp paths containing special characters in filenames."""
path1 = "/tmp/pytest-of-user/pytest-0/file with spaces.txt"
path2 = "/tmp/pytest-of-user/pytest-99/file with spaces.txt"
assert comparator(path1, path2)
path3 = "/tmp/pytest-of-user/pytest-0/file-with-dashes.txt"
path4 = "/tmp/pytest-of-user/pytest-99/file-with-dashes.txt"
assert comparator(path3, path4)
def test_path_with_unicode_characters(self):
"""Test temp paths with unicode characters."""
path1 = "/tmp/pytest-of-user/pytest-0/файл.txt"
path2 = "/tmp/pytest-of-user/pytest-99/файл.txt"
assert comparator(path1, path2)
def test_very_long_session_number(self):
"""Test temp paths with very long session numbers."""
path1 = "/tmp/pytest-of-user/pytest-9999999999/file.txt"
path2 = "/tmp/pytest-of-user/pytest-0/file.txt"
assert comparator(path1, path2)
def test_username_with_special_characters(self):
"""Test temp paths with special characters in username."""
path1 = "/tmp/pytest-of-user-name/pytest-0/file.txt"
path2 = "/tmp/pytest-of-other-user/pytest-99/file.txt"
assert comparator(path1, path2)
def test_path_only_differs_in_temp_portion(self):
"""Test that only the temp portion is normalized, rest must match."""
path1 = "/tmp/pytest-of-user/pytest-0/subdir/nested/file.txt"
path2 = "/tmp/pytest-of-user/pytest-99/subdir/nested/file.txt"
assert comparator(path1, path2)
path3 = "/tmp/pytest-of-user/pytest-0/subdir/nested/other.txt"
assert not comparator(path1, path3)
def test_multiple_slashes(self):
"""Test temp paths with multiple consecutive slashes (should still work)."""
# Note: The regex handles the standard format, extra slashes may not be normalized
path1 = "/tmp/pytest-of-user/pytest-0/file.txt"
path2 = "/tmp/pytest-of-user/pytest-99/file.txt"
assert comparator(path1, path2)
def test_temp_path_at_start_middle_end(self):
"""Test that temp paths are detected regardless of position in string."""
# Path at start
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test")
# Path in middle (embedded in error message)
assert _is_temp_path("Error: /tmp/pytest-of-user/pytest-0/test failed")
# Path at end
assert _is_temp_path("Output saved to /tmp/pytest-of-user/pytest-0/")
def test_partial_temp_path_patterns(self):
"""Test strings that partially match temp path pattern."""
# Missing components
assert not _is_temp_path("/tmp/pytest-of-user/")
assert not _is_temp_path("/tmp/pytest-0/")
assert not _is_temp_path("pytest-of-user/pytest-0/")
class TestPytestTempPathPatternRegex:
"""Tests for the PYTEST_TEMP_PATH_PATTERN regex directly."""
def test_pattern_matches_standard_format(self):
"""Test regex matches standard pytest temp path format."""
assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-0/")
assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-123/file")
def test_pattern_captures_correctly(self):
"""Test that the pattern substitution works correctly."""
result = PYTEST_TEMP_PATH_PATTERN.sub("REPLACED", "/tmp/pytest-of-user/pytest-0/file.txt")
assert result == "REPLACEDfile.txt"
def test_pattern_handles_multiple_matches(self):
"""Test pattern with multiple temp paths in same string."""
text = "/tmp/pytest-of-a/pytest-1/ and /tmp/pytest-of-b/pytest-2/"
result = PYTEST_TEMP_PATH_PATTERN.sub("X", text)
assert result == "X and X"
def test_pattern_greedy_behavior(self):
"""Test that the pattern doesn't over-match."""
# The pattern should stop at the trailing slash of the session number
path = "/tmp/pytest-of-user/pytest-0/subdir/pytest-1/file.txt"
result = PYTEST_TEMP_PATH_PATTERN.sub("X", path)
# The first temp path should be replaced, but "pytest-1" in subdir shouldn't trigger
assert "subdir" in result
class TestComparatorTempPathsWithSuperset:
"""Tests for temp path comparison with superset_obj=True."""
def test_superset_with_temp_paths_in_dict(self):
"""Test superset comparison with temp paths in dictionaries."""
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
new = {"path": "/tmp/pytest-of-user/pytest-99/file.txt", "extra": "data"}
assert comparator(orig, new, superset_obj=True)
def test_superset_temp_paths_must_still_match(self):
"""Test that temp paths must still be equivalent in superset mode."""
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
new = {"path": "/tmp/pytest-of-user/pytest-99/other.txt", "extra": "data"}
assert not comparator(orig, new, superset_obj=True)
def test_superset_nested_dict_with_temp_paths(self):
"""Test superset comparison with temp paths in nested dictionaries."""
orig = {"config": {"output": "/tmp/pytest-of-alice/pytest-5/results.json"}}
new = {
"config": {"output": "/tmp/pytest-of-bob/pytest-100/results.json", "debug": True},
"metadata": {"version": "1.0"},
}
assert comparator(orig, new, superset_obj=True)
def test_superset_multiple_temp_paths_in_dict(self):
"""Test superset with multiple temp paths in dictionary values."""
orig = {"input": "/tmp/pytest-of-user/pytest-0/input.txt", "output": "/tmp/pytest-of-user/pytest-0/output.txt"}
new = {
"input": "/tmp/pytest-of-user/pytest-99/input.txt",
"output": "/tmp/pytest-of-user/pytest-99/output.txt",
"log": "/tmp/pytest-of-user/pytest-99/debug.log",
}
assert comparator(orig, new, superset_obj=True)
def test_superset_temp_path_in_list_inside_dict(self):
"""Test superset with temp paths in lists inside dictionaries."""
orig = {"files": ["/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt"]}
new = {"files": ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"], "count": 2}
assert comparator(orig, new, superset_obj=True)
def test_superset_false_when_temp_path_missing(self):
"""Test superset fails when temp path key is missing in new."""
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
new = {"other": "data"}
assert not comparator(orig, new, superset_obj=True)
def test_superset_temp_path_with_different_filenames_fails(self):
"""Test superset fails when normalized temp paths have different filenames."""
orig = {"result": "/tmp/pytest-of-user/pytest-0/output_v1.json"}
new = {"result": "/tmp/pytest-of-user/pytest-99/output_v2.json", "extra": "data"}
assert not comparator(orig, new, superset_obj=True)
def test_superset_mixed_temp_and_regular_paths(self):
"""Test superset with mix of temp paths and regular paths."""
orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"}
new = {
"temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt",
"config_file": "/etc/app/config.yaml",
"extra_key": "extra_value",
}
assert comparator(orig, new, superset_obj=True)
def test_superset_regular_path_must_match_exactly(self):
"""Test that regular paths must match exactly even in superset mode."""
orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"}
new = {
"temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt",
"config_file": "/etc/app/other.yaml",
"extra_key": "extra_value",
}
assert not comparator(orig, new, superset_obj=True)
def test_superset_deeply_nested_temp_paths(self):
"""Test superset with deeply nested structures containing temp paths."""
orig = {"level1": {"level2": {"level3": {"path": "/tmp/pytest-of-user/pytest-0/deep.txt"}}}}
new = {
"level1": {
"level2": {
"level3": {"path": "/tmp/pytest-of-other/pytest-999/deep.txt", "extra": True},
"sibling": "value",
}
},
"top_level_extra": 123,
}
assert comparator(orig, new, superset_obj=True)
def test_superset_with_attrs_class_containing_temp_paths(self):
"""Test superset with attrs classes containing temp paths."""
try:
import attr
except ImportError:
pytest.skip("attrs not installed")
@attr.s
class Config:
path = attr.ib()
name = attr.ib(default="default")
# Test that temp paths are normalized in attrs classes
orig = Config(path="/tmp/pytest-of-user/pytest-0/config.json")
new = Config(path="/tmp/pytest-of-user/pytest-99/config.json")
assert comparator(orig, new, superset_obj=True)
# Test that different non-temp values still fail
orig2 = Config(path="/tmp/pytest-of-user/pytest-0/config.json", name="name1")
new2 = Config(path="/tmp/pytest-of-user/pytest-99/config.json", name="name2")
assert not comparator(orig2, new2, superset_obj=True)
def test_superset_with_class_dict_containing_temp_paths(self):
"""Test superset with regular class objects containing temp paths."""
class Result:
def __init__(self, output_path):
self.output_path = output_path
class ResultExtended:
def __init__(self, output_path, extra=None):
self.output_path = output_path
self.extra = extra
# Note: These are different classes, so type check will fail first
# Let's use the same class
orig = Result("/tmp/pytest-of-user/pytest-0/result.json")
new = Result("/tmp/pytest-of-user/pytest-99/result.json")
# Add extra attribute to new
new.extra_field = "extra_data"
assert comparator(orig, new, superset_obj=True)
def test_superset_list_temp_paths_must_have_same_length(self):
"""Test that lists with temp paths must have same length even in superset mode."""
# superset_obj doesn't apply to list lengths - they must match
orig = ["/tmp/pytest-of-user/pytest-0/a.txt"]
new = ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"]
assert not comparator(orig, new, superset_obj=True)
def test_superset_tuple_temp_paths_must_have_same_length(self):
"""Test that tuples with temp paths must have same length even in superset mode."""
orig = ("/tmp/pytest-of-user/pytest-0/a.txt",)
new = ("/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt")
assert not comparator(orig, new, superset_obj=True)
def test_superset_with_exception_containing_temp_path(self):
"""Test superset with exception objects containing temp paths in attributes."""
class CustomError(Exception):
def __init__(self, message, path):
super().__init__(message)
self.path = path
orig = CustomError("File error", "/tmp/pytest-of-user/pytest-0/file.txt")
new = CustomError("File error", "/tmp/pytest-of-user/pytest-99/file.txt")
new.extra_info = "additional data"
assert comparator(orig, new, superset_obj=True)
class TestComparatorTempPathsRealisticScenarios:
"""Tests simulating realistic scenarios where temp path comparison matters."""
def test_test_output_comparison(self):
"""Simulate comparing test outputs that contain temp paths."""
original_result = {
"status": "success",
"output_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/results.json",
"log_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/debug.log",
}
replay_result = {
"status": "success",
"output_file": "/tmp/pytest-of-local-user/pytest-0/test_output/results.json",
"log_file": "/tmp/pytest-of-local-user/pytest-0/test_output/debug.log",
}
assert comparator(original_result, replay_result)
def test_exception_message_with_temp_path(self):
"""Test comparing exception-like structures with temp paths."""
exc1 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-0/missing.txt"}
exc2 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-99/missing.txt"}
assert comparator(exc1, exc2)
def test_function_return_with_temp_path(self):
"""Test comparing function returns that include temp paths."""
# Simulating a function that returns a created file path
return1 = "/tmp/pytest-of-user/pytest-5/generated_file_abc123.txt"
return2 = "/tmp/pytest-of-user/pytest-10/generated_file_abc123.txt"
assert comparator(return1, return2)
def test_list_of_created_files(self):
"""Test comparing lists of created file paths."""
files1 = [
"/tmp/pytest-of-user/pytest-0/output/file1.txt",
"/tmp/pytest-of-user/pytest-0/output/file2.txt",
"/tmp/pytest-of-user/pytest-0/output/file3.txt",
]
files2 = [
"/tmp/pytest-of-user/pytest-99/output/file1.txt",
"/tmp/pytest-of-user/pytest-99/output/file2.txt",
"/tmp/pytest-of-user/pytest-99/output/file3.txt",
]
assert comparator(files1, files2)
def test_config_object_with_paths(self):
"""Test comparing config-like objects with multiple paths."""
config1 = {
"temp_dir": "/tmp/pytest-of-user/pytest-0/",
"cache_dir": "/tmp/pytest-of-user/pytest-0/cache/",
"output_dir": "/tmp/pytest-of-user/pytest-0/output/",
"permanent_dir": "/home/user/data/",
}
config2 = {
"temp_dir": "/tmp/pytest-of-other/pytest-100/",
"cache_dir": "/tmp/pytest-of-other/pytest-100/cache/",
"output_dir": "/tmp/pytest-of-other/pytest-100/output/",
"permanent_dir": "/home/user/data/",
}
assert comparator(config1, config2)
class TestPythonTempfilePaths:
"""Tests for Python tempfile paths (from tempfile.mkdtemp() or TemporaryDirectory())."""
def test_is_temp_path_detects_python_tempfile(self):
"""Test that _is_temp_path detects Python tempfile paths."""
assert _is_temp_path("/tmp/tmpqtwy7hpf/special.txt")
assert _is_temp_path("/tmp/tmpp6wx3tz3/special.txt")
assert _is_temp_path("/tmp/tmpabcdef12/")
assert _is_temp_path("/tmp/tmp_underscore/file.txt")
def test_is_temp_path_various_tempfile_names(self):
"""Test various tempfile naming patterns."""
assert _is_temp_path("/tmp/tmpABCDEF/file.txt") # uppercase
assert _is_temp_path("/tmp/tmp123456/file.txt") # numeric
assert _is_temp_path("/tmp/tmpaBc123/file.txt") # mixed
assert _is_temp_path("/tmp/tmp_test_dir/subdir/file.txt") # with underscore
def test_is_temp_path_non_tempfile(self):
"""Test that non-tempfile paths are not detected."""
assert not _is_temp_path("/tmp/mydir/file.txt") # doesn't start with tmp
assert not _is_temp_path("/tmp/temp/file.txt") # temp, not tmp
assert not _is_temp_path("/home/user/tmp123/file.txt") # not in /tmp/
def test_normalize_temp_path_python_tempfile(self):
"""Test normalization of Python tempfile paths."""
path1 = _normalize_temp_path("/tmp/tmpqtwy7hpf/special.txt")
path2 = _normalize_temp_path("/tmp/tmpp6wx3tz3/special.txt")
assert path1 == path2 == "/tmp/python-temp/special.txt"
def test_normalize_temp_path_preserves_subdirs(self):
"""Test that subdirectories are preserved during normalization."""
result = _normalize_temp_path("/tmp/tmpabcdef12/subdir/nested/file.txt")
assert result == "/tmp/python-temp/subdir/nested/file.txt"
def test_comparator_python_tempfile_paths_equal(self):
"""Test that different tempfile paths with same content are equal."""
path1 = "/tmp/tmpqtwy7hpf/special.txt"
path2 = "/tmp/tmpp6wx3tz3/special.txt"
assert comparator(path1, path2)
def test_comparator_python_tempfile_different_filenames_not_equal(self):
"""Test that different filenames in tempfile paths are not equal."""
path1 = "/tmp/tmpqtwy7hpf/special.txt"
path2 = "/tmp/tmpp6wx3tz3/different.txt"
assert not comparator(path1, path2)
def test_comparator_python_tempfile_in_tuple(self):
"""Test tempfile paths in tuples."""
orig = ("/tmp/tmpqtwy7hpf/special.txt",)
new = ("/tmp/tmpp6wx3tz3/special.txt",)
assert comparator(orig, new)
def test_comparator_python_tempfile_in_list(self):
"""Test tempfile paths in lists."""
orig = ["/tmp/tmpabcdef12/file1.txt", "/tmp/tmpabcdef12/file2.txt"]
new = ["/tmp/tmpxyz78901/file1.txt", "/tmp/tmpxyz78901/file2.txt"]
assert comparator(orig, new)
def test_comparator_python_tempfile_in_dict(self):
"""Test tempfile paths in dictionaries."""
orig = {"output": "/tmp/tmpabcdef12/result.json"}
new = {"output": "/tmp/tmpxyz78901/result.json"}
assert comparator(orig, new)
def test_comparator_mixed_pytest_and_python_tempfile(self):
"""Test that pytest and Python tempfile paths don't match each other."""
pytest_path = "/tmp/pytest-of-user/pytest-0/file.txt"
python_path = "/tmp/tmpabcdef12/file.txt"
# These should not be equal - they're different temp path types
assert not comparator(pytest_path, python_path)
def test_python_tempfile_pattern_regex(self):
"""Test the PYTHON_TEMPFILE_PATTERN regex directly."""
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmpabcdef/file.txt")
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/")
assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt")
assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt")
@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+")
class TestUnionType:
def test_union_type_equal(self):
assert comparator(int | str, int | str)
def test_union_type_not_equal(self):
assert not comparator(int | str, int | float)
def test_union_type_order_independent(self):
assert comparator(int | str, str | int)
def test_union_type_multiple_args(self):
assert comparator(int | str | float, int | str | float)
def test_union_type_in_list(self):
assert comparator([int | str, 1], [int | str, 1])
def test_union_type_in_dict(self):
assert comparator({"key": int | str}, {"key": int | str})
def test_union_type_vs_none(self):
assert not comparator(int | str, None)
class SlotsOnly:
__slots__ = ("x", "y")
def __init__(self, x, y):
self.x = x
self.y = y
class SlotsInherited(SlotsOnly):
__slots__ = ("z",)
def __init__(self, x, y, z):
super().__init__(x, y)
self.z = z
class TestSlotsObjects:
def test_slots_equal(self):
assert comparator(SlotsOnly(1, 2), SlotsOnly(1, 2))
def test_slots_not_equal(self):
assert not comparator(SlotsOnly(1, 2), SlotsOnly(1, 3))
def test_slots_inherited_equal(self):
assert comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 3))
def test_slots_inherited_not_equal(self):
assert not comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 4))
def test_slots_nested(self):
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
b = SlotsOnly(SlotsOnly(1, 2), [3, 4])
assert comparator(a, b)
def test_slots_nested_not_equal(self):
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
b = SlotsOnly(SlotsOnly(1, 9), [3, 4])
assert not comparator(a, b)