support approx comparisons between Decimal and float

This commit is contained in:
Kale Kundert 2021-04-02 16:38:48 -04:00
parent 9c151a65c8
commit e3794a1dc7
No known key found for this signature in database
GPG Key ID: C6238221D17CAFAE
2 changed files with 26 additions and 8 deletions

View File

@ -273,8 +273,20 @@ class ApproxScalar(ApproxBase):
if math.isinf(abs(self.expected)): # type: ignore[arg-type]
return False
# Return true if the two numbers are within the tolerance.
# Return true if the two numbers are within the tolerance. In order to
# be flexible about which types are supported, try making the
# comparison in two ways. The first requires that the actual value can
# be subtracted from the expected value. This is necessary for complex
# numbers, which don't implement comparison operators. The second
# requires that the actual and expected values can be compared. This
# is necessary for comparing Decimals with floats, see #8495.
try:
result: bool = abs(self.expected - actual) <= self.tolerance
except TypeError:
low = self.expected - self.tolerance
high = self.expected + self.tolerance
result: bool = low <= actual <= high
return result
# Ignore type because of https://github.com/python/mypy/issues/4266.

View File

@ -276,15 +276,21 @@ class TestApprox:
def test_decimal(self):
within_1e6 = [
(Decimal("1.000001"), Decimal("1.0")),
(Decimal("-1.000001"), Decimal("-1.0")),
(Decimal("1.0000005"), Decimal("1.0")),
(Decimal("-1.0000005"), Decimal("-1.0")),
(Decimal("1.0000005"), 1.0),
(Decimal("-1.0000005"), -1.0),
]
for a, x in within_1e6:
T = type(x)
# Need to test the default values here, because some code is needed
# to account for the facts that you can't add floats to decimals.
assert a == approx(x)
assert a == approx(x, rel=Decimal("5e-6"), abs=0)
assert a != approx(x, rel=Decimal("5e-7"), abs=0)
assert approx(x, rel=Decimal("5e-6"), abs=0) == a
assert approx(x, rel=Decimal("5e-7"), abs=0) != a
assert a == approx(x, rel=T("1e-6"), abs=0)
assert a != approx(x, rel=T("1e-7"), abs=0)
assert approx(x, rel=T("1e-6"), abs=0) == a
assert approx(x, rel=T("1e-7"), abs=0) != a
def test_fraction(self):
within_1e6 = [