mirror of
https://github.com/codeflash-ai/codeflash.git
synced 2026-05-04 18:25:17 +00:00
Extend comparator support to class objects without a user defined __eq__ function
This commit is contained in:
parent
7c4cd23476
commit
80e2950524
2 changed files with 17 additions and 10 deletions
|
|
@ -114,7 +114,8 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
return orig.equals(new)
|
||||
|
||||
if HAS_PANDAS and isinstance(
|
||||
orig, (pandas.CategoricalDtype, pandas.Interval, pandas.Period),
|
||||
orig,
|
||||
(pandas.CategoricalDtype, pandas.Interval, pandas.Period),
|
||||
):
|
||||
return orig == new
|
||||
|
||||
|
|
@ -134,13 +135,19 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
return orig == new
|
||||
|
||||
# If the object passed has a user defined __eq__ method, use that
|
||||
# This could fail if the user defined __eq__ is defined with cython
|
||||
# This could fail if the user defined __eq__ is defined with C-extensions
|
||||
try:
|
||||
if hasattr(orig, "__eq__") and str(type(orig.__eq__)) == "<class 'method'>":
|
||||
return orig == new
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# For class objects
|
||||
if hasattr(orig, "__dict__") and hasattr(new, "__dict__"):
|
||||
orig_keys = orig.__dict__
|
||||
new_keys = new.__dict__
|
||||
return comparator(orig_keys, new_keys)
|
||||
|
||||
# TODO : Add other types here
|
||||
logging.warning(f"Unknown comparator input type: {type(orig)}")
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -387,17 +387,17 @@ def test_custom_object():
|
|||
assert not comparator(a, c)
|
||||
|
||||
class TestClass2:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
def __init__(self, value1, value2=6):
|
||||
self.value1 = value1
|
||||
self.value2 = value2
|
||||
|
||||
a = TestClass(5)
|
||||
b = TestClass2(5)
|
||||
c = TestClass2(5)
|
||||
b = TestClass2(5, 6)
|
||||
c = TestClass2(5, 7)
|
||||
d = TestClass2(5, 6)
|
||||
assert not comparator(a, b)
|
||||
assert comparator(
|
||||
b,
|
||||
c,
|
||||
) # This is a fallback to True right now since we don't know how to compare them. This can be improved later
|
||||
assert not comparator(b, c)
|
||||
assert comparator(b, d)
|
||||
|
||||
class TestClass3(TestClass):
|
||||
def print(self):
|
||||
|
|
|
|||
Loading…
Reference in a new issue