Update test_comparator.py
check for ellipsis & ast change order add tests ruff silencio
This commit is contained in:
parent
0b2ff7cfa6
commit
a0107159b9
3 changed files with 60 additions and 41 deletions
|
|
@ -5,7 +5,7 @@ import math
|
|||
import re
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
import ast
|
||||
import sentry_sdk
|
||||
|
||||
from codeflash.cli_cmds.console import logger
|
||||
|
|
@ -182,7 +182,6 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
return orig == new
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# For class objects
|
||||
if hasattr(orig, "__dict__") and hasattr(new, "__dict__"):
|
||||
orig_keys = orig.__dict__
|
||||
|
|
@ -196,13 +195,18 @@ def comparator(orig: Any, new: Any) -> bool:
|
|||
orig_keys = {k: v for k, v in orig_keys.items() if not k.startswith("__")}
|
||||
new_keys = {k: v for k, v in new_keys.items() if not k.startswith("__")}
|
||||
|
||||
if isinstance(orig, ast.AST):
|
||||
orig_keys = {k: v for k, v in orig.__dict__.items() if k != "parent"}
|
||||
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
|
||||
|
||||
return comparator(orig_keys, new_keys)
|
||||
|
||||
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
|
||||
return new == orig
|
||||
if str(type(orig)) == "<class 'object'>":
|
||||
return True
|
||||
|
||||
if orig is Ellipsis and new is Ellipsis:
|
||||
return True
|
||||
# TODO : Add other types here
|
||||
logger.warning(f"Unknown comparator input type: {type(orig)}")
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ warn_required_dynamic_aliases = true
|
|||
line-length = 120
|
||||
fix = true
|
||||
show-fixes = true
|
||||
exclude = ["code_to_optimize/", "pie_test_set/"]
|
||||
exclude = ["code_to_optimize/", "pie_test_set/", "tests"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["ALL"]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import ast
|
||||
import copy
|
||||
import dataclasses
|
||||
import datetime
|
||||
import decimal
|
||||
|
|
@ -6,13 +8,15 @@ from enum import Enum, Flag, IntFlag, auto
|
|||
|
||||
import pydantic
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from codeflash.either import Failure, Success
|
||||
from codeflash.verification.comparator import comparator
|
||||
from codeflash.verification.equivalence import compare_test_results
|
||||
from codeflash.verification.test_results import FunctionTestInvocation, InvocationId, TestResults, TestType
|
||||
|
||||
|
||||
def test_basic_python_objects():
|
||||
def test_basic_python_objects() -> None:
|
||||
a = 5
|
||||
b = 5
|
||||
c = 6
|
||||
|
|
@ -120,40 +124,40 @@ def test_basic_python_objects():
|
|||
assert not comparator(a, c)
|
||||
|
||||
|
||||
def test_standard_python_library_objects():
|
||||
a = datetime.datetime(2020, 2, 2, 2, 2, 2)
|
||||
b = datetime.datetime(2020, 2, 2, 2, 2, 2)
|
||||
c = datetime.datetime(2020, 2, 2, 2, 2, 3)
|
||||
def test_standard_python_library_objects() -> None:
|
||||
a = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
|
||||
b = datetime.datetime(2020, 2, 2, 2, 2, 2) # type: ignore
|
||||
c = datetime.datetime(2020, 2, 2, 2, 2, 3) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = datetime.date(2020, 2, 2)
|
||||
b = datetime.date(2020, 2, 2)
|
||||
c = datetime.date(2020, 2, 3)
|
||||
a = datetime.date(2020, 2, 2) # type: ignore
|
||||
b = datetime.date(2020, 2, 2) # type: ignore
|
||||
c = datetime.date(2020, 2, 3) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = datetime.timedelta(days=1)
|
||||
b = datetime.timedelta(days=1)
|
||||
c = datetime.timedelta(days=2)
|
||||
a = datetime.timedelta(days=1) # type: ignore
|
||||
b = datetime.timedelta(days=1) # type: ignore
|
||||
c = datetime.timedelta(days=2) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = datetime.time(2, 2, 2)
|
||||
b = datetime.time(2, 2, 2)
|
||||
c = datetime.time(2, 2, 3)
|
||||
a = datetime.time(2, 2, 2) # type: ignore
|
||||
b = datetime.time(2, 2, 2) # type: ignore
|
||||
c = datetime.time(2, 2, 3) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = datetime.timezone.utc
|
||||
b = datetime.timezone.utc
|
||||
c = datetime.timezone(datetime.timedelta(hours=1))
|
||||
a = datetime.timezone.utc # type: ignore
|
||||
b = datetime.timezone.utc # type: ignore
|
||||
c = datetime.timezone(datetime.timedelta(hours=1)) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = decimal.Decimal(3.14)
|
||||
b = decimal.Decimal(3.14)
|
||||
c = decimal.Decimal(3.15)
|
||||
a = decimal.Decimal(3.14) # type: ignore
|
||||
b = decimal.Decimal(3.14) # type: ignore
|
||||
c = decimal.Decimal(3.15) # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
|
|
@ -167,15 +171,15 @@ def test_standard_python_library_objects():
|
|||
GREEN = auto()
|
||||
BLUE = auto()
|
||||
|
||||
a = Color.RED
|
||||
b = Color.RED
|
||||
c = Color.GREEN
|
||||
a = Color.RED # type: ignore
|
||||
b = Color.RED # type: ignore
|
||||
c = Color.GREEN # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
a = Color2.RED
|
||||
b = Color2.RED
|
||||
c = Color2.GREEN
|
||||
a = Color2.RED # type: ignore
|
||||
b = Color2.RED # type: ignore
|
||||
c = Color2.GREEN # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
|
|
@ -184,9 +188,9 @@ def test_standard_python_library_objects():
|
|||
GREEN = auto()
|
||||
BLUE = auto()
|
||||
|
||||
a = Color4.RED
|
||||
b = Color4.RED
|
||||
c = Color4.GREEN
|
||||
a = Color4.RED # type: ignore
|
||||
b = Color4.RED # type: ignore
|
||||
c = Color4.GREEN # type: ignore
|
||||
assert comparator(a, b)
|
||||
assert not comparator(a, c)
|
||||
|
||||
|
|
@ -296,7 +300,7 @@ def test_numpy():
|
|||
|
||||
def test_scipy():
|
||||
try:
|
||||
import scipy as sp
|
||||
import scipy as sp # type: ignore
|
||||
except ImportError:
|
||||
pytest.skip()
|
||||
a = sp.sparse.csr_matrix([[1, 0, 0], [0, 0, 3], [4, 0, 5]])
|
||||
|
|
@ -466,7 +470,7 @@ def test_pandas():
|
|||
|
||||
def test_pyrsistent():
|
||||
try:
|
||||
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector
|
||||
from pyrsistent import PBag, PClass, PRecord, field, pdeque, pmap, pset, pvector # type: ignore
|
||||
except ImportError:
|
||||
pytest.skip()
|
||||
|
||||
|
|
@ -678,7 +682,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=True,
|
||||
runtime=5,
|
||||
test_framework="unittest",
|
||||
|
|
@ -699,7 +703,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=True,
|
||||
runtime=10,
|
||||
test_framework="unittest",
|
||||
|
|
@ -722,7 +726,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=True,
|
||||
runtime=10,
|
||||
test_framework="unittest",
|
||||
|
|
@ -745,7 +749,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=True,
|
||||
runtime=10,
|
||||
test_framework="unittest",
|
||||
|
|
@ -764,7 +768,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="2",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=True,
|
||||
runtime=10,
|
||||
test_framework="unittest",
|
||||
|
|
@ -787,7 +791,7 @@ def test_compare_results_fn():
|
|||
function_getting_tested="function_getting_tested",
|
||||
iteration_id="0",
|
||||
),
|
||||
file_name="file_name",
|
||||
file_name=Path("file_name"),
|
||||
did_pass=False,
|
||||
runtime=5,
|
||||
test_framework="unittest",
|
||||
|
|
@ -941,3 +945,14 @@ def test_exceptions_comparator():
|
|||
|
||||
zero_division_exc3 = ZeroDivisionError("Different message")
|
||||
assert comparator(zero_division_exc1, zero_division_exc3)
|
||||
|
||||
assert comparator(..., ...)
|
||||
assert comparator(Ellipsis, Ellipsis)
|
||||
|
||||
code7 = "a = 1 + 2"
|
||||
module7 = ast.parse(code7)
|
||||
for node in ast.walk(module7):
|
||||
for child in ast.iter_child_nodes(node):
|
||||
child.parent = node # type: ignore
|
||||
module8 = copy.deepcopy(module7)
|
||||
assert comparator(module7, module8)
|
||||
|
|
|
|||
Loading…
Reference in a new issue