Add support for pytest.approx comparisons between array and scalar
This commit is contained in:
@@ -31,6 +31,10 @@ class ApproxBase(object):
|
||||
or sequences of numbers.
|
||||
"""
|
||||
|
||||
# Tell numpy to use our `__eq__` operator instead of its when left side in a numpy array but right side is
|
||||
# an instance of ApproxBase
|
||||
__array_ufunc__ = None
|
||||
|
||||
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
||||
self.expected = expected
|
||||
self.abs = abs
|
||||
@@ -89,7 +93,7 @@ class ApproxNumpy(ApproxBase):
|
||||
except: # noqa
|
||||
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
||||
|
||||
if actual.shape != self.expected.shape:
|
||||
if not np.isscalar(self.expected) and actual.shape != self.expected.shape:
|
||||
return False
|
||||
|
||||
return ApproxBase.__eq__(self, actual)
|
||||
@@ -100,8 +104,13 @@ class ApproxNumpy(ApproxBase):
|
||||
# We can be sure that `actual` is a numpy array, because it's
|
||||
# casted in `__eq__` before being passed to `ApproxBase.__eq__`,
|
||||
# which is the only method that calls this one.
|
||||
for i in np.ndindex(self.expected.shape):
|
||||
yield actual[i], self.expected[i]
|
||||
|
||||
if np.isscalar(self.expected):
|
||||
for i in np.ndindex(actual.shape):
|
||||
yield actual[i], self.expected
|
||||
else:
|
||||
for i in np.ndindex(self.expected.shape):
|
||||
yield actual[i], self.expected[i]
|
||||
|
||||
|
||||
class ApproxMapping(ApproxBase):
|
||||
@@ -189,6 +198,8 @@ class ApproxScalar(ApproxBase):
|
||||
Return true if the given value is equal to the expected value within
|
||||
the pre-specified tolerance.
|
||||
"""
|
||||
if _is_numpy_array(actual):
|
||||
return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok)
|
||||
|
||||
# Short-circuit exact equality.
|
||||
if actual == self.expected:
|
||||
@@ -308,12 +319,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
||||
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
|
||||
True
|
||||
|
||||
And ``numpy`` arrays::
|
||||
``numpy`` arrays::
|
||||
|
||||
>>> import numpy as np # doctest: +SKIP
|
||||
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
|
||||
True
|
||||
|
||||
And for a ``numpy`` array against a scalar::
|
||||
|
||||
>>> import numpy as np # doctest: +SKIP
|
||||
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
|
||||
True
|
||||
|
||||
By default, ``approx`` considers numbers within a relative tolerance of
|
||||
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
|
||||
This treatment would lead to surprising results if the expected value was
|
||||
|
||||
Reference in New Issue
Block a user