diff --git a/_pytest/python_api.py b/_pytest/python_api.py index e2c83aeab..9de4dd2a8 100644 --- a/_pytest/python_api.py +++ b/_pytest/python_api.py @@ -87,14 +87,15 @@ class ApproxNumpy(ApproxBase): def __eq__(self, actual): import numpy as np + # self.expected is supposed to always be an array here + if not np.isscalar(actual): try: actual = np.asarray(actual) except: # noqa raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) - if (not np.isscalar(self.expected) and not np.isscalar(actual) - and actual.shape != self.expected.shape): + if not np.isscalar(actual) and actual.shape != self.expected.shape: return False return ApproxBase.__eq__(self, actual) @@ -102,16 +103,11 @@ class ApproxNumpy(ApproxBase): def _yield_comparisons(self, actual): import numpy as np - # For both `actual` and `self.expected`, they can independently be - # either a `numpy.array` or a scalar (but both can't be scalar, - # in this case an `ApproxScalar` is used). - # They are treated in `__eq__` before being passed to - # `ApproxBase.__eq__`, which is the only method that calls this one. + # `actual` can either be a numpy array or a scalar, it is treated in + # `__eq__` before being passed to `ApproxBase.__eq__`, which is the + # only method that calls this one. - if np.isscalar(self.expected): - for i in np.ndindex(actual.shape): - yield np.asscalar(actual[i]), self.expected - elif np.isscalar(actual): + if np.isscalar(actual): for i in np.ndindex(self.expected.shape): yield actual, np.asscalar(self.expected[i]) else: @@ -202,7 +198,7 @@ class ApproxScalar(ApproxBase): the pre-specified tolerance. """ if _is_numpy_array(actual): - return actual == ApproxNumpy(self.expected, self.abs, self.rel, self.nan_ok) + return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected # Short-circuit exact equality. if actual == self.expected: