Make ExceptionInfo generic in the exception type

This way, in

    with pytest.raises(ValueError) as cm:
        ...

cm.value is a ValueError and not a BaseException.
This commit is contained in:
Ran Benita
2019-07-10 20:12:41 +03:00
parent 56dcc9e1f8
commit 14bf4cdf44
2 changed files with 33 additions and 21 deletions

View File

@@ -10,10 +10,13 @@ from numbers import Number
from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import TypeVar
from typing import Union
from more_itertools.more import always_iterable
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
# builtin pytest.raises helper
_E = TypeVar("_E", bound=BaseException)
@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*,
match: Optional[Union[str, Pattern]] = ...
) -> "RaisesContext":
) -> "RaisesContext[_E]":
... # pragma: no cover
@overload
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable,
*args: Any,
match: Optional[str] = ...,
**kwargs: Any
) -> Optional[_pytest._code.ExceptionInfo]:
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
... # pragma: no cover
def raises(
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
*args: Any,
match: Optional[Union[str, Pattern]] = None,
**kwargs: Any
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
r"""
Assert that a code block/function call raises ``expected_exception``
or raise a failure exception otherwise.
@@ -703,28 +708,30 @@ def raises(
try:
func(*args[1:], **kwargs)
except expected_exception:
return _pytest._code.ExceptionInfo.from_current()
# Cast to narrow the type to expected_exception (_E).
return cast(
_pytest._code.ExceptionInfo[_E],
_pytest._code.ExceptionInfo.from_current(),
)
fail(message)
raises.Exception = fail.Exception # type: ignore
class RaisesContext:
class RaisesContext(Generic[_E]):
def __init__(
self,
expected_exception: Union[
"Type[BaseException]", Tuple["Type[BaseException]", ...]
],
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
message: str,
match_expr: Optional[Union[str, Pattern]] = None,
) -> None:
self.expected_exception = expected_exception
self.message = message
self.match_expr = match_expr
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
def __enter__(self) -> _pytest._code.ExceptionInfo:
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
self.excinfo = _pytest._code.ExceptionInfo.for_later()
return self.excinfo