python_api: type annotate some parts of pytest.approx()
This commit is contained in:
		
							parent
							
								
									142d8963e6
								
							
						
					
					
						commit
						8f8f472379
					
				| 
						 | 
					@ -33,7 +33,7 @@ if TYPE_CHECKING:
 | 
				
			||||||
BASE_TYPE = (type, STRING_TYPES)
 | 
					BASE_TYPE = (type, STRING_TYPES)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _non_numeric_type_error(value, at):
 | 
					def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
 | 
				
			||||||
    at_str = " at {}".format(at) if at else ""
 | 
					    at_str = " at {}".format(at) if at else ""
 | 
				
			||||||
    return TypeError(
 | 
					    return TypeError(
 | 
				
			||||||
        "cannot make approximate comparisons to non-numeric values: {!r} {}".format(
 | 
					        "cannot make approximate comparisons to non-numeric values: {!r} {}".format(
 | 
				
			||||||
| 
						 | 
					@ -55,7 +55,7 @@ class ApproxBase:
 | 
				
			||||||
    __array_ufunc__ = None
 | 
					    __array_ufunc__ = None
 | 
				
			||||||
    __array_priority__ = 100
 | 
					    __array_priority__ = 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, expected, rel=None, abs=None, nan_ok=False):
 | 
					    def __init__(self, expected, rel=None, abs=None, nan_ok: bool = False) -> None:
 | 
				
			||||||
        __tracebackhide__ = True
 | 
					        __tracebackhide__ = True
 | 
				
			||||||
        self.expected = expected
 | 
					        self.expected = expected
 | 
				
			||||||
        self.abs = abs
 | 
					        self.abs = abs
 | 
				
			||||||
| 
						 | 
					@ -63,10 +63,10 @@ class ApproxBase:
 | 
				
			||||||
        self.nan_ok = nan_ok
 | 
					        self.nan_ok = nan_ok
 | 
				
			||||||
        self._check_type()
 | 
					        self._check_type()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __eq__(self, actual):
 | 
					    def __eq__(self, actual) -> bool:
 | 
				
			||||||
        return all(
 | 
					        return all(
 | 
				
			||||||
            a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
 | 
					            a == self._approx_scalar(x) for a, x in self._yield_comparisons(actual)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
| 
						 | 
					@ -74,10 +74,10 @@ class ApproxBase:
 | 
				
			||||||
    # Ignore type because of https://github.com/python/mypy/issues/4266.
 | 
					    # Ignore type because of https://github.com/python/mypy/issues/4266.
 | 
				
			||||||
    __hash__ = None  # type: ignore
 | 
					    __hash__ = None  # type: ignore
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __ne__(self, actual):
 | 
					    def __ne__(self, actual) -> bool:
 | 
				
			||||||
        return not (actual == self)
 | 
					        return not (actual == self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _approx_scalar(self, x):
 | 
					    def _approx_scalar(self, x) -> "ApproxScalar":
 | 
				
			||||||
        return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
 | 
					        return ApproxScalar(x, rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _yield_comparisons(self, actual):
 | 
					    def _yield_comparisons(self, actual):
 | 
				
			||||||
| 
						 | 
					@ -87,7 +87,7 @@ class ApproxBase:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _check_type(self):
 | 
					    def _check_type(self) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Raise a TypeError if the expected value is not a valid type.
 | 
					        Raise a TypeError if the expected value is not a valid type.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -111,11 +111,11 @@ class ApproxNumpy(ApproxBase):
 | 
				
			||||||
    Perform approximate comparisons where the expected value is numpy array.
 | 
					    Perform approximate comparisons where the expected value is numpy array.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
        list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
 | 
					        list_scalars = _recursive_list_map(self._approx_scalar, self.expected.tolist())
 | 
				
			||||||
        return "approx({!r})".format(list_scalars)
 | 
					        return "approx({!r})".format(list_scalars)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __eq__(self, actual):
 | 
					    def __eq__(self, actual) -> bool:
 | 
				
			||||||
        import numpy as np
 | 
					        import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # self.expected is supposed to always be an array here
 | 
					        # self.expected is supposed to always be an array here
 | 
				
			||||||
| 
						 | 
					@ -154,12 +154,12 @@ class ApproxMapping(ApproxBase):
 | 
				
			||||||
    numeric values (the keys can be anything).
 | 
					    numeric values (the keys can be anything).
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
        return "approx({!r})".format(
 | 
					        return "approx({!r})".format(
 | 
				
			||||||
            {k: self._approx_scalar(v) for k, v in self.expected.items()}
 | 
					            {k: self._approx_scalar(v) for k, v in self.expected.items()}
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __eq__(self, actual):
 | 
					    def __eq__(self, actual) -> bool:
 | 
				
			||||||
        if set(actual.keys()) != set(self.expected.keys()):
 | 
					        if set(actual.keys()) != set(self.expected.keys()):
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -169,7 +169,7 @@ class ApproxMapping(ApproxBase):
 | 
				
			||||||
        for k in self.expected.keys():
 | 
					        for k in self.expected.keys():
 | 
				
			||||||
            yield actual[k], self.expected[k]
 | 
					            yield actual[k], self.expected[k]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _check_type(self):
 | 
					    def _check_type(self) -> None:
 | 
				
			||||||
        __tracebackhide__ = True
 | 
					        __tracebackhide__ = True
 | 
				
			||||||
        for key, value in self.expected.items():
 | 
					        for key, value in self.expected.items():
 | 
				
			||||||
            if isinstance(value, type(self.expected)):
 | 
					            if isinstance(value, type(self.expected)):
 | 
				
			||||||
| 
						 | 
					@ -185,7 +185,7 @@ class ApproxSequencelike(ApproxBase):
 | 
				
			||||||
    numbers.
 | 
					    numbers.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
        seq_type = type(self.expected)
 | 
					        seq_type = type(self.expected)
 | 
				
			||||||
        if seq_type not in (tuple, list, set):
 | 
					        if seq_type not in (tuple, list, set):
 | 
				
			||||||
            seq_type = list
 | 
					            seq_type = list
 | 
				
			||||||
| 
						 | 
					@ -193,7 +193,7 @@ class ApproxSequencelike(ApproxBase):
 | 
				
			||||||
            seq_type(self._approx_scalar(x) for x in self.expected)
 | 
					            seq_type(self._approx_scalar(x) for x in self.expected)
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __eq__(self, actual):
 | 
					    def __eq__(self, actual) -> bool:
 | 
				
			||||||
        if len(actual) != len(self.expected):
 | 
					        if len(actual) != len(self.expected):
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
        return ApproxBase.__eq__(self, actual)
 | 
					        return ApproxBase.__eq__(self, actual)
 | 
				
			||||||
| 
						 | 
					@ -201,7 +201,7 @@ class ApproxSequencelike(ApproxBase):
 | 
				
			||||||
    def _yield_comparisons(self, actual):
 | 
					    def _yield_comparisons(self, actual):
 | 
				
			||||||
        return zip(actual, self.expected)
 | 
					        return zip(actual, self.expected)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _check_type(self):
 | 
					    def _check_type(self) -> None:
 | 
				
			||||||
        __tracebackhide__ = True
 | 
					        __tracebackhide__ = True
 | 
				
			||||||
        for index, x in enumerate(self.expected):
 | 
					        for index, x in enumerate(self.expected):
 | 
				
			||||||
            if isinstance(x, type(self.expected)):
 | 
					            if isinstance(x, type(self.expected)):
 | 
				
			||||||
| 
						 | 
					@ -223,7 +223,7 @@ class ApproxScalar(ApproxBase):
 | 
				
			||||||
    DEFAULT_ABSOLUTE_TOLERANCE = 1e-12  # type: Union[float, Decimal]
 | 
					    DEFAULT_ABSOLUTE_TOLERANCE = 1e-12  # type: Union[float, Decimal]
 | 
				
			||||||
    DEFAULT_RELATIVE_TOLERANCE = 1e-6  # type: Union[float, Decimal]
 | 
					    DEFAULT_RELATIVE_TOLERANCE = 1e-6  # type: Union[float, Decimal]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __repr__(self):
 | 
					    def __repr__(self) -> str:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Return a string communicating both the expected value and the tolerance
 | 
					        Return a string communicating both the expected value and the tolerance
 | 
				
			||||||
        for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
 | 
					        for the comparison being made, e.g. '1.0 ± 1e-6', '(3+4j) ± 5e-6 ∠ ±180°'.
 | 
				
			||||||
| 
						 | 
					@ -245,7 +245,7 @@ class ApproxScalar(ApproxBase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return "{} ± {}".format(self.expected, vetted_tolerance)
 | 
					        return "{} ± {}".format(self.expected, vetted_tolerance)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __eq__(self, actual):
 | 
					    def __eq__(self, actual) -> bool:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Return true if the given value is equal to the expected value within
 | 
					        Return true if the given value is equal to the expected value within
 | 
				
			||||||
        the pre-specified tolerance.
 | 
					        the pre-specified tolerance.
 | 
				
			||||||
| 
						 | 
					@ -275,7 +275,8 @@ class ApproxScalar(ApproxBase):
 | 
				
			||||||
            return False
 | 
					            return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Return true if the two numbers are within the tolerance.
 | 
					        # Return true if the two numbers are within the tolerance.
 | 
				
			||||||
        return abs(self.expected - actual) <= self.tolerance
 | 
					        result = abs(self.expected - actual) <= self.tolerance  # type: bool
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Ignore type because of https://github.com/python/mypy/issues/4266.
 | 
					    # Ignore type because of https://github.com/python/mypy/issues/4266.
 | 
				
			||||||
    __hash__ = None  # type: ignore
 | 
					    __hash__ = None  # type: ignore
 | 
				
			||||||
| 
						 | 
					@ -337,7 +338,7 @@ class ApproxDecimal(ApproxScalar):
 | 
				
			||||||
    DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
 | 
					    DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def approx(expected, rel=None, abs=None, nan_ok=False):
 | 
					def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Assert that two numbers (or two sets of numbers) are equal to each other
 | 
					    Assert that two numbers (or two sets of numbers) are equal to each other
 | 
				
			||||||
    within some tolerance.
 | 
					    within some tolerance.
 | 
				
			||||||
| 
						 | 
					@ -527,7 +528,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
 | 
				
			||||||
    return cls(expected, rel, abs, nan_ok)
 | 
					    return cls(expected, rel, abs, nan_ok)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _is_numpy_array(obj):
 | 
					def _is_numpy_array(obj: object) -> bool:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Return true if the given object is a numpy array.  Make a special effort to
 | 
					    Return true if the given object is a numpy array.  Make a special effort to
 | 
				
			||||||
    avoid importing numpy unless it's really necessary.
 | 
					    avoid importing numpy unless it's really necessary.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,6 +3,7 @@ from decimal import Decimal
 | 
				
			||||||
from fractions import Fraction
 | 
					from fractions import Fraction
 | 
				
			||||||
from operator import eq
 | 
					from operator import eq
 | 
				
			||||||
from operator import ne
 | 
					from operator import ne
 | 
				
			||||||
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from pytest import approx
 | 
					from pytest import approx
 | 
				
			||||||
| 
						 | 
					@ -121,18 +122,22 @@ class TestApprox:
 | 
				
			||||||
            assert a == approx(x, rel=5e-1, abs=0.0)
 | 
					            assert a == approx(x, rel=5e-1, abs=0.0)
 | 
				
			||||||
            assert a != approx(x, rel=5e-2, abs=0.0)
 | 
					            assert a != approx(x, rel=5e-2, abs=0.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_negative_tolerance(self):
 | 
					    @pytest.mark.parametrize(
 | 
				
			||||||
 | 
					        ("rel", "abs"),
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            (-1e100, None),
 | 
				
			||||||
 | 
					            (None, -1e100),
 | 
				
			||||||
 | 
					            (1e100, -1e100),
 | 
				
			||||||
 | 
					            (-1e100, 1e100),
 | 
				
			||||||
 | 
					            (-1e100, -1e100),
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    def test_negative_tolerance(
 | 
				
			||||||
 | 
					        self, rel: Optional[float], abs: Optional[float]
 | 
				
			||||||
 | 
					    ) -> None:
 | 
				
			||||||
        # Negative tolerances are not allowed.
 | 
					        # Negative tolerances are not allowed.
 | 
				
			||||||
        illegal_kwargs = [
 | 
					 | 
				
			||||||
            dict(rel=-1e100),
 | 
					 | 
				
			||||||
            dict(abs=-1e100),
 | 
					 | 
				
			||||||
            dict(rel=1e100, abs=-1e100),
 | 
					 | 
				
			||||||
            dict(rel=-1e100, abs=1e100),
 | 
					 | 
				
			||||||
            dict(rel=-1e100, abs=-1e100),
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
        for kwargs in illegal_kwargs:
 | 
					 | 
				
			||||||
        with pytest.raises(ValueError):
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
                1.1 == approx(1, **kwargs)
 | 
					            1.1 == approx(1, rel, abs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_inf_tolerance(self):
 | 
					    def test_inf_tolerance(self):
 | 
				
			||||||
        # Everything should be equal if the tolerance is infinite.
 | 
					        # Everything should be equal if the tolerance is infinite.
 | 
				
			||||||
| 
						 | 
					@ -143,19 +148,21 @@ class TestApprox:
 | 
				
			||||||
            assert a == approx(x, rel=0.0, abs=inf)
 | 
					            assert a == approx(x, rel=0.0, abs=inf)
 | 
				
			||||||
            assert a == approx(x, rel=inf, abs=inf)
 | 
					            assert a == approx(x, rel=inf, abs=inf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_inf_tolerance_expecting_zero(self):
 | 
					    def test_inf_tolerance_expecting_zero(self) -> None:
 | 
				
			||||||
        # If the relative tolerance is zero but the expected value is infinite,
 | 
					        # If the relative tolerance is zero but the expected value is infinite,
 | 
				
			||||||
        # the actual tolerance is a NaN, which should be an error.
 | 
					        # the actual tolerance is a NaN, which should be an error.
 | 
				
			||||||
        illegal_kwargs = [dict(rel=inf, abs=0.0), dict(rel=inf, abs=inf)]
 | 
					 | 
				
			||||||
        for kwargs in illegal_kwargs:
 | 
					 | 
				
			||||||
        with pytest.raises(ValueError):
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
                1 == approx(0, **kwargs)
 | 
					            1 == approx(0, rel=inf, abs=0.0)
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            1 == approx(0, rel=inf, abs=inf)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_nan_tolerance(self):
 | 
					    def test_nan_tolerance(self) -> None:
 | 
				
			||||||
        illegal_kwargs = [dict(rel=nan), dict(abs=nan), dict(rel=nan, abs=nan)]
 | 
					 | 
				
			||||||
        for kwargs in illegal_kwargs:
 | 
					 | 
				
			||||||
        with pytest.raises(ValueError):
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
                1.1 == approx(1, **kwargs)
 | 
					            1.1 == approx(1, rel=nan)
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            1.1 == approx(1, abs=nan)
 | 
				
			||||||
 | 
					        with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					            1.1 == approx(1, rel=nan, abs=nan)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_reasonable_defaults(self):
 | 
					    def test_reasonable_defaults(self):
 | 
				
			||||||
        # Whatever the defaults are, they should work for numbers close to 1
 | 
					        # Whatever the defaults are, they should work for numbers close to 1
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue