Type-annotate pytest.warns
This commit is contained in:
		
							parent
							
								
									d7ee3dac2c
								
							
						
					
					
						commit
						2dca68b863
					
				|  | @ -1,11 +1,23 @@ | ||||||
| """ recording warnings during test function execution. """ | """ recording warnings during test function execution. """ | ||||||
| import inspect |  | ||||||
| import re | import re | ||||||
| import warnings | import warnings | ||||||
|  | from types import TracebackType | ||||||
|  | from typing import Any | ||||||
|  | from typing import Callable | ||||||
|  | from typing import Iterator | ||||||
|  | from typing import List | ||||||
|  | from typing import Optional | ||||||
|  | from typing import overload | ||||||
|  | from typing import Pattern | ||||||
|  | from typing import Tuple | ||||||
|  | from typing import Union | ||||||
| 
 | 
 | ||||||
| from _pytest.fixtures import yield_fixture | from _pytest.fixtures import yield_fixture | ||||||
| from _pytest.outcomes import fail | from _pytest.outcomes import fail | ||||||
| 
 | 
 | ||||||
|  | if False:  # TYPE_CHECKING | ||||||
|  |     from typing import Type | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @yield_fixture | @yield_fixture | ||||||
| def recwarn(): | def recwarn(): | ||||||
|  | @ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs): | ||||||
|     return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs) |     return warns((DeprecationWarning, PendingDeprecationWarning), *args, **kwargs) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def warns(expected_warning, *args, match=None, **kwargs): | @overload | ||||||
|  | def warns( | ||||||
|  |     expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], | ||||||
|  |     *, | ||||||
|  |     match: Optional[Union[str, Pattern]] = ... | ||||||
|  | ) -> "WarningsChecker": | ||||||
|  |     ...  # pragma: no cover | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @overload | ||||||
|  | def warns( | ||||||
|  |     expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], | ||||||
|  |     func: Callable, | ||||||
|  |     *args: Any, | ||||||
|  |     match: Optional[Union[str, Pattern]] = ..., | ||||||
|  |     **kwargs: Any | ||||||
|  | ) -> Union[Any]: | ||||||
|  |     ...  # pragma: no cover | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def warns( | ||||||
|  |     expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], | ||||||
|  |     *args: Any, | ||||||
|  |     match: Optional[Union[str, Pattern]] = None, | ||||||
|  |     **kwargs: Any | ||||||
|  | ) -> Union["WarningsChecker", Any]: | ||||||
|     r"""Assert that code raises a particular class of warning. |     r"""Assert that code raises a particular class of warning. | ||||||
| 
 | 
 | ||||||
|     Specifically, the parameter ``expected_warning`` can be a warning class or |     Specifically, the parameter ``expected_warning`` can be a warning class or | ||||||
|  | @ -101,26 +138,26 @@ class WarningsRecorder(warnings.catch_warnings): | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         super().__init__(record=True) |         super().__init__(record=True) | ||||||
|         self._entered = False |         self._entered = False | ||||||
|         self._list = [] |         self._list = []  # type: List[warnings._Record] | ||||||
| 
 | 
 | ||||||
|     @property |     @property | ||||||
|     def list(self): |     def list(self) -> List["warnings._Record"]: | ||||||
|         """The list of recorded warnings.""" |         """The list of recorded warnings.""" | ||||||
|         return self._list |         return self._list | ||||||
| 
 | 
 | ||||||
|     def __getitem__(self, i): |     def __getitem__(self, i: int) -> "warnings._Record": | ||||||
|         """Get a recorded warning by index.""" |         """Get a recorded warning by index.""" | ||||||
|         return self._list[i] |         return self._list[i] | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __iter__(self) -> Iterator["warnings._Record"]: | ||||||
|         """Iterate through the recorded warnings.""" |         """Iterate through the recorded warnings.""" | ||||||
|         return iter(self._list) |         return iter(self._list) | ||||||
| 
 | 
 | ||||||
|     def __len__(self): |     def __len__(self) -> int: | ||||||
|         """The number of recorded warnings.""" |         """The number of recorded warnings.""" | ||||||
|         return len(self._list) |         return len(self._list) | ||||||
| 
 | 
 | ||||||
|     def pop(self, cls=Warning): |     def pop(self, cls: "Type[Warning]" = Warning) -> "warnings._Record": | ||||||
|         """Pop the first recorded warning, raise exception if not exists.""" |         """Pop the first recorded warning, raise exception if not exists.""" | ||||||
|         for i, w in enumerate(self._list): |         for i, w in enumerate(self._list): | ||||||
|             if issubclass(w.category, cls): |             if issubclass(w.category, cls): | ||||||
|  | @ -128,54 +165,80 @@ class WarningsRecorder(warnings.catch_warnings): | ||||||
|         __tracebackhide__ = True |         __tracebackhide__ = True | ||||||
|         raise AssertionError("%r not found in warning list" % cls) |         raise AssertionError("%r not found in warning list" % cls) | ||||||
| 
 | 
 | ||||||
|     def clear(self): |     def clear(self) -> None: | ||||||
|         """Clear the list of recorded warnings.""" |         """Clear the list of recorded warnings.""" | ||||||
|         self._list[:] = [] |         self._list[:] = [] | ||||||
| 
 | 
 | ||||||
|     def __enter__(self): |     # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__ | ||||||
|  |     # -- it returns a List but we only emulate one. | ||||||
|  |     def __enter__(self) -> "WarningsRecorder":  # type: ignore | ||||||
|         if self._entered: |         if self._entered: | ||||||
|             __tracebackhide__ = True |             __tracebackhide__ = True | ||||||
|             raise RuntimeError("Cannot enter %r twice" % self) |             raise RuntimeError("Cannot enter %r twice" % self) | ||||||
|         self._list = super().__enter__() |         _list = super().__enter__() | ||||||
|  |         # record=True means it's None. | ||||||
|  |         assert _list is not None | ||||||
|  |         self._list = _list | ||||||
|         warnings.simplefilter("always") |         warnings.simplefilter("always") | ||||||
|         return self |         return self | ||||||
| 
 | 
 | ||||||
|     def __exit__(self, *exc_info): |     def __exit__( | ||||||
|  |         self, | ||||||
|  |         exc_type: Optional["Type[BaseException]"], | ||||||
|  |         exc_val: Optional[BaseException], | ||||||
|  |         exc_tb: Optional[TracebackType], | ||||||
|  |     ) -> bool: | ||||||
|         if not self._entered: |         if not self._entered: | ||||||
|             __tracebackhide__ = True |             __tracebackhide__ = True | ||||||
|             raise RuntimeError("Cannot exit %r without entering first" % self) |             raise RuntimeError("Cannot exit %r without entering first" % self) | ||||||
| 
 | 
 | ||||||
|         super().__exit__(*exc_info) |         super().__exit__(exc_type, exc_val, exc_tb) | ||||||
| 
 | 
 | ||||||
|         # Built-in catch_warnings does not reset entered state so we do it |         # Built-in catch_warnings does not reset entered state so we do it | ||||||
|         # manually here for this context manager to become reusable. |         # manually here for this context manager to become reusable. | ||||||
|         self._entered = False |         self._entered = False | ||||||
| 
 | 
 | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| class WarningsChecker(WarningsRecorder): | class WarningsChecker(WarningsRecorder): | ||||||
|     def __init__(self, expected_warning=None, match_expr=None): |     def __init__( | ||||||
|  |         self, | ||||||
|  |         expected_warning: Optional[ | ||||||
|  |             Union["Type[Warning]", Tuple["Type[Warning]", ...]] | ||||||
|  |         ] = None, | ||||||
|  |         match_expr: Optional[Union[str, Pattern]] = None, | ||||||
|  |     ) -> None: | ||||||
|         super().__init__() |         super().__init__() | ||||||
| 
 | 
 | ||||||
|         msg = "exceptions must be derived from Warning, not %s" |         msg = "exceptions must be derived from Warning, not %s" | ||||||
|         if isinstance(expected_warning, tuple): |         if expected_warning is None: | ||||||
|  |             expected_warning_tup = None | ||||||
|  |         elif isinstance(expected_warning, tuple): | ||||||
|             for exc in expected_warning: |             for exc in expected_warning: | ||||||
|                 if not inspect.isclass(exc): |                 if not issubclass(exc, Warning): | ||||||
|                     raise TypeError(msg % type(exc)) |                     raise TypeError(msg % type(exc)) | ||||||
|         elif inspect.isclass(expected_warning): |             expected_warning_tup = expected_warning | ||||||
|             expected_warning = (expected_warning,) |         elif issubclass(expected_warning, Warning): | ||||||
|         elif expected_warning is not None: |             expected_warning_tup = (expected_warning,) | ||||||
|  |         else: | ||||||
|             raise TypeError(msg % type(expected_warning)) |             raise TypeError(msg % type(expected_warning)) | ||||||
| 
 | 
 | ||||||
|         self.expected_warning = expected_warning |         self.expected_warning = expected_warning_tup | ||||||
|         self.match_expr = match_expr |         self.match_expr = match_expr | ||||||
| 
 | 
 | ||||||
|     def __exit__(self, *exc_info): |     def __exit__( | ||||||
|         super().__exit__(*exc_info) |         self, | ||||||
|  |         exc_type: Optional["Type[BaseException]"], | ||||||
|  |         exc_val: Optional[BaseException], | ||||||
|  |         exc_tb: Optional[TracebackType], | ||||||
|  |     ) -> bool: | ||||||
|  |         super().__exit__(exc_type, exc_val, exc_tb) | ||||||
| 
 | 
 | ||||||
|         __tracebackhide__ = True |         __tracebackhide__ = True | ||||||
| 
 | 
 | ||||||
|         # only check if we're not currently handling an exception |         # only check if we're not currently handling an exception | ||||||
|         if all(a is None for a in exc_info): |         if exc_type is None and exc_val is None and exc_tb is None: | ||||||
|             if self.expected_warning is not None: |             if self.expected_warning is not None: | ||||||
|                 if not any(issubclass(r.category, self.expected_warning) for r in self): |                 if not any(issubclass(r.category, self.expected_warning) for r in self): | ||||||
|                     __tracebackhide__ = True |                     __tracebackhide__ = True | ||||||
|  | @ -200,3 +263,4 @@ class WarningsChecker(WarningsRecorder): | ||||||
|                                 [each.message for each in self], |                                 [each.message for each in self], | ||||||
|                             ) |                             ) | ||||||
|                         ) |                         ) | ||||||
|  |         return False | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue