Merge pull request #3741 from kalekundert/approx_misc_tweaks
Miscellaneous improvements to approx()
This commit is contained in:
commit
804fc4063a
|
@ -0,0 +1 @@
|
||||||
|
Raise immediately if ``approx()`` is given an expected value of a type it doesn't understand (e.g. strings, nested dicts, etc.).
|
|
@ -0,0 +1 @@
|
||||||
|
Correctly represent the dimensions of an numpy array when calling ``repr()`` on ``approx()``.
|
|
@ -1,5 +1,8 @@
|
||||||
import math
|
import math
|
||||||
|
import pprint
|
||||||
import sys
|
import sys
|
||||||
|
from numbers import Number
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
import py
|
import py
|
||||||
from six.moves import zip, filterfalse
|
from six.moves import zip, filterfalse
|
||||||
|
@ -30,6 +33,15 @@ def _cmp_raises_type_error(self, other):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _non_numeric_type_error(value, at):
|
||||||
|
at_str = " at {}".format(at) if at else ""
|
||||||
|
return TypeError(
|
||||||
|
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
|
||||||
|
value, at_str
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# builtin pytest.approx helper
|
# builtin pytest.approx helper
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,15 +51,17 @@ class ApproxBase(object):
|
||||||
or sequences of numbers.
|
or sequences of numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Tell numpy to use our `__eq__` operator instead of its
|
# Tell numpy to use our `__eq__` operator instead of its.
|
||||||
__array_ufunc__ = None
|
__array_ufunc__ = None
|
||||||
__array_priority__ = 100
|
__array_priority__ = 100
|
||||||
|
|
||||||
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
def __init__(self, expected, rel=None, abs=None, nan_ok=False):
|
||||||
|
__tracebackhide__ = True
|
||||||
self.expected = expected
|
self.expected = expected
|
||||||
self.abs = abs
|
self.abs = abs
|
||||||
self.rel = rel
|
self.rel = rel
|
||||||
self.nan_ok = nan_ok
|
self.nan_ok = nan_ok
|
||||||
|
self._check_type()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -75,21 +89,32 @@ class ApproxBase(object):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _check_type(self):
|
||||||
|
"""
|
||||||
|
Raise a TypeError if the expected value is not a valid type.
|
||||||
|
"""
|
||||||
|
# This is only a concern if the expected value is a sequence. In every
|
||||||
|
# other case, the approx() function ensures that the expected value has
|
||||||
|
# a numeric type. For this reason, the default is to do nothing. The
|
||||||
|
# classes that deal with sequences should reimplement this method to
|
||||||
|
# raise if there are any non-numeric elements in the sequence.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _recursive_list_map(f, x):
|
||||||
|
if isinstance(x, list):
|
||||||
|
return list(_recursive_list_map(f, xi) for xi in x)
|
||||||
|
else:
|
||||||
|
return f(x)
|
||||||
|
|
||||||
|
|
||||||
class ApproxNumpy(ApproxBase):
|
class ApproxNumpy(ApproxBase):
|
||||||
"""
|
"""
|
||||||
Perform approximate comparisons for numpy arrays.
|
Perform approximate comparisons where the expected value is numpy array.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
# It might be nice to rewrite this function to account for the
|
list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
|
||||||
# shape of the array...
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
list_scalars = []
|
|
||||||
for x in np.ndindex(self.expected.shape):
|
|
||||||
list_scalars.append(self._approx_scalar(np.asscalar(self.expected[x])))
|
|
||||||
|
|
||||||
return "approx({!r})".format(list_scalars)
|
return "approx({!r})".format(list_scalars)
|
||||||
|
|
||||||
if sys.version_info[0] == 2:
|
if sys.version_info[0] == 2:
|
||||||
|
@ -128,8 +153,8 @@ class ApproxNumpy(ApproxBase):
|
||||||
|
|
||||||
class ApproxMapping(ApproxBase):
|
class ApproxMapping(ApproxBase):
|
||||||
"""
|
"""
|
||||||
Perform approximate comparisons for mappings where the values are numbers
|
Perform approximate comparisons where the expected value is a mapping with
|
||||||
(the keys can be anything).
|
numeric values (the keys can be anything).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -147,10 +172,20 @@ class ApproxMapping(ApproxBase):
|
||||||
for k in self.expected.keys():
|
for k in self.expected.keys():
|
||||||
yield actual[k], self.expected[k]
|
yield actual[k], self.expected[k]
|
||||||
|
|
||||||
|
def _check_type(self):
|
||||||
|
__tracebackhide__ = True
|
||||||
|
for key, value in self.expected.items():
|
||||||
|
if isinstance(value, type(self.expected)):
|
||||||
|
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
|
||||||
|
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
|
||||||
|
elif not isinstance(value, Number):
|
||||||
|
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
|
||||||
|
|
||||||
|
|
||||||
class ApproxSequence(ApproxBase):
|
class ApproxSequence(ApproxBase):
|
||||||
"""
|
"""
|
||||||
Perform approximate comparisons for sequences of numbers.
|
Perform approximate comparisons where the expected value is a sequence of
|
||||||
|
numbers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -169,10 +204,21 @@ class ApproxSequence(ApproxBase):
|
||||||
def _yield_comparisons(self, actual):
|
def _yield_comparisons(self, actual):
|
||||||
return zip(actual, self.expected)
|
return zip(actual, self.expected)
|
||||||
|
|
||||||
|
def _check_type(self):
|
||||||
|
__tracebackhide__ = True
|
||||||
|
for index, x in enumerate(self.expected):
|
||||||
|
if isinstance(x, type(self.expected)):
|
||||||
|
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
|
||||||
|
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
|
||||||
|
elif not isinstance(x, Number):
|
||||||
|
raise _non_numeric_type_error(
|
||||||
|
self.expected, at="index {}".format(index)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ApproxScalar(ApproxBase):
|
class ApproxScalar(ApproxBase):
|
||||||
"""
|
"""
|
||||||
Perform approximate comparisons for single numbers only.
|
Perform approximate comparisons where the expected value is a single number.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12
|
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12
|
||||||
|
@ -286,7 +332,9 @@ class ApproxScalar(ApproxBase):
|
||||||
|
|
||||||
|
|
||||||
class ApproxDecimal(ApproxScalar):
|
class ApproxDecimal(ApproxScalar):
|
||||||
from decimal import Decimal
|
"""
|
||||||
|
Perform approximate comparisons where the expected value is a decimal.
|
||||||
|
"""
|
||||||
|
|
||||||
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
|
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
|
||||||
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
|
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
|
||||||
|
@ -445,32 +493,35 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
|
||||||
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
|
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from decimal import Decimal
|
|
||||||
|
|
||||||
# Delegate the comparison to a class that knows how to deal with the type
|
# Delegate the comparison to a class that knows how to deal with the type
|
||||||
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
|
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
|
||||||
#
|
#
|
||||||
# This architecture is really driven by the need to support numpy arrays.
|
# The primary responsibility of these classes is to implement ``__eq__()``
|
||||||
# The only way to override `==` for arrays without requiring that approx be
|
# and ``__repr__()``. The former is used to actually check if some
|
||||||
# the left operand is to inherit the approx object from `numpy.ndarray`.
|
# "actual" value is equivalent to the given expected value within the
|
||||||
# But that can't be a general solution, because it requires (1) numpy to be
|
# allowed tolerance. The latter is used to show the user the expected
|
||||||
# installed and (2) the expected value to be a numpy array. So the general
|
# value and tolerance, in the case that a test failed.
|
||||||
# solution is to delegate each type of expected value to a different class.
|
|
||||||
#
|
#
|
||||||
# This has the advantage that it made it easy to support mapping types
|
# The actual logic for making approximate comparisons can be found in
|
||||||
# (i.e. dict). The old code accepted mapping types, but would only compare
|
# ApproxScalar, which is used to compare individual numbers. All of the
|
||||||
# their keys, which is probably not what most people would expect.
|
# other Approx classes eventually delegate to this class. The ApproxBase
|
||||||
|
# class provides some convenient methods and overloads, but isn't really
|
||||||
|
# essential.
|
||||||
|
|
||||||
if _is_numpy_array(expected):
|
__tracebackhide__ = True
|
||||||
cls = ApproxNumpy
|
|
||||||
|
if isinstance(expected, Decimal):
|
||||||
|
cls = ApproxDecimal
|
||||||
|
elif isinstance(expected, Number):
|
||||||
|
cls = ApproxScalar
|
||||||
elif isinstance(expected, Mapping):
|
elif isinstance(expected, Mapping):
|
||||||
cls = ApproxMapping
|
cls = ApproxMapping
|
||||||
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES):
|
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES):
|
||||||
cls = ApproxSequence
|
cls = ApproxSequence
|
||||||
elif isinstance(expected, Decimal):
|
elif _is_numpy_array(expected):
|
||||||
cls = ApproxDecimal
|
cls = ApproxNumpy
|
||||||
else:
|
else:
|
||||||
cls = ApproxScalar
|
raise _non_numeric_type_error(expected, at=None)
|
||||||
|
|
||||||
return cls(expected, rel, abs, nan_ok)
|
return cls(expected, rel, abs, nan_ok)
|
||||||
|
|
||||||
|
@ -480,17 +531,11 @@ def _is_numpy_array(obj):
|
||||||
Return true if the given object is a numpy array. Make a special effort to
|
Return true if the given object is a numpy array. Make a special effort to
|
||||||
avoid importing numpy unless it's really necessary.
|
avoid importing numpy unless it's really necessary.
|
||||||
"""
|
"""
|
||||||
import inspect
|
import sys
|
||||||
|
|
||||||
for cls in inspect.getmro(type(obj)):
|
|
||||||
if cls.__module__ == "numpy":
|
|
||||||
try:
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
return isinstance(obj, np.ndarray)
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
np = sys.modules.get("numpy")
|
||||||
|
if np is not None:
|
||||||
|
return isinstance(obj, np.ndarray)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -59,17 +59,21 @@ class TestApprox(object):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_repr_0d_array(self, plus_minus):
|
@pytest.mark.parametrize(
|
||||||
|
"value, repr_string",
|
||||||
|
[
|
||||||
|
(5., "approx(5.0 {pm} 5.0e-06)"),
|
||||||
|
([5.], "approx([5.0 {pm} 5.0e-06])"),
|
||||||
|
([[5.]], "approx([[5.0 {pm} 5.0e-06]])"),
|
||||||
|
([[5., 6.]], "approx([[5.0 {pm} 5.0e-06, 6.0 {pm} 6.0e-06]])"),
|
||||||
|
([[5.], [6.]], "approx([[5.0 {pm} 5.0e-06], [6.0 {pm} 6.0e-06]])"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_repr_nd_array(self, plus_minus, value, repr_string):
|
||||||
|
"""Make sure that arrays of all different dimensions are repr'd correctly."""
|
||||||
np = pytest.importorskip("numpy")
|
np = pytest.importorskip("numpy")
|
||||||
np_array = np.array(5.)
|
np_array = np.array(value)
|
||||||
assert approx(np_array) == 5.0
|
assert repr(approx(np_array)) == repr_string.format(pm=plus_minus)
|
||||||
string_expected = "approx([5.0 {} 5.0e-06])".format(plus_minus)
|
|
||||||
|
|
||||||
assert repr(approx(np_array)) == string_expected
|
|
||||||
|
|
||||||
np_array = np.array([5.])
|
|
||||||
assert approx(np_array) == 5.0
|
|
||||||
assert repr(approx(np_array)) == string_expected
|
|
||||||
|
|
||||||
def test_operator_overloading(self):
|
def test_operator_overloading(self):
|
||||||
assert 1 == approx(1, rel=1e-6, abs=1e-12)
|
assert 1 == approx(1, rel=1e-6, abs=1e-12)
|
||||||
|
@ -439,6 +443,21 @@ class TestApprox(object):
|
||||||
["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="]
|
["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"x",
|
||||||
|
[
|
||||||
|
pytest.param(None),
|
||||||
|
pytest.param("string"),
|
||||||
|
pytest.param(["string"], id="nested-str"),
|
||||||
|
pytest.param([[1]], id="nested-list"),
|
||||||
|
pytest.param({"key": "string"}, id="dict-with-string"),
|
||||||
|
pytest.param({"key": {"key": 1}}, id="nested-dict"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_expected_value_type_error(self, x):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
approx(x)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"op",
|
"op",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue