# Conflicts: # .claude/rules/architecture.md # .claude/rules/code-style.md # .github/workflows/claude.yml # .github/workflows/duplicate-code-detector.yml # codeflash/api/aiservice.py # codeflash/cli_cmds/console.py # codeflash/cli_cmds/logging_config.py # codeflash/code_utils/deduplicate_code.py # codeflash/discovery/discover_unit_tests.py # codeflash/languages/base.py # codeflash/languages/code_replacer.py # codeflash/languages/javascript/mocha_runner.py # codeflash/languages/javascript/support.py # codeflash/languages/python/support.py # codeflash/optimization/function_optimizer.py # codeflash/verification/parse_test_output.py # codeflash/verification/verification_utils.py # codeflash/verification/verifier.py # packages/codeflash/package-lock.json # packages/codeflash/package.json # tests/languages/javascript/test_support_dispatch.py # tests/test_codeflash_capture.py # tests/test_languages/test_javascript_test_runner.py # tests/test_multi_file_code_replacement.py
5528 lines
178 KiB
Python
5528 lines
178 KiB
Python
import array # Add import for array
|
|
import ast
|
|
import copy
|
|
import dataclasses
|
|
import datetime
|
|
import decimal
|
|
import re
|
|
import sys
|
|
import uuid
|
|
import weakref
|
|
from collections import ChainMap, Counter, OrderedDict, UserDict, UserList, UserString, defaultdict, deque, namedtuple
|
|
from enum import Enum, Flag, IntFlag, auto
|
|
from pathlib import Path
|
|
|
|
import pydantic
|
|
import pytest
|
|
|
|
from codeflash.either import Failure, Success
|
|
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
|
|
from codeflash.verification.comparator import (
|
|
PYTEST_TEMP_PATH_PATTERN,
|
|
PYTHON_TEMPFILE_PATTERN,
|
|
_extract_exception_from_message,
|
|
_get_wrapped_exception,
|
|
_is_temp_path,
|
|
_normalize_temp_path,
|
|
comparator,
|
|
)
|
|
from codeflash.verification.equivalence import compare_test_results
|
|
|
|
|
|
def test_basic_python_objects() -> None:
|
|
a = 5
|
|
b = 5
|
|
c = 6
|
|
d = None
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
a = 5.0
|
|
b = 5.0
|
|
c = 6.0
|
|
d = None
|
|
e = None
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
assert not comparator(d, a)
|
|
assert comparator(d, e)
|
|
|
|
a = "Hello"
|
|
b = "Hello"
|
|
c = "World"
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = [1, 2, 3]
|
|
b = [1, 2, 3]
|
|
c = [1, 2, 4]
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = {"a": 1, "b": 2}
|
|
b = {"a": 1, "b": 2}
|
|
c = {"a": 1, "b": 3}
|
|
d = {"c": 1, "b": 2}
|
|
e = {"a": 1, "b": 2, "c": 3}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
assert not comparator(a, e)
|
|
|
|
a = (1, 2, "str")
|
|
b = (1, 2, "str")
|
|
c = (1, 2, "str2")
|
|
d = [1, 2, "str"]
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
a = {1, 2, 3}
|
|
b = {2, 3, 1}
|
|
c = {1, 2, 4}
|
|
d = {1, 2, 3, 4}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
a = (65).to_bytes(1, byteorder="big")
|
|
b = (65).to_bytes(1, byteorder="big")
|
|
c = (66).to_bytes(1, byteorder="big")
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
a = (65).to_bytes(2, byteorder="little")
|
|
b = (65).to_bytes(2, byteorder="big")
|
|
assert not comparator(a, b)
|
|
|
|
a = bytearray([65, 64, 63])
|
|
b = bytearray([65, 64, 63])
|
|
c = bytearray([65, 64, 62])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
memoryview_a = memoryview(bytearray([65, 64, 63]))
|
|
memoryview_b = memoryview(bytearray([65, 64, 63]))
|
|
memoryview_c = memoryview(bytearray([65, 64, 62]))
|
|
assert comparator(memoryview_a, memoryview_b)
|
|
assert not comparator(memoryview_a, memoryview_c)
|
|
|
|
a = frozenset([1, 2, 3])
|
|
b = frozenset([2, 3, 1])
|
|
c = frozenset([1, 2, 4])
|
|
d = frozenset([1, 2, 3, 4])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
a = map
|
|
b = pow
|
|
c = pow
|
|
d = abs
|
|
assert comparator(b, c)
|
|
assert not comparator(a, b)
|
|
assert not comparator(c, d)
|
|
|
|
a = object()
|
|
b = object()
|
|
c = abs
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = type([])
|
|
b = type([])
|
|
c = type({})
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
|
|
def test_weakref() -> None:
|
|
"""Test comparator for weakref.ref objects."""
|
|
|
|
# Helper class that supports weak references and has comparable __dict__
|
|
class Holder:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
# Test weak references to the same object
|
|
obj = Holder([1, 2, 3])
|
|
ref1 = weakref.ref(obj)
|
|
ref2 = weakref.ref(obj)
|
|
assert comparator(ref1, ref2)
|
|
|
|
# Test weak references to equivalent but different objects
|
|
obj1 = Holder({"key": "value"})
|
|
obj2 = Holder({"key": "value"})
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert comparator(ref1, ref2)
|
|
|
|
# Test weak references to different objects
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 4])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
# Test weak references with different data
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 3, 4])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
# Test dead weak references (both dead)
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 3])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
del obj1
|
|
del obj2
|
|
# Both refs are now dead, should be equal
|
|
assert comparator(ref1, ref2)
|
|
|
|
# Test one dead, one alive weak reference
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 3])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
del obj1
|
|
# ref1 is dead, ref2 is alive, should not be equal
|
|
assert not comparator(ref1, ref2)
|
|
assert not comparator(ref2, ref1)
|
|
|
|
# Test weak references to nested structures
|
|
obj1 = Holder({"nested": [1, 2, {"inner": "value"}]})
|
|
obj2 = Holder({"nested": [1, 2, {"inner": "value"}]})
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert comparator(ref1, ref2)
|
|
|
|
# Test weak references to nested structures with differences
|
|
obj1 = Holder({"nested": [1, 2, {"inner": "value1"}]})
|
|
obj2 = Holder({"nested": [1, 2, {"inner": "value2"}]})
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
# Test weak references in a dictionary (simulating __dict__ with weakrefs)
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 3])
|
|
dict1 = {"data": 42, "ref": weakref.ref(obj1)}
|
|
dict2 = {"data": 42, "ref": weakref.ref(obj2)}
|
|
assert comparator(dict1, dict2)
|
|
|
|
# Test weak references in a dictionary with different referents
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([4, 5, 6])
|
|
dict1 = {"data": 42, "ref": weakref.ref(obj1)}
|
|
dict2 = {"data": 42, "ref": weakref.ref(obj2)}
|
|
assert not comparator(dict1, dict2)
|
|
|
|
# Test weak references in a list
|
|
obj1 = Holder({"a": 1})
|
|
obj2 = Holder({"a": 1})
|
|
list1 = [weakref.ref(obj1), "other"]
|
|
list2 = [weakref.ref(obj2), "other"]
|
|
assert comparator(list1, list2)
|
|
|
|
|
|
def test_weakref_to_custom_objects() -> None:
|
|
"""Test comparator for weakref.ref to custom class instances."""
|
|
|
|
class MyClass:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
# Test weak references to equivalent custom objects
|
|
obj1 = MyClass(42)
|
|
obj2 = MyClass(42)
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert comparator(ref1, ref2)
|
|
|
|
# Test weak references to different custom objects
|
|
obj1 = MyClass(42)
|
|
obj2 = MyClass(99)
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
# Test weak references to custom objects with nested data
|
|
class Container:
|
|
def __init__(self, items):
|
|
self.items = items
|
|
|
|
obj1 = Container([1, 2, 3])
|
|
obj2 = Container([1, 2, 3])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert comparator(ref1, ref2)
|
|
|
|
obj1 = Container([1, 2, 3])
|
|
obj2 = Container([1, 2, 4])
|
|
ref1 = weakref.ref(obj1)
|
|
ref2 = weakref.ref(obj2)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
|
|
def test_weakref_with_callbacks() -> None:
|
|
"""Test that weakrefs with callbacks are compared correctly."""
|
|
|
|
class Holder:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
callback_called = []
|
|
|
|
def callback(ref):
|
|
callback_called.append(ref)
|
|
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([1, 2, 3])
|
|
# Weakrefs with callbacks should still compare based on referents
|
|
ref1 = weakref.ref(obj1, callback)
|
|
ref2 = weakref.ref(obj2, callback)
|
|
assert comparator(ref1, ref2)
|
|
|
|
obj1 = Holder([1, 2, 3])
|
|
obj2 = Holder([4, 5, 6])
|
|
ref1 = weakref.ref(obj1, callback)
|
|
ref2 = weakref.ref(obj2, callback)
|
|
assert not comparator(ref1, ref2)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"r1, r2, expected",
|
|
[
|
|
(range(1, 10), range(1, 10), True), # equal
|
|
(range(10), range(1, 10), False), # different start
|
|
(range(2, 10), range(1, 10), False),
|
|
(range(1, 5), range(1, 10), False), # different stop
|
|
(range(1, 20), range(1, 10), False),
|
|
(range(1, 10, 1), range(1, 10, 2), False), # different step
|
|
(range(1, 10, 3), range(1, 10, 2), False),
|
|
(range(-5, 0), range(-5, 0), True), # negative ranges
|
|
(range(-10, 0), range(-5, 0), False),
|
|
(range(5, 1), range(10, 5), True), # empty ranges
|
|
(range(5, 1), range(5, 1), True),
|
|
(range(7), range(7), True),
|
|
(range(7), range(0, 7, 1), True),
|
|
(range(7), range(0, 7, 1), True),
|
|
],
|
|
)
|
|
def test_ranges(r1, r2, expected):
|
|
assert comparator(r1, r2) == expected
|
|
|
|
|
|
def test_standard_python_library_objects() -> None:
|
|
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
|
|
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
|
|
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = datetime.date(2020, 2, 2) # type: ignore
|
|
b = datetime.date(2020, 2, 2) # type: ignore
|
|
c = datetime.date(2020, 2, 3) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = datetime.timedelta(days=1) # type: ignore
|
|
b = datetime.timedelta(days=1) # type: ignore
|
|
c = datetime.timedelta(days=2) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = datetime.time(2, 2, 2) # type: ignore
|
|
b = datetime.time(2, 2, 2) # type: ignore
|
|
c = datetime.time(2, 2, 3) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = datetime.timezone.utc # type: ignore
|
|
b = datetime.timezone.utc # type: ignore
|
|
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = decimal.Decimal(3.14) # type: ignore
|
|
b = decimal.Decimal(3.14) # type: ignore
|
|
c = decimal.Decimal(3.15) # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
class Color(Flag):
|
|
RED = auto()
|
|
GREEN = auto()
|
|
BLUE = auto()
|
|
|
|
class Color2(Enum):
|
|
RED = auto()
|
|
GREEN = auto()
|
|
BLUE = auto()
|
|
|
|
a = Color.RED # type: ignore
|
|
b = Color.RED # type: ignore
|
|
c = Color.GREEN # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a = Color2.RED # type: ignore
|
|
b = Color2.RED # type: ignore
|
|
c = Color2.GREEN # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
class Color4(IntFlag):
|
|
RED = auto()
|
|
GREEN = auto()
|
|
BLUE = auto()
|
|
|
|
a = Color4.RED # type: ignore
|
|
b = Color4.RED # type: ignore
|
|
c = Color4.GREEN # type: ignore
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
a: re.Pattern = re.compile("a")
|
|
b: re.Pattern = re.compile("a")
|
|
c: re.Pattern = re.compile("b")
|
|
d: re.Pattern = re.compile("a", re.IGNORECASE)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
arr1 = array.array("i", [1, 2, 3])
|
|
arr2 = array.array("i", [1, 2, 3])
|
|
arr3 = array.array("i", [4, 5, 6])
|
|
arr4 = array.array("f", [1.0, 2.0, 3.0])
|
|
|
|
assert comparator(arr1, arr2)
|
|
assert not comparator(arr1, arr3)
|
|
assert not comparator(arr1, arr4)
|
|
assert not comparator(arr1, [1, 2, 3])
|
|
|
|
empty_arr_i1 = array.array("i")
|
|
empty_arr_i2 = array.array("i")
|
|
empty_arr_f = array.array("f")
|
|
assert comparator(empty_arr_i1, empty_arr_i2)
|
|
assert not comparator(empty_arr_i1, empty_arr_f)
|
|
assert not comparator(empty_arr_i1, arr1)
|
|
|
|
id1 = uuid.uuid4()
|
|
id3 = uuid.uuid4()
|
|
assert comparator(id1, id1)
|
|
assert not comparator(id1, id3)
|
|
|
|
|
|
def test_itertools_count() -> None:
|
|
import itertools
|
|
|
|
# Equal: same start and step (default step=1)
|
|
assert comparator(itertools.count(0), itertools.count(0))
|
|
assert comparator(itertools.count(5), itertools.count(5))
|
|
assert comparator(itertools.count(0, 1), itertools.count(0, 1))
|
|
assert comparator(itertools.count(10, 3), itertools.count(10, 3))
|
|
|
|
# Equal: negative start and step
|
|
assert comparator(itertools.count(-5, -2), itertools.count(-5, -2))
|
|
|
|
# Equal: float start and step
|
|
assert comparator(itertools.count(0.5, 0.1), itertools.count(0.5, 0.1))
|
|
|
|
# Not equal: different start
|
|
assert not comparator(itertools.count(0), itertools.count(1))
|
|
assert not comparator(itertools.count(5), itertools.count(10))
|
|
|
|
# Not equal: different step
|
|
assert not comparator(itertools.count(0, 1), itertools.count(0, 2))
|
|
assert not comparator(itertools.count(0, 1), itertools.count(0, -1))
|
|
|
|
# Not equal: different type
|
|
assert not comparator(itertools.count(0), 0)
|
|
assert not comparator(itertools.count(0), [0, 1, 2])
|
|
|
|
# Equal after partial consumption (both advanced to the same state)
|
|
a = itertools.count(0)
|
|
b = itertools.count(0)
|
|
next(a)
|
|
next(b)
|
|
assert comparator(a, b)
|
|
|
|
# Not equal after different consumption
|
|
a = itertools.count(0)
|
|
b = itertools.count(0)
|
|
next(a)
|
|
assert not comparator(a, b)
|
|
|
|
# Works inside containers
|
|
assert comparator([itertools.count(0)], [itertools.count(0)])
|
|
assert comparator({"key": itertools.count(5, 2)}, {"key": itertools.count(5, 2)})
|
|
assert not comparator([itertools.count(0)], [itertools.count(1)])
|
|
|
|
|
|
def test_itertools_repeat() -> None:
|
|
import itertools
|
|
|
|
# Equal: infinite repeat
|
|
assert comparator(itertools.repeat(5), itertools.repeat(5))
|
|
assert comparator(itertools.repeat("hello"), itertools.repeat("hello"))
|
|
|
|
# Equal: bounded repeat
|
|
assert comparator(itertools.repeat(5, 3), itertools.repeat(5, 3))
|
|
assert comparator(itertools.repeat(None, 10), itertools.repeat(None, 10))
|
|
|
|
# Not equal: different value
|
|
assert not comparator(itertools.repeat(5), itertools.repeat(6))
|
|
assert not comparator(itertools.repeat(5, 3), itertools.repeat(6, 3))
|
|
|
|
# Not equal: different count
|
|
assert not comparator(itertools.repeat(5, 3), itertools.repeat(5, 4))
|
|
|
|
# Not equal: bounded vs infinite
|
|
assert not comparator(itertools.repeat(5), itertools.repeat(5, 3))
|
|
|
|
# Not equal: different type
|
|
assert not comparator(itertools.repeat(5), 5)
|
|
assert not comparator(itertools.repeat(5), [5])
|
|
|
|
# Equal after partial consumption
|
|
a = itertools.repeat(5, 5)
|
|
b = itertools.repeat(5, 5)
|
|
next(a)
|
|
next(b)
|
|
assert comparator(a, b)
|
|
|
|
# Not equal after different consumption
|
|
a = itertools.repeat(5, 5)
|
|
b = itertools.repeat(5, 5)
|
|
next(a)
|
|
assert not comparator(a, b)
|
|
|
|
# Works inside containers
|
|
assert comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 3)])
|
|
assert not comparator([itertools.repeat(5, 3)], [itertools.repeat(5, 4)])
|
|
|
|
|
|
def test_itertools_cycle() -> None:
|
|
import itertools
|
|
|
|
# Equal: same sequence
|
|
assert comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 3]))
|
|
assert comparator(itertools.cycle("abc"), itertools.cycle("abc"))
|
|
|
|
# Not equal: different sequence
|
|
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2, 4]))
|
|
assert not comparator(itertools.cycle([1, 2, 3]), itertools.cycle([1, 2]))
|
|
|
|
# Not equal: different type
|
|
assert not comparator(itertools.cycle([1, 2, 3]), [1, 2, 3])
|
|
|
|
# Equal after same partial consumption
|
|
a = itertools.cycle([1, 2, 3])
|
|
b = itertools.cycle([1, 2, 3])
|
|
next(a)
|
|
next(b)
|
|
assert comparator(a, b)
|
|
|
|
# Not equal after different consumption
|
|
a = itertools.cycle([1, 2, 3])
|
|
b = itertools.cycle([1, 2, 3])
|
|
next(a)
|
|
assert not comparator(a, b)
|
|
|
|
# Equal after consuming a full cycle
|
|
a = itertools.cycle([1, 2, 3])
|
|
b = itertools.cycle([1, 2, 3])
|
|
for _ in range(3):
|
|
next(a)
|
|
next(b)
|
|
assert comparator(a, b)
|
|
|
|
# Equal at same position across different full-cycle counts
|
|
a = itertools.cycle([1, 2, 3])
|
|
b = itertools.cycle([1, 2, 3])
|
|
for _ in range(4):
|
|
next(a)
|
|
for _ in range(7):
|
|
next(b)
|
|
# Both at position 1 within the cycle (4%3 == 7%3 == 1)
|
|
assert comparator(a, b)
|
|
|
|
# Works inside containers
|
|
assert comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 2])])
|
|
assert not comparator([itertools.cycle([1, 2])], [itertools.cycle([1, 3])])
|
|
|
|
|
|
def test_itertools_chain() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 4]))
|
|
assert not comparator(itertools.chain([1, 2], [3, 4]), itertools.chain([1, 2], [3, 5]))
|
|
assert comparator(itertools.chain.from_iterable([[1, 2], [3]]), itertools.chain.from_iterable([[1, 2], [3]]))
|
|
assert comparator(itertools.chain(), itertools.chain())
|
|
assert not comparator(itertools.chain([1]), itertools.chain([1, 2]))
|
|
|
|
|
|
def test_itertools_islice() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 5))
|
|
assert not comparator(itertools.islice(range(10), 5), itertools.islice(range(10), 6))
|
|
assert comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 5))
|
|
assert not comparator(itertools.islice(range(10), 2, 5), itertools.islice(range(10), 2, 6))
|
|
|
|
|
|
def test_itertools_product() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.product("AB", repeat=2), itertools.product("AB", repeat=2))
|
|
assert not comparator(itertools.product("AB", repeat=2), itertools.product("AC", repeat=2))
|
|
assert comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 4]))
|
|
assert not comparator(itertools.product([1, 2], [3, 4]), itertools.product([1, 2], [3, 5]))
|
|
|
|
|
|
def test_itertools_permutations_combinations() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.permutations("ABC", 2), itertools.permutations("ABC", 2))
|
|
assert not comparator(itertools.permutations("ABC", 2), itertools.permutations("ABD", 2))
|
|
assert comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 2))
|
|
assert not comparator(itertools.combinations("ABCD", 2), itertools.combinations("ABCD", 3))
|
|
assert comparator(
|
|
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABC", 2)
|
|
)
|
|
assert not comparator(
|
|
itertools.combinations_with_replacement("ABC", 2), itertools.combinations_with_replacement("ABD", 2)
|
|
)
|
|
|
|
|
|
def test_itertools_accumulate() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 4]))
|
|
assert not comparator(itertools.accumulate([1, 2, 3, 4]), itertools.accumulate([1, 2, 3, 5]))
|
|
assert comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=10))
|
|
assert not comparator(itertools.accumulate([1, 2, 3], initial=10), itertools.accumulate([1, 2, 3], initial=0))
|
|
|
|
|
|
def test_itertools_filtering() -> None:
|
|
import itertools
|
|
|
|
# compress
|
|
assert comparator(
|
|
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1])
|
|
)
|
|
assert not comparator(
|
|
itertools.compress("ABCDEF", [1, 0, 1, 0, 1, 1]), itertools.compress("ABCDEF", [1, 1, 1, 0, 1, 1])
|
|
)
|
|
|
|
# dropwhile
|
|
assert comparator(
|
|
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1])
|
|
)
|
|
assert not comparator(
|
|
itertools.dropwhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.dropwhile(lambda x: x < 5, [1, 4, 7, 4, 1])
|
|
)
|
|
|
|
# takewhile
|
|
assert comparator(
|
|
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1])
|
|
)
|
|
assert not comparator(
|
|
itertools.takewhile(lambda x: x < 5, [1, 4, 6, 4, 1]), itertools.takewhile(lambda x: x < 5, [1, 3, 6, 4, 1])
|
|
)
|
|
|
|
# filterfalse
|
|
assert comparator(
|
|
itertools.filterfalse(lambda x: x % 2, range(10)), itertools.filterfalse(lambda x: x % 2, range(10))
|
|
)
|
|
|
|
|
|
def test_itertools_starmap() -> None:
|
|
import itertools
|
|
|
|
assert comparator(
|
|
itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)]), itertools.starmap(pow, [(2, 3), (3, 2), (10, 0)])
|
|
)
|
|
assert not comparator(itertools.starmap(pow, [(2, 3), (3, 2)]), itertools.starmap(pow, [(2, 3), (3, 3)]))
|
|
|
|
|
|
def test_itertools_zip_longest() -> None:
|
|
import itertools
|
|
|
|
assert comparator(
|
|
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="-")
|
|
)
|
|
assert not comparator(
|
|
itertools.zip_longest("AB", "xyz", fillvalue="-"), itertools.zip_longest("AB", "xyz", fillvalue="*")
|
|
)
|
|
|
|
|
|
def test_itertools_groupby() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBBCC"))
|
|
assert not comparator(itertools.groupby("AAABBBCC"), itertools.groupby("AAABBCC"))
|
|
assert comparator(itertools.groupby([]), itertools.groupby([]))
|
|
|
|
# With key function
|
|
assert comparator(
|
|
itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x), itertools.groupby([1, 1, 2, 2, 3], key=lambda x: x)
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 10), reason="itertools.pairwise requires Python 3.10+")
|
|
def test_itertools_pairwise() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 4]))
|
|
assert not comparator(itertools.pairwise([1, 2, 3, 4]), itertools.pairwise([1, 2, 3, 5]))
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 12), reason="itertools.batched requires Python 3.12+")
|
|
def test_itertools_batched() -> None:
|
|
import itertools
|
|
|
|
assert comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 3))
|
|
assert not comparator(itertools.batched("ABCDEFG", 3), itertools.batched("ABCDEFG", 2))
|
|
|
|
|
|
def test_itertools_in_containers() -> None:
|
|
import itertools
|
|
|
|
# Itertools objects nested in dicts/lists
|
|
assert comparator(
|
|
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
|
{"a": itertools.chain([1], [2]), "b": itertools.islice(range(5), 3)},
|
|
)
|
|
assert not comparator([itertools.product("AB", repeat=2)], [itertools.product("AC", repeat=2)])
|
|
|
|
# Different itertools types should not match
|
|
assert not comparator(itertools.chain([1, 2]), itertools.islice([1, 2], 2))
|
|
|
|
|
|
def test_numpy():
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip()
|
|
a = np.array([1, 2, 3])
|
|
b = np.array([1, 2, 3])
|
|
c = np.array([1, 2, 4])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
d = np.array([[1, 2], [3, 4]])
|
|
e = np.array([[1, 2], [3, 4]])
|
|
f = np.array([[1, 2], [3, 5]])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
assert not comparator(a, d)
|
|
|
|
g = np.array([1.0, 2.0, 3.0])
|
|
assert not comparator(a, g)
|
|
|
|
h = np.float32(1.0)
|
|
i = np.float32(1.0)
|
|
assert comparator(h, i)
|
|
|
|
j = np.float64(1.0)
|
|
k = np.float64(1.0)
|
|
assert not comparator(h, j)
|
|
assert comparator(j, k)
|
|
|
|
l = np.int32(1)
|
|
m = np.int32(1)
|
|
assert comparator(l, m)
|
|
assert not comparator(l, h)
|
|
assert not comparator(l, j)
|
|
|
|
n = np.int64(1)
|
|
o = np.int64(1)
|
|
assert not comparator(n, l)
|
|
assert comparator(n, o)
|
|
|
|
p = np.uint32(1)
|
|
q = np.uint32(1)
|
|
assert comparator(p, q)
|
|
assert not comparator(p, l)
|
|
|
|
r = np.uint64(1)
|
|
s = np.uint64(1)
|
|
assert not comparator(r, p)
|
|
assert comparator(r, s)
|
|
|
|
t = np.bool_(True)
|
|
u = np.bool_(True)
|
|
assert comparator(t, u)
|
|
assert not comparator(t, r)
|
|
|
|
v = np.complex64(1.0 + 1.0j)
|
|
w = np.complex64(1.0 + 1.0j)
|
|
assert comparator(v, w)
|
|
assert not comparator(v, t)
|
|
|
|
x = np.complex128(1.0 + 1.0j)
|
|
y = np.complex128(1.0 + 1.0j)
|
|
assert not comparator(x, v)
|
|
assert comparator(x, y)
|
|
|
|
# Create numpy array with mixed type object
|
|
z = np.array([1, 2, "str"], dtype=np.object_)
|
|
aa = np.array([1, 2, "str"], dtype=np.object_)
|
|
ab = np.array([1, 2, "str2"], dtype=np.object_)
|
|
assert comparator(z, aa)
|
|
assert not comparator(z, ab)
|
|
|
|
ac = np.array([1, 2, "str2"])
|
|
ad = np.array([1, 2, "str2"])
|
|
assert comparator(ac, ad)
|
|
|
|
# Test for numpy array with nan and inf
|
|
ae = np.array([1, 2, np.nan])
|
|
af = np.array([1, 2, np.nan])
|
|
ag = np.array([1, 2, np.inf])
|
|
ah = np.array([1, 2, np.inf])
|
|
ai = np.inf
|
|
aj = np.inf
|
|
ak = np.nan
|
|
al = np.nan
|
|
assert comparator(ae, af)
|
|
assert comparator(ag, ah)
|
|
assert not comparator(ae, ag)
|
|
assert not comparator(af, ah)
|
|
assert comparator(ai, aj)
|
|
assert comparator(ak, al)
|
|
assert not comparator(ai, ak)
|
|
|
|
dt = np.dtype([("name", "S10"), ("age", np.int32)])
|
|
a_struct = np.array([("Alice", 25)], dtype=dt)
|
|
b_struct = np.array([("Alice", 25)], dtype=dt)
|
|
c_struct = np.array([("Bob", 30)], dtype=dt)
|
|
|
|
a_void = a_struct[0]
|
|
b_void = b_struct[0]
|
|
c_void = c_struct[0]
|
|
|
|
assert isinstance(a_void, np.void)
|
|
assert comparator(a_void, b_void)
|
|
assert not comparator(a_void, c_void)
|
|
|
|
|
|
def test_numpy_random_generator():
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test numpy.random.Generator (modern API)
|
|
# Same seed should produce equal generators
|
|
rng1 = np.random.default_rng(seed=42)
|
|
rng2 = np.random.default_rng(seed=42)
|
|
assert comparator(rng1, rng2)
|
|
|
|
# Different seeds should produce non-equal generators
|
|
rng3 = np.random.default_rng(seed=123)
|
|
assert not comparator(rng1, rng3)
|
|
|
|
# After generating numbers, state changes
|
|
rng4 = np.random.default_rng(seed=42)
|
|
rng5 = np.random.default_rng(seed=42)
|
|
rng4.random() # Advance state
|
|
assert not comparator(rng4, rng5)
|
|
|
|
# Both advanced by same amount should be equal
|
|
rng5.random()
|
|
assert comparator(rng4, rng5)
|
|
|
|
# Test with different bit generators
|
|
from numpy.random import MT19937, PCG64
|
|
|
|
rng_pcg1 = np.random.Generator(PCG64(seed=42))
|
|
rng_pcg2 = np.random.Generator(PCG64(seed=42))
|
|
assert comparator(rng_pcg1, rng_pcg2)
|
|
|
|
rng_mt1 = np.random.Generator(MT19937(seed=42))
|
|
rng_mt2 = np.random.Generator(MT19937(seed=42))
|
|
assert comparator(rng_mt1, rng_mt2)
|
|
|
|
# Different bit generator types should not be equal
|
|
assert not comparator(rng_pcg1, rng_mt1)
|
|
|
|
|
|
def test_numpy_random_state():
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test numpy.random.RandomState (legacy API)
|
|
# Same seed should produce equal states
|
|
rs1 = np.random.RandomState(seed=42)
|
|
rs2 = np.random.RandomState(seed=42)
|
|
assert comparator(rs1, rs2)
|
|
|
|
# Different seeds should produce non-equal states
|
|
rs3 = np.random.RandomState(seed=123)
|
|
assert not comparator(rs1, rs3)
|
|
|
|
# After generating numbers, state changes
|
|
rs4 = np.random.RandomState(seed=42)
|
|
rs5 = np.random.RandomState(seed=42)
|
|
rs4.random() # Advance state
|
|
assert not comparator(rs4, rs5)
|
|
|
|
# Both advanced by same amount should be equal
|
|
rs5.random()
|
|
assert comparator(rs4, rs5)
|
|
|
|
# Test state restoration
|
|
rs6 = np.random.RandomState(seed=42)
|
|
state = rs6.get_state()
|
|
rs6.random() # Advance state
|
|
rs7 = np.random.RandomState(seed=42)
|
|
rs7.set_state(state)
|
|
# rs6 advanced, rs7 restored to original state
|
|
assert not comparator(rs6, rs7)
|
|
|
|
|
|
def test_scipy():
|
|
try:
|
|
import scipy as sp # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
b = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
c = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
ca = sp.sparse.csr_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(c, ca)
|
|
|
|
d = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
e = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
f = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
fa = sp.sparse.csc_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
assert not comparator(a, d)
|
|
assert not comparator(c, f)
|
|
assert not comparator(f, fa)
|
|
|
|
g = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
h = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
i = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
assert not comparator(a, g)
|
|
|
|
j = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
k = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
l = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
assert not comparator(a, j)
|
|
|
|
m = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
n = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
o = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
assert comparator(m, n)
|
|
assert not comparator(m, o)
|
|
assert not comparator(a, m)
|
|
|
|
p = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
q = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
r = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
assert comparator(p, q)
|
|
assert not comparator(p, r)
|
|
assert not comparator(a, p)
|
|
|
|
s = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
t = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
|
u = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
|
|
assert comparator(s, t)
|
|
assert not comparator(s, u)
|
|
assert not comparator(a, s)
|
|
|
|
try:
|
|
import numpy as np
|
|
|
|
row = np.array([0, 3, 1, 0])
|
|
col = np.array([0, 3, 1, 2])
|
|
data = np.array([4, 5, 7, 9])
|
|
v = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
|
|
w = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
|
|
assert comparator(v, w)
|
|
except ImportError:
|
|
print("Should run tests with numpy installed to test more thoroughly")
|
|
|
|
|
|
def test_pandas():
|
|
try:
|
|
import pandas as pd
|
|
except ImportError:
|
|
pytest.skip()
|
|
a = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
b = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
c = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 7]})
|
|
ca = pd.DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(c, ca)
|
|
|
|
ak = pd.DataFrame(
|
|
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
|
|
)
|
|
al = pd.DataFrame(
|
|
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
|
|
)
|
|
am = pd.DataFrame(
|
|
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]}
|
|
)
|
|
assert comparator(ak, al)
|
|
assert not comparator(ak, am)
|
|
|
|
d = pd.Series([1, 2, 3])
|
|
e = pd.Series([1, 2, 3])
|
|
f = pd.Series([1, 2, 4])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
g = pd.Index([1, 2, 3])
|
|
h = pd.Index([1, 2, 3])
|
|
i = pd.Index([1, 2, 4])
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
j = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
|
|
k = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
|
|
l = pd.MultiIndex.from_tuples([(1, 2), (3, 5)])
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
m = pd.Categorical([1, 2, 3])
|
|
n = pd.Categorical([1, 2, 3])
|
|
o = pd.Categorical([1, 2, 4])
|
|
assert comparator(m, n)
|
|
assert not comparator(m, o)
|
|
|
|
p = pd.Interval(1, 2)
|
|
q = pd.Interval(1, 2)
|
|
r = pd.Interval(1, 3)
|
|
assert comparator(p, q)
|
|
assert not comparator(p, r)
|
|
|
|
s = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
|
|
t = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
|
|
u = pd.IntervalIndex.from_tuples([(1, 2), (3, 5)])
|
|
assert comparator(s, t)
|
|
assert not comparator(s, u)
|
|
|
|
v = pd.Period("2021-01")
|
|
w = pd.Period("2021-01")
|
|
x = pd.Period("2021-02")
|
|
assert comparator(v, w)
|
|
assert not comparator(v, x)
|
|
|
|
y = pd.period_range(start="2021-01", periods=3, freq="M")
|
|
z = pd.period_range(start="2021-01", periods=3, freq="M")
|
|
aa = pd.period_range(start="2021-01", periods=4, freq="M")
|
|
assert comparator(y, z)
|
|
assert not comparator(y, aa)
|
|
|
|
ab = pd.Timedelta("1 days")
|
|
ac = pd.Timedelta("1 days")
|
|
ad = pd.Timedelta("2 days")
|
|
assert comparator(ab, ac)
|
|
assert not comparator(ab, ad)
|
|
|
|
ae = pd.TimedeltaIndex(["1 days", "2 days"])
|
|
af = pd.TimedeltaIndex(["1 days", "2 days"])
|
|
ag = pd.TimedeltaIndex(["1 days", "3 days"])
|
|
assert comparator(ae, af)
|
|
assert not comparator(ae, ag)
|
|
|
|
ah = pd.Timestamp("2021-01-01")
|
|
ai = pd.Timestamp("2021-01-01")
|
|
aj = pd.Timestamp("2021-01-02")
|
|
assert comparator(ah, ai)
|
|
assert not comparator(ah, aj)
|
|
|
|
# test cases for sparse pandas arrays
|
|
an = pd.arrays.SparseArray([1, 2, 3])
|
|
ao = pd.arrays.SparseArray([1, 2, 3])
|
|
ap = pd.arrays.SparseArray([1, 2, 4])
|
|
assert comparator(an, ao)
|
|
assert not comparator(an, ap)
|
|
|
|
assert comparator(pd.NA, pd.NA)
|
|
assert not comparator(pd.NA, None)
|
|
assert not comparator(None, pd.NA)
|
|
|
|
s1 = pd.Series([1, 2, pd.NA, 4])
|
|
s2 = pd.Series([1, 2, pd.NA, 4])
|
|
s3 = pd.Series([1, 2, None, 4])
|
|
|
|
assert comparator(s1, s2)
|
|
assert not comparator(s1, s3)
|
|
|
|
df1 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]})
|
|
df2 = pd.DataFrame({"a": [1, 2, pd.NA], "b": [4, pd.NA, 6]})
|
|
df3 = pd.DataFrame({"a": [1, 2, None], "b": [4, None, 6]})
|
|
assert comparator(df1, df2)
|
|
assert not comparator(df1, df3)
|
|
|
|
d1 = {"a": pd.NA, "b": [1, pd.NA, 3]}
|
|
d2 = {"a": pd.NA, "b": [1, pd.NA, 3]}
|
|
d3 = {"a": None, "b": [1, None, 3]}
|
|
assert comparator(d1, d2)
|
|
assert not comparator(d1, d3)
|
|
|
|
s1 = pd.Series([1, 2, pd.NA, 4])
|
|
s2 = pd.Series([1, 2, pd.NA, 4])
|
|
|
|
filtered1 = s1[s1 > 1]
|
|
filtered2 = s2[s2 > 1]
|
|
assert comparator(filtered1, filtered2)
|
|
|
|
|
|
def test_pyarrow():
|
|
try:
|
|
import pyarrow as pa
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test PyArrow Table
|
|
table1 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
table2 = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
|
|
table3 = pa.table({"a": [1, 2, 3], "b": [4, 5, 7]})
|
|
table4 = pa.table({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
|
|
table5 = pa.table({"a": [1, 2, 3], "c": [4, 5, 6]}) # different column name
|
|
|
|
assert comparator(table1, table2)
|
|
assert not comparator(table1, table3)
|
|
assert not comparator(table1, table4)
|
|
assert not comparator(table1, table5)
|
|
|
|
# Test PyArrow RecordBatch
|
|
batch1 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
|
|
batch2 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 4.0]})
|
|
batch3 = pa.RecordBatch.from_pydict({"x": [1, 2], "y": [3.0, 5.0]})
|
|
batch4 = pa.RecordBatch.from_pydict({"x": [1, 2, 3], "y": [3.0, 4.0, 5.0]})
|
|
|
|
assert comparator(batch1, batch2)
|
|
assert not comparator(batch1, batch3)
|
|
assert not comparator(batch1, batch4)
|
|
|
|
# Test PyArrow Array
|
|
arr1 = pa.array([1, 2, 3])
|
|
arr2 = pa.array([1, 2, 3])
|
|
arr3 = pa.array([1, 2, 4])
|
|
arr4 = pa.array([1, 2, 3, 4])
|
|
arr5 = pa.array([1.0, 2.0, 3.0]) # different type
|
|
|
|
assert comparator(arr1, arr2)
|
|
assert not comparator(arr1, arr3)
|
|
assert not comparator(arr1, arr4)
|
|
assert not comparator(arr1, arr5)
|
|
|
|
# Test PyArrow Array with nulls
|
|
arr_null1 = pa.array([1, None, 3])
|
|
arr_null2 = pa.array([1, None, 3])
|
|
arr_null3 = pa.array([1, 2, 3])
|
|
|
|
assert comparator(arr_null1, arr_null2)
|
|
assert not comparator(arr_null1, arr_null3)
|
|
|
|
# Test PyArrow ChunkedArray
|
|
chunked1 = pa.chunked_array([[1, 2], [3, 4]])
|
|
chunked2 = pa.chunked_array([[1, 2], [3, 4]])
|
|
chunked3 = pa.chunked_array([[1, 2], [3, 5]])
|
|
chunked4 = pa.chunked_array([[1, 2, 3], [4, 5]])
|
|
|
|
assert comparator(chunked1, chunked2)
|
|
assert not comparator(chunked1, chunked3)
|
|
assert not comparator(chunked1, chunked4)
|
|
|
|
# Test PyArrow Scalar
|
|
scalar1 = pa.scalar(42)
|
|
scalar2 = pa.scalar(42)
|
|
scalar3 = pa.scalar(43)
|
|
scalar4 = pa.scalar(42.0) # different type
|
|
|
|
assert comparator(scalar1, scalar2)
|
|
assert not comparator(scalar1, scalar3)
|
|
assert not comparator(scalar1, scalar4)
|
|
|
|
# Test null scalars
|
|
null_scalar1 = pa.scalar(None, type=pa.int64())
|
|
null_scalar2 = pa.scalar(None, type=pa.int64())
|
|
null_scalar3 = pa.scalar(None, type=pa.float64())
|
|
|
|
assert comparator(null_scalar1, null_scalar2)
|
|
assert not comparator(null_scalar1, null_scalar3)
|
|
|
|
# Test PyArrow Schema
|
|
schema1 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
|
|
schema2 = pa.schema([("a", pa.int64()), ("b", pa.float64())])
|
|
schema3 = pa.schema([("a", pa.int64()), ("c", pa.float64())])
|
|
schema4 = pa.schema([("a", pa.int32()), ("b", pa.float64())])
|
|
|
|
assert comparator(schema1, schema2)
|
|
assert not comparator(schema1, schema3)
|
|
assert not comparator(schema1, schema4)
|
|
|
|
# Test PyArrow Field
|
|
field1 = pa.field("name", pa.int64())
|
|
field2 = pa.field("name", pa.int64())
|
|
field3 = pa.field("other", pa.int64())
|
|
field4 = pa.field("name", pa.float64())
|
|
|
|
assert comparator(field1, field2)
|
|
assert not comparator(field1, field3)
|
|
assert not comparator(field1, field4)
|
|
|
|
# Test PyArrow DataType
|
|
type1 = pa.int64()
|
|
type2 = pa.int64()
|
|
type3 = pa.int32()
|
|
type4 = pa.float64()
|
|
|
|
assert comparator(type1, type2)
|
|
assert not comparator(type1, type3)
|
|
assert not comparator(type1, type4)
|
|
|
|
# Test string arrays
|
|
str_arr1 = pa.array(["hello", "world"])
|
|
str_arr2 = pa.array(["hello", "world"])
|
|
str_arr3 = pa.array(["hello", "there"])
|
|
|
|
assert comparator(str_arr1, str_arr2)
|
|
assert not comparator(str_arr1, str_arr3)
|
|
|
|
# Test nested types (struct)
|
|
struct_arr1 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
|
|
struct_arr2 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 4}])
|
|
struct_arr3 = pa.array([{"x": 1, "y": 2}, {"x": 3, "y": 5}])
|
|
|
|
assert comparator(struct_arr1, struct_arr2)
|
|
assert not comparator(struct_arr1, struct_arr3)
|
|
|
|
# Test list arrays
|
|
list_arr1 = pa.array([[1, 2], [3, 4, 5]])
|
|
list_arr2 = pa.array([[1, 2], [3, 4, 5]])
|
|
list_arr3 = pa.array([[1, 2], [3, 4, 6]])
|
|
|
|
assert comparator(list_arr1, list_arr2)
|
|
assert not comparator(list_arr1, list_arr3)
|
|
|
|
|
|
def test_pyrsistent():
|
|
try:
|
|
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
a = pmap({"a": 1, "b": 2})
|
|
b = pmap({"a": 1, "b": 2})
|
|
c = pmap({"a": 1, "b": 3})
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
d = pvector([1, 2, 3])
|
|
e = pvector([1, 2, 3])
|
|
f = pvector([1, 2, 4])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
g = pset([1, 2, 3])
|
|
h = pset([2, 3, 1])
|
|
i = pset([1, 2, 4])
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
class TestRecord(PRecord):
|
|
a = field()
|
|
b = field()
|
|
|
|
j = TestRecord()
|
|
k = TestRecord()
|
|
l = TestRecord(a=2, b=3)
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
class TestClass(PClass):
|
|
a = field()
|
|
b = field()
|
|
|
|
m = TestClass()
|
|
n = TestClass()
|
|
o = TestClass(a=1, b=3)
|
|
assert comparator(m, n)
|
|
assert not comparator(m, o)
|
|
|
|
p = pdeque([1, 2, 3], 3)
|
|
q = pdeque([1, 2, 3], 3)
|
|
r = pdeque([1, 2, 4], 3)
|
|
assert comparator(p, q)
|
|
assert not comparator(p, r)
|
|
|
|
s = PBag([1, 2, 3])
|
|
t = PBag([1, 2, 3])
|
|
u = PBag([1, 2, 4])
|
|
assert comparator(s, t)
|
|
assert not comparator(s, u)
|
|
|
|
v = pvector([1, 2, 3])
|
|
w = pvector([1, 2, 3])
|
|
x = pvector([1, 2, 4])
|
|
assert comparator(v, w)
|
|
assert not comparator(v, x)
|
|
|
|
|
|
def test_torch_dtype():
|
|
try:
|
|
import torch # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test torch.dtype comparisons
|
|
a = torch.float32
|
|
b = torch.float32
|
|
c = torch.float64
|
|
d = torch.int32
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# Test different dtype categories
|
|
e = torch.int64
|
|
f = torch.int64
|
|
g = torch.int32
|
|
assert comparator(e, f)
|
|
assert not comparator(e, g)
|
|
|
|
# Test complex dtypes
|
|
h = torch.complex64
|
|
i = torch.complex64
|
|
j = torch.complex128
|
|
assert comparator(h, i)
|
|
assert not comparator(h, j)
|
|
|
|
# Test bool dtype
|
|
k = torch.bool
|
|
l = torch.bool
|
|
m = torch.int8
|
|
assert comparator(k, l)
|
|
assert not comparator(k, m)
|
|
|
|
|
|
def test_torch():
|
|
try:
|
|
import torch # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
a = torch.tensor([1, 2, 3])
|
|
b = torch.tensor([1, 2, 3])
|
|
c = torch.tensor([1, 2, 4])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test tensors with different data types
|
|
g = torch.tensor([1, 2, 3], dtype=torch.float32)
|
|
h = torch.tensor([1, 2, 3], dtype=torch.float32)
|
|
i = torch.tensor([1, 2, 3], dtype=torch.int64)
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test 3D tensors
|
|
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test tensors with different shapes
|
|
m = torch.tensor([1, 2, 3])
|
|
n = torch.tensor([[1, 2, 3]])
|
|
assert not comparator(m, n)
|
|
|
|
# Test empty tensors
|
|
o = torch.tensor([])
|
|
p = torch.tensor([])
|
|
q = torch.tensor([1])
|
|
assert comparator(o, p)
|
|
assert not comparator(o, q)
|
|
|
|
# Test tensors with NaN values
|
|
r = torch.tensor([1.0, float("nan"), 3.0])
|
|
s = torch.tensor([1.0, float("nan"), 3.0])
|
|
t = torch.tensor([1.0, 2.0, 3.0])
|
|
assert comparator(r, s) # NaN == NaN
|
|
assert not comparator(r, t)
|
|
|
|
# Test tensors with infinity values
|
|
u = torch.tensor([1.0, float("inf"), 3.0])
|
|
v = torch.tensor([1.0, float("inf"), 3.0])
|
|
w = torch.tensor([1.0, float("-inf"), 3.0])
|
|
assert comparator(u, v)
|
|
assert not comparator(u, w)
|
|
|
|
# Test tensors with different devices (if CUDA is available)
|
|
if torch.cuda.is_available():
|
|
x = torch.tensor([1, 2, 3]).cuda()
|
|
y = torch.tensor([1, 2, 3]).cuda()
|
|
z = torch.tensor([1, 2, 3])
|
|
assert comparator(x, y)
|
|
assert not comparator(x, z)
|
|
|
|
# Test tensors with requires_grad
|
|
aa = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
|
bb = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
|
cc = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
|
|
assert comparator(aa, bb)
|
|
assert not comparator(aa, cc)
|
|
|
|
# Test complex tensors
|
|
dd = torch.tensor([1 + 2j, 3 + 4j])
|
|
ee = torch.tensor([1 + 2j, 3 + 4j])
|
|
ff = torch.tensor([1 + 2j, 3 + 5j])
|
|
assert comparator(dd, ee)
|
|
assert not comparator(dd, ff)
|
|
|
|
# Test boolean tensors
|
|
gg = torch.tensor([True, False, True])
|
|
hh = torch.tensor([True, False, True])
|
|
ii = torch.tensor([True, True, True])
|
|
assert comparator(gg, hh)
|
|
assert not comparator(gg, ii)
|
|
|
|
|
|
def test_torch_device():
|
|
try:
|
|
import torch # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test torch.device comparisons - same device type
|
|
a = torch.device("cpu")
|
|
b = torch.device("cpu")
|
|
assert comparator(a, b)
|
|
|
|
# Test different device types
|
|
c = torch.device("cpu")
|
|
d = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
assert not comparator(c, d)
|
|
|
|
# Test device with index
|
|
e = torch.device("cpu")
|
|
f = torch.device("cpu")
|
|
assert comparator(e, f)
|
|
|
|
# Test cuda devices with different indices (if multiple GPUs available)
|
|
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
|
|
g = torch.device("cuda:0")
|
|
h = torch.device("cuda:0")
|
|
i = torch.device("cuda:1")
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test cuda device with and without explicit index
|
|
if torch.cuda.is_available():
|
|
j = torch.device("cuda:0")
|
|
k = torch.device("cuda", 0)
|
|
assert comparator(j, k)
|
|
|
|
# Test meta device
|
|
l = torch.device("meta")
|
|
m = torch.device("meta")
|
|
n = torch.device("cpu")
|
|
assert comparator(l, m)
|
|
assert not comparator(l, n)
|
|
|
|
|
|
def test_torch_nn_linear():
|
|
"""Test comparator for torch.nn.Linear modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Linear layers
|
|
torch.manual_seed(42)
|
|
a = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
b = nn.Linear(10, 5)
|
|
assert comparator(a, b)
|
|
|
|
# Test Linear layers with different weights (different seeds)
|
|
torch.manual_seed(42)
|
|
c = nn.Linear(10, 5)
|
|
torch.manual_seed(123)
|
|
d = nn.Linear(10, 5)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Linear layers with different in_features
|
|
torch.manual_seed(42)
|
|
e = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
f = nn.Linear(20, 5)
|
|
assert not comparator(e, f)
|
|
|
|
# Test Linear layers with different out_features
|
|
torch.manual_seed(42)
|
|
g = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
h = nn.Linear(10, 10)
|
|
assert not comparator(g, h)
|
|
|
|
# Test Linear with and without bias
|
|
torch.manual_seed(42)
|
|
i = nn.Linear(10, 5, bias=True)
|
|
torch.manual_seed(42)
|
|
j = nn.Linear(10, 5, bias=False)
|
|
assert not comparator(i, j)
|
|
|
|
# Test Linear layers in train vs eval mode
|
|
torch.manual_seed(42)
|
|
k = nn.Linear(10, 5)
|
|
k.train()
|
|
torch.manual_seed(42)
|
|
l = nn.Linear(10, 5)
|
|
l.eval()
|
|
assert not comparator(k, l)
|
|
|
|
|
|
def test_torch_nn_conv2d():
|
|
"""Test comparator for torch.nn.Conv2d modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Conv2d layers
|
|
torch.manual_seed(42)
|
|
a = nn.Conv2d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
b = nn.Conv2d(3, 16, kernel_size=3)
|
|
assert comparator(a, b)
|
|
|
|
# Test Conv2d with different weights
|
|
torch.manual_seed(42)
|
|
c = nn.Conv2d(3, 16, kernel_size=3)
|
|
torch.manual_seed(123)
|
|
d = nn.Conv2d(3, 16, kernel_size=3)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Conv2d with different in_channels
|
|
torch.manual_seed(42)
|
|
e = nn.Conv2d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
f = nn.Conv2d(1, 16, kernel_size=3)
|
|
assert not comparator(e, f)
|
|
|
|
# Test Conv2d with different out_channels
|
|
torch.manual_seed(42)
|
|
g = nn.Conv2d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
h = nn.Conv2d(3, 32, kernel_size=3)
|
|
assert not comparator(g, h)
|
|
|
|
# Test Conv2d with different kernel_size
|
|
torch.manual_seed(42)
|
|
i = nn.Conv2d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
j = nn.Conv2d(3, 16, kernel_size=5)
|
|
assert not comparator(i, j)
|
|
|
|
# Test Conv2d with different stride
|
|
torch.manual_seed(42)
|
|
k = nn.Conv2d(3, 16, kernel_size=3, stride=1)
|
|
torch.manual_seed(42)
|
|
l = nn.Conv2d(3, 16, kernel_size=3, stride=2)
|
|
assert not comparator(k, l)
|
|
|
|
# Test Conv2d with different padding
|
|
torch.manual_seed(42)
|
|
m = nn.Conv2d(3, 16, kernel_size=3, padding=0)
|
|
torch.manual_seed(42)
|
|
n = nn.Conv2d(3, 16, kernel_size=3, padding=1)
|
|
assert not comparator(m, n)
|
|
|
|
|
|
def test_torch_nn_batchnorm():
|
|
"""Test comparator for torch.nn.BatchNorm modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical BatchNorm2d layers
|
|
torch.manual_seed(42)
|
|
a = nn.BatchNorm2d(16)
|
|
torch.manual_seed(42)
|
|
b = nn.BatchNorm2d(16)
|
|
assert comparator(a, b)
|
|
|
|
# Test BatchNorm2d with different num_features
|
|
torch.manual_seed(42)
|
|
c = nn.BatchNorm2d(16)
|
|
torch.manual_seed(42)
|
|
d = nn.BatchNorm2d(32)
|
|
assert not comparator(c, d)
|
|
|
|
# Test BatchNorm2d with different eps
|
|
torch.manual_seed(42)
|
|
e = nn.BatchNorm2d(16, eps=1e-5)
|
|
torch.manual_seed(42)
|
|
f = nn.BatchNorm2d(16, eps=1e-3)
|
|
assert not comparator(e, f)
|
|
|
|
# Test BatchNorm2d with different momentum
|
|
torch.manual_seed(42)
|
|
g = nn.BatchNorm2d(16, momentum=0.1)
|
|
torch.manual_seed(42)
|
|
h = nn.BatchNorm2d(16, momentum=0.01)
|
|
assert not comparator(g, h)
|
|
|
|
# Test BatchNorm2d with and without affine
|
|
torch.manual_seed(42)
|
|
i = nn.BatchNorm2d(16, affine=True)
|
|
torch.manual_seed(42)
|
|
j = nn.BatchNorm2d(16, affine=False)
|
|
assert not comparator(i, j)
|
|
|
|
# Test BatchNorm2d running stats after forward passes
|
|
torch.manual_seed(42)
|
|
k = nn.BatchNorm2d(16)
|
|
k.train()
|
|
input_k = torch.randn(4, 16, 8, 8)
|
|
_ = k(input_k)
|
|
torch.manual_seed(42)
|
|
l = nn.BatchNorm2d(16)
|
|
l.train()
|
|
input_l = torch.randn(4, 16, 8, 8)
|
|
_ = l(input_l)
|
|
# Same seed means same running stats
|
|
assert comparator(k, l)
|
|
|
|
# Test BatchNorm2d with different running stats
|
|
torch.manual_seed(42)
|
|
m = nn.BatchNorm2d(16)
|
|
m.train()
|
|
torch.manual_seed(42)
|
|
_ = m(torch.randn(4, 16, 8, 8))
|
|
torch.manual_seed(42)
|
|
n = nn.BatchNorm2d(16)
|
|
n.train()
|
|
torch.manual_seed(123)
|
|
_ = n(torch.randn(4, 16, 8, 8))
|
|
assert not comparator(m, n)
|
|
|
|
# Test BatchNorm1d
|
|
torch.manual_seed(42)
|
|
o = nn.BatchNorm1d(16)
|
|
torch.manual_seed(42)
|
|
p = nn.BatchNorm1d(16)
|
|
assert comparator(o, p)
|
|
|
|
|
|
def test_torch_nn_dropout():
|
|
"""Test comparator for torch.nn.Dropout modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Dropout layers
|
|
a = nn.Dropout(p=0.5)
|
|
b = nn.Dropout(p=0.5)
|
|
assert comparator(a, b)
|
|
|
|
# Test Dropout with different p values
|
|
c = nn.Dropout(p=0.5)
|
|
d = nn.Dropout(p=0.3)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Dropout with different inplace values
|
|
e = nn.Dropout(p=0.5, inplace=False)
|
|
f = nn.Dropout(p=0.5, inplace=True)
|
|
assert not comparator(e, f)
|
|
|
|
# Test Dropout2d
|
|
g = nn.Dropout2d(p=0.5)
|
|
h = nn.Dropout2d(p=0.5)
|
|
assert comparator(g, h)
|
|
|
|
# Test Dropout vs Dropout2d (different types)
|
|
i = nn.Dropout(p=0.5)
|
|
j = nn.Dropout2d(p=0.5)
|
|
assert not comparator(i, j)
|
|
|
|
|
|
def test_torch_nn_activation():
|
|
"""Test comparator for torch.nn activation modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test ReLU
|
|
a = nn.ReLU()
|
|
b = nn.ReLU()
|
|
assert comparator(a, b)
|
|
|
|
# Test ReLU with different inplace
|
|
c = nn.ReLU(inplace=False)
|
|
d = nn.ReLU(inplace=True)
|
|
assert not comparator(c, d)
|
|
|
|
# Test LeakyReLU
|
|
e = nn.LeakyReLU(negative_slope=0.01)
|
|
f = nn.LeakyReLU(negative_slope=0.01)
|
|
assert comparator(e, f)
|
|
|
|
# Test LeakyReLU with different negative_slope
|
|
g = nn.LeakyReLU(negative_slope=0.01)
|
|
h = nn.LeakyReLU(negative_slope=0.1)
|
|
assert not comparator(g, h)
|
|
|
|
# Test Sigmoid vs ReLU (different types)
|
|
i = nn.Sigmoid()
|
|
j = nn.ReLU()
|
|
assert not comparator(i, j)
|
|
|
|
# Test GELU
|
|
k = nn.GELU()
|
|
l = nn.GELU()
|
|
assert comparator(k, l)
|
|
|
|
# Test Softmax
|
|
m = nn.Softmax(dim=1)
|
|
n = nn.Softmax(dim=1)
|
|
assert comparator(m, n)
|
|
|
|
# Test Softmax with different dim
|
|
o = nn.Softmax(dim=1)
|
|
p = nn.Softmax(dim=0)
|
|
assert not comparator(o, p)
|
|
|
|
|
|
def test_torch_nn_pooling():
|
|
"""Test comparator for torch.nn pooling modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test MaxPool2d
|
|
a = nn.MaxPool2d(kernel_size=2)
|
|
b = nn.MaxPool2d(kernel_size=2)
|
|
assert comparator(a, b)
|
|
|
|
# Test MaxPool2d with different kernel_size
|
|
c = nn.MaxPool2d(kernel_size=2)
|
|
d = nn.MaxPool2d(kernel_size=3)
|
|
assert not comparator(c, d)
|
|
|
|
# Test MaxPool2d with different stride
|
|
e = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
f = nn.MaxPool2d(kernel_size=2, stride=1)
|
|
assert not comparator(e, f)
|
|
|
|
# Test AvgPool2d
|
|
g = nn.AvgPool2d(kernel_size=2)
|
|
h = nn.AvgPool2d(kernel_size=2)
|
|
assert comparator(g, h)
|
|
|
|
# Test MaxPool2d vs AvgPool2d (different types)
|
|
i = nn.MaxPool2d(kernel_size=2)
|
|
j = nn.AvgPool2d(kernel_size=2)
|
|
assert not comparator(i, j)
|
|
|
|
# Test AdaptiveAvgPool2d
|
|
k = nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
|
l = nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
|
assert comparator(k, l)
|
|
|
|
# Test AdaptiveAvgPool2d with different output_size
|
|
m = nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
|
n = nn.AdaptiveAvgPool2d(output_size=(2, 2))
|
|
assert not comparator(m, n)
|
|
|
|
|
|
def test_torch_nn_embedding():
|
|
"""Test comparator for torch.nn.Embedding modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Embedding layers
|
|
torch.manual_seed(42)
|
|
a = nn.Embedding(1000, 128)
|
|
torch.manual_seed(42)
|
|
b = nn.Embedding(1000, 128)
|
|
assert comparator(a, b)
|
|
|
|
# Test Embedding with different weights
|
|
torch.manual_seed(42)
|
|
c = nn.Embedding(1000, 128)
|
|
torch.manual_seed(123)
|
|
d = nn.Embedding(1000, 128)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Embedding with different num_embeddings
|
|
torch.manual_seed(42)
|
|
e = nn.Embedding(1000, 128)
|
|
torch.manual_seed(42)
|
|
f = nn.Embedding(2000, 128)
|
|
assert not comparator(e, f)
|
|
|
|
# Test Embedding with different embedding_dim
|
|
torch.manual_seed(42)
|
|
g = nn.Embedding(1000, 128)
|
|
torch.manual_seed(42)
|
|
h = nn.Embedding(1000, 256)
|
|
assert not comparator(g, h)
|
|
|
|
# Test Embedding with different padding_idx
|
|
torch.manual_seed(42)
|
|
i = nn.Embedding(1000, 128, padding_idx=0)
|
|
torch.manual_seed(42)
|
|
j = nn.Embedding(1000, 128, padding_idx=1)
|
|
assert not comparator(i, j)
|
|
|
|
# Test Embedding with and without padding_idx
|
|
torch.manual_seed(42)
|
|
k = nn.Embedding(1000, 128)
|
|
torch.manual_seed(42)
|
|
l = nn.Embedding(1000, 128, padding_idx=0)
|
|
assert not comparator(k, l)
|
|
|
|
|
|
def test_torch_nn_lstm():
|
|
"""Test comparator for torch.nn.LSTM modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical LSTM layers
|
|
torch.manual_seed(42)
|
|
a = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
|
|
torch.manual_seed(42)
|
|
b = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
|
|
assert comparator(a, b)
|
|
|
|
# Test LSTM with different weights
|
|
torch.manual_seed(42)
|
|
c = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
|
|
torch.manual_seed(123)
|
|
d = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
|
|
assert not comparator(c, d)
|
|
|
|
# Test LSTM with different input_size
|
|
torch.manual_seed(42)
|
|
e = nn.LSTM(input_size=10, hidden_size=20)
|
|
torch.manual_seed(42)
|
|
f = nn.LSTM(input_size=20, hidden_size=20)
|
|
assert not comparator(e, f)
|
|
|
|
# Test LSTM with different hidden_size
|
|
torch.manual_seed(42)
|
|
g = nn.LSTM(input_size=10, hidden_size=20)
|
|
torch.manual_seed(42)
|
|
h = nn.LSTM(input_size=10, hidden_size=40)
|
|
assert not comparator(g, h)
|
|
|
|
# Test LSTM with different num_layers
|
|
torch.manual_seed(42)
|
|
i = nn.LSTM(input_size=10, hidden_size=20, num_layers=1)
|
|
torch.manual_seed(42)
|
|
j = nn.LSTM(input_size=10, hidden_size=20, num_layers=2)
|
|
assert not comparator(i, j)
|
|
|
|
# Test LSTM with different bidirectional
|
|
torch.manual_seed(42)
|
|
k = nn.LSTM(input_size=10, hidden_size=20, bidirectional=False)
|
|
torch.manual_seed(42)
|
|
l = nn.LSTM(input_size=10, hidden_size=20, bidirectional=True)
|
|
assert not comparator(k, l)
|
|
|
|
# Test LSTM with different batch_first
|
|
torch.manual_seed(42)
|
|
m = nn.LSTM(input_size=10, hidden_size=20, batch_first=False)
|
|
torch.manual_seed(42)
|
|
n = nn.LSTM(input_size=10, hidden_size=20, batch_first=True)
|
|
assert not comparator(m, n)
|
|
|
|
|
|
def test_torch_nn_gru():
|
|
"""Test comparator for torch.nn.GRU modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical GRU layers
|
|
torch.manual_seed(42)
|
|
a = nn.GRU(input_size=10, hidden_size=20, num_layers=2)
|
|
torch.manual_seed(42)
|
|
b = nn.GRU(input_size=10, hidden_size=20, num_layers=2)
|
|
assert comparator(a, b)
|
|
|
|
# Test GRU with different hidden_size
|
|
torch.manual_seed(42)
|
|
c = nn.GRU(input_size=10, hidden_size=20)
|
|
torch.manual_seed(42)
|
|
d = nn.GRU(input_size=10, hidden_size=40)
|
|
assert not comparator(c, d)
|
|
|
|
# Test GRU vs LSTM (different types)
|
|
torch.manual_seed(42)
|
|
e = nn.GRU(input_size=10, hidden_size=20)
|
|
torch.manual_seed(42)
|
|
f = nn.LSTM(input_size=10, hidden_size=20)
|
|
assert not comparator(e, f)
|
|
|
|
|
|
def test_torch_nn_layernorm():
|
|
"""Test comparator for torch.nn.LayerNorm modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical LayerNorm layers
|
|
torch.manual_seed(42)
|
|
a = nn.LayerNorm(normalized_shape=[10])
|
|
torch.manual_seed(42)
|
|
b = nn.LayerNorm(normalized_shape=[10])
|
|
assert comparator(a, b)
|
|
|
|
# Test LayerNorm with different normalized_shape
|
|
torch.manual_seed(42)
|
|
c = nn.LayerNorm(normalized_shape=[10])
|
|
torch.manual_seed(42)
|
|
d = nn.LayerNorm(normalized_shape=[20])
|
|
assert not comparator(c, d)
|
|
|
|
# Test LayerNorm with different eps
|
|
torch.manual_seed(42)
|
|
e = nn.LayerNorm(normalized_shape=[10], eps=1e-5)
|
|
torch.manual_seed(42)
|
|
f = nn.LayerNorm(normalized_shape=[10], eps=1e-3)
|
|
assert not comparator(e, f)
|
|
|
|
# Test LayerNorm with and without elementwise_affine
|
|
torch.manual_seed(42)
|
|
g = nn.LayerNorm(normalized_shape=[10], elementwise_affine=True)
|
|
torch.manual_seed(42)
|
|
h = nn.LayerNorm(normalized_shape=[10], elementwise_affine=False)
|
|
assert not comparator(g, h)
|
|
|
|
|
|
def test_torch_nn_multihead_attention():
|
|
"""Test comparator for torch.nn.MultiheadAttention modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical MultiheadAttention layers
|
|
torch.manual_seed(42)
|
|
a = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
torch.manual_seed(42)
|
|
b = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
assert comparator(a, b)
|
|
|
|
# Test MultiheadAttention with different weights
|
|
torch.manual_seed(42)
|
|
c = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
torch.manual_seed(123)
|
|
d = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
assert not comparator(c, d)
|
|
|
|
# Test MultiheadAttention with different embed_dim
|
|
torch.manual_seed(42)
|
|
e = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
torch.manual_seed(42)
|
|
f = nn.MultiheadAttention(embed_dim=128, num_heads=8)
|
|
assert not comparator(e, f)
|
|
|
|
# Test MultiheadAttention with different num_heads
|
|
torch.manual_seed(42)
|
|
g = nn.MultiheadAttention(embed_dim=64, num_heads=8)
|
|
torch.manual_seed(42)
|
|
h = nn.MultiheadAttention(embed_dim=64, num_heads=4)
|
|
assert not comparator(g, h)
|
|
|
|
# Test MultiheadAttention with different dropout
|
|
torch.manual_seed(42)
|
|
i = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.0)
|
|
torch.manual_seed(42)
|
|
j = nn.MultiheadAttention(embed_dim=64, num_heads=8, dropout=0.1)
|
|
assert not comparator(i, j)
|
|
|
|
|
|
def test_torch_nn_sequential():
|
|
"""Test comparator for torch.nn.Sequential modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Sequential modules
|
|
torch.manual_seed(42)
|
|
a = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
|
torch.manual_seed(42)
|
|
b = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
|
assert comparator(a, b)
|
|
|
|
# Test Sequential with different weights
|
|
torch.manual_seed(42)
|
|
c = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
|
torch.manual_seed(123)
|
|
d = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
|
assert not comparator(c, d)
|
|
|
|
# Test Sequential with different number of layers
|
|
torch.manual_seed(42)
|
|
e = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
|
|
torch.manual_seed(42)
|
|
f = nn.Sequential(nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5))
|
|
assert not comparator(e, f)
|
|
|
|
# Test Sequential with different layer types
|
|
torch.manual_seed(42)
|
|
g = nn.Sequential(nn.Linear(10, 20), nn.ReLU())
|
|
torch.manual_seed(42)
|
|
h = nn.Sequential(nn.Linear(10, 20), nn.Sigmoid())
|
|
assert not comparator(g, h)
|
|
|
|
|
|
def test_torch_nn_modulelist():
|
|
"""Test comparator for torch.nn.ModuleList modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical ModuleList
|
|
torch.manual_seed(42)
|
|
a = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
|
|
torch.manual_seed(42)
|
|
b = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
|
|
assert comparator(a, b)
|
|
|
|
# Test ModuleList with different number of modules
|
|
torch.manual_seed(42)
|
|
c = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
|
|
torch.manual_seed(42)
|
|
d = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)])
|
|
assert not comparator(c, d)
|
|
|
|
|
|
def test_torch_nn_moduledict():
|
|
"""Test comparator for torch.nn.ModuleDict modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical ModuleDict
|
|
torch.manual_seed(42)
|
|
a = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
|
torch.manual_seed(42)
|
|
b = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
|
assert comparator(a, b)
|
|
|
|
# Test ModuleDict with different keys
|
|
torch.manual_seed(42)
|
|
c = nn.ModuleDict({"fc1": nn.Linear(10, 20), "fc2": nn.Linear(20, 5)})
|
|
torch.manual_seed(42)
|
|
d = nn.ModuleDict({"layer1": nn.Linear(10, 20), "layer2": nn.Linear(20, 5)})
|
|
assert not comparator(c, d)
|
|
|
|
|
|
def test_torch_nn_custom_module():
|
|
"""Test comparator for custom torch.nn.Module subclasses."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
class SimpleNet(nn.Module):
|
|
def __init__(self, hidden_size):
|
|
super().__init__()
|
|
self.fc1 = nn.Linear(10, hidden_size)
|
|
self.relu = nn.ReLU()
|
|
self.fc2 = nn.Linear(hidden_size, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
# Test identical custom modules
|
|
torch.manual_seed(42)
|
|
a = SimpleNet(hidden_size=20)
|
|
torch.manual_seed(42)
|
|
b = SimpleNet(hidden_size=20)
|
|
assert comparator(a, b)
|
|
|
|
# Test custom modules with different weights
|
|
torch.manual_seed(42)
|
|
c = SimpleNet(hidden_size=20)
|
|
torch.manual_seed(123)
|
|
d = SimpleNet(hidden_size=20)
|
|
assert not comparator(c, d)
|
|
|
|
# Test custom modules with different architecture
|
|
torch.manual_seed(42)
|
|
e = SimpleNet(hidden_size=20)
|
|
torch.manual_seed(42)
|
|
f = SimpleNet(hidden_size=40)
|
|
assert not comparator(e, f)
|
|
|
|
|
|
def test_torch_nn_nested_modules():
|
|
"""Test comparator for nested torch.nn.Module structures."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
class EncoderBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
self.relu = nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.bn(self.conv(x)))
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.block1 = EncoderBlock(3, 16)
|
|
self.block2 = EncoderBlock(16, 32)
|
|
self.pool = nn.MaxPool2d(2)
|
|
|
|
def forward(self, x):
|
|
x = self.block1(x)
|
|
x = self.pool(x)
|
|
x = self.block2(x)
|
|
x = self.pool(x)
|
|
return x
|
|
|
|
# Test identical nested modules
|
|
torch.manual_seed(42)
|
|
a = Encoder()
|
|
torch.manual_seed(42)
|
|
b = Encoder()
|
|
assert comparator(a, b)
|
|
|
|
# Test nested modules with different weights
|
|
torch.manual_seed(42)
|
|
c = Encoder()
|
|
torch.manual_seed(123)
|
|
d = Encoder()
|
|
assert not comparator(c, d)
|
|
|
|
|
|
def test_torch_nn_transformer():
|
|
"""Test comparator for torch.nn.Transformer modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test identical Transformer
|
|
torch.manual_seed(42)
|
|
a = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
|
|
torch.manual_seed(42)
|
|
b = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
|
|
assert comparator(a, b)
|
|
|
|
# Test Transformer with different d_model
|
|
torch.manual_seed(42)
|
|
c = nn.Transformer(d_model=64, nhead=4)
|
|
torch.manual_seed(42)
|
|
d = nn.Transformer(d_model=128, nhead=4)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Transformer with different nhead
|
|
torch.manual_seed(42)
|
|
e = nn.Transformer(d_model=64, nhead=4)
|
|
torch.manual_seed(42)
|
|
f = nn.Transformer(d_model=64, nhead=8)
|
|
assert not comparator(e, f)
|
|
|
|
# Test TransformerEncoder
|
|
torch.manual_seed(42)
|
|
encoder_layer_a = nn.TransformerEncoderLayer(d_model=64, nhead=4)
|
|
g = nn.TransformerEncoder(encoder_layer_a, num_layers=2)
|
|
torch.manual_seed(42)
|
|
encoder_layer_b = nn.TransformerEncoderLayer(d_model=64, nhead=4)
|
|
h = nn.TransformerEncoder(encoder_layer_b, num_layers=2)
|
|
assert comparator(g, h)
|
|
|
|
|
|
def test_torch_nn_parameter_buffer_modification():
|
|
"""Test comparator detects parameter and buffer modifications."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test that modifying a parameter is detected
|
|
torch.manual_seed(42)
|
|
a = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
b = nn.Linear(10, 5)
|
|
assert comparator(a, b)
|
|
|
|
# Modify a parameter
|
|
with torch.no_grad():
|
|
b.weight[0, 0] = 999.0
|
|
assert not comparator(a, b)
|
|
|
|
# Test that modifying a buffer is detected (BatchNorm running_mean)
|
|
torch.manual_seed(42)
|
|
c = nn.BatchNorm2d(16)
|
|
torch.manual_seed(42)
|
|
d = nn.BatchNorm2d(16)
|
|
assert comparator(c, d)
|
|
|
|
# Modify a buffer
|
|
d.running_mean[0] = 999.0
|
|
assert not comparator(c, d)
|
|
|
|
|
|
def test_torch_nn_device_placement():
|
|
"""Test comparator handles modules on different devices."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Create modules on CPU
|
|
torch.manual_seed(42)
|
|
cpu_module = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
cpu_module2 = nn.Linear(10, 5)
|
|
assert comparator(cpu_module, cpu_module2)
|
|
|
|
# If CUDA is available, test device mismatch
|
|
if torch.cuda.is_available():
|
|
torch.manual_seed(42)
|
|
cpu_mod = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
cuda_mod = nn.Linear(10, 5).cuda()
|
|
assert not comparator(cpu_mod, cuda_mod)
|
|
|
|
|
|
def test_torch_nn_conv1d_conv3d():
|
|
"""Test comparator for Conv1d and Conv3d modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test Conv1d
|
|
torch.manual_seed(42)
|
|
a = nn.Conv1d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
b = nn.Conv1d(3, 16, kernel_size=3)
|
|
assert comparator(a, b)
|
|
|
|
# Test Conv1d with different out_channels
|
|
torch.manual_seed(42)
|
|
c = nn.Conv1d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
d = nn.Conv1d(3, 32, kernel_size=3)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Conv3d
|
|
torch.manual_seed(42)
|
|
e = nn.Conv3d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
f = nn.Conv3d(3, 16, kernel_size=3)
|
|
assert comparator(e, f)
|
|
|
|
# Test Conv1d vs Conv2d (different types)
|
|
torch.manual_seed(42)
|
|
g = nn.Conv1d(3, 16, kernel_size=3)
|
|
torch.manual_seed(42)
|
|
h = nn.Conv2d(3, 16, kernel_size=3)
|
|
assert not comparator(g, h)
|
|
|
|
|
|
def test_torch_nn_flatten_unflatten():
|
|
"""Test comparator for Flatten and Unflatten modules."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test Flatten
|
|
a = nn.Flatten()
|
|
b = nn.Flatten()
|
|
assert comparator(a, b)
|
|
|
|
# Test Flatten with different start_dim
|
|
c = nn.Flatten(start_dim=1)
|
|
d = nn.Flatten(start_dim=0)
|
|
assert not comparator(c, d)
|
|
|
|
# Test Unflatten
|
|
e = nn.Unflatten(dim=1, unflattened_size=(2, 5))
|
|
f = nn.Unflatten(dim=1, unflattened_size=(2, 5))
|
|
assert comparator(e, f)
|
|
|
|
|
|
def test_torch_nn_identity():
|
|
"""Test comparator for Identity module."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test Identity
|
|
a = nn.Identity()
|
|
b = nn.Identity()
|
|
assert comparator(a, b)
|
|
|
|
# Test Identity vs Linear (different types)
|
|
torch.manual_seed(42)
|
|
c = nn.Identity()
|
|
d = nn.Linear(10, 10)
|
|
assert not comparator(c, d)
|
|
|
|
|
|
def test_torch_nn_with_superset():
|
|
"""Test comparator superset_obj mode with nn.Module."""
|
|
try:
|
|
import torch
|
|
from torch import nn
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# For nn.Module, superset_obj should still work
|
|
torch.manual_seed(42)
|
|
a = nn.Linear(10, 5)
|
|
torch.manual_seed(42)
|
|
b = nn.Linear(10, 5)
|
|
|
|
# superset_obj=True should pass for identical modules
|
|
assert comparator(a, b, superset_obj=True)
|
|
|
|
# Different modules should still fail
|
|
torch.manual_seed(42)
|
|
c = nn.Linear(10, 5)
|
|
torch.manual_seed(123)
|
|
d = nn.Linear(10, 5)
|
|
assert not comparator(c, d, superset_obj=True)
|
|
|
|
|
|
def test_jax():
|
|
try:
|
|
import jax.numpy as jnp
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test basic arrays
|
|
a = jnp.array([1, 2, 3])
|
|
b = jnp.array([1, 2, 3])
|
|
c = jnp.array([1, 2, 4])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test 2D arrays
|
|
d = jnp.array([[1, 2, 3], [4, 5, 6]])
|
|
e = jnp.array([[1, 2, 3], [4, 5, 6]])
|
|
f = jnp.array([[1, 2, 3], [4, 5, 7]])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test arrays with different data types
|
|
g = jnp.array([1, 2, 3], dtype=jnp.float32)
|
|
h = jnp.array([1, 2, 3], dtype=jnp.float32)
|
|
i = jnp.array([1, 2, 3], dtype=jnp.int32)
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test 3D arrays
|
|
j = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
k = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
l = jnp.array([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test arrays with different shapes
|
|
m = jnp.array([1, 2, 3])
|
|
n = jnp.array([[1, 2, 3]])
|
|
assert not comparator(m, n)
|
|
|
|
# Test empty arrays
|
|
o = jnp.array([])
|
|
p = jnp.array([])
|
|
q = jnp.array([1])
|
|
assert comparator(o, p)
|
|
assert not comparator(o, q)
|
|
|
|
# Test arrays with NaN values
|
|
r = jnp.array([1.0, jnp.nan, 3.0])
|
|
s = jnp.array([1.0, jnp.nan, 3.0])
|
|
t = jnp.array([1.0, 2.0, 3.0])
|
|
assert comparator(r, s) # NaN == NaN
|
|
assert not comparator(r, t)
|
|
|
|
# Test arrays with infinity values
|
|
u = jnp.array([1.0, jnp.inf, 3.0])
|
|
v = jnp.array([1.0, jnp.inf, 3.0])
|
|
w = jnp.array([1.0, -jnp.inf, 3.0])
|
|
assert comparator(u, v)
|
|
assert not comparator(u, w)
|
|
|
|
# Test complex arrays
|
|
x = jnp.array([1 + 2j, 3 + 4j])
|
|
y = jnp.array([1 + 2j, 3 + 4j])
|
|
z = jnp.array([1 + 2j, 3 + 5j])
|
|
assert comparator(x, y)
|
|
assert not comparator(x, z)
|
|
|
|
# Test boolean arrays
|
|
aa = jnp.array([True, False, True])
|
|
bb = jnp.array([True, False, True])
|
|
cc = jnp.array([True, True, True])
|
|
assert comparator(aa, bb)
|
|
assert not comparator(aa, cc)
|
|
|
|
|
|
def test_xarray():
|
|
try:
|
|
import numpy as np
|
|
import xarray as xr
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
# Test basic DataArray
|
|
a = xr.DataArray([1, 2, 3], dims=["x"])
|
|
b = xr.DataArray([1, 2, 3], dims=["x"])
|
|
c = xr.DataArray([1, 2, 4], dims=["x"])
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test DataArray with coordinates
|
|
d = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"])
|
|
e = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 2]}, dims=["x"])
|
|
f = xr.DataArray([1, 2, 3], coords={"x": [0, 1, 3]}, dims=["x"])
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test DataArray with attributes
|
|
g = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"})
|
|
h = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "meters"})
|
|
i = xr.DataArray([1, 2, 3], dims=["x"], attrs={"units": "feet"})
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test 2D DataArray
|
|
j = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"])
|
|
k = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=["x", "y"])
|
|
l = xr.DataArray([[1, 2, 3], [4, 5, 7]], dims=["x", "y"])
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test DataArray with different dimensions
|
|
m = xr.DataArray([1, 2, 3], dims=["x"])
|
|
n = xr.DataArray([1, 2, 3], dims=["y"])
|
|
assert not comparator(m, n)
|
|
|
|
# Test DataArray with NaN values
|
|
o = xr.DataArray([1.0, np.nan, 3.0], dims=["x"])
|
|
p = xr.DataArray([1.0, np.nan, 3.0], dims=["x"])
|
|
q = xr.DataArray([1.0, 2.0, 3.0], dims=["x"])
|
|
assert comparator(o, p)
|
|
assert not comparator(o, q)
|
|
|
|
# Test Dataset
|
|
r = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])})
|
|
s = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 8]])})
|
|
t = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]]), "pressure": (["x", "y"], [[5, 6], [7, 9]])})
|
|
assert comparator(r, s)
|
|
assert not comparator(r, t)
|
|
|
|
# Test Dataset with coordinates
|
|
u = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]})
|
|
v = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 1], "y": [0, 1]})
|
|
w = xr.Dataset({"temp": (["x", "y"], [[1, 2], [3, 4]])}, coords={"x": [0, 2], "y": [0, 1]})
|
|
assert comparator(u, v)
|
|
assert not comparator(u, w)
|
|
|
|
# Test Dataset with attributes
|
|
x = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"})
|
|
y = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "sensor"})
|
|
z = xr.Dataset({"temp": (["x"], [1, 2, 3])}, attrs={"source": "model"})
|
|
assert comparator(x, y)
|
|
assert not comparator(x, z)
|
|
|
|
# Test Dataset with different variables
|
|
aa = xr.Dataset({"temp": (["x"], [1, 2, 3])})
|
|
bb = xr.Dataset({"temp": (["x"], [1, 2, 3])})
|
|
cc = xr.Dataset({"pressure": (["x"], [1, 2, 3])})
|
|
assert comparator(aa, bb)
|
|
assert not comparator(aa, cc)
|
|
|
|
# Test empty Dataset
|
|
dd = xr.Dataset()
|
|
ee = xr.Dataset()
|
|
assert comparator(dd, ee)
|
|
|
|
# Test DataArray with different shapes
|
|
ff = xr.DataArray([1, 2, 3], dims=["x"])
|
|
gg = xr.DataArray([[1, 2, 3]], dims=["x", "y"])
|
|
assert not comparator(ff, gg)
|
|
|
|
# Test DataArray with different data types
|
|
# Note: xarray.identical() considers int and float arrays with same values as identical
|
|
hh = xr.DataArray(np.array([1, 2, 3], dtype="int32"), dims=["x"])
|
|
ii = xr.DataArray(np.array([1, 2, 3], dtype="int64"), dims=["x"])
|
|
# xarray is permissive with dtype comparisons, treats these as identical
|
|
assert comparator(hh, ii)
|
|
|
|
# Test DataArray with infinity
|
|
jj = xr.DataArray([1.0, np.inf, 3.0], dims=["x"])
|
|
kk = xr.DataArray([1.0, np.inf, 3.0], dims=["x"])
|
|
ll = xr.DataArray([1.0, -np.inf, 3.0], dims=["x"])
|
|
assert comparator(jj, kk)
|
|
assert not comparator(jj, ll)
|
|
|
|
# Test Dataset vs DataArray (different types)
|
|
mm = xr.DataArray([1, 2, 3], dims=["x"])
|
|
nn = xr.Dataset({"data": (["x"], [1, 2, 3])})
|
|
assert not comparator(mm, nn)
|
|
|
|
|
|
def test_returns():
|
|
a = Success(5)
|
|
b = Success(5)
|
|
c = Success(6)
|
|
d = Failure(5)
|
|
e = Success((5, 5))
|
|
f = Success((5, 6))
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
assert not comparator(a, e)
|
|
assert not comparator(e, f)
|
|
|
|
g = Success((5, 5))
|
|
h = Success((5, 5))
|
|
i = Success((5, 6))
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
|
|
def test_custom_object():
|
|
class TestClass:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __eq__(self, other):
|
|
return self.value == other.value
|
|
|
|
a = TestClass(5)
|
|
b = TestClass(5)
|
|
c = TestClass(6)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
class TestClass2:
|
|
def __init__(self, value1, value2=6):
|
|
self.value1 = value1
|
|
self.value2 = value2
|
|
|
|
a = TestClass(5)
|
|
b = TestClass2(5, 6)
|
|
c = TestClass2(5, 7)
|
|
d = TestClass2(5, 6)
|
|
assert not comparator(a, b)
|
|
assert not comparator(b, c)
|
|
assert comparator(b, d)
|
|
|
|
class TestClass3(TestClass):
|
|
def print(self):
|
|
print(self.value)
|
|
|
|
a = TestClass2(5)
|
|
b = TestClass3(5)
|
|
c = TestClass3(5)
|
|
assert not comparator(a, b)
|
|
assert comparator(b, c)
|
|
|
|
@dataclasses.dataclass
|
|
class InventoryItem:
|
|
"""Class for keeping track of an item in inventory."""
|
|
|
|
name: str
|
|
unit_price: float
|
|
quantity_on_hand: int = 0
|
|
|
|
def total_cost(self) -> float:
|
|
return self.unit_price * self.quantity_on_hand
|
|
|
|
a = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
b = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
c = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=11)
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
@pydantic.dataclasses.dataclass
|
|
class InventoryItemPydantic:
|
|
"""Class for keeping track of an item in inventory."""
|
|
|
|
name: str
|
|
unit_price: float
|
|
quantity_on_hand: int = 0
|
|
|
|
def total_cost(self) -> float:
|
|
return self.unit_price * self.quantity_on_hand
|
|
|
|
a = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
b = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
c = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
class InventoryItemBasePydantic(pydantic.BaseModel):
|
|
name: str
|
|
unit_price: float
|
|
quantity_on_hand: int = 0
|
|
|
|
def total_cost(self) -> float:
|
|
return self.unit_price * self.quantity_on_hand
|
|
|
|
a = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
b = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
|
|
c = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
class A:
|
|
items = [1, 2, 3]
|
|
val = 5
|
|
|
|
class B:
|
|
items = [1, 2, 4]
|
|
val = 5
|
|
|
|
assert comparator(A, A)
|
|
assert not comparator(A, B)
|
|
|
|
class C:
|
|
items = [1, 2, 3]
|
|
val = 5
|
|
|
|
def __init__(self):
|
|
self.itemm2 = [1, 2, 3]
|
|
self.val2 = 5
|
|
|
|
class D:
|
|
items = [1, 2, 3]
|
|
val = 5
|
|
|
|
def __init__(self):
|
|
self.itemm2 = [1, 2, 4]
|
|
self.val2 = 5
|
|
|
|
assert comparator(C, C)
|
|
assert not comparator(C, D)
|
|
|
|
E = C
|
|
assert comparator(C, E)
|
|
|
|
|
|
def test_custom_object_2():
|
|
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
|
|
original_code = fto_path.read_text("utf-8")
|
|
from code_to_optimize.bubble_sort_method import BubbleSorter
|
|
|
|
a = BubbleSorter()
|
|
assert a.x == 0
|
|
try:
|
|
# Remove the module from sys.modules, to get the updated class
|
|
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
|
|
from code_to_optimize.bubble_sort_method import BubbleSorter
|
|
|
|
b = BubbleSorter()
|
|
assert comparator(
|
|
a, b
|
|
) # Note that type(a) != type(b) as the class type objects are different, even if the code is the same.
|
|
|
|
optimized_code_mutated_attr = """
|
|
class BubbleSorter:
|
|
z = 0
|
|
|
|
def __init__(self, x=1):
|
|
self.x = x
|
|
|
|
def sorter(self, arr):
|
|
for i in range(len(arr)):
|
|
for j in range(len(arr) - 1):
|
|
if arr[j] > arr[j + 1]:
|
|
temp = arr[j]
|
|
arr[j] = arr[j + 1]
|
|
arr[j + 1] = temp
|
|
return arr
|
|
"""
|
|
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
|
|
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
|
|
from code_to_optimize.bubble_sort_method import BubbleSorter
|
|
|
|
c = BubbleSorter()
|
|
assert c.x == 1
|
|
assert not comparator(a, c)
|
|
|
|
optimized_code_new_attr = """
|
|
class BubbleSorter:
|
|
z = 5
|
|
|
|
def __init__(self, x=0):
|
|
self.x = x
|
|
|
|
def sorter(self, arr):
|
|
for i in range(len(arr)):
|
|
for j in range(len(arr) - 1):
|
|
if arr[j] > arr[j + 1]:
|
|
temp = arr[j]
|
|
arr[j] = arr[j + 1]
|
|
arr[j + 1] = temp
|
|
return arr
|
|
"""
|
|
fto_path.write_text(optimized_code_new_attr, "utf-8")
|
|
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
|
|
from code_to_optimize.bubble_sort_method import BubbleSorter
|
|
|
|
d = BubbleSorter()
|
|
assert d.x == 0
|
|
# Currently, we do not check if class variables are different, since the code replacer does not allow this.
|
|
# In the future, if this functionality is allowed, this assert should be false.
|
|
assert comparator(a, d)
|
|
finally:
|
|
fto_path.write_text(original_code, "utf-8")
|
|
|
|
|
|
def test_superset():
|
|
class A:
|
|
def __init__(self):
|
|
self.a = 1
|
|
|
|
obj = A()
|
|
obj.x = 3
|
|
|
|
assert comparator(A(), obj, superset_obj=True)
|
|
assert not comparator(obj, A(), superset_obj=True)
|
|
assert not comparator(A(), obj)
|
|
assert not comparator(obj, A())
|
|
assert comparator(obj, obj, superset_obj=True)
|
|
assert comparator(obj, obj)
|
|
|
|
|
|
def test_compare_results_fn():
|
|
original_results = TestResults()
|
|
original_results.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
new_results_1 = TestResults()
|
|
new_results_1.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=10,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(original_results, new_results_1)
|
|
assert match
|
|
|
|
new_results_2 = TestResults()
|
|
new_results_2.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=10,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=[5],
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(original_results, new_results_2)
|
|
assert not match
|
|
|
|
new_results_3 = TestResults()
|
|
new_results_3.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=10,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
new_results_3.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="2",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=10,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(original_results, new_results_3)
|
|
assert match
|
|
|
|
new_results_4 = TestResults()
|
|
new_results_4.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=False,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.EXISTING_UNIT_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(original_results, new_results_4)
|
|
assert not match
|
|
|
|
new_results_5_baseline = TestResults()
|
|
new_results_5_baseline.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.GENERATED_REGRESSION,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
new_results_5_opt = TestResults()
|
|
new_results_5_opt.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=False,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.GENERATED_REGRESSION,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(new_results_5_baseline, new_results_5_opt)
|
|
assert not match
|
|
|
|
new_results_6_baseline = TestResults()
|
|
new_results_6_baseline.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=True,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.REPLAY_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
new_results_6_opt = TestResults()
|
|
new_results_6_opt.add(
|
|
FunctionTestInvocation(
|
|
id=InvocationId(
|
|
test_module_path="test_module_path",
|
|
test_class_name="test_class_name",
|
|
test_function_name="test_function_name",
|
|
function_getting_tested="function_getting_tested",
|
|
iteration_id="0",
|
|
),
|
|
file_name=Path("file_name"),
|
|
did_pass=False,
|
|
runtime=5,
|
|
test_framework="unittest",
|
|
test_type=TestType.REPLAY_TEST,
|
|
return_value=5,
|
|
timed_out=False,
|
|
loop_index=1,
|
|
)
|
|
)
|
|
|
|
match, _ = compare_test_results(new_results_6_baseline, new_results_6_opt)
|
|
assert not match
|
|
|
|
match, _ = compare_test_results(TestResults(), TestResults())
|
|
assert not match
|
|
|
|
|
|
def test_exceptions():
|
|
type_error = TypeError("This is a type error")
|
|
|
|
type_error_2 = TypeError("This is a type error")
|
|
|
|
assert comparator(type_error, type_error_2)
|
|
|
|
|
|
def raise_exception():
|
|
raise Exception("This is an exception")
|
|
|
|
|
|
def test_exceptions_comparator():
|
|
# Currently we are only comparing the exception types and the attributes that don't start with "_"
|
|
# there are complications with comparing the exception messages
|
|
try:
|
|
raise_exception()
|
|
except Exception as e:
|
|
exception = e
|
|
|
|
try:
|
|
raise_exception()
|
|
except Exception as b:
|
|
exception_2 = b
|
|
|
|
assert comparator(exception, exception_2)
|
|
|
|
exc1 = ValueError("same message")
|
|
exc2 = ValueError("same message")
|
|
assert comparator(exc1, exc2)
|
|
|
|
exc_msg1 = ValueError("message one")
|
|
exc_msg2 = ValueError("message two")
|
|
# Different messages but same types
|
|
assert comparator(exc_msg1, exc_msg2)
|
|
|
|
exc1 = ValueError("common message")
|
|
exc2 = TypeError("common message")
|
|
assert not comparator(exc1, exc2)
|
|
|
|
exc_file_1 = FileNotFoundError(2, "No such file or directory")
|
|
|
|
exc_file2 = FileNotFoundError(2, "No such file or directory")
|
|
|
|
exc_file4 = FileNotFoundError(2, "File not found")
|
|
|
|
exc_file3 = FileNotFoundError(3, "No such file or directory")
|
|
|
|
assert not comparator(exc1, exc2)
|
|
|
|
assert comparator(exc_file_1, exc_file2)
|
|
|
|
assert comparator(exc_file_1, exc_file3)
|
|
|
|
assert comparator(exc_file_1, exc_file4)
|
|
|
|
assert comparator(exception, exception)
|
|
|
|
assert not comparator(exception, None)
|
|
assert not comparator(None, exception)
|
|
assert comparator(None, None)
|
|
|
|
# Different exception types
|
|
exc_type1 = TypeError("Type error")
|
|
exc_type2 = TypeError("Another type error")
|
|
assert comparator(exc_type1, exc_type2)
|
|
|
|
exc_type3 = KeyError("Missing key")
|
|
exc_type4 = KeyError("Missing key")
|
|
assert comparator(exc_type3, exc_type4)
|
|
assert not comparator(exc_type1, exc_type3)
|
|
|
|
# compare the attributes of the exception as well
|
|
class CustomError(Exception):
|
|
def __init__(self, message, code):
|
|
super().__init__(message)
|
|
self.code = code
|
|
|
|
custom_exc1 = CustomError("Something went wrong", 101)
|
|
custom_exc2 = CustomError("Something went wrong", 101)
|
|
assert comparator(custom_exc1, custom_exc2)
|
|
|
|
custom_exc4 = CustomError("Something went wrong", 102)
|
|
|
|
assert not comparator(custom_exc1, custom_exc4)
|
|
|
|
class CustomErrorNoArgs(Exception):
|
|
pass
|
|
|
|
custom_no_args1 = CustomErrorNoArgs()
|
|
custom_no_args2 = CustomErrorNoArgs()
|
|
assert comparator(custom_no_args1, custom_no_args2)
|
|
|
|
exc_empty1 = ValueError("")
|
|
exc_empty2 = ValueError("")
|
|
assert comparator(exc_empty1, exc_empty2)
|
|
|
|
exc_not_empty = ValueError("Not empty")
|
|
assert comparator(exc_empty1, exc_not_empty)
|
|
|
|
class CustomValueError(ValueError):
|
|
pass
|
|
|
|
custom_value_error1 = CustomValueError("A custom value error")
|
|
value_error1 = ValueError("A custom value error")
|
|
assert not comparator(custom_value_error1, value_error1)
|
|
|
|
custom_value_error2 = CustomValueError("Another custom value error")
|
|
assert comparator(custom_value_error1, custom_value_error2)
|
|
|
|
class CustomExceptionWithArgs(Exception):
|
|
def __init__(self, arg1, arg2):
|
|
self.args = (arg1, arg2)
|
|
|
|
custom_args_exc1 = CustomExceptionWithArgs(1, "test")
|
|
custom_args_exc2 = CustomExceptionWithArgs(1, "test")
|
|
assert comparator(custom_args_exc1, custom_args_exc2)
|
|
|
|
custom_args_exc3 = CustomExceptionWithArgs(1, "different")
|
|
assert comparator(custom_args_exc1, custom_args_exc3)
|
|
|
|
def raise_specific_exception():
|
|
raise ZeroDivisionError("Cannot divide by zero")
|
|
|
|
try:
|
|
raise_specific_exception()
|
|
except ZeroDivisionError as z1:
|
|
zero_division_exc1 = z1
|
|
|
|
try:
|
|
raise_specific_exception()
|
|
except ZeroDivisionError as z2:
|
|
zero_division_exc2 = z2
|
|
|
|
assert comparator(zero_division_exc1, zero_division_exc2)
|
|
|
|
zero_division_exc3 = ZeroDivisionError("Different message")
|
|
assert comparator(zero_division_exc1, zero_division_exc3)
|
|
|
|
assert comparator(..., ...)
|
|
assert comparator(Ellipsis, Ellipsis)
|
|
|
|
assert not comparator(..., None)
|
|
|
|
assert not comparator(Ellipsis, None)
|
|
|
|
code7 = "a = 1 + 2"
|
|
module7 = ast.parse(code7)
|
|
for node in ast.walk(module7):
|
|
for child in ast.iter_child_nodes(node):
|
|
child.parent = node # type: ignore
|
|
module8 = copy.deepcopy(module7)
|
|
assert comparator(module7, module8)
|
|
|
|
code2 = "a = 1 + 3"
|
|
|
|
module2 = ast.parse(code2)
|
|
|
|
assert not comparator(module7, module2)
|
|
|
|
|
|
def test_torch_runtime_error_wrapping():
|
|
"""Test that TorchRuntimeError wrapping is handled correctly.
|
|
|
|
When torch.compile is used, exceptions are wrapped in TorchRuntimeError.
|
|
The comparator should consider an IndexError equivalent to a TorchRuntimeError
|
|
that wraps an IndexError.
|
|
"""
|
|
|
|
# Create a mock TorchRuntimeError class that mimics torch._dynamo.exc.TorchRuntimeError
|
|
class TorchRuntimeError(Exception):
|
|
"""Mock TorchRuntimeError for testing."""
|
|
|
|
# Monkey-patch the __module__ to match torch._dynamo.exc
|
|
TorchRuntimeError.__module__ = "torch._dynamo.exc"
|
|
|
|
# Test 1: TorchRuntimeError with __cause__ set to the same exception type
|
|
index_error = IndexError("index 0 is out of bounds for dimension 0 with size 0")
|
|
torch_error = TorchRuntimeError(
|
|
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')"
|
|
)
|
|
torch_error.__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0")
|
|
|
|
# These should be considered equivalent since TorchRuntimeError wraps IndexError
|
|
assert comparator(index_error, torch_error)
|
|
assert comparator(torch_error, index_error)
|
|
|
|
# Test 2: TorchRuntimeError without __cause__ but with matching error type in message
|
|
torch_error_no_cause = TorchRuntimeError(
|
|
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds')"
|
|
)
|
|
assert comparator(index_error, torch_error_no_cause)
|
|
assert comparator(torch_error_no_cause, index_error)
|
|
|
|
# Test 3: Different exception types should not be equivalent
|
|
value_error = ValueError("some value error")
|
|
torch_error_index = TorchRuntimeError("got IndexError('some error')")
|
|
torch_error_index.__cause__ = IndexError("some error")
|
|
assert not comparator(value_error, torch_error_index)
|
|
assert not comparator(torch_error_index, value_error)
|
|
|
|
# Test 4: TorchRuntimeError wrapping a different type should not match
|
|
type_error = TypeError("some type error")
|
|
torch_error_with_index = TorchRuntimeError("got IndexError('index error')")
|
|
torch_error_with_index.__cause__ = IndexError("index error")
|
|
assert not comparator(type_error, torch_error_with_index)
|
|
|
|
# Test 5: Two TorchRuntimeErrors wrapping the same exception type
|
|
torch_error1 = TorchRuntimeError("got IndexError('error 1')")
|
|
torch_error1.__cause__ = IndexError("error 1")
|
|
torch_error2 = TorchRuntimeError("got IndexError('error 2')")
|
|
torch_error2.__cause__ = IndexError("error 2")
|
|
assert comparator(torch_error1, torch_error2)
|
|
|
|
# Test 6: Regular exception comparison still works
|
|
error1 = IndexError("same error")
|
|
error2 = IndexError("same error")
|
|
assert comparator(error1, error2)
|
|
|
|
# Test 7: Exception wrapped in tuple (return value scenario from debug output)
|
|
orig_return = (("tensor1", "tensor2"), {}, IndexError("index 0 is out of bounds for dimension 0 with size 0"))
|
|
torch_wrapped_return = (
|
|
("tensor1", "tensor2"),
|
|
{},
|
|
TorchRuntimeError("Dynamo failed: got IndexError('index 0 is out of bounds for dimension 0 with size 0')"),
|
|
)
|
|
torch_wrapped_return[2].__cause__ = IndexError("index 0 is out of bounds for dimension 0 with size 0")
|
|
assert comparator(orig_return, torch_wrapped_return)
|
|
|
|
|
|
def test_extract_exception_from_message():
|
|
"""Test the _extract_exception_from_message helper function."""
|
|
# Test with single-quoted message
|
|
result = _extract_exception_from_message("got IndexError('some error message')")
|
|
assert result is not None
|
|
assert isinstance(result, IndexError)
|
|
|
|
# Test with double-quoted message
|
|
result = _extract_exception_from_message('got ValueError("another error")')
|
|
assert result is not None
|
|
assert isinstance(result, ValueError)
|
|
|
|
# Test with various builtin exception types
|
|
for exc_name, exc_class in [
|
|
("TypeError", TypeError),
|
|
("KeyError", KeyError),
|
|
("RuntimeError", RuntimeError),
|
|
("AttributeError", AttributeError),
|
|
("ZeroDivisionError", ZeroDivisionError),
|
|
]:
|
|
result = _extract_exception_from_message(f"got {exc_name}('test')")
|
|
assert result is not None
|
|
assert isinstance(result, exc_class)
|
|
|
|
# Test with no matching pattern
|
|
result = _extract_exception_from_message("This is a normal error message")
|
|
assert result is None
|
|
|
|
# Test with non-exception class name
|
|
result = _extract_exception_from_message("got SomeRandomClass('not an exception')")
|
|
assert result is None
|
|
|
|
# Test with partial match (no opening quote)
|
|
result = _extract_exception_from_message("got IndexError without quotes")
|
|
assert result is None
|
|
|
|
# Test with empty string
|
|
result = _extract_exception_from_message("")
|
|
assert result is None
|
|
|
|
# Test with torch-like error message format
|
|
result = _extract_exception_from_message(
|
|
"Dynamo failed to run FX node with fake tensors: got IndexError('index 0 is out of bounds for dimension 0 with size 0')"
|
|
)
|
|
assert result is not None
|
|
assert isinstance(result, IndexError)
|
|
|
|
|
|
def test_get_wrapped_exception():
|
|
"""Test the _get_wrapped_exception helper function."""
|
|
# Test with __cause__ (explicit chaining)
|
|
inner_error = ValueError("inner error")
|
|
outer_error = RuntimeError("outer error")
|
|
outer_error.__cause__ = inner_error
|
|
result = _get_wrapped_exception(outer_error)
|
|
assert result is inner_error
|
|
|
|
# Test with no wrapping
|
|
plain_error = ValueError("plain error")
|
|
result = _get_wrapped_exception(plain_error)
|
|
assert result is None
|
|
|
|
# Test with message pattern
|
|
error_with_pattern = RuntimeError("got TypeError('some type error')")
|
|
result = _get_wrapped_exception(error_with_pattern)
|
|
assert result is not None
|
|
assert isinstance(result, TypeError)
|
|
|
|
# Test that __cause__ takes precedence over message pattern
|
|
actual_cause = IndexError("actual cause")
|
|
error_with_both = RuntimeError("got TypeError('different error in message')")
|
|
error_with_both.__cause__ = actual_cause
|
|
result = _get_wrapped_exception(error_with_both)
|
|
assert result is actual_cause
|
|
assert isinstance(result, IndexError)
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+")
|
|
def test_get_wrapped_exception_exception_group():
|
|
"""Test _get_wrapped_exception with ExceptionGroup (Python 3.11+)."""
|
|
# ExceptionGroup with single exception
|
|
inner_error = ValueError("single inner error")
|
|
group = ExceptionGroup("group", [inner_error])
|
|
result = _get_wrapped_exception(group)
|
|
assert result is inner_error
|
|
|
|
# ExceptionGroup with multiple exceptions - should return None
|
|
error1 = ValueError("error 1")
|
|
error2 = TypeError("error 2")
|
|
multi_group = ExceptionGroup("multi group", [error1, error2])
|
|
result = _get_wrapped_exception(multi_group)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 11), reason="ExceptionGroup requires Python 3.11+")
|
|
def test_comparator_with_exception_group():
|
|
"""Test comparator with ExceptionGroup wrapping (Python 3.11+)."""
|
|
# ExceptionGroup wrapping a single ValueError should match a plain ValueError
|
|
inner_value_error = ValueError("some value error")
|
|
group = ExceptionGroup("group", [inner_value_error])
|
|
|
|
plain_value_error = ValueError("different message but same type")
|
|
assert comparator(group, plain_value_error)
|
|
assert comparator(plain_value_error, group)
|
|
|
|
# ExceptionGroup with different exception type should not match
|
|
inner_type_error = TypeError("type error")
|
|
type_group = ExceptionGroup("group", [inner_type_error])
|
|
assert not comparator(type_group, plain_value_error)
|
|
|
|
# Two ExceptionGroups with same wrapped type should match
|
|
group1 = ExceptionGroup("group1", [ValueError("error 1")])
|
|
group2 = ExceptionGroup("group2", [ValueError("error 2")])
|
|
assert comparator(group1, group2)
|
|
|
|
|
|
def test_comparator_with_cause_chaining():
|
|
"""Test comparator with __cause__ exception chaining."""
|
|
# Create an exception chain using 'raise from'
|
|
inner = IndexError("inner index error")
|
|
outer = RuntimeError("outer runtime error")
|
|
outer.__cause__ = inner
|
|
|
|
# Outer exception should match the inner exception type
|
|
plain_index_error = IndexError("different index error")
|
|
assert comparator(outer, plain_index_error)
|
|
assert comparator(plain_index_error, outer)
|
|
|
|
# Should not match a different type
|
|
plain_type_error = TypeError("type error")
|
|
assert not comparator(outer, plain_type_error)
|
|
|
|
# Two chained exceptions with same wrapper type match (regardless of inner type)
|
|
# because same-type exceptions compare non-private attributes only (__cause__ is ignored)
|
|
outer1 = RuntimeError("outer 1")
|
|
outer1.__cause__ = ValueError("inner 1")
|
|
outer2 = RuntimeError("outer 2")
|
|
outer2.__cause__ = ValueError("inner 2")
|
|
assert comparator(outer1, outer2)
|
|
|
|
# Different wrapper types with same inner type - unwrapping makes them match
|
|
class WrapperA(Exception):
|
|
pass
|
|
|
|
class WrapperB(Exception):
|
|
pass
|
|
|
|
wrapper_a = WrapperA("wrapper a")
|
|
wrapper_a.__cause__ = KeyError("same inner type")
|
|
wrapper_b = WrapperB("wrapper b")
|
|
wrapper_b.__cause__ = KeyError("same inner type")
|
|
# Both unwrap to KeyError, so they should match
|
|
assert comparator(wrapper_a, wrapper_b)
|
|
|
|
# Different wrapper types with different inner types should not match
|
|
wrapper_c = WrapperA("wrapper c")
|
|
wrapper_c.__cause__ = ValueError("value error")
|
|
wrapper_d = WrapperB("wrapper d")
|
|
wrapper_d.__cause__ = TypeError("type error")
|
|
assert not comparator(wrapper_c, wrapper_d)
|
|
|
|
|
|
def test_comparator_with_message_pattern():
|
|
"""Test comparator with exception type extracted from message pattern."""
|
|
# Exception with wrapped type in message (no __cause__)
|
|
wrapper = RuntimeError("Operation failed: got IndexError('list index out of range')")
|
|
|
|
plain_index = IndexError("some index error")
|
|
assert comparator(wrapper, plain_index)
|
|
assert comparator(plain_index, wrapper)
|
|
|
|
# Should not match different types
|
|
plain_key = KeyError("some key error")
|
|
assert not comparator(wrapper, plain_key)
|
|
|
|
|
|
def test_comparator_wrapped_exceptions_bidirectional():
|
|
"""Test that wrapped exception comparison works in both directions."""
|
|
|
|
class CustomWrapper(Exception):
|
|
pass
|
|
|
|
# Create wrapper with __cause__
|
|
inner = AttributeError("attr error")
|
|
wrapper = CustomWrapper("wrapper message")
|
|
wrapper.__cause__ = inner
|
|
|
|
plain_attr = AttributeError("plain attr error")
|
|
|
|
# Test both directions
|
|
assert comparator(wrapper, plain_attr)
|
|
assert comparator(plain_attr, wrapper)
|
|
|
|
# Test with superset_obj flag
|
|
assert comparator(wrapper, plain_attr, superset_obj=True)
|
|
assert comparator(plain_attr, wrapper, superset_obj=True)
|
|
|
|
|
|
def test_comparator_same_type_exceptions_still_work():
|
|
"""Ensure that same-type exception comparison still works correctly."""
|
|
exc1 = ValueError("message 1")
|
|
exc2 = ValueError("message 2")
|
|
assert comparator(exc1, exc2)
|
|
|
|
# With custom attributes
|
|
class CustomError(Exception):
|
|
def __init__(self, msg, code):
|
|
super().__init__(msg)
|
|
self.code = code
|
|
|
|
custom1 = CustomError("msg1", 100)
|
|
custom2 = CustomError("msg2", 100)
|
|
assert comparator(custom1, custom2)
|
|
|
|
custom3 = CustomError("msg3", 200)
|
|
assert not comparator(custom1, custom3)
|
|
|
|
|
|
def test_comparator_no_false_positives_for_wrapped_exceptions():
|
|
"""Test that unrelated exception types don't match due to wrapping logic."""
|
|
# Two completely different exception types should never match
|
|
val_err = ValueError("value error")
|
|
type_err = TypeError("type error")
|
|
assert not comparator(val_err, type_err)
|
|
|
|
# Wrapper with different inner type should not match
|
|
wrapper = RuntimeError("some error")
|
|
wrapper.__cause__ = KeyError("key error")
|
|
assert not comparator(wrapper, val_err)
|
|
assert not comparator(val_err, wrapper)
|
|
|
|
|
|
def test_collections() -> None:
|
|
# Deque
|
|
a = deque([1, 2, 3])
|
|
b = deque([1, 2, 3])
|
|
c = deque([1, 2, 4])
|
|
d = deque([1, 2])
|
|
e = [1, 2, 3]
|
|
f = deque([1, 2, 3], maxlen=5)
|
|
assert comparator(a, b)
|
|
assert comparator(a, f) # same elements, different maxlen is ok
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
assert not comparator(a, e)
|
|
|
|
g = deque([{"a": 1}, {"b": 2}])
|
|
h = deque([{"a": 1}, {"b": 2}])
|
|
i = deque([{"a": 1}, {"b": 3}])
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
empty_deque1 = deque()
|
|
empty_deque2 = deque()
|
|
assert comparator(empty_deque1, empty_deque2)
|
|
assert not comparator(empty_deque1, a)
|
|
|
|
# namedtuple
|
|
Point = namedtuple("Point", ["x", "y"])
|
|
a = Point(x=1, y=2)
|
|
b = Point(x=1, y=2)
|
|
c = Point(x=1, y=3)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
Point2 = namedtuple("Point2", ["x", "y"])
|
|
d = Point2(x=1, y=2)
|
|
assert not comparator(a, d)
|
|
|
|
e = (1, 2)
|
|
assert not comparator(a, e)
|
|
|
|
# ChainMap
|
|
map1 = {"a": 1, "b": 2}
|
|
map2 = {"c": 3, "d": 4}
|
|
a = ChainMap(map1, map2)
|
|
b = ChainMap(map1, map2)
|
|
c = ChainMap(map2, map1)
|
|
d = {"a": 1, "b": 2, "c": 3, "d": 4}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# Counter
|
|
a = Counter(["a", "b", "a", "c", "b", "a"])
|
|
b = Counter({"a": 3, "b": 2, "c": 1})
|
|
c = Counter({"a": 3, "b": 2, "c": 2})
|
|
d = {"a": 3, "b": 2, "c": 1}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# OrderedDict
|
|
a = OrderedDict([("a", 1), ("b", 2)])
|
|
b = OrderedDict([("a", 1), ("b", 2)])
|
|
c = OrderedDict([("b", 2), ("a", 1)])
|
|
d = {"a": 1, "b": 2}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# defaultdict
|
|
a = defaultdict(int, {"a": 1, "b": 2})
|
|
b = defaultdict(int, {"a": 1, "b": 2})
|
|
c = defaultdict(list, {"a": 1, "b": 2})
|
|
d = {"a": 1, "b": 2}
|
|
e = defaultdict(int, {"a": 1, "b": 3})
|
|
assert comparator(a, b)
|
|
assert comparator(a, c)
|
|
assert not comparator(a, d)
|
|
assert not comparator(a, e)
|
|
|
|
# UserDict
|
|
a = UserDict({"a": 1, "b": 2})
|
|
b = UserDict({"a": 1, "b": 2})
|
|
c = UserDict({"a": 1, "b": 3})
|
|
d = {"a": 1, "b": 2}
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# UserList
|
|
a = UserList([1, 2, 3])
|
|
b = UserList([1, 2, 3])
|
|
c = UserList([1, 2, 4])
|
|
d = [1, 2, 3]
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
# UserString
|
|
a = UserString("hello")
|
|
b = UserString("hello")
|
|
c = UserString("world")
|
|
d = "hello"
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
|
|
def test_attrs():
|
|
try:
|
|
import attrs # type: ignore
|
|
except ImportError:
|
|
pytest.skip()
|
|
|
|
@attrs.define
|
|
class Person:
|
|
name: str
|
|
age: int = 10
|
|
|
|
a = Person("Alice", 25)
|
|
b = Person("Alice", 25)
|
|
c = Person("Bob", 25)
|
|
d = Person("Alice", 30)
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
assert not comparator(a, d)
|
|
|
|
@attrs.frozen
|
|
class Point:
|
|
x: int
|
|
y: int
|
|
|
|
p1 = Point(1, 2)
|
|
p2 = Point(1, 2)
|
|
p3 = Point(2, 3)
|
|
assert comparator(p1, p2)
|
|
assert not comparator(p1, p3)
|
|
|
|
@attrs.define(slots=True)
|
|
class Vehicle:
|
|
brand: str
|
|
model: str
|
|
year: int = 2020
|
|
|
|
v1 = Vehicle("Toyota", "Camry", 2021)
|
|
v2 = Vehicle("Toyota", "Camry", 2021)
|
|
v3 = Vehicle("Honda", "Civic", 2021)
|
|
assert comparator(v1, v2)
|
|
assert not comparator(v1, v3)
|
|
|
|
@attrs.define
|
|
class ComplexClass:
|
|
public_field: str
|
|
private_field: str = attrs.field(repr=False)
|
|
non_eq_field: int = attrs.field(eq=False, default=0)
|
|
computed: str = attrs.field(init=False, eq=True)
|
|
|
|
def __attrs_post_init__(self):
|
|
self.computed = f"{self.public_field}_{self.private_field}"
|
|
|
|
c1 = ComplexClass("test", "secret")
|
|
c2 = ComplexClass("test", "secret")
|
|
c3 = ComplexClass("different", "secret")
|
|
|
|
c1.non_eq_field = 100
|
|
c2.non_eq_field = 200
|
|
|
|
assert comparator(c1, c2)
|
|
assert not comparator(c1, c3)
|
|
|
|
@attrs.define
|
|
class Address:
|
|
street: str
|
|
city: str
|
|
|
|
@attrs.define
|
|
class PersonWithAddress:
|
|
name: str
|
|
address: Address
|
|
|
|
addr1 = Address("123 Main St", "Anytown")
|
|
addr2 = Address("123 Main St", "Anytown")
|
|
addr3 = Address("456 Oak Ave", "Anytown")
|
|
|
|
person1 = PersonWithAddress("John", addr1)
|
|
person2 = PersonWithAddress("John", addr2)
|
|
person3 = PersonWithAddress("John", addr3)
|
|
|
|
assert comparator(person1, person2)
|
|
assert not comparator(person1, person3)
|
|
|
|
@attrs.define
|
|
class Container:
|
|
items: list
|
|
metadata: dict
|
|
|
|
cont1 = Container([1, 2, 3], {"type": "numbers"})
|
|
cont2 = Container([1, 2, 3], {"type": "numbers"})
|
|
cont3 = Container([1, 2, 4], {"type": "numbers"})
|
|
|
|
assert comparator(cont1, cont2)
|
|
assert not comparator(cont1, cont3)
|
|
|
|
@attrs.define
|
|
class BaseClass:
|
|
name: str
|
|
value: int
|
|
|
|
@attrs.define
|
|
class ExtendedClass:
|
|
name: str
|
|
value: int
|
|
extra_field: str = "default"
|
|
|
|
base = BaseClass("test", 42)
|
|
extended = ExtendedClass("test", 42, "extra")
|
|
|
|
assert not comparator(base, extended)
|
|
|
|
@attrs.define
|
|
class WithNonEqFields:
|
|
name: str
|
|
timestamp: float = attrs.field(eq=False) # Should be ignored
|
|
debug_info: str = attrs.field(eq=False, default="debug")
|
|
|
|
obj1 = WithNonEqFields("test", 1000.0, "info1")
|
|
obj2 = WithNonEqFields("test", 9999.0, "info2") # Different non-eq fields
|
|
obj3 = WithNonEqFields("different", 1000.0, "info1")
|
|
|
|
assert comparator(obj1, obj2) # Should be equal despite different timestamp/debug_info
|
|
assert not comparator(obj1, obj3) # Should be different due to name
|
|
|
|
@attrs.define
|
|
class MinimalClass:
|
|
name: str
|
|
value: int
|
|
|
|
@attrs.define
|
|
class ExtendedClass:
|
|
name: str
|
|
value: int
|
|
extra_field: str = "default"
|
|
metadata: dict = attrs.field(factory=dict)
|
|
timestamp: float = attrs.field(eq=False, default=0.0) # This should be ignored
|
|
|
|
minimal = MinimalClass("test", 42)
|
|
extended = ExtendedClass("test", 42, "extra", {"key": "value"}, 1000.0)
|
|
|
|
assert not comparator(minimal, extended)
|
|
|
|
|
|
def test_dict_views() -> None:
|
|
"""Test comparator support for dict_keys, dict_values, and dict_items."""
|
|
# Test dict_keys
|
|
d1 = {"a": 1, "b": 2, "c": 3}
|
|
d2 = {"a": 1, "b": 2, "c": 3}
|
|
d3 = {"a": 1, "b": 2, "d": 3}
|
|
d4 = {"a": 1, "b": 2}
|
|
|
|
# dict_keys - same keys
|
|
assert comparator(d1.keys(), d2.keys())
|
|
# dict_keys - different keys
|
|
assert not comparator(d1.keys(), d3.keys())
|
|
# dict_keys - different length
|
|
assert not comparator(d1.keys(), d4.keys())
|
|
|
|
# Test dict_values
|
|
v1 = {"a": 1, "b": 2, "c": 3}
|
|
v2 = {"x": 1, "y": 2, "z": 3} # same values, different keys
|
|
v3 = {"a": 1, "b": 2, "c": 4} # different value
|
|
v4 = {"a": 1, "b": 2} # different length
|
|
|
|
# dict_values - same values (order matters for values since they're iterable)
|
|
assert comparator(v1.values(), v2.values())
|
|
# dict_values - different values
|
|
assert not comparator(v1.values(), v3.values())
|
|
# dict_values - different length
|
|
assert not comparator(v1.values(), v4.values())
|
|
|
|
# Test dict_items
|
|
i1 = {"a": 1, "b": 2, "c": 3}
|
|
i2 = {"a": 1, "b": 2, "c": 3}
|
|
i3 = {"a": 1, "b": 2, "c": 4} # different value
|
|
i4 = {"a": 1, "b": 2, "d": 3} # different key
|
|
i5 = {"a": 1, "b": 2} # different length
|
|
i6 = {"b": 2, "c": 3, "a": 1} # different order
|
|
|
|
# dict_items - same items
|
|
assert comparator(i1.items(), i2.items())
|
|
# dict_items - different value
|
|
assert not comparator(i1.items(), i3.items())
|
|
# dict_items - different key
|
|
assert not comparator(i1.items(), i4.items())
|
|
# dict_items - different length
|
|
assert not comparator(i1.items(), i5.items())
|
|
|
|
assert comparator(i1.items(), i6.items())
|
|
|
|
# Test empty dicts
|
|
empty1 = {}
|
|
empty2 = {}
|
|
assert comparator(empty1.keys(), empty2.keys())
|
|
assert comparator(empty1.values(), empty2.values())
|
|
assert comparator(empty1.items(), empty2.items())
|
|
|
|
# Test with nested values
|
|
nested1 = {"a": [1, 2, 3], "b": {"x": 1}}
|
|
nested2 = {"a": [1, 2, 3], "b": {"x": 1}}
|
|
nested3 = {"a": [1, 2, 4], "b": {"x": 1}}
|
|
|
|
assert comparator(nested1.values(), nested2.values())
|
|
assert not comparator(nested1.values(), nested3.values())
|
|
assert comparator(nested1.items(), nested2.items())
|
|
assert not comparator(nested1.items(), nested3.items())
|
|
|
|
# Test that dict views are not equal to lists/sets
|
|
d = {"a": 1, "b": 2}
|
|
assert not comparator(d.keys(), ["a", "b"])
|
|
assert not comparator(d.keys(), {"a", "b"})
|
|
assert not comparator(d.values(), [1, 2])
|
|
assert not comparator(d.items(), [("a", 1), ("b", 2)])
|
|
|
|
|
|
def test_mappingproxy() -> None:
|
|
"""Test comparator support for types.MappingProxyType (read-only dict view)."""
|
|
import types
|
|
|
|
# Basic equality
|
|
mp1 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
|
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
|
assert comparator(mp1, mp2)
|
|
|
|
# Different values
|
|
mp3 = types.MappingProxyType({"a": 1, "b": 2, "c": 4})
|
|
assert not comparator(mp1, mp3)
|
|
|
|
# Different keys
|
|
mp4 = types.MappingProxyType({"a": 1, "b": 2, "d": 3})
|
|
assert not comparator(mp1, mp4)
|
|
|
|
# Different length
|
|
mp5 = types.MappingProxyType({"a": 1, "b": 2})
|
|
assert not comparator(mp1, mp5)
|
|
|
|
# Order doesn't matter (like dict)
|
|
mp6 = types.MappingProxyType({"c": 3, "a": 1, "b": 2})
|
|
assert comparator(mp1, mp6)
|
|
|
|
# Empty mappingproxy
|
|
empty1 = types.MappingProxyType({})
|
|
empty2 = types.MappingProxyType({})
|
|
assert comparator(empty1, empty2)
|
|
|
|
# Nested values
|
|
nested1 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
|
|
nested2 = types.MappingProxyType({"a": [1, 2, 3], "b": {"x": 1}})
|
|
nested3 = types.MappingProxyType({"a": [1, 2, 4], "b": {"x": 1}})
|
|
assert comparator(nested1, nested2)
|
|
assert not comparator(nested1, nested3)
|
|
|
|
# mappingproxy is not equal to dict (different types)
|
|
d = {"a": 1, "b": 2}
|
|
mp = types.MappingProxyType({"a": 1, "b": 2})
|
|
assert not comparator(mp, d)
|
|
assert not comparator(d, mp)
|
|
|
|
# Verify class __dict__ is indeed a mappingproxy
|
|
class MyClass:
|
|
x = 1
|
|
y = 2
|
|
|
|
assert isinstance(MyClass.__dict__, types.MappingProxyType)
|
|
|
|
|
|
def test_mappingproxy_superset() -> None:
|
|
"""Test comparator superset_obj support for mappingproxy."""
|
|
import types
|
|
|
|
mp1 = types.MappingProxyType({"a": 1, "b": 2})
|
|
mp2 = types.MappingProxyType({"a": 1, "b": 2, "c": 3})
|
|
|
|
# mp2 is a superset of mp1
|
|
assert comparator(mp1, mp2, superset_obj=True)
|
|
# mp1 is not a superset of mp2
|
|
assert not comparator(mp2, mp1, superset_obj=True)
|
|
|
|
# Same mappingproxy with superset_obj=True
|
|
assert comparator(mp1, mp1, superset_obj=True)
|
|
|
|
# Different values even with superset
|
|
mp3 = types.MappingProxyType({"a": 1, "b": 99, "c": 3})
|
|
assert not comparator(mp1, mp3, superset_obj=True)
|
|
|
|
|
|
def test_tensorflow_tensor() -> None:
|
|
"""Test comparator support for TensorFlow Tensor objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test basic 1D tensors
|
|
a = tf.constant([1, 2, 3])
|
|
b = tf.constant([1, 2, 3])
|
|
c = tf.constant([1, 2, 4])
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test 2D tensors
|
|
d = tf.constant([[1, 2, 3], [4, 5, 6]])
|
|
e = tf.constant([[1, 2, 3], [4, 5, 6]])
|
|
f = tf.constant([[1, 2, 3], [4, 5, 7]])
|
|
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test tensors with different shapes
|
|
g = tf.constant([1, 2, 3])
|
|
h = tf.constant([[1, 2, 3]])
|
|
|
|
assert not comparator(g, h)
|
|
|
|
# Test tensors with different dtypes
|
|
i = tf.constant([1, 2, 3], dtype=tf.float32)
|
|
j = tf.constant([1, 2, 3], dtype=tf.float32)
|
|
k = tf.constant([1, 2, 3], dtype=tf.int32)
|
|
|
|
assert comparator(i, j)
|
|
assert not comparator(i, k)
|
|
|
|
# Test 3D tensors
|
|
l = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
m = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
|
n = tf.constant([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
|
|
|
|
assert comparator(l, m)
|
|
assert not comparator(l, n)
|
|
|
|
# Test empty tensors
|
|
o = tf.constant([])
|
|
p = tf.constant([])
|
|
q = tf.constant([1.0])
|
|
|
|
assert comparator(o, p)
|
|
assert not comparator(o, q)
|
|
|
|
# Test tensors with NaN values
|
|
r = tf.constant([1.0, float("nan"), 3.0])
|
|
s = tf.constant([1.0, float("nan"), 3.0])
|
|
t = tf.constant([1.0, 2.0, 3.0])
|
|
|
|
assert comparator(r, s) # NaN == NaN should be True
|
|
assert not comparator(r, t)
|
|
|
|
# Test tensors with infinity values
|
|
u = tf.constant([1.0, float("inf"), 3.0])
|
|
v = tf.constant([1.0, float("inf"), 3.0])
|
|
w = tf.constant([1.0, float("-inf"), 3.0])
|
|
|
|
assert comparator(u, v)
|
|
assert not comparator(u, w)
|
|
|
|
# Test complex tensors
|
|
x = tf.constant([1 + 2j, 3 + 4j])
|
|
y = tf.constant([1 + 2j, 3 + 4j])
|
|
z = tf.constant([1 + 2j, 3 + 5j])
|
|
|
|
assert comparator(x, y)
|
|
assert not comparator(x, z)
|
|
|
|
# Test boolean tensors
|
|
aa = tf.constant([True, False, True])
|
|
bb = tf.constant([True, False, True])
|
|
cc = tf.constant([True, True, True])
|
|
|
|
assert comparator(aa, bb)
|
|
assert not comparator(aa, cc)
|
|
|
|
# Test string tensors
|
|
dd = tf.constant(["hello", "world"])
|
|
ee = tf.constant(["hello", "world"])
|
|
ff = tf.constant(["hello", "there"])
|
|
|
|
assert comparator(dd, ee)
|
|
assert not comparator(dd, ff)
|
|
|
|
|
|
def test_tensorflow_dtype() -> None:
|
|
"""Test comparator support for TensorFlow DType objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test float dtypes
|
|
a = tf.float32
|
|
b = tf.float32
|
|
c = tf.float64
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test integer dtypes
|
|
d = tf.int32
|
|
e = tf.int32
|
|
f = tf.int64
|
|
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test unsigned integer dtypes
|
|
g = tf.uint8
|
|
h = tf.uint8
|
|
i = tf.uint16
|
|
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test complex dtypes
|
|
j = tf.complex64
|
|
k = tf.complex64
|
|
l = tf.complex128
|
|
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test bool dtype
|
|
m = tf.bool
|
|
n = tf.bool
|
|
o = tf.int8
|
|
|
|
assert comparator(m, n)
|
|
assert not comparator(m, o)
|
|
|
|
# Test string dtype
|
|
p = tf.string
|
|
q = tf.string
|
|
r = tf.int32
|
|
|
|
assert comparator(p, q)
|
|
assert not comparator(p, r)
|
|
|
|
|
|
def test_tensorflow_variable() -> None:
|
|
"""Test comparator support for TensorFlow Variable objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test basic variables
|
|
a = tf.Variable([1, 2, 3], dtype=tf.float32)
|
|
b = tf.Variable([1, 2, 3], dtype=tf.float32)
|
|
c = tf.Variable([1, 2, 4], dtype=tf.float32)
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test variables with different dtypes
|
|
d = tf.Variable([1, 2, 3], dtype=tf.float32)
|
|
e = tf.Variable([1, 2, 3], dtype=tf.float64)
|
|
|
|
assert not comparator(d, e)
|
|
|
|
# Test 2D variables
|
|
f = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
|
|
g = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)
|
|
h = tf.Variable([[1, 2], [3, 5]], dtype=tf.float32)
|
|
|
|
assert comparator(f, g)
|
|
assert not comparator(f, h)
|
|
|
|
# Test variables with different shapes
|
|
i = tf.Variable([1, 2, 3], dtype=tf.float32)
|
|
j = tf.Variable([[1, 2, 3]], dtype=tf.float32)
|
|
|
|
assert not comparator(i, j)
|
|
|
|
|
|
def test_tensorflow_tensor_shape() -> None:
|
|
"""Test comparator support for TensorFlow TensorShape objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test equal shapes
|
|
a = tf.TensorShape([2, 3, 4])
|
|
b = tf.TensorShape([2, 3, 4])
|
|
c = tf.TensorShape([2, 3, 5])
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test different ranks
|
|
d = tf.TensorShape([2, 3])
|
|
e = tf.TensorShape([2, 3, 4])
|
|
|
|
assert not comparator(d, e)
|
|
|
|
# Test scalar shapes
|
|
f = tf.TensorShape([])
|
|
g = tf.TensorShape([])
|
|
h = tf.TensorShape([1])
|
|
|
|
assert comparator(f, g)
|
|
assert not comparator(f, h)
|
|
|
|
# Test shapes with None dimensions (unknown dimensions)
|
|
i = tf.TensorShape([None, 3, 4])
|
|
j = tf.TensorShape([None, 3, 4])
|
|
k = tf.TensorShape([2, 3, 4])
|
|
|
|
assert comparator(i, j)
|
|
assert not comparator(i, k)
|
|
|
|
# Test fully unknown shapes
|
|
l = tf.TensorShape(None)
|
|
m = tf.TensorShape(None)
|
|
n = tf.TensorShape([1, 2])
|
|
|
|
assert comparator(l, m)
|
|
assert not comparator(l, n)
|
|
|
|
|
|
def test_tensorflow_sparse_tensor() -> None:
|
|
"""Test comparator support for TensorFlow SparseTensor objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test equal sparse tensors
|
|
a = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4])
|
|
b = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=[3, 4])
|
|
c = tf.SparseTensor(
|
|
indices=[[0, 0], [1, 2]],
|
|
values=[1.0, 3.0], # Different value
|
|
dense_shape=[3, 4],
|
|
)
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test sparse tensors with different indices
|
|
d = tf.SparseTensor(
|
|
indices=[[0, 0], [1, 3]], # Different index
|
|
values=[1.0, 2.0],
|
|
dense_shape=[3, 4],
|
|
)
|
|
|
|
assert not comparator(a, d)
|
|
|
|
# Test sparse tensors with different shapes
|
|
e = tf.SparseTensor(
|
|
indices=[[0, 0], [1, 2]],
|
|
values=[1.0, 2.0],
|
|
dense_shape=[4, 5], # Different shape
|
|
)
|
|
|
|
assert not comparator(a, e)
|
|
|
|
# Test empty sparse tensors
|
|
f = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4])
|
|
g = tf.SparseTensor(indices=tf.zeros([0, 2], dtype=tf.int64), values=[], dense_shape=[3, 4])
|
|
|
|
assert comparator(f, g)
|
|
|
|
|
|
def test_tensorflow_ragged_tensor() -> None:
|
|
"""Test comparator support for TensorFlow RaggedTensor objects."""
|
|
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
pytest.skip("tensorflow required for this test")
|
|
|
|
# Test equal ragged tensors
|
|
a = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
|
|
b = tf.ragged.constant([[1, 2], [3, 4, 5], [6]])
|
|
c = tf.ragged.constant([[1, 2], [3, 4, 6], [6]]) # Different value
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test ragged tensors with different row lengths
|
|
d = tf.ragged.constant([[1, 2, 3], [4, 5], [6]]) # Different structure
|
|
|
|
assert not comparator(a, d)
|
|
|
|
# Test ragged tensors with different dtypes
|
|
e = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
|
|
f = tf.ragged.constant([[1.0, 2.0], [3.0, 4.0, 5.0], [6.0]])
|
|
|
|
assert comparator(e, f)
|
|
assert not comparator(a, e) # int vs float
|
|
|
|
# Test nested ragged tensors
|
|
g = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
|
|
h = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
|
|
i = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 7]]])
|
|
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test empty ragged tensors
|
|
j = tf.ragged.constant([[], [], []])
|
|
k = tf.ragged.constant([[], [], []])
|
|
|
|
assert comparator(j, k)
|
|
|
|
|
|
def test_slice() -> None:
|
|
"""Test comparator support for slice objects."""
|
|
# Test equal slices
|
|
a = slice(1, 10, 2)
|
|
b = slice(1, 10, 2)
|
|
assert comparator(a, b)
|
|
|
|
# Test slices with different start
|
|
c = slice(2, 10, 2)
|
|
assert not comparator(a, c)
|
|
|
|
# Test slices with different stop
|
|
d = slice(1, 11, 2)
|
|
assert not comparator(a, d)
|
|
|
|
# Test slices with different step
|
|
e = slice(1, 10, 3)
|
|
assert not comparator(a, e)
|
|
|
|
# Test slices with None values
|
|
f = slice(None, 10, 2)
|
|
g = slice(None, 10, 2)
|
|
h = slice(1, 10, 2)
|
|
assert comparator(f, g)
|
|
assert not comparator(f, h)
|
|
|
|
# Test slices with all None (equivalent to [:])
|
|
i = slice(None, None, None)
|
|
j = slice(None, None, None)
|
|
k = slice(None, None, 1)
|
|
assert comparator(i, j)
|
|
assert not comparator(i, k)
|
|
|
|
# Test slices with only stop
|
|
l = slice(5)
|
|
m = slice(5)
|
|
n = slice(6)
|
|
assert comparator(l, m)
|
|
assert not comparator(l, n)
|
|
|
|
# Test slices with negative values
|
|
o = slice(-5, -1, 1)
|
|
p = slice(-5, -1, 1)
|
|
q = slice(-5, -2, 1)
|
|
assert comparator(o, p)
|
|
assert not comparator(o, q)
|
|
|
|
# Test slice is not equal to other types
|
|
r = slice(1, 10)
|
|
s = (1, 10)
|
|
assert not comparator(r, s)
|
|
|
|
|
|
def test_numpy_datetime64() -> None:
|
|
"""Test comparator support for numpy datetime64 and timedelta64 types."""
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip("numpy required for this test")
|
|
|
|
# Test datetime64 equality
|
|
a = np.datetime64("2021-01-01")
|
|
b = np.datetime64("2021-01-01")
|
|
c = np.datetime64("2021-01-02")
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test datetime64 with different units
|
|
d = np.datetime64("2021-01-01", "D")
|
|
e = np.datetime64("2021-01-01", "D")
|
|
f = np.datetime64("2021-01-01", "s") # Different unit (seconds)
|
|
|
|
assert comparator(d, e)
|
|
# Note: datetime64 with different units but same moment may or may not be equal
|
|
# depending on numpy version behavior
|
|
|
|
# Test datetime64 with time
|
|
g = np.datetime64("2021-01-01T12:00:00")
|
|
h = np.datetime64("2021-01-01T12:00:00")
|
|
i = np.datetime64("2021-01-01T12:00:01")
|
|
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test timedelta64 equality
|
|
j = np.timedelta64(1, "D")
|
|
k = np.timedelta64(1, "D")
|
|
l = np.timedelta64(2, "D")
|
|
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test timedelta64 with different units
|
|
m = np.timedelta64(1, "h")
|
|
n = np.timedelta64(1, "h")
|
|
o = np.timedelta64(60, "m") # Same duration, different unit
|
|
|
|
assert comparator(m, n)
|
|
# 1 hour == 60 minutes, but they have different units
|
|
# numpy may treat them as equal or not depending on comparison
|
|
|
|
# Test NaT (Not a Time) - numpy's equivalent of NaN for datetime
|
|
p = np.datetime64("NaT")
|
|
q = np.datetime64("NaT")
|
|
r = np.datetime64("2021-01-01")
|
|
|
|
assert comparator(p, q) # NaT == NaT should be True
|
|
assert not comparator(p, r)
|
|
|
|
# Test timedelta64 NaT
|
|
s = np.timedelta64("NaT")
|
|
t = np.timedelta64("NaT")
|
|
u = np.timedelta64(1, "D")
|
|
|
|
assert comparator(s, t) # NaT == NaT should be True
|
|
assert not comparator(s, u)
|
|
|
|
# Test datetime64 is not equal to other types
|
|
v = np.datetime64("2021-01-01")
|
|
w = "2021-01-01"
|
|
assert not comparator(v, w)
|
|
|
|
# Test arrays of datetime64
|
|
x = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64")
|
|
y = np.array(["2021-01-01", "2021-01-02"], dtype="datetime64")
|
|
z = np.array(["2021-01-01", "2021-01-03"], dtype="datetime64")
|
|
|
|
assert comparator(x, y)
|
|
assert not comparator(x, z)
|
|
|
|
|
|
def test_numpy_0d_array() -> None:
|
|
"""Test comparator handles 0-d numpy arrays without 'iteration over 0-d array' error."""
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip("numpy required for this test")
|
|
|
|
# Test 0-d integer array
|
|
a = np.array(5)
|
|
b = np.array(5)
|
|
c = np.array(6)
|
|
|
|
assert comparator(a, b)
|
|
assert not comparator(a, c)
|
|
|
|
# Test 0-d float array
|
|
d = np.array(3.14)
|
|
e = np.array(3.14)
|
|
f = np.array(2.71)
|
|
|
|
assert comparator(d, e)
|
|
assert not comparator(d, f)
|
|
|
|
# Test 0-d complex array
|
|
g = np.array(1 + 2j)
|
|
h = np.array(1 + 2j)
|
|
i = np.array(1 + 3j)
|
|
|
|
assert comparator(g, h)
|
|
assert not comparator(g, i)
|
|
|
|
# Test 0-d string array
|
|
j = np.array("hello")
|
|
k = np.array("hello")
|
|
l = np.array("world")
|
|
|
|
assert comparator(j, k)
|
|
assert not comparator(j, l)
|
|
|
|
# Test 0-d boolean array
|
|
m = np.array(True)
|
|
n = np.array(True)
|
|
o = np.array(False)
|
|
|
|
assert comparator(m, n)
|
|
assert not comparator(m, o)
|
|
|
|
# Test 0-d array with NaN
|
|
p = np.array(np.nan)
|
|
q = np.array(np.nan)
|
|
r = np.array(1.0)
|
|
|
|
assert comparator(p, q) # NaN == NaN should be True
|
|
assert not comparator(p, r)
|
|
|
|
# Test 0-d datetime64 array
|
|
s = np.array(np.datetime64("2021-01-01"))
|
|
t = np.array(np.datetime64("2021-01-01"))
|
|
u = np.array(np.datetime64("2021-01-02"))
|
|
|
|
assert comparator(s, t)
|
|
assert not comparator(s, u)
|
|
|
|
# Test 0-d array vs scalar
|
|
v = np.array(5)
|
|
w = 5
|
|
# 0-d array and scalar are different types
|
|
assert not comparator(v, w)
|
|
|
|
# Test 0-d array vs 1-d array with one element
|
|
x = np.array(5)
|
|
y = np.array([5])
|
|
# Different shapes
|
|
assert not comparator(x, y)
|
|
|
|
|
|
def test_numpy_dtypes() -> None:
|
|
"""Test comparator for numpy.dtypes types like Float64DType, Int64DType, etc."""
|
|
try:
|
|
import numpy as np
|
|
from numpy import dtypes
|
|
except ImportError:
|
|
pytest.skip("numpy not available")
|
|
|
|
# Test Float64DType
|
|
a = dtypes.Float64DType()
|
|
b = dtypes.Float64DType()
|
|
assert comparator(a, b)
|
|
|
|
# Test Int64DType
|
|
c = dtypes.Int64DType()
|
|
d = dtypes.Int64DType()
|
|
assert comparator(c, d)
|
|
|
|
# Test different DType classes should not be equal
|
|
assert not comparator(a, c) # Float64DType vs Int64DType
|
|
|
|
# Test various numeric DType classes
|
|
assert comparator(dtypes.Int8DType(), dtypes.Int8DType())
|
|
assert comparator(dtypes.Int16DType(), dtypes.Int16DType())
|
|
assert comparator(dtypes.Int32DType(), dtypes.Int32DType())
|
|
assert comparator(dtypes.UInt8DType(), dtypes.UInt8DType())
|
|
assert comparator(dtypes.UInt16DType(), dtypes.UInt16DType())
|
|
assert comparator(dtypes.UInt32DType(), dtypes.UInt32DType())
|
|
assert comparator(dtypes.UInt64DType(), dtypes.UInt64DType())
|
|
assert comparator(dtypes.Float32DType(), dtypes.Float32DType())
|
|
assert comparator(dtypes.Complex64DType(), dtypes.Complex64DType())
|
|
assert comparator(dtypes.Complex128DType(), dtypes.Complex128DType())
|
|
assert comparator(dtypes.BoolDType(), dtypes.BoolDType())
|
|
|
|
# Test cross-type comparisons should be False
|
|
assert not comparator(dtypes.Int32DType(), dtypes.Int64DType())
|
|
assert not comparator(dtypes.Float32DType(), dtypes.Float64DType())
|
|
assert not comparator(dtypes.UInt32DType(), dtypes.Int32DType())
|
|
|
|
# Test regular np.dtype instances
|
|
e = np.dtype("float64")
|
|
f = np.dtype("float64")
|
|
assert comparator(e, f)
|
|
|
|
g = np.dtype("int64")
|
|
h = np.dtype("int64")
|
|
assert comparator(g, h)
|
|
|
|
assert not comparator(e, g) # float64 vs int64
|
|
|
|
# Test DType class instances vs regular np.dtype (they should be equal if same underlying type)
|
|
assert comparator(dtypes.Float64DType(), np.dtype("float64"))
|
|
assert comparator(dtypes.Int64DType(), np.dtype("int64"))
|
|
assert comparator(dtypes.Int32DType(), np.dtype("int32"))
|
|
assert comparator(dtypes.BoolDType(), np.dtype("bool"))
|
|
|
|
# Test that DType and np.dtype of different types are not equal
|
|
assert not comparator(dtypes.Float64DType(), np.dtype("int64"))
|
|
assert not comparator(dtypes.Int32DType(), np.dtype("float32"))
|
|
|
|
|
|
def test_numpy_extended_precision_types() -> None:
|
|
"""Test comparator for numpy extended precision types like clongdouble."""
|
|
try:
|
|
import numpy as np
|
|
except ImportError:
|
|
pytest.skip("numpy not available")
|
|
|
|
# Test np.clongdouble (extended precision complex)
|
|
c1 = np.clongdouble(1 + 2j)
|
|
c2 = np.clongdouble(1 + 2j)
|
|
c3 = np.clongdouble(1 + 3j)
|
|
assert comparator(c1, c2)
|
|
assert not comparator(c1, c3)
|
|
|
|
# Test np.longdouble (extended precision float)
|
|
l1 = np.longdouble(1.5)
|
|
l2 = np.longdouble(1.5)
|
|
l3 = np.longdouble(2.5)
|
|
assert comparator(l1, l2)
|
|
assert not comparator(l1, l3)
|
|
|
|
# Test NaN handling for extended precision complex
|
|
nan_c1 = np.clongdouble(complex(np.nan, 2))
|
|
nan_c2 = np.clongdouble(complex(np.nan, 2))
|
|
assert comparator(nan_c1, nan_c2)
|
|
|
|
# Test NaN handling for extended precision float
|
|
nan_l1 = np.longdouble(np.nan)
|
|
nan_l2 = np.longdouble(np.nan)
|
|
assert comparator(nan_l1, nan_l2)
|
|
|
|
|
|
def test_numpy_typing_types() -> None:
|
|
"""Test comparator for numpy.typing types like NDArray type aliases."""
|
|
try:
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
except ImportError:
|
|
pytest.skip("numpy or numpy.typing not available")
|
|
|
|
# Test NDArray type alias comparisons
|
|
arr_type1 = npt.NDArray[np.float64]
|
|
arr_type2 = npt.NDArray[np.float64]
|
|
arr_type3 = npt.NDArray[np.int64]
|
|
assert comparator(arr_type1, arr_type2)
|
|
assert not comparator(arr_type1, arr_type3)
|
|
|
|
# Test NBitBase (if it can be instantiated)
|
|
try:
|
|
nbit1 = npt.NBitBase()
|
|
nbit2 = npt.NBitBase()
|
|
# NBitBase instances with empty __dict__ should compare as equal
|
|
assert comparator(nbit1, nbit2)
|
|
# Also test with superset_obj=True
|
|
assert comparator(nbit1, nbit2, superset_obj=True)
|
|
except TypeError:
|
|
# NBitBase may not be instantiable in all numpy versions
|
|
pass
|
|
|
|
|
|
def test_numpy_typing_superset_obj() -> None:
|
|
"""Test comparator with superset_obj=True for numpy types."""
|
|
try:
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
except ImportError:
|
|
pytest.skip("numpy or numpy.typing not available")
|
|
|
|
# Test numpy arrays with object dtype containing dicts (superset scenario)
|
|
a1 = np.array([{"a": 1}], dtype=object)
|
|
a2 = np.array([{"a": 1, "b": 2}], dtype=object) # superset
|
|
assert comparator(a1, a2, superset_obj=True)
|
|
assert not comparator(a1, a2, superset_obj=False)
|
|
|
|
# Test extended precision types with superset_obj=True
|
|
c1 = np.clongdouble(1 + 2j)
|
|
c2 = np.clongdouble(1 + 2j)
|
|
assert comparator(c1, c2, superset_obj=True)
|
|
|
|
l1 = np.longdouble(1.5)
|
|
l2 = np.longdouble(1.5)
|
|
assert comparator(l1, l2, superset_obj=True)
|
|
|
|
# Test NDArray type alias with superset_obj=True
|
|
arr_type1 = npt.NDArray[np.float64]
|
|
arr_type2 = npt.NDArray[np.float64]
|
|
assert comparator(arr_type1, arr_type2, superset_obj=True)
|
|
|
|
# Test numpy structured arrays (np.void) with superset_obj=True
|
|
dt = np.dtype([("name", "S10"), ("age", np.int32)])
|
|
a_struct = np.array([("Alice", 25)], dtype=dt)
|
|
b_struct = np.array([("Alice", 25)], dtype=dt)
|
|
assert comparator(a_struct[0], b_struct[0], superset_obj=True)
|
|
|
|
# Test numpy random generators with superset_obj=True
|
|
rng1 = np.random.default_rng(seed=42)
|
|
rng2 = np.random.default_rng(seed=42)
|
|
assert comparator(rng1, rng2, superset_obj=True)
|
|
|
|
rs1 = np.random.RandomState(seed=42)
|
|
rs2 = np.random.RandomState(seed=42)
|
|
assert comparator(rs1, rs2, superset_obj=True)
|
|
|
|
|
|
def test_numba_typed_list() -> None:
|
|
"""Test comparator for numba.typed.List."""
|
|
try:
|
|
import numba
|
|
from numba.typed import List as NumbaList
|
|
except ImportError:
|
|
pytest.skip("numba not available")
|
|
|
|
# Test equal lists
|
|
a = NumbaList([1, 2, 3])
|
|
b = NumbaList([1, 2, 3])
|
|
assert comparator(a, b)
|
|
|
|
# Test different values
|
|
c = NumbaList([1, 2, 4])
|
|
assert not comparator(a, c)
|
|
|
|
# Test different lengths
|
|
d = NumbaList([1, 2, 3, 4])
|
|
assert not comparator(a, d)
|
|
|
|
# Test empty lists
|
|
e = NumbaList.empty_list(item_type=numba.int64)
|
|
f = NumbaList.empty_list(item_type=numba.int64)
|
|
assert comparator(e, f)
|
|
|
|
# Test nested values (floats)
|
|
g = NumbaList([1.0, 2.0, 3.0])
|
|
h = NumbaList([1.0, 2.0, 3.0])
|
|
assert comparator(g, h)
|
|
|
|
i = NumbaList([1.0, 2.0, 4.0])
|
|
assert not comparator(g, i)
|
|
|
|
|
|
def test_numba_typed_dict() -> None:
|
|
"""Test comparator for numba.typed.Dict."""
|
|
try:
|
|
import numba
|
|
from numba.typed import Dict as NumbaDict
|
|
except ImportError:
|
|
pytest.skip("numba not available")
|
|
|
|
# Test equal dicts
|
|
a = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
a["x"] = 1
|
|
a["y"] = 2
|
|
|
|
b = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
b["x"] = 1
|
|
b["y"] = 2
|
|
assert comparator(a, b)
|
|
|
|
# Test different values
|
|
c = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
c["x"] = 1
|
|
c["y"] = 3
|
|
assert not comparator(a, c)
|
|
|
|
# Test different keys
|
|
d = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
d["x"] = 1
|
|
d["z"] = 2
|
|
assert not comparator(a, d)
|
|
|
|
# Test different lengths
|
|
e = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
e["x"] = 1
|
|
assert not comparator(a, e)
|
|
|
|
# Test empty dicts
|
|
f = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
g = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
assert comparator(f, g)
|
|
|
|
|
|
def test_numba_types() -> None:
|
|
"""Test comparator for numba type objects."""
|
|
try:
|
|
import numba
|
|
from numba import types
|
|
except ImportError:
|
|
pytest.skip("numba not available")
|
|
|
|
# Test basic numeric types from numba module
|
|
assert comparator(numba.int64, numba.int64)
|
|
assert comparator(numba.float64, numba.float64)
|
|
assert comparator(numba.int32, numba.int32)
|
|
assert comparator(numba.float32, numba.float32)
|
|
|
|
# Test basic numeric types from numba.types module
|
|
assert comparator(types.int64, types.int64)
|
|
assert comparator(types.float64, types.float64)
|
|
assert comparator(types.int8, types.int8)
|
|
assert comparator(types.int16, types.int16)
|
|
assert comparator(types.uint8, types.uint8)
|
|
assert comparator(types.uint16, types.uint16)
|
|
assert comparator(types.uint32, types.uint32)
|
|
assert comparator(types.uint64, types.uint64)
|
|
assert comparator(types.complex64, types.complex64)
|
|
assert comparator(types.complex128, types.complex128)
|
|
|
|
# Test different types
|
|
assert not comparator(numba.int64, numba.float64)
|
|
assert not comparator(numba.int32, numba.int64)
|
|
assert not comparator(numba.float32, numba.float64)
|
|
assert not comparator(types.int8, types.int16)
|
|
assert not comparator(types.uint32, types.int32)
|
|
assert not comparator(types.complex64, types.complex128)
|
|
|
|
# Test boolean type
|
|
assert comparator(numba.boolean, numba.boolean)
|
|
assert comparator(types.boolean, types.boolean)
|
|
assert not comparator(numba.boolean, numba.int64)
|
|
|
|
# Test special types
|
|
assert comparator(types.none, types.none)
|
|
assert comparator(types.void, types.void)
|
|
assert comparator(types.pyobject, types.pyobject)
|
|
assert comparator(types.unicode_type, types.unicode_type)
|
|
# Note: types.none and types.void are the same object in numba
|
|
assert comparator(types.none, types.void)
|
|
assert not comparator(types.unicode_type, types.pyobject)
|
|
assert not comparator(types.none, types.int64)
|
|
|
|
# Test array types
|
|
arr_type1 = types.Array(numba.float64, 1, "C")
|
|
arr_type2 = types.Array(numba.float64, 1, "C")
|
|
arr_type3 = types.Array(numba.float64, 2, "C")
|
|
arr_type4 = types.Array(numba.int64, 1, "C")
|
|
arr_type5 = types.Array(numba.float64, 1, "F") # Fortran order
|
|
|
|
assert comparator(arr_type1, arr_type2)
|
|
assert not comparator(arr_type1, arr_type3) # different ndim
|
|
assert not comparator(arr_type1, arr_type4) # different dtype
|
|
assert not comparator(arr_type1, arr_type5) # different layout
|
|
|
|
# Test tuple types
|
|
tuple_type1 = types.UniTuple(types.int64, 3)
|
|
tuple_type2 = types.UniTuple(types.int64, 3)
|
|
tuple_type3 = types.UniTuple(types.int64, 4)
|
|
tuple_type4 = types.UniTuple(types.float64, 3)
|
|
|
|
assert comparator(tuple_type1, tuple_type2)
|
|
assert not comparator(tuple_type1, tuple_type3) # different count
|
|
assert not comparator(tuple_type1, tuple_type4) # different dtype
|
|
|
|
# Test heterogeneous tuple types
|
|
hetero_tuple1 = types.Tuple([types.int64, types.float64])
|
|
hetero_tuple2 = types.Tuple([types.int64, types.float64])
|
|
hetero_tuple3 = types.Tuple([types.int64, types.int64])
|
|
|
|
assert comparator(hetero_tuple1, hetero_tuple2)
|
|
assert not comparator(hetero_tuple1, hetero_tuple3)
|
|
|
|
# Test ListType and DictType
|
|
list_type1 = types.ListType(types.int64)
|
|
list_type2 = types.ListType(types.int64)
|
|
list_type3 = types.ListType(types.float64)
|
|
|
|
assert comparator(list_type1, list_type2)
|
|
assert not comparator(list_type1, list_type3)
|
|
|
|
dict_type1 = types.DictType(types.unicode_type, types.int64)
|
|
dict_type2 = types.DictType(types.unicode_type, types.int64)
|
|
dict_type3 = types.DictType(types.unicode_type, types.float64)
|
|
dict_type4 = types.DictType(types.int64, types.int64)
|
|
|
|
assert comparator(dict_type1, dict_type2)
|
|
assert not comparator(dict_type1, dict_type3) # different value type
|
|
assert not comparator(dict_type1, dict_type4) # different key type
|
|
|
|
|
|
def test_numba_jit_functions() -> None:
|
|
"""Test comparator for numba JIT-compiled functions."""
|
|
try:
|
|
from numba import jit
|
|
except ImportError:
|
|
pytest.skip("numba not available")
|
|
|
|
@jit(nopython=True)
|
|
def add(x, y):
|
|
return x + y
|
|
|
|
@jit(nopython=True)
|
|
def add2(x, y):
|
|
return x + y
|
|
|
|
@jit(nopython=True)
|
|
def multiply(x, y):
|
|
return x * y
|
|
|
|
# Compile the functions by calling them
|
|
add(1, 2)
|
|
add2(1, 2)
|
|
multiply(1, 2)
|
|
|
|
# Same function should compare equal to itself
|
|
assert comparator(add, add)
|
|
|
|
# Different functions (even with same code) should not compare equal
|
|
# since they are distinct function objects
|
|
assert not comparator(add, add2)
|
|
|
|
# Different functions with different code should not compare equal
|
|
assert not comparator(add, multiply)
|
|
|
|
|
|
def test_numba_superset_obj() -> None:
|
|
"""Test comparator for numba types with superset_obj=True."""
|
|
try:
|
|
import numba
|
|
from numba.typed import Dict as NumbaDict
|
|
from numba.typed import List as NumbaList
|
|
except ImportError:
|
|
pytest.skip("numba not available")
|
|
|
|
# Test NumbaDict with superset_obj=True
|
|
orig_dict = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
orig_dict["x"] = 1
|
|
orig_dict["y"] = 2
|
|
|
|
# New dict with same keys - should pass
|
|
new_dict_same = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
new_dict_same["x"] = 1
|
|
new_dict_same["y"] = 2
|
|
assert comparator(orig_dict, new_dict_same, superset_obj=True)
|
|
|
|
# New dict with extra keys - should pass with superset_obj=True
|
|
new_dict_superset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
new_dict_superset["x"] = 1
|
|
new_dict_superset["y"] = 2
|
|
new_dict_superset["z"] = 3
|
|
assert comparator(orig_dict, new_dict_superset, superset_obj=True)
|
|
# But should fail with superset_obj=False
|
|
assert not comparator(orig_dict, new_dict_superset, superset_obj=False)
|
|
|
|
# New dict missing keys - should fail even with superset_obj=True
|
|
new_dict_subset = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
new_dict_subset["x"] = 1
|
|
assert not comparator(orig_dict, new_dict_subset, superset_obj=True)
|
|
|
|
# New dict with different values - should fail
|
|
new_dict_diff = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
new_dict_diff["x"] = 1
|
|
new_dict_diff["y"] = 99
|
|
assert not comparator(orig_dict, new_dict_diff, superset_obj=True)
|
|
|
|
# Test NumbaList with superset_obj=True (lists don't support superset semantics)
|
|
orig_list = NumbaList([1, 2, 3])
|
|
new_list_same = NumbaList([1, 2, 3])
|
|
new_list_longer = NumbaList([1, 2, 3, 4])
|
|
|
|
assert comparator(orig_list, new_list_same, superset_obj=True)
|
|
# Lists must have same length regardless of superset_obj
|
|
assert not comparator(orig_list, new_list_longer, superset_obj=True)
|
|
|
|
# Test empty dict with superset_obj=True
|
|
empty_orig = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
non_empty_new = NumbaDict.empty(key_type=numba.types.unicode_type, value_type=numba.int64)
|
|
non_empty_new["a"] = 1
|
|
# Empty orig should match any superset
|
|
assert comparator(empty_orig, non_empty_new, superset_obj=True)
|
|
assert not comparator(empty_orig, non_empty_new, superset_obj=False)
|
|
|
|
|
|
# =============================================================================
|
|
# Tests for pytest temp path normalization (lines 28-69 in comparator.py)
|
|
# =============================================================================
|
|
|
|
|
|
class TestIsTempPath:
|
|
"""Tests for the _is_temp_path() function."""
|
|
|
|
def test_standard_pytest_temp_path(self):
|
|
"""Test detection of standard pytest temp paths."""
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_something")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-123/")
|
|
assert _is_temp_path("/tmp/pytest-of-admin/pytest-999/subdir/file.txt")
|
|
|
|
def test_different_usernames(self):
|
|
"""Test temp paths with various usernames."""
|
|
assert _is_temp_path("/tmp/pytest-of-root/pytest-1/")
|
|
assert _is_temp_path("/tmp/pytest-of-john_doe/pytest-42/")
|
|
assert _is_temp_path("/tmp/pytest-of-user123/pytest-0/test")
|
|
assert _is_temp_path("/tmp/pytest-of-test-user/pytest-5/data")
|
|
|
|
def test_different_session_numbers(self):
|
|
"""Test temp paths with various session numbers."""
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-1/")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-99/")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-12345/")
|
|
|
|
def test_paths_with_subdirectories(self):
|
|
"""Test temp paths with nested subdirectories."""
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_func/subdir")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/a/b/c/d/file.txt")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test_module0/test_file.py")
|
|
|
|
def test_paths_with_filenames(self):
|
|
"""Test temp paths ending with filenames."""
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/output.json")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test.log")
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/data.csv")
|
|
|
|
def test_non_temp_paths(self):
|
|
"""Test that non-temp paths are correctly identified."""
|
|
assert not _is_temp_path("/home/user/project/test.py")
|
|
assert not _is_temp_path("/tmp/other/directory")
|
|
assert not _is_temp_path("/var/log/test.log")
|
|
assert not _is_temp_path("./relative/path")
|
|
assert not _is_temp_path("test_file.py")
|
|
|
|
def test_similar_but_not_temp_paths(self):
|
|
"""Test paths that look similar but don't match the pattern."""
|
|
assert not _is_temp_path("/tmp/pytest-user/pytest-0/") # missing "of-"
|
|
assert not _is_temp_path("/tmp/pytest-of-user/pytest-/") # no number
|
|
assert not _is_temp_path("/tmp/pytest-of-/pytest-0/") # empty username
|
|
assert not _is_temp_path("/tmp/pytest-of-user/pytest-abc/") # non-numeric session
|
|
|
|
def test_edge_cases(self):
|
|
"""Test edge cases for _is_temp_path."""
|
|
assert not _is_temp_path("")
|
|
assert not _is_temp_path("/")
|
|
assert not _is_temp_path("/tmp/")
|
|
assert not _is_temp_path("/tmp/pytest-of-")
|
|
|
|
def test_path_embedded_in_string(self):
|
|
"""Test that temp paths are detected when embedded in longer strings."""
|
|
assert _is_temp_path("Error in /tmp/pytest-of-user/pytest-0/test.py: failed")
|
|
assert _is_temp_path("File: /tmp/pytest-of-user/pytest-123/output.txt")
|
|
|
|
def test_windows_style_paths(self):
|
|
"""Test that Windows-style paths are not detected as temp paths."""
|
|
assert not _is_temp_path("C:\\Users\\test\\pytest")
|
|
assert not _is_temp_path("D:\\tmp\\pytest-of-user\\pytest-0\\")
|
|
|
|
|
|
class TestNormalizeTempPath:
|
|
"""Tests for the _normalize_temp_path() function."""
|
|
|
|
def test_basic_normalization(self):
|
|
"""Test basic temp path normalization."""
|
|
assert _normalize_temp_path("/tmp/pytest-of-user/pytest-0/test") == "/tmp/pytest-temp/test"
|
|
assert _normalize_temp_path("/tmp/pytest-of-user/pytest-123/test") == "/tmp/pytest-temp/test"
|
|
|
|
def test_different_session_numbers_normalize_same(self):
|
|
"""Test that different session numbers normalize to the same result."""
|
|
path1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/file.txt")
|
|
path2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-99/file.txt")
|
|
path3 = _normalize_temp_path("/tmp/pytest-of-user/pytest-12345/file.txt")
|
|
assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt"
|
|
|
|
def test_different_usernames_normalize_same(self):
|
|
"""Test that different usernames normalize to the same result."""
|
|
path1 = _normalize_temp_path("/tmp/pytest-of-alice/pytest-0/file.txt")
|
|
path2 = _normalize_temp_path("/tmp/pytest-of-bob/pytest-0/file.txt")
|
|
path3 = _normalize_temp_path("/tmp/pytest-of-root/pytest-0/file.txt")
|
|
assert path1 == path2 == path3 == "/tmp/pytest-temp/file.txt"
|
|
|
|
def test_complex_subdirectories(self):
|
|
"""Test normalization with complex subdirectory structures."""
|
|
result = _normalize_temp_path("/tmp/pytest-of-user/pytest-42/test_module/subdir/file.py")
|
|
assert result == "/tmp/pytest-temp/test_module/subdir/file.py"
|
|
|
|
def test_non_temp_path_unchanged(self):
|
|
"""Test that non-temp paths are returned unchanged."""
|
|
path = "/home/user/project/test.py"
|
|
assert _normalize_temp_path(path) == path
|
|
|
|
def test_empty_string(self):
|
|
"""Test normalization of empty string."""
|
|
assert _normalize_temp_path("") == ""
|
|
|
|
def test_path_with_multiple_occurrences(self):
|
|
"""Test paths with multiple temp path patterns (unusual but possible in error messages)."""
|
|
path = "/tmp/pytest-of-user/pytest-0/ref to /tmp/pytest-of-user/pytest-1/other"
|
|
result = _normalize_temp_path(path)
|
|
assert result == "/tmp/pytest-temp/ref to /tmp/pytest-temp/other"
|
|
|
|
def test_trailing_slash_handling(self):
|
|
"""Test normalization preserves or removes trailing slashes correctly."""
|
|
result1 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/")
|
|
result2 = _normalize_temp_path("/tmp/pytest-of-user/pytest-0/subdir/")
|
|
assert result1 == "/tmp/pytest-temp/"
|
|
assert result2 == "/tmp/pytest-temp/subdir/"
|
|
|
|
|
|
class TestComparatorTempPaths:
|
|
"""Tests for comparator() with temp path strings."""
|
|
|
|
def test_identical_temp_paths(self):
|
|
"""Test that identical temp paths compare as equal."""
|
|
path = "/tmp/pytest-of-user/pytest-0/test.txt"
|
|
assert comparator(path, path)
|
|
|
|
def test_different_session_numbers(self):
|
|
"""Test that paths differing only in session number are equal."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/output.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-99/output.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_different_usernames(self):
|
|
"""Test that paths differing in username are equal."""
|
|
path1 = "/tmp/pytest-of-alice/pytest-0/result.json"
|
|
path2 = "/tmp/pytest-of-bob/pytest-0/result.json"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_different_usernames_and_sessions(self):
|
|
"""Test that paths differing in both username and session are equal."""
|
|
path1 = "/tmp/pytest-of-alice/pytest-10/data/file.csv"
|
|
path2 = "/tmp/pytest-of-bob/pytest-999/data/file.csv"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_different_subdirectories_not_equal(self):
|
|
"""Test that paths with different subdirectories are not equal."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/subdir1/file.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-0/subdir2/file.txt"
|
|
assert not comparator(path1, path2)
|
|
|
|
def test_different_filenames_not_equal(self):
|
|
"""Test that paths with different filenames are not equal."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/file1.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-0/file2.txt"
|
|
assert not comparator(path1, path2)
|
|
|
|
def test_temp_path_vs_non_temp_path(self):
|
|
"""Test that temp paths don't match non-temp paths."""
|
|
temp_path = "/tmp/pytest-of-user/pytest-0/file.txt"
|
|
non_temp_path = "/home/user/file.txt"
|
|
assert not comparator(temp_path, non_temp_path)
|
|
|
|
def test_regular_strings_still_work(self):
|
|
"""Test that regular string comparison still works."""
|
|
assert comparator("hello", "hello")
|
|
assert not comparator("hello", "world")
|
|
assert comparator("", "")
|
|
assert not comparator("test", "")
|
|
|
|
def test_non_temp_paths_must_be_exact(self):
|
|
"""Test that non-temp paths require exact equality."""
|
|
path1 = "/home/user/project/file.txt"
|
|
path2 = "/home/user/project/file.txt"
|
|
path3 = "/home/user/project/other.txt"
|
|
assert comparator(path1, path2)
|
|
assert not comparator(path1, path3)
|
|
|
|
|
|
class TestComparatorTempPathsInNestedStructures:
|
|
"""Tests for comparator() with temp paths in nested data structures."""
|
|
|
|
def test_temp_paths_in_list(self):
|
|
"""Test temp paths inside lists."""
|
|
list1 = ["/tmp/pytest-of-alice/pytest-0/file.txt", "other"]
|
|
list2 = ["/tmp/pytest-of-bob/pytest-99/file.txt", "other"]
|
|
assert comparator(list1, list2)
|
|
|
|
def test_temp_paths_in_tuple(self):
|
|
"""Test temp paths inside tuples."""
|
|
tuple1 = ("/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt")
|
|
tuple2 = ("/tmp/pytest-of-user/pytest-123/a.txt", "/tmp/pytest-of-user/pytest-123/b.txt")
|
|
assert comparator(tuple1, tuple2)
|
|
|
|
def test_temp_paths_in_dict_values(self):
|
|
"""Test temp paths as dictionary values."""
|
|
dict1 = {"path": "/tmp/pytest-of-user/pytest-0/output.json", "name": "test"}
|
|
dict2 = {"path": "/tmp/pytest-of-user/pytest-999/output.json", "name": "test"}
|
|
assert comparator(dict1, dict2)
|
|
|
|
def test_temp_paths_in_dict_keys_not_supported(self):
|
|
"""Test that temp paths as dictionary keys must match exactly (keys are not normalized)."""
|
|
# Dict keys use direct comparison, so temp paths as keys won't be normalized
|
|
# This tests the expected behavior
|
|
dict1 = {"/tmp/pytest-of-user/pytest-0/key": "value"}
|
|
dict2 = {"/tmp/pytest-of-user/pytest-0/key": "value"}
|
|
assert comparator(dict1, dict2)
|
|
|
|
def test_temp_paths_in_nested_dict(self):
|
|
"""Test temp paths in nested dictionaries."""
|
|
nested1 = {
|
|
"config": {
|
|
"output_path": "/tmp/pytest-of-alice/pytest-5/results",
|
|
"log_path": "/tmp/pytest-of-alice/pytest-5/logs",
|
|
}
|
|
}
|
|
nested2 = {
|
|
"config": {
|
|
"output_path": "/tmp/pytest-of-bob/pytest-10/results",
|
|
"log_path": "/tmp/pytest-of-bob/pytest-10/logs",
|
|
}
|
|
}
|
|
assert comparator(nested1, nested2)
|
|
|
|
def test_temp_paths_in_deeply_nested_structure(self):
|
|
"""Test temp paths in deeply nested structures."""
|
|
deep1 = {"a": {"b": {"c": ["/tmp/pytest-of-user/pytest-0/file.txt"]}}}
|
|
deep2 = {"a": {"b": {"c": ["/tmp/pytest-of-other/pytest-99/file.txt"]}}}
|
|
assert comparator(deep1, deep2)
|
|
|
|
def test_mixed_temp_and_regular_paths(self):
|
|
"""Test structures with both temp and regular paths."""
|
|
data1 = {"temp": "/tmp/pytest-of-user/pytest-0/temp.txt", "regular": "/home/user/file.txt"}
|
|
data2 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/file.txt"}
|
|
assert comparator(data1, data2)
|
|
|
|
data3 = {"temp": "/tmp/pytest-of-user/pytest-99/temp.txt", "regular": "/home/user/different.txt"}
|
|
assert not comparator(data1, data3)
|
|
|
|
def test_temp_paths_in_deque(self):
|
|
"""Test temp paths inside deque."""
|
|
from collections import deque
|
|
|
|
d1 = deque(["/tmp/pytest-of-user/pytest-0/file.txt"])
|
|
d2 = deque(["/tmp/pytest-of-user/pytest-123/file.txt"])
|
|
assert comparator(d1, d2)
|
|
|
|
def test_temp_paths_in_chainmap(self):
|
|
"""Test temp paths inside ChainMap."""
|
|
from collections import ChainMap
|
|
|
|
cm1 = ChainMap({"path": "/tmp/pytest-of-user/pytest-0/file.txt"})
|
|
cm2 = ChainMap({"path": "/tmp/pytest-of-user/pytest-99/file.txt"})
|
|
assert comparator(cm1, cm2)
|
|
|
|
|
|
class TestComparatorTempPathsEdgeCases:
|
|
"""Edge case tests for temp path handling in comparator."""
|
|
|
|
def test_empty_string_vs_temp_path(self):
|
|
"""Test empty string comparison with temp path."""
|
|
assert not comparator("", "/tmp/pytest-of-user/pytest-0/file.txt")
|
|
assert not comparator("/tmp/pytest-of-user/pytest-0/file.txt", "")
|
|
|
|
def test_path_with_special_characters(self):
|
|
"""Test temp paths containing special characters in filenames."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/file with spaces.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-99/file with spaces.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
path3 = "/tmp/pytest-of-user/pytest-0/file-with-dashes.txt"
|
|
path4 = "/tmp/pytest-of-user/pytest-99/file-with-dashes.txt"
|
|
assert comparator(path3, path4)
|
|
|
|
def test_path_with_unicode_characters(self):
|
|
"""Test temp paths with unicode characters."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/файл.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-99/файл.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_very_long_session_number(self):
|
|
"""Test temp paths with very long session numbers."""
|
|
path1 = "/tmp/pytest-of-user/pytest-9999999999/file.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-0/file.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_username_with_special_characters(self):
|
|
"""Test temp paths with special characters in username."""
|
|
path1 = "/tmp/pytest-of-user-name/pytest-0/file.txt"
|
|
path2 = "/tmp/pytest-of-other-user/pytest-99/file.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_path_only_differs_in_temp_portion(self):
|
|
"""Test that only the temp portion is normalized, rest must match."""
|
|
path1 = "/tmp/pytest-of-user/pytest-0/subdir/nested/file.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-99/subdir/nested/file.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
path3 = "/tmp/pytest-of-user/pytest-0/subdir/nested/other.txt"
|
|
assert not comparator(path1, path3)
|
|
|
|
def test_multiple_slashes(self):
|
|
"""Test temp paths with multiple consecutive slashes (should still work)."""
|
|
# Note: The regex handles the standard format, extra slashes may not be normalized
|
|
path1 = "/tmp/pytest-of-user/pytest-0/file.txt"
|
|
path2 = "/tmp/pytest-of-user/pytest-99/file.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_temp_path_at_start_middle_end(self):
|
|
"""Test that temp paths are detected regardless of position in string."""
|
|
# Path at start
|
|
assert _is_temp_path("/tmp/pytest-of-user/pytest-0/test")
|
|
# Path in middle (embedded in error message)
|
|
assert _is_temp_path("Error: /tmp/pytest-of-user/pytest-0/test failed")
|
|
# Path at end
|
|
assert _is_temp_path("Output saved to /tmp/pytest-of-user/pytest-0/")
|
|
|
|
def test_partial_temp_path_patterns(self):
|
|
"""Test strings that partially match temp path pattern."""
|
|
# Missing components
|
|
assert not _is_temp_path("/tmp/pytest-of-user/")
|
|
assert not _is_temp_path("/tmp/pytest-0/")
|
|
assert not _is_temp_path("pytest-of-user/pytest-0/")
|
|
|
|
|
|
class TestPytestTempPathPatternRegex:
|
|
"""Tests for the PYTEST_TEMP_PATH_PATTERN regex directly."""
|
|
|
|
def test_pattern_matches_standard_format(self):
|
|
"""Test regex matches standard pytest temp path format."""
|
|
assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-0/")
|
|
assert PYTEST_TEMP_PATH_PATTERN.search("/tmp/pytest-of-user/pytest-123/file")
|
|
|
|
def test_pattern_captures_correctly(self):
|
|
"""Test that the pattern substitution works correctly."""
|
|
result = PYTEST_TEMP_PATH_PATTERN.sub("REPLACED", "/tmp/pytest-of-user/pytest-0/file.txt")
|
|
assert result == "REPLACEDfile.txt"
|
|
|
|
def test_pattern_handles_multiple_matches(self):
|
|
"""Test pattern with multiple temp paths in same string."""
|
|
text = "/tmp/pytest-of-a/pytest-1/ and /tmp/pytest-of-b/pytest-2/"
|
|
result = PYTEST_TEMP_PATH_PATTERN.sub("X", text)
|
|
assert result == "X and X"
|
|
|
|
def test_pattern_greedy_behavior(self):
|
|
"""Test that the pattern doesn't over-match."""
|
|
# The pattern should stop at the trailing slash of the session number
|
|
path = "/tmp/pytest-of-user/pytest-0/subdir/pytest-1/file.txt"
|
|
result = PYTEST_TEMP_PATH_PATTERN.sub("X", path)
|
|
# The first temp path should be replaced, but "pytest-1" in subdir shouldn't trigger
|
|
assert "subdir" in result
|
|
|
|
|
|
class TestComparatorTempPathsWithSuperset:
|
|
"""Tests for temp path comparison with superset_obj=True."""
|
|
|
|
def test_superset_with_temp_paths_in_dict(self):
|
|
"""Test superset comparison with temp paths in dictionaries."""
|
|
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
|
|
new = {"path": "/tmp/pytest-of-user/pytest-99/file.txt", "extra": "data"}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_temp_paths_must_still_match(self):
|
|
"""Test that temp paths must still be equivalent in superset mode."""
|
|
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
|
|
new = {"path": "/tmp/pytest-of-user/pytest-99/other.txt", "extra": "data"}
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_nested_dict_with_temp_paths(self):
|
|
"""Test superset comparison with temp paths in nested dictionaries."""
|
|
orig = {"config": {"output": "/tmp/pytest-of-alice/pytest-5/results.json"}}
|
|
new = {
|
|
"config": {"output": "/tmp/pytest-of-bob/pytest-100/results.json", "debug": True},
|
|
"metadata": {"version": "1.0"},
|
|
}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_multiple_temp_paths_in_dict(self):
|
|
"""Test superset with multiple temp paths in dictionary values."""
|
|
orig = {"input": "/tmp/pytest-of-user/pytest-0/input.txt", "output": "/tmp/pytest-of-user/pytest-0/output.txt"}
|
|
new = {
|
|
"input": "/tmp/pytest-of-user/pytest-99/input.txt",
|
|
"output": "/tmp/pytest-of-user/pytest-99/output.txt",
|
|
"log": "/tmp/pytest-of-user/pytest-99/debug.log",
|
|
}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_temp_path_in_list_inside_dict(self):
|
|
"""Test superset with temp paths in lists inside dictionaries."""
|
|
orig = {"files": ["/tmp/pytest-of-user/pytest-0/a.txt", "/tmp/pytest-of-user/pytest-0/b.txt"]}
|
|
new = {"files": ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"], "count": 2}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_false_when_temp_path_missing(self):
|
|
"""Test superset fails when temp path key is missing in new."""
|
|
orig = {"path": "/tmp/pytest-of-user/pytest-0/file.txt"}
|
|
new = {"other": "data"}
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_temp_path_with_different_filenames_fails(self):
|
|
"""Test superset fails when normalized temp paths have different filenames."""
|
|
orig = {"result": "/tmp/pytest-of-user/pytest-0/output_v1.json"}
|
|
new = {"result": "/tmp/pytest-of-user/pytest-99/output_v2.json", "extra": "data"}
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_mixed_temp_and_regular_paths(self):
|
|
"""Test superset with mix of temp paths and regular paths."""
|
|
orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"}
|
|
new = {
|
|
"temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt",
|
|
"config_file": "/etc/app/config.yaml",
|
|
"extra_key": "extra_value",
|
|
}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_regular_path_must_match_exactly(self):
|
|
"""Test that regular paths must match exactly even in superset mode."""
|
|
orig = {"temp_file": "/tmp/pytest-of-user/pytest-0/temp.txt", "config_file": "/etc/app/config.yaml"}
|
|
new = {
|
|
"temp_file": "/tmp/pytest-of-user/pytest-99/temp.txt",
|
|
"config_file": "/etc/app/other.yaml",
|
|
"extra_key": "extra_value",
|
|
}
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_deeply_nested_temp_paths(self):
|
|
"""Test superset with deeply nested structures containing temp paths."""
|
|
orig = {"level1": {"level2": {"level3": {"path": "/tmp/pytest-of-user/pytest-0/deep.txt"}}}}
|
|
new = {
|
|
"level1": {
|
|
"level2": {
|
|
"level3": {"path": "/tmp/pytest-of-other/pytest-999/deep.txt", "extra": True},
|
|
"sibling": "value",
|
|
}
|
|
},
|
|
"top_level_extra": 123,
|
|
}
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_with_attrs_class_containing_temp_paths(self):
|
|
"""Test superset with attrs classes containing temp paths."""
|
|
try:
|
|
import attr
|
|
except ImportError:
|
|
pytest.skip("attrs not installed")
|
|
|
|
@attr.s
|
|
class Config:
|
|
path = attr.ib()
|
|
name = attr.ib(default="default")
|
|
|
|
# Test that temp paths are normalized in attrs classes
|
|
orig = Config(path="/tmp/pytest-of-user/pytest-0/config.json")
|
|
new = Config(path="/tmp/pytest-of-user/pytest-99/config.json")
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
# Test that different non-temp values still fail
|
|
orig2 = Config(path="/tmp/pytest-of-user/pytest-0/config.json", name="name1")
|
|
new2 = Config(path="/tmp/pytest-of-user/pytest-99/config.json", name="name2")
|
|
assert not comparator(orig2, new2, superset_obj=True)
|
|
|
|
def test_superset_with_class_dict_containing_temp_paths(self):
|
|
"""Test superset with regular class objects containing temp paths."""
|
|
|
|
class Result:
|
|
def __init__(self, output_path):
|
|
self.output_path = output_path
|
|
|
|
class ResultExtended:
|
|
def __init__(self, output_path, extra=None):
|
|
self.output_path = output_path
|
|
self.extra = extra
|
|
|
|
# Note: These are different classes, so type check will fail first
|
|
# Let's use the same class
|
|
orig = Result("/tmp/pytest-of-user/pytest-0/result.json")
|
|
new = Result("/tmp/pytest-of-user/pytest-99/result.json")
|
|
# Add extra attribute to new
|
|
new.extra_field = "extra_data"
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_list_temp_paths_must_have_same_length(self):
|
|
"""Test that lists with temp paths must have same length even in superset mode."""
|
|
# superset_obj doesn't apply to list lengths - they must match
|
|
orig = ["/tmp/pytest-of-user/pytest-0/a.txt"]
|
|
new = ["/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt"]
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_tuple_temp_paths_must_have_same_length(self):
|
|
"""Test that tuples with temp paths must have same length even in superset mode."""
|
|
orig = ("/tmp/pytest-of-user/pytest-0/a.txt",)
|
|
new = ("/tmp/pytest-of-user/pytest-99/a.txt", "/tmp/pytest-of-user/pytest-99/b.txt")
|
|
assert not comparator(orig, new, superset_obj=True)
|
|
|
|
def test_superset_with_exception_containing_temp_path(self):
|
|
"""Test superset with exception objects containing temp paths in attributes."""
|
|
|
|
class CustomError(Exception):
|
|
def __init__(self, message, path):
|
|
super().__init__(message)
|
|
self.path = path
|
|
|
|
orig = CustomError("File error", "/tmp/pytest-of-user/pytest-0/file.txt")
|
|
new = CustomError("File error", "/tmp/pytest-of-user/pytest-99/file.txt")
|
|
new.extra_info = "additional data"
|
|
assert comparator(orig, new, superset_obj=True)
|
|
|
|
|
|
class TestComparatorTempPathsRealisticScenarios:
|
|
"""Tests simulating realistic scenarios where temp path comparison matters."""
|
|
|
|
def test_test_output_comparison(self):
|
|
"""Simulate comparing test outputs that contain temp paths."""
|
|
original_result = {
|
|
"status": "success",
|
|
"output_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/results.json",
|
|
"log_file": "/tmp/pytest-of-ci-runner/pytest-42/test_output/debug.log",
|
|
}
|
|
replay_result = {
|
|
"status": "success",
|
|
"output_file": "/tmp/pytest-of-local-user/pytest-0/test_output/results.json",
|
|
"log_file": "/tmp/pytest-of-local-user/pytest-0/test_output/debug.log",
|
|
}
|
|
assert comparator(original_result, replay_result)
|
|
|
|
def test_exception_message_with_temp_path(self):
|
|
"""Test comparing exception-like structures with temp paths."""
|
|
exc1 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-0/missing.txt"}
|
|
exc2 = {"type": "FileNotFoundError", "message": "File not found: /tmp/pytest-of-user/pytest-99/missing.txt"}
|
|
assert comparator(exc1, exc2)
|
|
|
|
def test_function_return_with_temp_path(self):
|
|
"""Test comparing function returns that include temp paths."""
|
|
# Simulating a function that returns a created file path
|
|
return1 = "/tmp/pytest-of-user/pytest-5/generated_file_abc123.txt"
|
|
return2 = "/tmp/pytest-of-user/pytest-10/generated_file_abc123.txt"
|
|
assert comparator(return1, return2)
|
|
|
|
def test_list_of_created_files(self):
|
|
"""Test comparing lists of created file paths."""
|
|
files1 = [
|
|
"/tmp/pytest-of-user/pytest-0/output/file1.txt",
|
|
"/tmp/pytest-of-user/pytest-0/output/file2.txt",
|
|
"/tmp/pytest-of-user/pytest-0/output/file3.txt",
|
|
]
|
|
files2 = [
|
|
"/tmp/pytest-of-user/pytest-99/output/file1.txt",
|
|
"/tmp/pytest-of-user/pytest-99/output/file2.txt",
|
|
"/tmp/pytest-of-user/pytest-99/output/file3.txt",
|
|
]
|
|
assert comparator(files1, files2)
|
|
|
|
def test_config_object_with_paths(self):
|
|
"""Test comparing config-like objects with multiple paths."""
|
|
config1 = {
|
|
"temp_dir": "/tmp/pytest-of-user/pytest-0/",
|
|
"cache_dir": "/tmp/pytest-of-user/pytest-0/cache/",
|
|
"output_dir": "/tmp/pytest-of-user/pytest-0/output/",
|
|
"permanent_dir": "/home/user/data/",
|
|
}
|
|
config2 = {
|
|
"temp_dir": "/tmp/pytest-of-other/pytest-100/",
|
|
"cache_dir": "/tmp/pytest-of-other/pytest-100/cache/",
|
|
"output_dir": "/tmp/pytest-of-other/pytest-100/output/",
|
|
"permanent_dir": "/home/user/data/",
|
|
}
|
|
assert comparator(config1, config2)
|
|
|
|
|
|
class TestPythonTempfilePaths:
|
|
"""Tests for Python tempfile paths (from tempfile.mkdtemp() or TemporaryDirectory())."""
|
|
|
|
def test_is_temp_path_detects_python_tempfile(self):
|
|
"""Test that _is_temp_path detects Python tempfile paths."""
|
|
assert _is_temp_path("/tmp/tmpqtwy7hpf/special.txt")
|
|
assert _is_temp_path("/tmp/tmpp6wx3tz3/special.txt")
|
|
assert _is_temp_path("/tmp/tmpabcdef12/")
|
|
assert _is_temp_path("/tmp/tmp_underscore/file.txt")
|
|
|
|
def test_is_temp_path_various_tempfile_names(self):
|
|
"""Test various tempfile naming patterns."""
|
|
assert _is_temp_path("/tmp/tmpABCDEF/file.txt") # uppercase
|
|
assert _is_temp_path("/tmp/tmp123456/file.txt") # numeric
|
|
assert _is_temp_path("/tmp/tmpaBc123/file.txt") # mixed
|
|
assert _is_temp_path("/tmp/tmp_test_dir/subdir/file.txt") # with underscore
|
|
|
|
def test_is_temp_path_non_tempfile(self):
|
|
"""Test that non-tempfile paths are not detected."""
|
|
assert not _is_temp_path("/tmp/mydir/file.txt") # doesn't start with tmp
|
|
assert not _is_temp_path("/tmp/temp/file.txt") # temp, not tmp
|
|
assert not _is_temp_path("/home/user/tmp123/file.txt") # not in /tmp/
|
|
|
|
def test_normalize_temp_path_python_tempfile(self):
|
|
"""Test normalization of Python tempfile paths."""
|
|
path1 = _normalize_temp_path("/tmp/tmpqtwy7hpf/special.txt")
|
|
path2 = _normalize_temp_path("/tmp/tmpp6wx3tz3/special.txt")
|
|
assert path1 == path2 == "/tmp/python-temp/special.txt"
|
|
|
|
def test_normalize_temp_path_preserves_subdirs(self):
|
|
"""Test that subdirectories are preserved during normalization."""
|
|
result = _normalize_temp_path("/tmp/tmpabcdef12/subdir/nested/file.txt")
|
|
assert result == "/tmp/python-temp/subdir/nested/file.txt"
|
|
|
|
def test_comparator_python_tempfile_paths_equal(self):
|
|
"""Test that different tempfile paths with same content are equal."""
|
|
path1 = "/tmp/tmpqtwy7hpf/special.txt"
|
|
path2 = "/tmp/tmpp6wx3tz3/special.txt"
|
|
assert comparator(path1, path2)
|
|
|
|
def test_comparator_python_tempfile_different_filenames_not_equal(self):
|
|
"""Test that different filenames in tempfile paths are not equal."""
|
|
path1 = "/tmp/tmpqtwy7hpf/special.txt"
|
|
path2 = "/tmp/tmpp6wx3tz3/different.txt"
|
|
assert not comparator(path1, path2)
|
|
|
|
def test_comparator_python_tempfile_in_tuple(self):
|
|
"""Test tempfile paths in tuples."""
|
|
orig = ("/tmp/tmpqtwy7hpf/special.txt",)
|
|
new = ("/tmp/tmpp6wx3tz3/special.txt",)
|
|
assert comparator(orig, new)
|
|
|
|
def test_comparator_python_tempfile_in_list(self):
|
|
"""Test tempfile paths in lists."""
|
|
orig = ["/tmp/tmpabcdef12/file1.txt", "/tmp/tmpabcdef12/file2.txt"]
|
|
new = ["/tmp/tmpxyz78901/file1.txt", "/tmp/tmpxyz78901/file2.txt"]
|
|
assert comparator(orig, new)
|
|
|
|
def test_comparator_python_tempfile_in_dict(self):
|
|
"""Test tempfile paths in dictionaries."""
|
|
orig = {"output": "/tmp/tmpabcdef12/result.json"}
|
|
new = {"output": "/tmp/tmpxyz78901/result.json"}
|
|
assert comparator(orig, new)
|
|
|
|
def test_comparator_mixed_pytest_and_python_tempfile(self):
|
|
"""Test that pytest and Python tempfile paths don't match each other."""
|
|
pytest_path = "/tmp/pytest-of-user/pytest-0/file.txt"
|
|
python_path = "/tmp/tmpabcdef12/file.txt"
|
|
# These should not be equal - they're different temp path types
|
|
assert not comparator(pytest_path, python_path)
|
|
|
|
def test_python_tempfile_pattern_regex(self):
|
|
"""Test the PYTHON_TEMPFILE_PATTERN regex directly."""
|
|
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmpabcdef/file.txt")
|
|
assert PYTHON_TEMPFILE_PATTERN.search("/tmp/tmp123456/")
|
|
assert not PYTHON_TEMPFILE_PATTERN.search("/tmp/mydir/file.txt")
|
|
assert not PYTHON_TEMPFILE_PATTERN.search("/home/tmp123/file.txt")
|
|
|
|
|
|
@pytest.mark.skipif(sys.version_info < (3, 10), reason="types.UnionType requires Python 3.10+")
|
|
class TestUnionType:
|
|
def test_union_type_equal(self):
|
|
assert comparator(int | str, int | str)
|
|
|
|
def test_union_type_not_equal(self):
|
|
assert not comparator(int | str, int | float)
|
|
|
|
def test_union_type_order_independent(self):
|
|
assert comparator(int | str, str | int)
|
|
|
|
def test_union_type_multiple_args(self):
|
|
assert comparator(int | str | float, int | str | float)
|
|
|
|
def test_union_type_in_list(self):
|
|
assert comparator([int | str, 1], [int | str, 1])
|
|
|
|
def test_union_type_in_dict(self):
|
|
assert comparator({"key": int | str}, {"key": int | str})
|
|
|
|
def test_union_type_vs_none(self):
|
|
assert not comparator(int | str, None)
|
|
|
|
|
|
class SlotsOnly:
|
|
__slots__ = ("x", "y")
|
|
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
|
|
class SlotsInherited(SlotsOnly):
|
|
__slots__ = ("z",)
|
|
|
|
def __init__(self, x, y, z):
|
|
super().__init__(x, y)
|
|
self.z = z
|
|
|
|
|
|
class TestSlotsObjects:
|
|
def test_slots_equal(self):
|
|
assert comparator(SlotsOnly(1, 2), SlotsOnly(1, 2))
|
|
|
|
def test_slots_not_equal(self):
|
|
assert not comparator(SlotsOnly(1, 2), SlotsOnly(1, 3))
|
|
|
|
def test_slots_inherited_equal(self):
|
|
assert comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 3))
|
|
|
|
def test_slots_inherited_not_equal(self):
|
|
assert not comparator(SlotsInherited(1, 2, 3), SlotsInherited(1, 2, 4))
|
|
|
|
def test_slots_nested(self):
|
|
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
|
b = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
|
assert comparator(a, b)
|
|
|
|
def test_slots_nested_not_equal(self):
|
|
a = SlotsOnly(SlotsOnly(1, 2), [3, 4])
|
|
b = SlotsOnly(SlotsOnly(1, 9), [3, 4])
|
|
assert not comparator(a, b)
|