Improve `numpy.approx` array-scalar comparisons

So that `self.expected` in ApproxNumpy is always a numpy array.
This commit is contained in:
Tadeu Manoel 2018-03-16 09:01:18 -03:00
parent 42c84f4f30
commit a754f00ae7
1 changed files with 8 additions and 12 deletions

View File

@ -87,14 +87,15 @@ class ApproxNumpy(ApproxBase):
def __eq__(self, actual): def __eq__(self, actual):
import numpy as np import numpy as np
# self.expected is supposed to always be an array here
if not np.isscalar(actual): if not np.isscalar(actual):
try: try:
actual = np.asarray(actual) actual = np.asarray(actual)
except: # noqa except: # noqa
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
if (not np.isscalar(self.expected) and not np.isscalar(actual) if not np.isscalar(actual) and actual.shape != self.expected.shape:
and actual.shape != self.expected.shape):
return False return False
return ApproxBase.__eq__(self, actual) return ApproxBase.__eq__(self, actual)
@ -102,16 +103,11 @@ class ApproxNumpy(ApproxBase):
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
import numpy as np import numpy as np
# For both `actual` and `self.expected`, they can independently be # `actual` can either be a numpy array or a scalar, it is treated in
# either a `numpy.array` or a scalar (but both can't be scalar, # `__eq__` before being passed to `ApproxBase.__eq__`, which is the
# in this case an `ApproxScalar` is used). # only method that calls this one.
# They are treated in `__eq__` before being passed to
# `ApproxBase.__eq__`, which is the only method that calls this one.
if np.isscalar(self.expected): if np.isscalar(actual):
for i in np.ndindex(actual.shape):
yield np.asscalar(actual[i]), self.expected
elif np.isscalar(actual):
for i in np.ndindex(self.expected.shape): for i in np.ndindex(self.expected.shape):
yield actual, np.asscalar(self.expected[i]) yield actual, np.asscalar(self.expected[i])
else: else:
@ -202,7 +198,7 @@ class ApproxScalar(ApproxBase):
the pre-specified tolerance. the pre-specified tolerance.
""" """
if _is_numpy_array(actual): if _is_numpy_array(actual):
return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok) return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected
# Short-circuit exact equality. # Short-circuit exact equality.
if actual == self.expected: if actual == self.expected: