Improve `numpy.approx` array-scalar comparisons
So that `self.expected` in ApproxNumpy is always a numpy array.
This commit is contained in:
parent
42c84f4f30
commit
a754f00ae7
|
@ -87,14 +87,15 @@ class ApproxNumpy(ApproxBase):
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
# self.expected is supposed to always be an array here
|
||||||
|
|
||||||
if not np.isscalar(actual):
|
if not np.isscalar(actual):
|
||||||
try:
|
try:
|
||||||
actual = np.asarray(actual)
|
actual = np.asarray(actual)
|
||||||
except: # noqa
|
except: # noqa
|
||||||
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
||||||
|
|
||||||
if (not np.isscalar(self.expected) and not np.isscalar(actual)
|
if not np.isscalar(actual) and actual.shape != self.expected.shape:
|
||||||
and actual.shape != self.expected.shape):
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return ApproxBase.__eq__(self, actual)
|
return ApproxBase.__eq__(self, actual)
|
||||||
|
@ -102,16 +103,11 @@ class ApproxNumpy(ApproxBase):
|
||||||
def _yield_comparisons(self, actual):
|
def _yield_comparisons(self, actual):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# For both `actual` and `self.expected`, they can independently be
|
# `actual` can either be a numpy array or a scalar, it is treated in
|
||||||
# either a `numpy.array` or a scalar (but both can't be scalar,
|
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
|
||||||
# in this case an `ApproxScalar` is used).
|
# only method that calls this one.
|
||||||
# They are treated in `__eq__` before being passed to
|
|
||||||
# `ApproxBase.__eq__`, which is the only method that calls this one.
|
|
||||||
|
|
||||||
if np.isscalar(self.expected):
|
if np.isscalar(actual):
|
||||||
for i in np.ndindex(actual.shape):
|
|
||||||
yield np.asscalar(actual[i]), self.expected
|
|
||||||
elif np.isscalar(actual):
|
|
||||||
for i in np.ndindex(self.expected.shape):
|
for i in np.ndindex(self.expected.shape):
|
||||||
yield actual, np.asscalar(self.expected[i])
|
yield actual, np.asscalar(self.expected[i])
|
||||||
else:
|
else:
|
||||||
|
@ -202,7 +198,7 @@ class ApproxScalar(ApproxBase):
|
||||||
the pre-specified tolerance.
|
the pre-specified tolerance.
|
||||||
"""
|
"""
|
||||||
if _is_numpy_array(actual):
|
if _is_numpy_array(actual):
|
||||||
return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok)
|
return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected
|
||||||
|
|
||||||
# Short-circuit exact equality.
|
# Short-circuit exact equality.
|
||||||
if actual == self.expected:
|
if actual == self.expected:
|
||||||
|
|
Loading…
Reference in New Issue