Merge pull request #3741 from kalekundert/approx_misc_tweaks

Miscellaneous improvements to approx()
This commit is contained in:
Bruno Oliveira 2018-08-01 23:40:21 -03:00 committed by GitHub
commit 804fc4063a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 117 additions and 51 deletions

View File

@ -0,0 +1 @@
Raise immediately if ``approx()`` is given an expected value of a type it doesn't understand (e.g. strings, nested dicts, etc.).

View File

@ -0,0 +1 @@
Correctly represent the dimensions of an numpy array when calling ``repr()`` on ``approx()``.

View File

@ -1,5 +1,8 @@
import math import math
import pprint
import sys import sys
from numbers import Number
from decimal import Decimal
import py import py
from six.moves import zip, filterfalse from six.moves import zip, filterfalse
@ -30,6 +33,15 @@ def _cmp_raises_type_error(self, other):
) )
def _non_numeric_type_error(value, at):
at_str = " at {}".format(at) if at else ""
return TypeError(
"cannot make approximate comparisons to non-numeric values: {!r} {}".format(
value, at_str
)
)
# builtin pytest.approx helper # builtin pytest.approx helper
@ -39,15 +51,17 @@ class ApproxBase(object):
or sequences of numbers. or sequences of numbers.
""" """
# Tell numpy to use our `__eq__` operator instead of its # Tell numpy to use our `__eq__` operator instead of its.
__array_ufunc__ = None __array_ufunc__ = None
__array_priority__ = 100 __array_priority__ = 100
def __init__(self, expected, rel=None, abs=None, nan_ok=False): def __init__(self, expected, rel=None, abs=None, nan_ok=False):
__tracebackhide__ = True
self.expected = expected self.expected = expected
self.abs = abs self.abs = abs
self.rel = rel self.rel = rel
self.nan_ok = nan_ok self.nan_ok = nan_ok
self._check_type()
def __repr__(self): def __repr__(self):
raise NotImplementedError raise NotImplementedError
@ -75,21 +89,32 @@ class ApproxBase(object):
""" """
raise NotImplementedError raise NotImplementedError
def _check_type(self):
"""
Raise a TypeError if the expected value is not a valid type.
"""
# This is only a concern if the expected value is a sequence. In every
# other case, the approx() function ensures that the expected value has
# a numeric type. For this reason, the default is to do nothing. The
# classes that deal with sequences should reimplement this method to
# raise if there are any non-numeric elements in the sequence.
pass
def _recursive_list_map(f, x):
if isinstance(x, list):
return list(_recursive_list_map(f, xi) for xi in x)
else:
return f(x)
class ApproxNumpy(ApproxBase): class ApproxNumpy(ApproxBase):
""" """
Perform approximate comparisons for numpy arrays. Perform approximate comparisons where the expected value is numpy array.
""" """
def __repr__(self): def __repr__(self):
# It might be nice to rewrite this function to account for the list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
# shape of the array...
import numpy as np
list_scalars = []
for x in np.ndindex(self.expected.shape):
list_scalars.append(self._approx_scalar(np.asscalar(self.expected[x])))
return "approx({!r})".format(list_scalars) return "approx({!r})".format(list_scalars)
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
@ -128,8 +153,8 @@ class ApproxNumpy(ApproxBase):
class ApproxMapping(ApproxBase): class ApproxMapping(ApproxBase):
""" """
Perform approximate comparisons for mappings where the values are numbers Perform approximate comparisons where the expected value is a mapping with
(the keys can be anything). numeric values (the keys can be anything).
""" """
def __repr__(self): def __repr__(self):
@ -147,10 +172,20 @@ class ApproxMapping(ApproxBase):
for k in self.expected.keys(): for k in self.expected.keys():
yield actual[k], self.expected[k] yield actual[k], self.expected[k]
def _check_type(self):
__tracebackhide__ = True
for key, value in self.expected.items():
if isinstance(value, type(self.expected)):
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
elif not isinstance(value, Number):
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
class ApproxSequence(ApproxBase): class ApproxSequence(ApproxBase):
""" """
Perform approximate comparisons for sequences of numbers. Perform approximate comparisons where the expected value is a sequence of
numbers.
""" """
def __repr__(self): def __repr__(self):
@ -169,10 +204,21 @@ class ApproxSequence(ApproxBase):
def _yield_comparisons(self, actual): def _yield_comparisons(self, actual):
return zip(actual, self.expected) return zip(actual, self.expected)
def _check_type(self):
__tracebackhide__ = True
for index, x in enumerate(self.expected):
if isinstance(x, type(self.expected)):
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
elif not isinstance(x, Number):
raise _non_numeric_type_error(
self.expected, at="index {}".format(index)
)
class ApproxScalar(ApproxBase): class ApproxScalar(ApproxBase):
""" """
Perform approximate comparisons for single numbers only. Perform approximate comparisons where the expected value is a single number.
""" """
DEFAULT_ABSOLUTE_TOLERANCE = 1e-12 DEFAULT_ABSOLUTE_TOLERANCE = 1e-12
@ -286,7 +332,9 @@ class ApproxScalar(ApproxBase):
class ApproxDecimal(ApproxScalar): class ApproxDecimal(ApproxScalar):
from decimal import Decimal """
Perform approximate comparisons where the expected value is a decimal.
"""
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12") DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6") DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
@ -445,32 +493,35 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__ __ https://docs.python.org/3/reference/datamodel.html#object.__ge__
""" """
from decimal import Decimal
# Delegate the comparison to a class that knows how to deal with the type # Delegate the comparison to a class that knows how to deal with the type
# of the expected value (e.g. int, float, list, dict, numpy.array, etc). # of the expected value (e.g. int, float, list, dict, numpy.array, etc).
# #
# This architecture is really driven by the need to support numpy arrays. # The primary responsibility of these classes is to implement ``__eq__()``
# The only way to override `==` for arrays without requiring that approx be # and ``__repr__()``. The former is used to actually check if some
# the left operand is to inherit the approx object from `numpy.ndarray`. # "actual" value is equivalent to the given expected value within the
# But that can't be a general solution, because it requires (1) numpy to be # allowed tolerance. The latter is used to show the user the expected
# installed and (2) the expected value to be a numpy array. So the general # value and tolerance, in the case that a test failed.
# solution is to delegate each type of expected value to a different class.
# #
# This has the advantage that it made it easy to support mapping types # The actual logic for making approximate comparisons can be found in
# (i.e. dict). The old code accepted mapping types, but would only compare # ApproxScalar, which is used to compare individual numbers. All of the
# their keys, which is probably not what most people would expect. # other Approx classes eventually delegate to this class. The ApproxBase
# class provides some convenient methods and overloads, but isn't really
# essential.
if _is_numpy_array(expected): __tracebackhide__ = True
cls = ApproxNumpy
if isinstance(expected, Decimal):
cls = ApproxDecimal
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping): elif isinstance(expected, Mapping):
cls = ApproxMapping cls = ApproxMapping
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES): elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES):
cls = ApproxSequence cls = ApproxSequence
elif isinstance(expected, Decimal): elif _is_numpy_array(expected):
cls = ApproxDecimal cls = ApproxNumpy
else: else:
cls = ApproxScalar raise _non_numeric_type_error(expected, at=None)
return cls(expected, rel, abs, nan_ok) return cls(expected, rel, abs, nan_ok)
@ -480,17 +531,11 @@ def _is_numpy_array(obj):
Return true if the given object is a numpy array. Make a special effort to Return true if the given object is a numpy array. Make a special effort to
avoid importing numpy unless it's really necessary. avoid importing numpy unless it's really necessary.
""" """
import inspect import sys
for cls in inspect.getmro(type(obj)):
if cls.__module__ == "numpy":
try:
import numpy as np
return isinstance(obj, np.ndarray)
except ImportError:
pass
np = sys.modules.get("numpy")
if np is not None:
return isinstance(obj, np.ndarray)
return False return False

View File

@ -59,17 +59,21 @@ class TestApprox(object):
), ),
) )
def test_repr_0d_array(self, plus_minus): @pytest.mark.parametrize(
"value, repr_string",
[
(5., "approx(5.0 {pm} 5.0e-06)"),
([5.], "approx([5.0 {pm} 5.0e-06])"),
([[5.]], "approx([[5.0 {pm} 5.0e-06]])"),
([[5., 6.]], "approx([[5.0 {pm} 5.0e-06, 6.0 {pm} 6.0e-06]])"),
([[5.], [6.]], "approx([[5.0 {pm} 5.0e-06], [6.0 {pm} 6.0e-06]])"),
],
)
def test_repr_nd_array(self, plus_minus, value, repr_string):
"""Make sure that arrays of all different dimensions are repr'd correctly."""
np = pytest.importorskip("numpy") np = pytest.importorskip("numpy")
np_array = np.array(5.) np_array = np.array(value)
assert approx(np_array) == 5.0 assert repr(approx(np_array)) == repr_string.format(pm=plus_minus)
string_expected = "approx([5.0 {} 5.0e-06])".format(plus_minus)
assert repr(approx(np_array)) == string_expected
np_array = np.array([5.])
assert approx(np_array) == 5.0
assert repr(approx(np_array)) == string_expected
def test_operator_overloading(self): def test_operator_overloading(self):
assert 1 == approx(1, rel=1e-6, abs=1e-12) assert 1 == approx(1, rel=1e-6, abs=1e-12)
@ -439,6 +443,21 @@ class TestApprox(object):
["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="] ["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="]
) )
@pytest.mark.parametrize(
"x",
[
pytest.param(None),
pytest.param("string"),
pytest.param(["string"], id="nested-str"),
pytest.param([[1]], id="nested-list"),
pytest.param({"key": "string"}, id="dict-with-string"),
pytest.param({"key": {"key": 1}}, id="nested-dict"),
],
)
def test_expected_value_type_error(self, x):
with pytest.raises(TypeError):
approx(x)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"op", "op",
[ [