Discuss alternative float comparison algorithms.

This commit is contained in:
Kale Kundert 2016-03-11 15:59:48 -08:00
parent 42a7e0488d
commit 078448008c
1 changed files with 102 additions and 50 deletions

View File

@ -1344,50 +1344,55 @@ class RaisesContext(object):
class approx(object): class approx(object):
""" """
Assert that two numbers (or two sets of numbers) are equal to each Assert that two numbers (or two sets of numbers) are equal to each other
other within some margin. within some tolerance.
Due to the intricacies of floating-point arithmetic, numbers that we would Due to the `intricacies of floating-point arithmetic`__, numbers that we
intuitively expect to be the same are not always so:: would intuitively expect to be equal are not always so::
>>> 0.1 + 0.2 == 0.3 >>> 0.1 + 0.2 == 0.3
False False
__ https://docs.python.org/3/tutorial/floatingpoint.html
This problem is commonly encountered when writing tests, e.g. when making This problem is commonly encountered when writing tests, e.g. when making
sure that floating-point values are what you expect them to be. One way to sure that floating-point values are what you expect them to be. One way to
deal with this problem is to assert that two floating-point numbers are deal with this problem is to assert that two floating-point numbers are
equal to within some appropriate margin:: equal to within some appropriate tolerance::
>>> abs((0.1 + 0.2) - 0.3) < 1e-6 >>> abs((0.1 + 0.2) - 0.3) < 1e-6
True True
However, comparisons like this are tedious to write and difficult to However, comparisons like this are tedious to write and difficult to
understand. Furthermore, absolute comparisons like the one above are understand. Furthermore, absolute comparisons like the one above are
usually discouraged in favor of relative comparisons, which can't even be usually discouraged because there's no tolerance that works well for all
easily written on one line. The ``approx`` class provides a way to make situations. ``1e-6`` is good for numbers around ``1``, but too small for
floating-point comparisons that solves both these problems:: very big numbers and too big for very small ones. It's better to express
the tolerance as a fraction of the expected value, but relative comparisons
like that are even more difficult to write correctly and concisely.
The ``approx`` class performs floating-point comparisons using a syntax
that's as intuitive as possible::
>>> from pytest import approx >>> from pytest import approx
>>> 0.1 + 0.2 == approx(0.3) >>> 0.1 + 0.2 == approx(0.3)
True True
``approx`` also makes is easy to compare ordered sets of numbers, which The same syntax also works on sequences of numbers::
would otherwise be very tedious::
>>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6)) >>> (0.1 + 0.2, 0.2 + 0.4) == approx((0.3, 0.6))
True True
By default, ``approx`` considers two numbers to be equal if the relative By default, ``approx`` considers numbers within a relative tolerance of
error between them is less than one part in a million (e.g. ``1e-6``). ``1e-6`` (i.e. one part in a million) of its expected value to be equal.
Relative error is defined as ``abs(x - a) / x`` where ``x`` is the value This treatment would lead to surprising results if the expected value was
you're expecting and ``a`` is the value you're comparing to. This ``0.0``, because nothing but ``0.0`` itself is relatively close to ``0.0``.
definition breaks down when the numbers being compared get very close to To handle this case less surprisingly, ``approx`` also considers numbers
zero, so ``approx`` will also consider two numbers to be equal if the within an absolute tolerance of ``1e-12`` of its expected value to be
absolute difference between them is less than one part in a trillion (e.g. equal. Infinite numbers are another special case. They are only
``1e-12``). considered equal to themselves, regardless of the relative tolerance. Both
the relative and absolute tolerances can be changed by passing arguments to
Both the relative and absolute error thresholds can be changed by passing the ``approx`` constructor::
arguments to the ``approx`` constructor::
>>> 1.0001 == approx(1) >>> 1.0001 == approx(1)
False False
@ -1396,12 +1401,12 @@ class approx(object):
>>> 1.0001 == approx(1, abs=1e-3) >>> 1.0001 == approx(1, abs=1e-3)
True True
Note that if you specify ``abs`` but not ``rel``, the comparison will not If you specify ``abs`` but not ``rel``, the comparison will not consider
consider the relative error between the two values at all. In other words, the relative tolerance at all. In other words, two numbers that are within
two numbers that are within the default relative error threshold of 1e-6 the default relative tolerance of ``1e-6`` will still be considered unequal
will still be considered unequal if they exceed the specified absolute if they exceed the specified absolute tolerance. If you specify both
error threshold. If you specify both ``abs`` and ``rel``, the numbers will ``abs`` and ``rel``, the numbers will be considered equal if either
be considered equal if either threshold is met:: tolerance is met::
>>> 1 + 1e-8 == approx(1) >>> 1 + 1e-8 == approx(1)
True True
@ -1409,6 +1414,46 @@ class approx(object):
False False
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12) >>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
True True
If you're thinking about using ``approx``, then you might want to know how
it compares to other good ways of comparing floating-point numbers. All of
these algorithms are based on relative and absolute tolerances, but they do
have meaningful differences:
- ``math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)``: True if the relative
tolerance is met w.r.t. either ``a`` or ``b`` or if the absolute
tolerance is met. Because the relative tolerance is calculated w.r.t.
both ``a`` and ``b``, this test is symmetric (i.e. neither ``a`` nor
``b`` is a "reference value"). You have to specify an absolute tolerance
if you want to compare to ``0.0`` because there is no tolerance by
default. Only available in python>=3.5. `More information...`__
__ https://docs.python.org/3/library/math.html#math.isclose
- ``numpy.isclose(a, b, rtol=1e-5, atol=1e-8)``: True if the difference
between ``a`` and ``b`` is less that the sum of the relative tolerance
w.r.t. ``b`` and the absolute tolerance. Because the relative tolerance
is only calculated w.r.t. ``b``, this test is asymmetric and you can
think of ``b`` as the reference value. Support for comparing sequences
is provided by ``numpy.allclose``. `More information...`__
__ http://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.isclose.html
- ``unittest.TestCase.assertAlmostEqual(a, b)``: True if ``a`` and ``b``
are within an absolute tolerance of ``1e-7``. No relative tolerance is
considered and the absolute tolerance cannot be changed, so this function
is not appropriate for very large or very small numbers. Also, it's only
available in subclasses of ``unittest.TestCase`` and it's ugly because it
doesn't follow PEP8. `More information...`__
__ https://docs.python.org/3/library/unittest.html#unittest.TestCase.assertAlmostEqual
- ``a == pytest.approx(b, rel=1e-6, abs=1e-12)``: True if the relative
tolerance is met w.r.t. ``b`` or the if the absolute tolerance is met.
Because the relative tolerance is only calculated w.r.t. ``b``, this test
is asymmetric and you can think of ``b`` as the reference value. In the
special case that you explicitly specify an absolute tolerance but not a
relative tolerance, only the absolute tolerance is considered.
""" """
def __init__(self, expected, rel=None, abs=None): def __init__(self, expected, rel=None, abs=None):
@ -1427,6 +1472,9 @@ class approx(object):
@property @property
def expected(self): def expected(self):
# Regardless of whether the user-specified expected value is a number
# or a sequence of numbers, return a list of ApproxNotIterable objects
# that can be compared against.
from collections import Iterable from collections import Iterable
approx_non_iter = lambda x: ApproxNonIterable(x, self.rel, self.abs) approx_non_iter = lambda x: ApproxNonIterable(x, self.rel, self.abs)
if isinstance(self._expected, Iterable): if isinstance(self._expected, Iterable):
@ -1437,14 +1485,19 @@ class approx(object):
@expected.setter @expected.setter
def expected(self, expected): def expected(self, expected):
self._expected = expected self._expected = expected
class ApproxNonIterable(object): class ApproxNonIterable(object):
""" """
Perform approximate comparisons for single numbers only. Perform approximate comparisons for single numbers only.
This class contains most of the In other words, the ``expected`` attribute for objects of this class must
be some sort of number. This is in contrast to the ``approx`` class, where
the ``expected`` attribute can either be a number of a sequence of numbers.
This class is responsible for making comparisons, while ``approx`` is
responsible for abstracting the difference between numbers and sequences of
numbers. Although this class can stand on its own, it's only meant to be
used within ``approx``.
""" """
def __init__(self, expected, rel=None, abs=None): def __init__(self, expected, rel=None, abs=None):
@ -1453,12 +1506,12 @@ class ApproxNonIterable(object):
self.rel = rel self.rel = rel
def __repr__(self): def __repr__(self):
# Infinities aren't compared using tolerances, so don't show a # Infinities aren't compared using tolerances, so don't show a
# tolerance. # tolerance.
if math.isinf(self.expected): if math.isinf(self.expected):
return str(self.expected) return str(self.expected)
# If a sensible tolerance can't be calculated, self.tolerance will # If a sensible tolerance can't be calculated, self.tolerance will
# raise a ValueError. In this case, display '???'. # raise a ValueError. In this case, display '???'.
try: try:
vetted_tolerance = '{:.1e}'.format(self.tolerance) vetted_tolerance = '{:.1e}'.format(self.tolerance)
@ -1467,9 +1520,9 @@ class ApproxNonIterable(object):
repr = u'{0} \u00b1 {1}'.format(self.expected, vetted_tolerance) repr = u'{0} \u00b1 {1}'.format(self.expected, vetted_tolerance)
# In python2, __repr__() must return a string (i.e. not a unicode # In python2, __repr__() must return a string (i.e. not a unicode
# object). In python3, __repr__() must return a unicode object # object). In python3, __repr__() must return a unicode object
# (although now strings are unicode objects and bytes are what # (although now strings are unicode objects and bytes are what
# strings were). # strings were).
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
return repr.encode('utf-8') return repr.encode('utf-8')
@ -1481,11 +1534,11 @@ class ApproxNonIterable(object):
if actual == self.expected: if actual == self.expected:
return True return True
# Infinity shouldn't be approximately equal to anything but itself, but # Infinity shouldn't be approximately equal to anything but itself, but
# if there's a relative tolerance, it will be infinite and infinity # if there's a relative tolerance, it will be infinite and infinity
# will seem approximately equal to everything. The equal-to-itself # will seem approximately equal to everything. The equal-to-itself
# case would have been short circuited above, so here we can just # case would have been short circuited above, so here we can just
# return false if the expected value is infinite. The abs() call is # return false if the expected value is infinite. The abs() call is
# for compatibility with complex numbers. # for compatibility with complex numbers.
if math.isinf(abs(self.expected)): if math.isinf(abs(self.expected)):
return False return False
@ -1497,7 +1550,7 @@ class ApproxNonIterable(object):
def tolerance(self): def tolerance(self):
set_default = lambda x, default: x if x is not None else default set_default = lambda x, default: x if x is not None else default
# Figure out what the absolute tolerance should be. ``self.abs`` is # Figure out what the absolute tolerance should be. ``self.abs`` is
# either None or a value specified by the user. # either None or a value specified by the user.
absolute_tolerance = set_default(self.abs, 1e-12) absolute_tolerance = set_default(self.abs, 1e-12)
@ -1505,17 +1558,17 @@ class ApproxNonIterable(object):
raise ValueError("absolute tolerance can't be negative: {}".format(absolute_tolerance)) raise ValueError("absolute tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(absolute_tolerance): if math.isnan(absolute_tolerance):
raise ValueError("absolute tolerance can't be NaN.") raise ValueError("absolute tolerance can't be NaN.")
# If the user specified an absolute tolerance but not a relative one, # If the user specified an absolute tolerance but not a relative one,
# just return the absolute tolerance. # just return the absolute tolerance.
if self.rel is None: if self.rel is None:
if self.abs is not None: if self.abs is not None:
return absolute_tolerance return absolute_tolerance
# Figure out what the absolute tolerance should be. ``self.rel`` is # Figure out what the absolute tolerance should be. ``self.rel`` is
# either None or a value specified by the user. This is done after # either None or a value specified by the user. This is done after
# we've made sure the user didn't ask for an absolute tolerance only, # we've made sure the user didn't ask for an absolute tolerance only,
# because we don't want to raise errors about the relative tolerance if # because we don't want to raise errors about the relative tolerance if
# it isn't even being used. # it isn't even being used.
relative_tolerance = set_default(self.rel, 1e-6) * abs(self.expected) relative_tolerance = set_default(self.rel, 1e-6) * abs(self.expected)
@ -1523,12 +1576,11 @@ class ApproxNonIterable(object):
raise ValueError("relative tolerance can't be negative: {}".format(absolute_tolerance)) raise ValueError("relative tolerance can't be negative: {}".format(absolute_tolerance))
if math.isnan(relative_tolerance): if math.isnan(relative_tolerance):
raise ValueError("relative tolerance can't be NaN.") raise ValueError("relative tolerance can't be NaN.")
# Return the larger of the relative and absolute tolerances. # Return the larger of the relative and absolute tolerances.
return max(relative_tolerance, absolute_tolerance) return max(relative_tolerance, absolute_tolerance)
# #
# the basic pytest Function item # the basic pytest Function item
# #