diff --git a/_pytest/python_api.py b/_pytest/python_api.py index 3dce7f6b4..9de4dd2a8 100644 --- a/_pytest/python_api.py +++ b/_pytest/python_api.py @@ -31,6 +31,10 @@ class ApproxBase(object): or sequences of numbers. """ + # Tell numpy to use our `__eq__` operator instead of its + __array_ufunc__ = None + __array_priority__ = 100 + def __init__(self, expected, rel=None, abs=None, nan_ok=False): self.expected = expected self.abs = abs @@ -69,14 +73,13 @@ class ApproxNumpy(ApproxBase): Perform approximate comparisons for numpy arrays. """ - # Tell numpy to use our `__eq__` operator instead of its. - __array_priority__ = 100 - def __repr__(self): # It might be nice to rewrite this function to account for the # shape of the array... + import numpy as np + return "approx({0!r})".format(list( - self._approx_scalar(x) for x in self.expected)) + self._approx_scalar(x) for x in np.asarray(self.expected))) if sys.version_info[0] == 2: __cmp__ = _cmp_raises_type_error @@ -84,12 +87,15 @@ class ApproxNumpy(ApproxBase): def __eq__(self, actual): import numpy as np - try: - actual = np.asarray(actual) - except: # noqa - raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) + # self.expected is supposed to always be an array here - if actual.shape != self.expected.shape: + if not np.isscalar(actual): + try: + actual = np.asarray(actual) + except: # noqa + raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual)) + + if not np.isscalar(actual) and actual.shape != self.expected.shape: return False return ApproxBase.__eq__(self, actual) @@ -97,11 +103,16 @@ class ApproxNumpy(ApproxBase): def _yield_comparisons(self, actual): import numpy as np - # We can be sure that `actual` is a numpy array, because it's - # casted in `__eq__` before being passed to `ApproxBase.__eq__`, - # which is the only method that calls this one. - for i in np.ndindex(self.expected.shape): - yield actual[i], self.expected[i] + # `actual` can either be a numpy array or a scalar, it is treated in + # `__eq__` before being passed to `ApproxBase.__eq__`, which is the + # only method that calls this one. + + if np.isscalar(actual): + for i in np.ndindex(self.expected.shape): + yield actual, np.asscalar(self.expected[i]) + else: + for i in np.ndindex(self.expected.shape): + yield np.asscalar(actual[i]), np.asscalar(self.expected[i]) class ApproxMapping(ApproxBase): @@ -131,9 +142,6 @@ class ApproxSequence(ApproxBase): Perform approximate comparisons for sequences of numbers. """ - # Tell numpy to use our `__eq__` operator instead of its. - __array_priority__ = 100 - def __repr__(self): seq_type = type(self.expected) if seq_type not in (tuple, list, set): @@ -189,6 +197,8 @@ class ApproxScalar(ApproxBase): Return true if the given value is equal to the expected value within the pre-specified tolerance. """ + if _is_numpy_array(actual): + return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected # Short-circuit exact equality. if actual == self.expected: @@ -308,12 +318,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False): >>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6}) True - And ``numpy`` arrays:: + ``numpy`` arrays:: >>> import numpy as np # doctest: +SKIP >>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP True + And for a ``numpy`` array against a scalar:: + + >>> import numpy as np # doctest: +SKIP + >>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP + True + By default, ``approx`` considers numbers within a relative tolerance of ``1e-6`` (i.e. one part in a million) of its expected value to be equal. This treatment would lead to surprising results if the expected value was diff --git a/changelog/3312.feature b/changelog/3312.feature new file mode 100644 index 000000000..ffb4df8e9 --- /dev/null +++ b/changelog/3312.feature @@ -0,0 +1 @@ +``pytest.approx`` now accepts comparing a numpy array with a scalar. diff --git a/testing/python/approx.py b/testing/python/approx.py index 341e5fcff..9ca21bdf8 100644 --- a/testing/python/approx.py +++ b/testing/python/approx.py @@ -391,3 +391,25 @@ class TestApprox(object): """ with pytest.raises(TypeError): op(1, approx(1, rel=1e-6, abs=1e-12)) + + def test_numpy_array_with_scalar(self): + np = pytest.importorskip('numpy') + + actual = np.array([1 + 1e-7, 1 - 1e-8]) + expected = 1.0 + + assert actual == approx(expected, rel=5e-7, abs=0) + assert actual != approx(expected, rel=5e-8, abs=0) + assert approx(expected, rel=5e-7, abs=0) == actual + assert approx(expected, rel=5e-8, abs=0) != actual + + def test_numpy_scalar_with_array(self): + np = pytest.importorskip('numpy') + + actual = 1.0 + expected = np.array([1 + 1e-7, 1 - 1e-8]) + + assert actual == approx(expected, rel=5e-7, abs=0) + assert actual != approx(expected, rel=5e-8, abs=0) + assert approx(expected, rel=5e-7, abs=0) == actual + assert approx(expected, rel=5e-8, abs=0) != actual