Use exact comparison for bool in approx()

Fixes #9353.
This commit is contained in:
Jakob van Santen 2021-11-30 11:52:28 +01:00
parent fa240b0bb4
commit f11965783c
3 changed files with 46 additions and 11 deletions

View File

@ -0,0 +1 @@
`approx()` now uses strict equality when `type(expected) == bool`.

View File

@ -265,10 +265,16 @@ class ApproxMapping(ApproxBase):
max_abs_diff = max( max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value) max_abs_diff, abs(approx_value.expected - other_value)
) )
try:
max_rel_diff = max( max_rel_diff = max(
max_rel_diff, max_rel_diff,
abs((approx_value.expected - other_value) / approx_value.expected), abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
) )
except ZeroDivisionError:
pass
different_ids.append(approx_key) different_ids.append(approx_key)
message_data = [ message_data = [
@ -395,8 +401,12 @@ class ApproxScalar(ApproxBase):
# Don't show a tolerance for values that aren't compared using # Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to # tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j). # handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf( if (
abs(self.expected) # type: ignore[arg-type] isinstance(self.expected, bool)
or (not isinstance(self.expected, (Complex, Decimal)))
or math.isinf(
abs(self.expected) or isinstance(self.expected, bool) # type: ignore[arg-type]
)
): ):
return str(self.expected) return str(self.expected)
@ -424,17 +434,20 @@ class ApproxScalar(ApproxBase):
# numpy<1.13. See #3748. # numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat) return all(self.__eq__(a) for a in asarray.flat)
# Short-circuit exact equality. # Short-circuit exact equality, except for bool
if actual == self.expected: if isinstance(self.expected, bool) and not isinstance(actual, bool):
return False
elif actual == self.expected:
return True return True
# If either type is non-numeric, fall back to strict equality. # If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__, # NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined. # __sub__, and __float__ are defined. Also, consider bool to be
# nonnumeric, even though it has the required arithmetic.
if not ( if not (
isinstance(self.expected, (Complex, Decimal)) isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal)) and isinstance(actual, (Complex, Decimal))
): ) or isinstance(self.expected, bool):
return False return False
# Allow the user to control whether NaNs are considered equal to each # Allow the user to control whether NaNs are considered equal to each

View File

@ -88,7 +88,7 @@ def assert_approx_raises_regex(pytestconfig):
return do_assert return do_assert
SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*" SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
SOME_INT = r"[0-9]+\s*" SOME_INT = r"[0-9]+\s*"
@ -96,6 +96,19 @@ class TestApprox:
def test_error_messages(self, assert_approx_raises_regex): def test_error_messages(self, assert_approx_raises_regex):
np = pytest.importorskip("numpy") np = pytest.importorskip("numpy")
# treat bool exactly
assert_approx_raises_regex(
{"a": 1.0, "b": True},
{"a": 1.0, "b": False},
[
" comparison failed. Mismatched elements: 1 / 2:",
f" Max absolute difference: {SOME_FLOAT}",
f" Max relative difference: {SOME_FLOAT}",
r" Index\s+\| Obtained\s+\| Expected",
r".*(True|False)\s+",
],
)
assert_approx_raises_regex( assert_approx_raises_regex(
2.0, 2.0,
1.0, 1.0,
@ -546,6 +559,13 @@ class TestApprox:
assert approx(x, rel=5e-6, abs=0) == a assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a assert approx(x, rel=5e-7, abs=0) != a
def test_bool(self):
assert True == approx(True)
assert False == approx(False)
assert True != approx(False)
assert True != approx(False, abs=2)
assert 1 != approx(True)
def test_list(self): def test_list(self):
actual = [1 + 1e-7, 2 + 1e-8] actual = [1 + 1e-7, 2 + 1e-8]
expected = [1, 2] expected = [1, 2]
@ -611,6 +631,7 @@ class TestApprox:
def test_dict_nonnumeric(self): def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None}) assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None}) assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)
def test_dict_vs_other(self): def test_dict_vs_other(self):
assert 1 != approx({"a": 0}) assert 1 != approx({"a": 0})