codeflash/tests/test_comparator.py
2025-03-28 15:26:27 -07:00

1173 lines
33 KiB
Python

import ast
import copy
import dataclasses
import datetime
import decimal
import re
import sys
from enum import Enum, Flag, IntFlag, auto
from pathlib import Path
import pydantic
import pytest
from codeflash.either import Failure, Success
from codeflash.models.models import FunctionTestInvocation, InvocationId, TestResults, TestType
from codeflash.verification.comparator import comparator
from codeflash.verification.equivalence import compare_test_results
def test_basic_python_objects() -> None:
a = 5
b = 5
c = 6
d = None
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = 5.0
b = 5.0
c = 6.0
d = None
e = None
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(d, a)
assert comparator(d, e)
a = "Hello"
b = "Hello"
c = "World"
assert comparator(a, b)
assert not comparator(a, c)
a = [1, 2, 3]
b = [1, 2, 3]
c = [1, 2, 4]
assert comparator(a, b)
assert not comparator(a, c)
a = {"a": 1, "b": 2}
b = {"a": 1, "b": 2}
c = {"a": 1, "b": 3}
d = {"c": 1, "b": 2}
e = {"a": 1, "b": 2, "c": 3}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
a = (1, 2, "str")
b = (1, 2, "str")
c = (1, 2, "str2")
d = [1, 2, "str"]
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = {1, 2, 3}
b = {2, 3, 1}
c = {1, 2, 4}
d = {1, 2, 3, 4}
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = (65).to_bytes(1, byteorder="big")
b = (65).to_bytes(1, byteorder="big")
c = (66).to_bytes(1, byteorder="big")
assert comparator(a, b)
assert not comparator(a, c)
a = (65).to_bytes(2, byteorder="little")
b = (65).to_bytes(2, byteorder="big")
assert not comparator(a, b)
a = bytearray([65, 64, 63])
b = bytearray([65, 64, 63])
c = bytearray([65, 64, 62])
assert comparator(a, b)
assert not comparator(a, c)
memoryview_a = memoryview(bytearray([65, 64, 63]))
memoryview_b = memoryview(bytearray([65, 64, 63]))
memoryview_c = memoryview(bytearray([65, 64, 62]))
assert comparator(memoryview_a, memoryview_b)
assert not comparator(memoryview_a, memoryview_c)
a = frozenset([1, 2, 3])
b = frozenset([2, 3, 1])
c = frozenset([1, 2, 4])
d = frozenset([1, 2, 3, 4])
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
a = map
b = pow
c = pow
d = abs
assert comparator(b, c)
assert not comparator(a, b)
assert not comparator(c, d)
a = object()
b = object()
c = abs
assert comparator(a, b)
assert not comparator(a, c)
a = type([])
b = type([])
c = type({})
assert comparator(a, b)
assert not comparator(a, c)
def test_standard_python_library_objects() -> None:
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.date(2020, 2, 2) # type: ignore
b = datetime.date(2020, 2, 2) # type: ignore
c = datetime.date(2020, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timedelta(days=1) # type: ignore
b = datetime.timedelta(days=1) # type: ignore
c = datetime.timedelta(days=2) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.time(2, 2, 2) # type: ignore
b = datetime.time(2, 2, 2) # type: ignore
c = datetime.time(2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timezone.utc # type: ignore
b = datetime.timezone.utc # type: ignore
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = decimal.Decimal(3.14) # type: ignore
b = decimal.Decimal(3.14) # type: ignore
c = decimal.Decimal(3.15) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
class Color(Flag):
RED = auto()
GREEN = auto()
BLUE = auto()
class Color2(Enum):
RED = auto()
GREEN = auto()
BLUE = auto()
a = Color.RED # type: ignore
b = Color.RED # type: ignore
c = Color.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = Color2.RED # type: ignore
b = Color2.RED # type: ignore
c = Color2.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
class Color4(IntFlag):
RED = auto()
GREEN = auto()
BLUE = auto()
a = Color4.RED # type: ignore
b = Color4.RED # type: ignore
c = Color4.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a: re.Pattern = re.compile("a")
b: re.Pattern = re.compile("a")
c: re.Pattern = re.compile("b")
d: re.Pattern = re.compile("a", re.IGNORECASE)
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
def test_numpy():
try:
import numpy as np
except ImportError:
pytest.skip()
a = np.array([1, 2, 3])
b = np.array([1, 2, 3])
c = np.array([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
d = np.array([[1, 2], [3, 4]])
e = np.array([[1, 2], [3, 4]])
f = np.array([[1, 2], [3, 5]])
assert comparator(d, e)
assert not comparator(d, f)
assert not comparator(a, d)
g = np.array([1.0, 2.0, 3.0])
assert not comparator(a, g)
h = np.float32(1.0)
i = np.float32(1.0)
assert comparator(h, i)
j = np.float64(1.0)
k = np.float64(1.0)
assert not comparator(h, j)
assert comparator(j, k)
l = np.int32(1)
m = np.int32(1)
assert comparator(l, m)
assert not comparator(l, h)
assert not comparator(l, j)
n = np.int64(1)
o = np.int64(1)
assert not comparator(n, l)
assert comparator(n, o)
p = np.uint32(1)
q = np.uint32(1)
assert comparator(p, q)
assert not comparator(p, l)
r = np.uint64(1)
s = np.uint64(1)
assert not comparator(r, p)
assert comparator(r, s)
t = np.bool_(True)
u = np.bool_(True)
assert comparator(t, u)
assert not comparator(t, r)
v = np.complex64(1.0 + 1.0j)
w = np.complex64(1.0 + 1.0j)
assert comparator(v, w)
assert not comparator(v, t)
x = np.complex128(1.0 + 1.0j)
y = np.complex128(1.0 + 1.0j)
assert not comparator(x, v)
assert comparator(x, y)
# Create numpy array with mixed type object
z = np.array([1, 2, "str"], dtype=np.object_)
aa = np.array([1, 2, "str"], dtype=np.object_)
ab = np.array([1, 2, "str2"], dtype=np.object_)
assert comparator(z, aa)
assert not comparator(z, ab)
ac = np.array([1, 2, "str2"])
ad = np.array([1, 2, "str2"])
assert comparator(ac, ad)
# Test for numpy array with nan and inf
ae = np.array([1, 2, np.nan])
af = np.array([1, 2, np.nan])
ag = np.array([1, 2, np.inf])
ah = np.array([1, 2, np.inf])
ai = np.inf
aj = np.inf
ak = np.nan
al = np.nan
assert comparator(ae, af)
assert comparator(ag, ah)
assert not comparator(ae, ag)
assert not comparator(af, ah)
assert comparator(ai, aj)
assert comparator(ak, al)
assert not comparator(ai, ak)
def test_scipy():
try:
import scipy as sp # type: ignore
except ImportError:
pytest.skip()
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
b = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
c = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
ca = sp.sparse.csr_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(c, ca)
d = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
e = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
f = sp.sparse.csc_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
fa = sp.sparse.csc_matrix([[1, 0, 0, 0], [0, 0, 3, 0], [4, 0, 6, 0]])
assert comparator(d, e)
assert not comparator(d, f)
assert not comparator(a, d)
assert not comparator(c, f)
assert not comparator(f, fa)
g = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
h = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
i = sp.sparse.lil_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(g, h)
assert not comparator(g, i)
assert not comparator(a, g)
j = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
k = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
l = sp.sparse.dok_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(j, k)
assert not comparator(j, l)
assert not comparator(a, j)
m = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
n = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
o = sp.sparse.dia_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(m, n)
assert not comparator(m, o)
assert not comparator(a, m)
p = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
q = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
r = sp.sparse.coo_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(p, q)
assert not comparator(p, r)
assert not comparator(a, p)
s = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
t = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
u = sp.sparse.bsr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 6]])
assert comparator(s, t)
assert not comparator(s, u)
assert not comparator(a, s)
try:
import numpy as np
row = np.array([0, 3, 1, 0])
col = np.array([0, 3, 1, 2])
data = np.array([4, 5, 7, 9])
v = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
w = sp.sparse.coo_array((data, (row, col)), shape=(4, 4)).toarray()
assert comparator(v, w)
except ImportError:
print("Should run tests with numpy installed to test more thoroughly")
def test_pandas():
try:
import pandas as pd
except ImportError:
pytest.skip()
a = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
b = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
c = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 7]})
ca = pd.DataFrame({"a": [1, 2, 3, 4], "b": [4, 5, 6, 7]})
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(c, ca)
ak = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
al = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 2)], "b": [4, 5]}
)
am = pd.DataFrame(
{"a": [datetime.datetime(2020, 2, 2, 2, 2, 2), datetime.datetime(2020, 2, 2, 2, 2, 3)], "b": [4, 5]}
)
assert comparator(ak, al)
assert not comparator(ak, am)
d = pd.Series([1, 2, 3])
e = pd.Series([1, 2, 3])
f = pd.Series([1, 2, 4])
assert comparator(d, e)
assert not comparator(d, f)
g = pd.Index([1, 2, 3])
h = pd.Index([1, 2, 3])
i = pd.Index([1, 2, 4])
assert comparator(g, h)
assert not comparator(g, i)
j = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
k = pd.MultiIndex.from_tuples([(1, 2), (3, 4)])
l = pd.MultiIndex.from_tuples([(1, 2), (3, 5)])
assert comparator(j, k)
assert not comparator(j, l)
m = pd.Categorical([1, 2, 3])
n = pd.Categorical([1, 2, 3])
o = pd.Categorical([1, 2, 4])
assert comparator(m, n)
assert not comparator(m, o)
p = pd.Interval(1, 2)
q = pd.Interval(1, 2)
r = pd.Interval(1, 3)
assert comparator(p, q)
assert not comparator(p, r)
s = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
t = pd.IntervalIndex.from_tuples([(1, 2), (3, 4)])
u = pd.IntervalIndex.from_tuples([(1, 2), (3, 5)])
assert comparator(s, t)
assert not comparator(s, u)
v = pd.Period("2021-01")
w = pd.Period("2021-01")
x = pd.Period("2021-02")
assert comparator(v, w)
assert not comparator(v, x)
y = pd.period_range(start="2021-01", periods=3, freq="M")
z = pd.period_range(start="2021-01", periods=3, freq="M")
aa = pd.period_range(start="2021-01", periods=4, freq="M")
assert comparator(y, z)
assert not comparator(y, aa)
ab = pd.Timedelta("1 days")
ac = pd.Timedelta("1 days")
ad = pd.Timedelta("2 days")
assert comparator(ab, ac)
assert not comparator(ab, ad)
ae = pd.TimedeltaIndex(["1 days", "2 days"])
af = pd.TimedeltaIndex(["1 days", "2 days"])
ag = pd.TimedeltaIndex(["1 days", "3 days"])
assert comparator(ae, af)
assert not comparator(ae, ag)
ah = pd.Timestamp("2021-01-01")
ai = pd.Timestamp("2021-01-01")
aj = pd.Timestamp("2021-01-02")
assert comparator(ah, ai)
assert not comparator(ah, aj)
# test cases for sparse pandas arrays
an = pd.arrays.SparseArray([1, 2, 3])
ao = pd.arrays.SparseArray([1, 2, 3])
ap = pd.arrays.SparseArray([1, 2, 4])
assert comparator(an, ao)
assert not comparator(an, ap)
assert comparator(pd.NA, pd.NA)
assert not comparator(pd.NA, None)
assert not comparator(None, pd.NA)
s1 = pd.Series([1, 2, pd.NA, 4])
s2 = pd.Series([1, 2, pd.NA, 4])
s3 = pd.Series([1, 2, None, 4])
assert comparator(s1, s2)
assert not comparator(s1, s3)
df1 = pd.DataFrame({'a': [1, 2, pd.NA], 'b': [4, pd.NA, 6]})
df2 = pd.DataFrame({'a': [1, 2, pd.NA], 'b': [4, pd.NA, 6]})
df3 = pd.DataFrame({'a': [1, 2, None], 'b': [4, None, 6]})
assert comparator(df1, df2)
assert not comparator(df1, df3)
d1 = {'a': pd.NA, 'b': [1, pd.NA, 3]}
d2 = {'a': pd.NA, 'b': [1, pd.NA, 3]}
d3 = {'a': None, 'b': [1, None, 3]}
assert comparator(d1, d2)
assert not comparator(d1, d3)
s1 = pd.Series([1, 2, pd.NA, 4])
s2 = pd.Series([1, 2, pd.NA, 4])
filtered1 = s1[s1 > 1]
filtered2 = s2[s2 > 1]
assert comparator(filtered1, filtered2)
def test_pyrsistent():
try:
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
except ImportError:
pytest.skip()
a = pmap({"a": 1, "b": 2})
b = pmap({"a": 1, "b": 2})
c = pmap({"a": 1, "b": 3})
assert comparator(a, b)
assert not comparator(a, c)
d = pvector([1, 2, 3])
e = pvector([1, 2, 3])
f = pvector([1, 2, 4])
assert comparator(d, e)
assert not comparator(d, f)
g = pset([1, 2, 3])
h = pset([2, 3, 1])
i = pset([1, 2, 4])
assert comparator(g, h)
assert not comparator(g, i)
class TestRecord(PRecord):
a = field()
b = field()
j = TestRecord()
k = TestRecord()
l = TestRecord(a=2, b=3)
assert comparator(j, k)
assert not comparator(j, l)
class TestClass(PClass):
a = field()
b = field()
m = TestClass()
n = TestClass()
o = TestClass(a=1, b=3)
assert comparator(m, n)
assert not comparator(m, o)
p = pdeque([1, 2, 3], 3)
q = pdeque([1, 2, 3], 3)
r = pdeque([1, 2, 4], 3)
assert comparator(p, q)
assert not comparator(p, r)
s = PBag([1, 2, 3])
t = PBag([1, 2, 3])
u = PBag([1, 2, 4])
assert comparator(s, t)
assert not comparator(s, u)
v = pvector([1, 2, 3])
w = pvector([1, 2, 3])
x = pvector([1, 2, 4])
assert comparator(v, w)
assert not comparator(v, x)
def test_torch():
try:
import torch # type: ignore
except ImportError:
pytest.skip()
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
c = torch.tensor([1, 2, 4])
assert comparator(a, b)
assert not comparator(a, c)
d = torch.tensor([[1, 2, 3], [4, 5, 6]])
e = torch.tensor([[1, 2, 3], [4, 5, 6]])
f = torch.tensor([[1, 2, 3], [4, 5, 7]])
assert comparator(d, e)
assert not comparator(d, f)
# Test tensors with different data types
g = torch.tensor([1, 2, 3], dtype=torch.float32)
h = torch.tensor([1, 2, 3], dtype=torch.float32)
i = torch.tensor([1, 2, 3], dtype=torch.int64)
assert comparator(g, h)
assert not comparator(g, i)
# Test 3D tensors
j = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
k = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
l = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 9]]])
assert comparator(j, k)
assert not comparator(j, l)
# Test tensors with different shapes
m = torch.tensor([1, 2, 3])
n = torch.tensor([[1, 2, 3]])
assert not comparator(m, n)
# Test empty tensors
o = torch.tensor([])
p = torch.tensor([])
q = torch.tensor([1])
assert comparator(o, p)
assert not comparator(o, q)
# Test tensors with NaN values
r = torch.tensor([1.0, float('nan'), 3.0])
s = torch.tensor([1.0, float('nan'), 3.0])
t = torch.tensor([1.0, 2.0, 3.0])
assert comparator(r, s) # NaN == NaN
assert not comparator(r, t)
# Test tensors with infinity values
u = torch.tensor([1.0, float('inf'), 3.0])
v = torch.tensor([1.0, float('inf'), 3.0])
w = torch.tensor([1.0, float('-inf'), 3.0])
assert comparator(u, v)
assert not comparator(u, w)
# Test tensors with different devices (if CUDA is available)
if torch.cuda.is_available():
x = torch.tensor([1, 2, 3]).cuda()
y = torch.tensor([1, 2, 3]).cuda()
z = torch.tensor([1, 2, 3])
assert comparator(x, y)
assert not comparator(x, z)
# Test tensors with requires_grad
aa = torch.tensor([1., 2., 3.], requires_grad=True)
bb = torch.tensor([1., 2., 3.], requires_grad=True)
cc = torch.tensor([1., 2., 3.], requires_grad=False)
assert comparator(aa, bb)
assert not comparator(aa, cc)
# Test complex tensors
dd = torch.tensor([1+2j, 3+4j])
ee = torch.tensor([1+2j, 3+4j])
ff = torch.tensor([1+2j, 3+5j])
assert comparator(dd, ee)
assert not comparator(dd, ff)
# Test boolean tensors
gg = torch.tensor([True, False, True])
hh = torch.tensor([True, False, True])
ii = torch.tensor([True, True, True])
assert comparator(gg, hh)
assert not comparator(gg, ii)
def test_returns():
a = Success(5)
b = Success(5)
c = Success(6)
d = Failure(5)
e = Success((5, 5))
f = Success((5, 6))
assert comparator(a, b)
assert not comparator(a, c)
assert not comparator(a, d)
assert not comparator(a, e)
assert not comparator(e, f)
g = Success((5, 5))
h = Success((5, 5))
i = Success((5, 6))
assert comparator(g, h)
assert not comparator(g, i)
def test_custom_object():
class TestClass:
def __init__(self, value):
self.value = value
def __eq__(self, other):
return self.value == other.value
a = TestClass(5)
b = TestClass(5)
c = TestClass(6)
assert comparator(a, b)
assert not comparator(a, c)
class TestClass2:
def __init__(self, value1, value2=6):
self.value1 = value1
self.value2 = value2
a = TestClass(5)
b = TestClass2(5, 6)
c = TestClass2(5, 7)
d = TestClass2(5, 6)
assert not comparator(a, b)
assert not comparator(b, c)
assert comparator(b, d)
class TestClass3(TestClass):
def print(self):
print(self.value)
a = TestClass2(5)
b = TestClass3(5)
c = TestClass3(5)
assert not comparator(a, b)
assert comparator(b, c)
@dataclasses.dataclass
class InventoryItem:
"""Class for keeping track of an item in inventory."""
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItem(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
@pydantic.dataclasses.dataclass
class InventoryItemPydantic:
"""Class for keeping track of an item in inventory."""
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItemPydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
class InventoryItemBasePydantic(pydantic.BaseModel):
name: str
unit_price: float
quantity_on_hand: int = 0
def total_cost(self) -> float:
return self.unit_price * self.quantity_on_hand
a = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
b = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=10)
c = InventoryItemBasePydantic(name="widget", unit_price=3.0, quantity_on_hand=11)
assert comparator(a, b)
assert not comparator(a, c)
class A:
items = [1, 2, 3]
val = 5
class B:
items = [1, 2, 4]
val = 5
assert comparator(A, A)
assert not comparator(A, B)
class C:
items = [1, 2, 3]
val = 5
def __init__(self):
self.itemm2 = [1, 2, 3]
self.val2 = 5
class D:
items = [1, 2, 3]
val = 5
def __init__(self):
self.itemm2 = [1, 2, 4]
self.val2 = 5
assert comparator(C, C)
assert not comparator(C, D)
E = C
assert comparator(C, E)
def test_custom_object_2():
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_method.py").resolve()
original_code = fto_path.read_text("utf-8")
from code_to_optimize.bubble_sort_method import BubbleSorter
a = BubbleSorter()
assert a.x == 0
try:
# Remove the module from sys.modules, to get the updated class
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
b = BubbleSorter()
assert comparator(
a, b
) # Note that type(a) != type(b) as the class type objects are different, even if the code is the same.
optimized_code_mutated_attr = """
class BubbleSorter:
z = 0
def __init__(self, x=1):
self.x = x
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
fto_path.write_text(optimized_code_mutated_attr, "utf-8")
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
c = BubbleSorter()
assert c.x == 1
assert not comparator(a, c)
optimized_code_new_attr = """
class BubbleSorter:
z = 5
def __init__(self, x=0):
self.x = x
def sorter(self, arr):
for i in range(len(arr)):
for j in range(len(arr) - 1):
if arr[j] > arr[j + 1]:
temp = arr[j]
arr[j] = arr[j + 1]
arr[j + 1] = temp
return arr
"""
fto_path.write_text(optimized_code_new_attr, "utf-8")
sys.modules.pop("code_to_optimize.bubble_sort_method", None)
from code_to_optimize.bubble_sort_method import BubbleSorter
d = BubbleSorter()
assert d.x == 0
# Currently, we do not check if class variables are different, since the code replacer does not allow this.
# In the future, if this functionality is allowed, this assert should be false.
assert comparator(a, d)
finally:
fto_path.write_text(original_code, "utf-8")
def test_superset():
class A:
def __init__(self):
self.a = 1
obj = A()
obj.x = 3
assert comparator(A(), obj, superset_obj=True)
assert not comparator(obj, A(), superset_obj=True)
assert not comparator(A(), obj)
assert not comparator(obj, A())
assert comparator(obj, obj, superset_obj=True)
assert comparator(obj, obj)
def test_compare_results_fn():
original_results = TestResults()
original_results.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=5,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_1 = TestResults()
new_results_1.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
assert compare_test_results(original_results, new_results_1)
new_results_2 = TestResults()
new_results_2.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=[5],
timed_out=False,
loop_index=1,
)
)
assert not compare_test_results(original_results, new_results_2)
new_results_3 = TestResults()
new_results_3.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
new_results_3.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="2",
),
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
assert compare_test_results(original_results, new_results_3)
new_results_4 = TestResults()
new_results_4.add(
FunctionTestInvocation(
id=InvocationId(
test_module_path="test_module_path",
test_class_name="test_class_name",
test_function_name="test_function_name",
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name=Path("file_name"),
did_pass=False,
runtime=5,
test_framework="unittest",
test_type=TestType.EXISTING_UNIT_TEST,
return_value=5,
timed_out=False,
loop_index=1,
)
)
assert not compare_test_results(original_results, new_results_4)
assert not compare_test_results(TestResults(), TestResults())
def test_exceptions():
type_error = TypeError("This is a type error")
type_error_2 = TypeError("This is a type error")
assert comparator(type_error, type_error_2)
def raise_exception():
raise Exception("This is an exception")
def test_exceptions_comparator():
# Currently we are only comparing the exception types and the attributes that don't start with "_"
# there are complications with comparing the exception messages
try:
raise_exception()
except Exception as e:
exception = e
try:
raise_exception()
except Exception as b:
exception_2 = b
assert comparator(exception, exception_2)
exc1 = ValueError("same message")
exc2 = ValueError("same message")
assert comparator(exc1, exc2)
exc_msg1 = ValueError("message one")
exc_msg2 = ValueError("message two")
# Different messages but same types
assert comparator(exc_msg1, exc_msg2)
exc1 = ValueError("common message")
exc2 = TypeError("common message")
assert not comparator(exc1, exc2)
exc_file_1 = FileNotFoundError(2, "No such file or directory")
exc_file2 = FileNotFoundError(2, "No such file or directory")
exc_file4 = FileNotFoundError(2, "File not found")
exc_file3 = FileNotFoundError(3, "No such file or directory")
assert not comparator(exc1, exc2)
assert comparator(exc_file_1, exc_file2)
assert comparator(exc_file_1, exc_file3)
assert comparator(exc_file_1, exc_file4)
assert comparator(exception, exception)
assert not comparator(exception, None)
assert not comparator(None, exception)
assert comparator(None, None)
# Different exception types
exc_type1 = TypeError("Type error")
exc_type2 = TypeError("Another type error")
assert comparator(exc_type1, exc_type2)
exc_type3 = KeyError("Missing key")
exc_type4 = KeyError("Missing key")
assert comparator(exc_type3, exc_type4)
assert not comparator(exc_type1, exc_type3)
# compare the attributes of the exception as well
class CustomError(Exception):
def __init__(self, message, code):
super().__init__(message)
self.code = code
custom_exc1 = CustomError("Something went wrong", 101)
custom_exc2 = CustomError("Something went wrong", 101)
assert comparator(custom_exc1, custom_exc2)
custom_exc4 = CustomError("Something went wrong", 102)
assert not comparator(custom_exc1, custom_exc4)
class CustomErrorNoArgs(Exception):
pass
custom_no_args1 = CustomErrorNoArgs()
custom_no_args2 = CustomErrorNoArgs()
assert comparator(custom_no_args1, custom_no_args2)
exc_empty1 = ValueError("")
exc_empty2 = ValueError("")
assert comparator(exc_empty1, exc_empty2)
exc_not_empty = ValueError("Not empty")
assert comparator(exc_empty1, exc_not_empty)
class CustomValueError(ValueError):
pass
custom_value_error1 = CustomValueError("A custom value error")
value_error1 = ValueError("A custom value error")
assert not comparator(custom_value_error1, value_error1)
custom_value_error2 = CustomValueError("Another custom value error")
assert comparator(custom_value_error1, custom_value_error2)
class CustomExceptionWithArgs(Exception):
def __init__(self, arg1, arg2):
self.args = (arg1, arg2)
custom_args_exc1 = CustomExceptionWithArgs(1, "test")
custom_args_exc2 = CustomExceptionWithArgs(1, "test")
assert comparator(custom_args_exc1, custom_args_exc2)
custom_args_exc3 = CustomExceptionWithArgs(1, "different")
assert comparator(custom_args_exc1, custom_args_exc3)
def raise_specific_exception():
raise ZeroDivisionError("Cannot divide by zero")
try:
raise_specific_exception()
except ZeroDivisionError as z1:
zero_division_exc1 = z1
try:
raise_specific_exception()
except ZeroDivisionError as z2:
zero_division_exc2 = z2
assert comparator(zero_division_exc1, zero_division_exc2)
zero_division_exc3 = ZeroDivisionError("Different message")
assert comparator(zero_division_exc1, zero_division_exc3)
assert comparator(..., ...)
assert comparator(Ellipsis, Ellipsis)
assert not comparator(..., None)
assert not comparator(Ellipsis, None)
code7 = "a = 1 + 2"
module7 = ast.parse(code7)
for node in ast.walk(module7):
for child in ast.iter_child_nodes(node):
child.parent = node # type: ignore
module8 = copy.deepcopy(module7)
assert comparator(module7, module8)
code2 = "a = 1 + 3"
module2 = ast.parse(code2)
assert not comparator(module7, module2)