Type-annotate pytest.warns

This commit is contained in:
Ran Benita 2019-07-10 14:36:07 +03:00
parent d7ee3dac2c
commit 2dca68b863
1 changed files with 87 additions and 23 deletions

View File

@ -1,11 +1,23 @@
""" recording warnings during test function execution. """ """ recording warnings during test function execution. """
import inspect
import re import re
import warnings import warnings
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Union
from _pytest.fixtures import yield_fixture from _pytest.fixtures import yield_fixture
from _pytest.outcomes import fail from _pytest.outcomes import fail
if False: # TYPE_CHECKING
from typing import Type
@yield_fixture @yield_fixture
def recwarn(): def recwarn():
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs) return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
def warns(expected_warning, *args, match=None, **kwargs): @overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "WarningsChecker":
... # pragma: no cover
@overload
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable,
*args: Any,
match: Optional[Union[str, Pattern]] = ...,
**kwargs: Any
) -> Union[Any]:
... # pragma: no cover
def warns(
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["WarningsChecker", Any]:
r"""Assert that code raises a particular class of warning. r"""Assert that code raises a particular class of warning.
Specifically, the parameter ``expected_warning`` can be a warning class or Specifically, the parameter ``expected_warning`` can be a warning class or
@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
def __init__(self): def __init__(self):
super().__init__(record=True) super().__init__(record=True)
self._entered = False self._entered = False
self._list = [] self._list = [] # type: List[warnings._Record]
@property @property
def list(self): def list(self) -> List["warnings._Record"]:
"""The list of recorded warnings.""" """The list of recorded warnings."""
return self._list return self._list
def __getitem__(self, i): def __getitem__(self, i: int) -> "warnings._Record":
"""Get a recorded warning by index.""" """Get a recorded warning by index."""
return self._list[i] return self._list[i]
def __iter__(self): def __iter__(self) -> Iterator["warnings._Record"]:
"""Iterate through the recorded warnings.""" """Iterate through the recorded warnings."""
return iter(self._list) return iter(self._list)
def __len__(self): def __len__(self) -> int:
"""The number of recorded warnings.""" """The number of recorded warnings."""
return len(self._list) return len(self._list)
def pop(self, cls=Warning): def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
"""Pop the first recorded warning, raise exception if not exists.""" """Pop the first recorded warning, raise exception if not exists."""
for i, w in enumerate(self._list): for i, w in enumerate(self._list):
if issubclass(w.category, cls): if issubclass(w.category, cls):
@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
__tracebackhide__ = True __tracebackhide__ = True
raise AssertionError("%r not found in warning list" % cls) raise AssertionError("%r not found in warning list" % cls)
def clear(self): def clear(self) -> None:
"""Clear the list of recorded warnings.""" """Clear the list of recorded warnings."""
self._list[:] = [] self._list[:] = []
def __enter__(self): # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
# -- it returns a List but we only emulate one.
def __enter__(self) -> "WarningsRecorder": # type: ignore
if self._entered: if self._entered:
__tracebackhide__ = True __tracebackhide__ = True
raise RuntimeError("Cannot enter %r twice" % self) raise RuntimeError("Cannot enter %r twice" % self)
self._list = super().__enter__() _list = super().__enter__()
# record=True means it's None.
assert _list is not None
self._list = _list
warnings.simplefilter("always") warnings.simplefilter("always")
return self return self
def __exit__(self, *exc_info): def __exit__(
self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
if not self._entered: if not self._entered:
__tracebackhide__ = True __tracebackhide__ = True
raise RuntimeError("Cannot exit %r without entering first" % self) raise RuntimeError("Cannot exit %r without entering first" % self)
super().__exit__(*exc_info) super().__exit__(exc_type, exc_val, exc_tb)
# Built-in catch_warnings does not reset entered state so we do it # Built-in catch_warnings does not reset entered state so we do it
# manually here for this context manager to become reusable. # manually here for this context manager to become reusable.
self._entered = False self._entered = False
return False
class WarningsChecker(WarningsRecorder): class WarningsChecker(WarningsRecorder):
def __init__(self, expected_warning=None, match_expr=None): def __init__(
self,
expected_warning: Optional[
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
] = None,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
super().__init__() super().__init__()
msg = "exceptions must be derived from Warning, not %s" msg = "exceptions must be derived from Warning, not %s"
if isinstance(expected_warning, tuple): if expected_warning is None:
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning: for exc in expected_warning:
if not inspect.isclass(exc): if not issubclass(exc, Warning):
raise TypeError(msg % type(exc)) raise TypeError(msg % type(exc))
elif inspect.isclass(expected_warning): expected_warning_tup = expected_warning
expected_warning = (expected_warning,) elif issubclass(expected_warning, Warning):
elif expected_warning is not None: expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning)) raise TypeError(msg % type(expected_warning))
self.expected_warning = expected_warning self.expected_warning = expected_warning_tup
self.match_expr = match_expr self.match_expr = match_expr
def __exit__(self, *exc_info): def __exit__(
super().__exit__(*exc_info) self,
exc_type: Optional["Type[BaseException]"],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
super().__exit__(exc_type, exc_val, exc_tb)
__tracebackhide__ = True __tracebackhide__ = True
# only check if we're not currently handling an exception # only check if we're not currently handling an exception
if all(a is None for a in exc_info): if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None: if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self): if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True __tracebackhide__ = True
@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
[each.message for each in self], [each.message for each in self],
) )
) )
return False