Update to fix issue #6

This commit adds a fixed enhancement for issue #6 to do with pytest not working for comparison of Decimal type and float (or other scalar) type
This commit is contained in:
samuel-barrett 2022-04-13 21:49:11 -07:00
parent 00ad12b9db
commit 99e1fe5f7b
3 changed files with 54 additions and 6 deletions

View File

@ -450,9 +450,26 @@ class ApproxScalar(ApproxBase):
# for compatibility with complex numbers. # for compatibility with complex numbers.
if math.isinf(abs(self.expected)): # type: ignore[arg-type] if math.isinf(abs(self.expected)): # type: ignore[arg-type]
return False return False
tolerance = self.tolerance
expected = self.expected
if (isinstance(self.expected, Decimal) and not isinstance(actual, Decimal)):
try:
actual = Decimal(str(actual))
tolerance = Decimal(str(tolerance))
except TypeError: #
return False
elif (isinstance(actual, Decimal) and not isinstance(self.expected, Decimal)):
try :
expected = Decimal(str(expected))
tolerance = Decimal(str(tolerance))
except TypeError:
return False
# Return true if the two numbers are within the tolerance. # Return true if the two numbers are within the tolerance.
result: bool = abs(self.expected - actual) <= self.tolerance result: bool = (abs(actual - expected) <= tolerance ) # type: ignore[arg-type]
return result return result
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
@ -491,9 +508,15 @@ class ApproxScalar(ApproxBase):
# we've made sure the user didn't ask for an absolute tolerance only, # we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if # because we don't want to raise errors about the relative tolerance if
# we aren't even going to use it. # we aren't even going to use it.
relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE if isinstance(self.rel, Decimal):
) * abs(self.expected) relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE
)*Decimal(str(abs(self.expected)))
else:
relative_tolerance = set_default(
self.rel, self.DEFAULT_RELATIVE_TOLERANCE
) * abs(self.expected)
if relative_tolerance < 0: if relative_tolerance < 0:
raise ValueError( raise ValueError(
@ -572,7 +595,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6}) >>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
True True
The comparison will be true if both mappings have the same keys and their The comparision will be true if both mappings have the same keys and their
respective values match the expected tolerances. respective values match the expected tolerances.
**Tolerances** **Tolerances**

View File

@ -341,6 +341,8 @@ class TestApprox:
(0.0, -0.0), (0.0, -0.0),
(345678, 345678), (345678, 345678),
(Decimal("1.0001"), Decimal("1.0001")), (Decimal("1.0001"), Decimal("1.0001")),
(1.0001, Decimal("1.0001")),
(Decimal("1.0001"), 1.0001),
(Fraction(1, 3), Fraction(-1, -3)), (Fraction(1, 3), Fraction(-1, -3)),
] ]
for a, x in examples: for a, x in examples:
@ -515,6 +517,10 @@ class TestApprox:
within_1e6 = [ within_1e6 = [
(Decimal("1.000001"), Decimal("1.0")), (Decimal("1.000001"), Decimal("1.0")),
(Decimal("-1.000001"), Decimal("-1.0")), (Decimal("-1.000001"), Decimal("-1.0")),
(-1.000001, Decimal("-1.0")),
(1.000001, Decimal("1.0")),
(Decimal("1.000001"), 1.0),
(Decimal("-1.000001"), -1.0),
] ]
for a, x in within_1e6: for a, x in within_1e6:
assert a == approx(x) assert a == approx(x)
@ -562,6 +568,15 @@ class TestApprox:
expected = [Decimal("1"), Decimal("2")] expected = [Decimal("1"), Decimal("2")]
assert actual == approx(expected) assert actual == approx(expected)
expected = [1, 2]
assert actual == approx(expected)
expected = [Decimal("1"), Decimal("2")]
actual = [1.000001, 2.000001]
assert actual == approx(expected)
def test_list_wrong_len(self): def test_list_wrong_len(self):
assert [1, 2] != approx([1]) assert [1, 2] != approx([1])
@ -603,6 +618,15 @@ class TestApprox:
expected = {"b": Decimal("2"), "a": Decimal("1")} expected = {"b": Decimal("2"), "a": Decimal("1")}
assert actual == approx(expected) assert actual == approx(expected)
actual = {"a": 1.000001, "b": 2.000001}
assert actual == approx(expected)
actual = {"a": Decimal("1.000001"), "b": Decimal("2.000001")}
expected = {"b": 2, "a": 1}
assert actual == approx(expected)
def test_dict_wrong_len(self): def test_dict_wrong_len(self):
assert {"a": 1, "b": 2} != approx({"a": 1}) assert {"a": 1, "b": 2} != approx({"a": 1})
@ -878,3 +902,4 @@ class TestApprox:
"""pytest.approx() should raise an error on unordered sequences (#9692).""" """pytest.approx() should raise an error on unordered sequences (#9692)."""
with pytest.raises(TypeError, match="only supports ordered sequences"): with pytest.raises(TypeError, match="only supports ordered sequences"):
assert {1, 2, 3} == approx({1, 2, 3}) assert {1, 2, 3} == approx({1, 2, 3})

View File

@ -51,7 +51,7 @@ class TestCollector:
fn3 = pytester.collect_by_name(modcol, "test_fail") fn3 = pytester.collect_by_name(modcol, "test_fail")
assert isinstance(fn3, pytest.Function) assert isinstance(fn3, pytest.Function)
assert not (fn1 == fn3) assert fn1 != fn3
assert fn1 != fn3 assert fn1 != fn3
for fn in fn1, fn2, fn3: for fn in fn1, fn2, fn3: