Use exact comparison for bool in approx()

Fixes #9353.
This commit is contained in:
Jakob van Santen 2021-11-30 11:52:28 +01:00
parent fa240b0bb4
commit f11965783c
3 changed files with 46 additions and 11 deletions

View File

@ -0,0 +1 @@
`approx()` now uses strict equality when `type(expected) == bool`.

View File

@ -265,10 +265,16 @@ class ApproxMapping(ApproxBase):
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
max_rel_diff = max(
max_rel_diff,
abs((approx_value.expected - other_value) / approx_value.expected),
)
try:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
)
except ZeroDivisionError:
pass
different_ids.append(approx_key)
message_data = [
@ -395,8 +401,12 @@ class ApproxScalar(ApproxBase):
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected) # type: ignore[arg-type]
if (
isinstance(self.expected, bool)
or (not isinstance(self.expected, (Complex, Decimal)))
or math.isinf(
abs(self.expected) or isinstance(self.expected, bool) # type: ignore[arg-type]
)
):
return str(self.expected)
@ -424,17 +434,20 @@ class ApproxScalar(ApproxBase):
# numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality.
if actual == self.expected:
# Short-circuit exact equality, except for bool
if isinstance(self.expected, bool) and not isinstance(actual, bool):
return False
elif actual == self.expected:
return True
# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
# __sub__, and __float__ are defined. Also, consider bool to be
# nonnumeric, even though it has the required arithmetic.
if not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
) or isinstance(self.expected, bool):
return False
# Allow the user to control whether NaNs are considered equal to each

View File

@ -88,7 +88,7 @@ def assert_approx_raises_regex(pytestconfig):
return do_assert
SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*"
SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
SOME_INT = r"[0-9]+\s*"
@ -96,6 +96,19 @@ class TestApprox:
def test_error_messages(self, assert_approx_raises_regex):
np = pytest.importorskip("numpy")
# treat bool exactly
assert_approx_raises_regex(
{"a": 1.0, "b": True},
{"a": 1.0, "b": False},
[
" comparison failed. Mismatched elements: 1 / 2:",
f" Max absolute difference: {SOME_FLOAT}",
f" Max relative difference: {SOME_FLOAT}",
r" Index\s+\| Obtained\s+\| Expected",
r".*(True|False)\s+",
],
)
assert_approx_raises_regex(
2.0,
1.0,
@ -546,6 +559,13 @@ class TestApprox:
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a
def test_bool(self):
assert True == approx(True)
assert False == approx(False)
assert True != approx(False)
assert True != approx(False, abs=2)
assert 1 != approx(True)
def test_list(self):
actual = [1 + 1e-7, 2 + 1e-8]
expected = [1, 2]
@ -611,6 +631,7 @@ class TestApprox:
def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)
def test_dict_vs_other(self):
assert 1 != approx({"a": 0})