Make ExceptionInfo generic in the exception type
This way, in
with pytest.raises(ValueError) as cm:
...
cm.value is a ValueError and not a BaseException.
This commit is contained in:
@@ -10,10 +10,13 @@ from numbers import Number
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Callable
|
||||
from typing import cast
|
||||
from typing import Generic
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Pattern
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
from more_itertools.more import always_iterable
|
||||
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
|
||||
|
||||
# builtin pytest.raises helper
|
||||
|
||||
_E = TypeVar("_E", bound=BaseException)
|
||||
|
||||
|
||||
@overload
|
||||
def raises(
|
||||
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
*,
|
||||
match: Optional[Union[str, Pattern]] = ...
|
||||
) -> "RaisesContext":
|
||||
) -> "RaisesContext[_E]":
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
@overload
|
||||
def raises(
|
||||
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
func: Callable,
|
||||
*args: Any,
|
||||
match: Optional[str] = ...,
|
||||
**kwargs: Any
|
||||
) -> Optional[_pytest._code.ExceptionInfo]:
|
||||
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
|
||||
... # pragma: no cover
|
||||
|
||||
|
||||
def raises(
|
||||
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
*args: Any,
|
||||
match: Optional[Union[str, Pattern]] = None,
|
||||
**kwargs: Any
|
||||
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
|
||||
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
|
||||
r"""
|
||||
Assert that a code block/function call raises ``expected_exception``
|
||||
or raise a failure exception otherwise.
|
||||
@@ -703,28 +708,30 @@ def raises(
|
||||
try:
|
||||
func(*args[1:], **kwargs)
|
||||
except expected_exception:
|
||||
return _pytest._code.ExceptionInfo.from_current()
|
||||
# Cast to narrow the type to expected_exception (_E).
|
||||
return cast(
|
||||
_pytest._code.ExceptionInfo[_E],
|
||||
_pytest._code.ExceptionInfo.from_current(),
|
||||
)
|
||||
fail(message)
|
||||
|
||||
|
||||
raises.Exception = fail.Exception # type: ignore
|
||||
|
||||
|
||||
class RaisesContext:
|
||||
class RaisesContext(Generic[_E]):
|
||||
def __init__(
|
||||
self,
|
||||
expected_exception: Union[
|
||||
"Type[BaseException]", Tuple["Type[BaseException]", ...]
|
||||
],
|
||||
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
|
||||
message: str,
|
||||
match_expr: Optional[Union[str, Pattern]] = None,
|
||||
) -> None:
|
||||
self.expected_exception = expected_exception
|
||||
self.message = message
|
||||
self.match_expr = match_expr
|
||||
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
|
||||
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
|
||||
|
||||
def __enter__(self) -> _pytest._code.ExceptionInfo:
|
||||
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
|
||||
self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
||||
return self.excinfo
|
||||
|
||||
|
||||
Reference in New Issue
Block a user