Type-annotate pytest.warns
This commit is contained in:
parent
d7ee3dac2c
commit
2dca68b863
|
@ -1,11 +1,23 @@
|
||||||
""" recording warnings during test function execution. """
|
""" recording warnings during test function execution. """
|
||||||
import inspect
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Iterator
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import overload
|
||||||
|
from typing import Pattern
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from _pytest.fixtures import yield_fixture
|
from _pytest.fixtures import yield_fixture
|
||||||
from _pytest.outcomes import fail
|
from _pytest.outcomes import fail
|
||||||
|
|
||||||
|
if False: # TYPE_CHECKING
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
|
||||||
@yield_fixture
|
@yield_fixture
|
||||||
def recwarn():
|
def recwarn():
|
||||||
|
@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
|
||||||
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def warns(expected_warning, *args, match=None, **kwargs):
|
@overload
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
*,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...
|
||||||
|
) -> "WarningsChecker":
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
func: Callable,
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = ...,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union[Any]:
|
||||||
|
... # pragma: no cover
|
||||||
|
|
||||||
|
|
||||||
|
def warns(
|
||||||
|
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
|
||||||
|
*args: Any,
|
||||||
|
match: Optional[Union[str, Pattern]] = None,
|
||||||
|
**kwargs: Any
|
||||||
|
) -> Union["WarningsChecker", Any]:
|
||||||
r"""Assert that code raises a particular class of warning.
|
r"""Assert that code raises a particular class of warning.
|
||||||
|
|
||||||
Specifically, the parameter ``expected_warning`` can be a warning class or
|
Specifically, the parameter ``expected_warning`` can be a warning class or
|
||||||
|
@ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(record=True)
|
super().__init__(record=True)
|
||||||
self._entered = False
|
self._entered = False
|
||||||
self._list = []
|
self._list = [] # type: List[warnings._Record]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def list(self):
|
def list(self) -> List["warnings._Record"]:
|
||||||
"""The list of recorded warnings."""
|
"""The list of recorded warnings."""
|
||||||
return self._list
|
return self._list
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i: int) -> "warnings._Record":
|
||||||
"""Get a recorded warning by index."""
|
"""Get a recorded warning by index."""
|
||||||
return self._list[i]
|
return self._list[i]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator["warnings._Record"]:
|
||||||
"""Iterate through the recorded warnings."""
|
"""Iterate through the recorded warnings."""
|
||||||
return iter(self._list)
|
return iter(self._list)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
"""The number of recorded warnings."""
|
"""The number of recorded warnings."""
|
||||||
return len(self._list)
|
return len(self._list)
|
||||||
|
|
||||||
def pop(self, cls=Warning):
|
def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record":
|
||||||
"""Pop the first recorded warning, raise exception if not exists."""
|
"""Pop the first recorded warning, raise exception if not exists."""
|
||||||
for i, w in enumerate(self._list):
|
for i, w in enumerate(self._list):
|
||||||
if issubclass(w.category, cls):
|
if issubclass(w.category, cls):
|
||||||
|
@ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings):
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise AssertionError("%r not found in warning list" % cls)
|
raise AssertionError("%r not found in warning list" % cls)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
"""Clear the list of recorded warnings."""
|
"""Clear the list of recorded warnings."""
|
||||||
self._list[:] = []
|
self._list[:] = []
|
||||||
|
|
||||||
def __enter__(self):
|
# Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
|
||||||
|
# -- it returns a List but we only emulate one.
|
||||||
|
def __enter__(self) -> "WarningsRecorder": # type: ignore
|
||||||
if self._entered:
|
if self._entered:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise RuntimeError("Cannot enter %r twice" % self)
|
raise RuntimeError("Cannot enter %r twice" % self)
|
||||||
self._list = super().__enter__()
|
_list = super().__enter__()
|
||||||
|
# record=True means it's None.
|
||||||
|
assert _list is not None
|
||||||
|
self._list = _list
|
||||||
warnings.simplefilter("always")
|
warnings.simplefilter("always")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, *exc_info):
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
if not self._entered:
|
if not self._entered:
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
raise RuntimeError("Cannot exit %r without entering first" % self)
|
raise RuntimeError("Cannot exit %r without entering first" % self)
|
||||||
|
|
||||||
super().__exit__(*exc_info)
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
# Built-in catch_warnings does not reset entered state so we do it
|
# Built-in catch_warnings does not reset entered state so we do it
|
||||||
# manually here for this context manager to become reusable.
|
# manually here for this context manager to become reusable.
|
||||||
self._entered = False
|
self._entered = False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class WarningsChecker(WarningsRecorder):
|
class WarningsChecker(WarningsRecorder):
|
||||||
def __init__(self, expected_warning=None, match_expr=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
expected_warning: Optional[
|
||||||
|
Union["Type[Warning]", Tuple["Type[Warning]", ...]]
|
||||||
|
] = None,
|
||||||
|
match_expr: Optional[Union[str, Pattern]] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
msg = "exceptions must be derived from Warning, not %s"
|
msg = "exceptions must be derived from Warning, not %s"
|
||||||
if isinstance(expected_warning, tuple):
|
if expected_warning is None:
|
||||||
|
expected_warning_tup = None
|
||||||
|
elif isinstance(expected_warning, tuple):
|
||||||
for exc in expected_warning:
|
for exc in expected_warning:
|
||||||
if not inspect.isclass(exc):
|
if not issubclass(exc, Warning):
|
||||||
raise TypeError(msg % type(exc))
|
raise TypeError(msg % type(exc))
|
||||||
elif inspect.isclass(expected_warning):
|
expected_warning_tup = expected_warning
|
||||||
expected_warning = (expected_warning,)
|
elif issubclass(expected_warning, Warning):
|
||||||
elif expected_warning is not None:
|
expected_warning_tup = (expected_warning,)
|
||||||
|
else:
|
||||||
raise TypeError(msg % type(expected_warning))
|
raise TypeError(msg % type(expected_warning))
|
||||||
|
|
||||||
self.expected_warning = expected_warning
|
self.expected_warning = expected_warning_tup
|
||||||
self.match_expr = match_expr
|
self.match_expr = match_expr
|
||||||
|
|
||||||
def __exit__(self, *exc_info):
|
def __exit__(
|
||||||
super().__exit__(*exc_info)
|
self,
|
||||||
|
exc_type: Optional["Type[BaseException]"],
|
||||||
|
exc_val: Optional[BaseException],
|
||||||
|
exc_tb: Optional[TracebackType],
|
||||||
|
) -> bool:
|
||||||
|
super().__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
|
|
||||||
# only check if we're not currently handling an exception
|
# only check if we're not currently handling an exception
|
||||||
if all(a is None for a in exc_info):
|
if exc_type is None and exc_val is None and exc_tb is None:
|
||||||
if self.expected_warning is not None:
|
if self.expected_warning is not None:
|
||||||
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
if not any(issubclass(r.category, self.expected_warning) for r in self):
|
||||||
__tracebackhide__ = True
|
__tracebackhide__ = True
|
||||||
|
@ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder):
|
||||||
[each.message for each in self],
|
[each.message for each in self],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return False
|
||||||
|
|
Loading…
Reference in New Issue