Update test_comparator.py

check for ellipsis & ast
change order
add tests
ruff silencio
This commit is contained in:
Kevin Turcios 2025-02-01 17:07:02 -05:00
parent 0b2ff7cfa6
commit a0107159b9
3 changed files with 60 additions and 41 deletions

View file

@ -5,7 +5,7 @@ import math
import re
import types
from typing import Any
import ast
import sentry_sdk
from codeflash.cli_cmds.console import logger
@ -182,7 +182,6 @@ def comparator(orig: Any, new: Any) -> bool:
return orig == new
except Exception:
pass
# For class objects
if hasattr(orig, "__dict__") and hasattr(new, "__dict__"):
orig_keys = orig.__dict__
@ -196,13 +195,18 @@ def comparator(orig: Any, new: Any) -> bool:
orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")}
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}
if isinstance(orig, ast.AST):
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
return comparator(orig_keys, new_keys)
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
return new == orig
if str(type(orig)) == "<class 'object'>":
return True
if orig is Ellipsis and new is Ellipsis:
return True
# TODO : Add other types here
logger.warning(f"Unknown comparator input type: {type(orig)}")
return False

View file

@ -154,7 +154,7 @@ warn_required_dynamic_aliases = true
line-length = 120
fix = true
show-fixes = true
exclude = ["code_to_optimize/", "pie_test_set/"]
exclude = ["code_to_optimize/", "pie_test_set/", "tests"]
[tool.ruff.lint]
select = ["ALL"]

View file

@ -1,3 +1,5 @@
import ast
import copy
import dataclasses
import datetime
import decimal
@ -6,13 +8,15 @@ from enum import Enum, Flag, IntFlag, auto
import pydantic
import pytest
from pathlib import Path
from codeflash.either import Failure, Success
from codeflash.verification.comparator import comparator
from codeflash.verification.equivalence import compare_test_results
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
def test_basic_python_objects():
def test_basic_python_objects() -> None:
a = 5
b = 5
c = 6
@ -120,40 +124,40 @@ def test_basic_python_objects():
assert not comparator(a, c)
def test_standard_python_library_objects():
a = datetime.datetime(2020, 2, 2, 2, 2, 2)
b = datetime.datetime(2020, 2, 2, 2, 2, 2)
c = datetime.datetime(2020, 2, 2, 2, 2, 3)
def test_standard_python_library_objects() -> None:
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.date(2020, 2, 2)
b = datetime.date(2020, 2, 2)
c = datetime.date(2020, 2, 3)
a = datetime.date(2020, 2, 2) # type: ignore
b = datetime.date(2020, 2, 2) # type: ignore
c = datetime.date(2020, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timedelta(days=1)
b = datetime.timedelta(days=1)
c = datetime.timedelta(days=2)
a = datetime.timedelta(days=1) # type: ignore
b = datetime.timedelta(days=1) # type: ignore
c = datetime.timedelta(days=2) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.time(2, 2, 2)
b = datetime.time(2, 2, 2)
c = datetime.time(2, 2, 3)
a = datetime.time(2, 2, 2) # type: ignore
b = datetime.time(2, 2, 2) # type: ignore
c = datetime.time(2, 2, 3) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = datetime.timezone.utc
b = datetime.timezone.utc
c = datetime.timezone(datetime.timedelta(hours=1))
a = datetime.timezone.utc # type: ignore
b = datetime.timezone.utc # type: ignore
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = decimal.Decimal(3.14)
b = decimal.Decimal(3.14)
c = decimal.Decimal(3.15)
a = decimal.Decimal(3.14) # type: ignore
b = decimal.Decimal(3.14) # type: ignore
c = decimal.Decimal(3.15) # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
@ -167,15 +171,15 @@ def test_standard_python_library_objects():
GREEN = auto()
BLUE = auto()
a = Color.RED
b = Color.RED
c = Color.GREEN
a = Color.RED # type: ignore
b = Color.RED # type: ignore
c = Color.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
a = Color2.RED
b = Color2.RED
c = Color2.GREEN
a = Color2.RED # type: ignore
b = Color2.RED # type: ignore
c = Color2.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
@ -184,9 +188,9 @@ def test_standard_python_library_objects():
GREEN = auto()
BLUE = auto()
a = Color4.RED
b = Color4.RED
c = Color4.GREEN
a = Color4.RED # type: ignore
b = Color4.RED # type: ignore
c = Color4.GREEN # type: ignore
assert comparator(a, b)
assert not comparator(a, c)
@ -296,7 +300,7 @@ def test_numpy():
def test_scipy():
try:
import scipy as sp
import scipy as sp # type: ignore
except ImportError:
pytest.skip()
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
@ -466,7 +470,7 @@ def test_pandas():
def test_pyrsistent():
try:
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
except ImportError:
pytest.skip()
@ -678,7 +682,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=True,
runtime=5,
test_framework="unittest",
@ -699,7 +703,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
@ -722,7 +726,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
@ -745,7 +749,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
@ -764,7 +768,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="2",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=True,
runtime=10,
test_framework="unittest",
@ -787,7 +791,7 @@ def test_compare_results_fn():
function_getting_tested="function_getting_tested",
iteration_id="0",
),
file_name="file_name",
file_name=Path("file_name"),
did_pass=False,
runtime=5,
test_framework="unittest",
@ -941,3 +945,14 @@ def test_exceptions_comparator():
zero_division_exc3 = ZeroDivisionError("Different message")
assert comparator(zero_division_exc1, zero_division_exc3)
assert comparator(..., ...)
assert comparator(Ellipsis, Ellipsis)
code7 = "a = 1 + 2"
module7 = ast.parse(code7)
for node in ast.walk(module7):
for child in ast.iter_child_nodes(node):
child.parent = node # type: ignore
module8 = copy.deepcopy(module7)
assert comparator(module7, module8)