Merge pull request #3313 from tadeu/approx-array-scalar
Add support for `pytest.approx` comparisons between array and scalar
This commit is contained in:
commit
86d6804e60
|
@ -31,6 +31,10 @@ class ApproxBase(object):
|
||||||
or sequences of numbers.
|
or sequences of numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Tell numpy to use our `__eq__` operator instead of its
|
||||||
|
__array_ufunc__ = None
|
||||||
|
__array_priority__ = 100
|
||||||
|
|
||||||
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
||||||
self.expected = expected
|
self.expected = expected
|
||||||
self.abs = abs
|
self.abs = abs
|
||||||
|
@ -69,14 +73,13 @@ class ApproxNumpy(ApproxBase):
|
||||||
Perform approximate comparisons for numpy arrays.
|
Perform approximate comparisons for numpy arrays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Tell numpy to use our `__eq__` operator instead of its.
|
|
||||||
__array_priority__ = 100
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# It might be nice to rewrite this function to account for the
|
# It might be nice to rewrite this function to account for the
|
||||||
# shape of the array...
|
# shape of the array...
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
return "approx({0!r})".format(list(
|
return "approx({0!r})".format(list(
|
||||||
self._approx_scalar(x) for x in self.expected))
|
self._approx_scalar(x) for x in np.asarray(self.expected)))
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
__cmp__ = _cmp_raises_type_error
|
__cmp__ = _cmp_raises_type_error
|
||||||
|
@ -84,12 +87,15 @@ class ApproxNumpy(ApproxBase):
|
||||||
def __eq__(self, actual):
|
def __eq__(self, actual):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
try:
|
# self.expected is supposed to always be an array here
|
||||||
actual = np.asarray(actual)
|
|
||||||
except: # noqa
|
|
||||||
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
|
||||||
|
|
||||||
if actual.shape != self.expected.shape:
|
if not np.isscalar(actual):
|
||||||
|
try:
|
||||||
|
actual = np.asarray(actual)
|
||||||
|
except: # noqa
|
||||||
|
raise TypeError("cannot compare '{0}' to numpy.ndarray".format(actual))
|
||||||
|
|
||||||
|
if not np.isscalar(actual) and actual.shape != self.expected.shape:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return ApproxBase.__eq__(self, actual)
|
return ApproxBase.__eq__(self, actual)
|
||||||
|
@ -97,11 +103,16 @@ class ApproxNumpy(ApproxBase):
|
||||||
def _yield_comparisons(self, actual):
|
def _yield_comparisons(self, actual):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
# We can be sure that `actual` is a numpy array, because it's
|
# `actual` can either be a numpy array or a scalar, it is treated in
|
||||||
# casted in `__eq__` before being passed to `ApproxBase.__eq__`,
|
# `__eq__` before being passed to `ApproxBase.__eq__`, which is the
|
||||||
# which is the only method that calls this one.
|
# only method that calls this one.
|
||||||
for i in np.ndindex(self.expected.shape):
|
|
||||||
yield actual[i], self.expected[i]
|
if np.isscalar(actual):
|
||||||
|
for i in np.ndindex(self.expected.shape):
|
||||||
|
yield actual, np.asscalar(self.expected[i])
|
||||||
|
else:
|
||||||
|
for i in np.ndindex(self.expected.shape):
|
||||||
|
yield np.asscalar(actual[i]), np.asscalar(self.expected[i])
|
||||||
|
|
||||||
|
|
||||||
class ApproxMapping(ApproxBase):
|
class ApproxMapping(ApproxBase):
|
||||||
|
@ -131,9 +142,6 @@ class ApproxSequence(ApproxBase):
|
||||||
Perform approximate comparisons for sequences of numbers.
|
Perform approximate comparisons for sequences of numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Tell numpy to use our `__eq__` operator instead of its.
|
|
||||||
__array_priority__ = 100
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
seq_type = type(self.expected)
|
seq_type = type(self.expected)
|
||||||
if seq_type not in (tuple, list, set):
|
if seq_type not in (tuple, list, set):
|
||||||
|
@ -189,6 +197,8 @@ class ApproxScalar(ApproxBase):
|
||||||
Return true if the given value is equal to the expected value within
|
Return true if the given value is equal to the expected value within
|
||||||
the pre-specified tolerance.
|
the pre-specified tolerance.
|
||||||
"""
|
"""
|
||||||
|
if _is_numpy_array(actual):
|
||||||
|
return ApproxNumpy(actual, self.abs, self.rel, self.nan_ok) == self.expected
|
||||||
|
|
||||||
# Short-circuit exact equality.
|
# Short-circuit exact equality.
|
||||||
if actual == self.expected:
|
if actual == self.expected:
|
||||||
|
@ -308,12 +318,18 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
||||||
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
|
>>> {'a': 0.1 + 0.2, 'b': 0.2 + 0.4} == approx({'a': 0.3, 'b': 0.6})
|
||||||
True
|
True
|
||||||
|
|
||||||
And ``numpy`` arrays::
|
``numpy`` arrays::
|
||||||
|
|
||||||
>>> import numpy as np # doctest: +SKIP
|
>>> import numpy as np # doctest: +SKIP
|
||||||
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
|
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.4]) == approx(np.array([0.3, 0.6])) # doctest: +SKIP
|
||||||
True
|
True
|
||||||
|
|
||||||
|
And for a ``numpy`` array against a scalar::
|
||||||
|
|
||||||
|
>>> import numpy as np # doctest: +SKIP
|
||||||
|
>>> np.array([0.1, 0.2]) + np.array([0.2, 0.1]) == approx(0.3) # doctest: +SKIP
|
||||||
|
True
|
||||||
|
|
||||||
By default, ``approx`` considers numbers within a relative tolerance of
|
By default, ``approx`` considers numbers within a relative tolerance of
|
||||||
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
|
``1e-6`` (i.e. one part in a million) of its expected value to be equal.
|
||||||
This treatment would lead to surprising results if the expected value was
|
This treatment would lead to surprising results if the expected value was
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
``pytest.approx`` now accepts comparing a numpy array with a scalar.
|
|
@ -391,3 +391,25 @@ class TestApprox(object):
|
||||||
"""
|
"""
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
op(1, approx(1, rel=1e-6, abs=1e-12))
|
op(1, approx(1, rel=1e-6, abs=1e-12))
|
||||||
|
|
||||||
|
def test_numpy_array_with_scalar(self):
|
||||||
|
np = pytest.importorskip('numpy')
|
||||||
|
|
||||||
|
actual = np.array([1 + 1e-7, 1 - 1e-8])
|
||||||
|
expected = 1.0
|
||||||
|
|
||||||
|
assert actual == approx(expected, rel=5e-7, abs=0)
|
||||||
|
assert actual != approx(expected, rel=5e-8, abs=0)
|
||||||
|
assert approx(expected, rel=5e-7, abs=0) == actual
|
||||||
|
assert approx(expected, rel=5e-8, abs=0) != actual
|
||||||
|
|
||||||
|
def test_numpy_scalar_with_array(self):
|
||||||
|
np = pytest.importorskip('numpy')
|
||||||
|
|
||||||
|
actual = 1.0
|
||||||
|
expected = np.array([1 + 1e-7, 1 - 1e-8])
|
||||||
|
|
||||||
|
assert actual == approx(expected, rel=5e-7, abs=0)
|
||||||
|
assert actual != approx(expected, rel=5e-8, abs=0)
|
||||||
|
assert approx(expected, rel=5e-7, abs=0) == actual
|
||||||
|
assert approx(expected, rel=5e-8, abs=0) != actual
|
||||||
|
|
Loading…
Reference in New Issue