Merge 7d90cb64b7
into 4546d5445a
This commit is contained in:
commit
bcf1791740
|
@ -1,6 +1,7 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import math
|
||||
import pprint
|
||||
import sys
|
||||
from collections.abc import Collection
|
||||
from collections.abc import Sized
|
||||
from decimal import Decimal
|
||||
|
@ -29,6 +30,10 @@ from _pytest.outcomes import fail
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from numpy import ndarray
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
def _compare_approx(
|
||||
|
@ -975,6 +980,108 @@ def raises( # noqa: F811
|
|||
raises.Exception = fail.Exception # type: ignore
|
||||
|
||||
|
||||
@final
|
||||
class RaisesGroup(ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]]):
|
||||
"""Helper for catching exceptions wrapped in an ExceptionGroup.
|
||||
|
||||
Similar to pytest.raises, except:
|
||||
* It requires that the exception is inside an exceptiongroup
|
||||
* It is only able to be used as a contextmanager
|
||||
* Due to the above, is not split into a caller function and a cm class
|
||||
Similar to trio.RaisesGroup, except:
|
||||
* does not handle multiple levels of nested groups.
|
||||
* does not have trio.Matcher, to add matching on the sub-exception
|
||||
* does not handle multiple exceptions in the exceptiongroup.
|
||||
|
||||
TODO: copy over docstring example usage from trio.RaisesGroup
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exception: Type[E],
|
||||
check: Optional[Callable[[BaseExceptionGroup[E]], bool]] = None,
|
||||
):
|
||||
# copied from raises() above
|
||||
if not isinstance(exception, type) or not issubclass(exception, BaseException):
|
||||
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
|
||||
not_a = (
|
||||
exception.__name__
|
||||
if isinstance(exception, type)
|
||||
else type(exception).__name__
|
||||
)
|
||||
raise TypeError(msg.format(not_a))
|
||||
|
||||
self.exception = exception
|
||||
self.check = check
|
||||
|
||||
def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]:
|
||||
self.excinfo: _pytest._code.ExceptionInfo[
|
||||
BaseExceptionGroup[E]
|
||||
] = _pytest._code.ExceptionInfo.for_later()
|
||||
return self.excinfo
|
||||
|
||||
def matches(
|
||||
self,
|
||||
exc_val: Optional[BaseException],
|
||||
) -> "TypeGuard[BaseExceptionGroup[E]]":
|
||||
return (
|
||||
exc_val is not None
|
||||
and isinstance(exc_val, BaseExceptionGroup)
|
||||
and len(exc_val.exceptions) == 1
|
||||
and isinstance(exc_val.exceptions[0], self.exception)
|
||||
and (self.check is None or self.check(exc_val))
|
||||
)
|
||||
|
||||
def assert_matches(
|
||||
self,
|
||||
exc_val: Optional[BaseException],
|
||||
) -> "TypeGuard[BaseExceptionGroup[E]]":
|
||||
assert (
|
||||
exc_val is not None
|
||||
), "Internal Error: exc_type is not None but exc_val is"
|
||||
assert isinstance(
|
||||
exc_val, BaseExceptionGroup
|
||||
), f"Expected an ExceptionGroup, not {type(exc_val)}"
|
||||
assert (
|
||||
len(exc_val.exceptions) == 1
|
||||
), f"Wrong number of exceptions: got {len(exc_val.exceptions)}, expected 1."
|
||||
assert isinstance(
|
||||
exc_val.exceptions[0], self.exception
|
||||
), f"Wrong type in group: got {type(exc_val.exceptions[0])}, expected {self.exception}"
|
||||
if self.check is not None:
|
||||
assert self.check(exc_val), f"Check failed on {repr(exc_val)}."
|
||||
|
||||
return True
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[TracebackType],
|
||||
) -> bool:
|
||||
__tracebackhide__ = True
|
||||
if exc_type is None:
|
||||
fail("DID NOT RAISE ANY EXCEPTION, expected " + self.expected_type())
|
||||
assert self.excinfo is not None, "__exit__ without __enter__"
|
||||
|
||||
self.assert_matches(exc_val)
|
||||
|
||||
# Cast to narrow the exception type now that it's verified.
|
||||
exc_info = cast(
|
||||
Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType],
|
||||
(exc_type, exc_val, exc_tb),
|
||||
)
|
||||
self.excinfo.fill_unfilled(exc_info)
|
||||
return True
|
||||
|
||||
def expected_type(self) -> str:
|
||||
if not issubclass(self.exception, Exception):
|
||||
base = "Base"
|
||||
else:
|
||||
base = ""
|
||||
return f"{base}ExceptionGroup({self.exception})"
|
||||
|
||||
|
||||
@final
|
||||
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
|
||||
def __init__(
|
||||
|
|
|
@ -57,6 +57,7 @@ from _pytest.python import Module
|
|||
from _pytest.python import Package
|
||||
from _pytest.python_api import approx
|
||||
from _pytest.python_api import raises
|
||||
from _pytest.python_api import RaisesGroup
|
||||
from _pytest.recwarn import deprecated_call
|
||||
from _pytest.recwarn import WarningsRecorder
|
||||
from _pytest.recwarn import warns
|
||||
|
@ -146,6 +147,7 @@ __all__ = [
|
|||
"PytestUnraisableExceptionWarning",
|
||||
"PytestWarning",
|
||||
"raises",
|
||||
"RaisesGroup",
|
||||
"RecordedHookCall",
|
||||
"register_assert_rewrite",
|
||||
"RunResult",
|
||||
|
|
|
@ -0,0 +1,149 @@
|
|||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from _pytest.outcomes import Failed
|
||||
from pytest import RaisesGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import assert_type
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
|
||||
def test_raises_group() -> None:
|
||||
# wrong type to constructor
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match="^expected exception must be a BaseException type, not ValueError$",
|
||||
):
|
||||
RaisesGroup(ValueError()) # type: ignore[arg-type]
|
||||
|
||||
# working example
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
with RaisesGroup(ValueError, check=lambda x: True):
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
# wrong subexception
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Wrong type in group: got <class 'SyntaxError'>, expected <class 'ValueError'>",
|
||||
):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (SyntaxError(),))
|
||||
|
||||
# will error if there's excess exceptions
|
||||
with pytest.raises(
|
||||
AssertionError, match="Wrong number of exceptions: got 2, expected 1"
|
||||
):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (ValueError(), ValueError()))
|
||||
|
||||
# double nested exceptions is not (currently) supported (contrary to expect*)
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="Wrong type in group: got <class '(exceptiongroup.)?ExceptionGroup'>, expected <class 'ValueError'>",
|
||||
):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
|
||||
# you'd need to write
|
||||
with RaisesGroup(ExceptionGroup) as excinfo:
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
RaisesGroup(ValueError).assert_matches(excinfo.value.exceptions[0])
|
||||
|
||||
# unwrapped exceptions are not accepted (contrary to expect*)
|
||||
with pytest.raises(
|
||||
AssertionError, match="Expected an ExceptionGroup, not <class 'ValueError'."
|
||||
):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ValueError
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=re.escape("Check failed on ExceptionGroup('foo', (ValueError(),))."),
|
||||
):
|
||||
with RaisesGroup(ValueError, check=lambda x: False):
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
|
||||
def test_RaisesGroup_matches() -> None:
|
||||
eeg = RaisesGroup(ValueError)
|
||||
# exc_val is None
|
||||
assert not eeg.matches(None)
|
||||
# exc_val is not an exceptiongroup
|
||||
assert not eeg.matches(ValueError())
|
||||
# wrong length
|
||||
assert not eeg.matches(ExceptionGroup("", (ValueError(), ValueError())))
|
||||
# wrong type
|
||||
assert not eeg.matches(ExceptionGroup("", (TypeError(),)))
|
||||
# check fails
|
||||
assert not RaisesGroup(ValueError, check=lambda _: False).matches(
|
||||
ExceptionGroup("", (ValueError(),))
|
||||
)
|
||||
# success
|
||||
assert eeg.matches(ExceptionGroup("", (ValueError(),)))
|
||||
|
||||
|
||||
def test_RaisesGroup_assert_matches() -> None:
|
||||
"""Check direct use of RaisesGroup.assert_matches, without a context manager"""
|
||||
eeg = RaisesGroup(ValueError)
|
||||
with pytest.raises(AssertionError):
|
||||
eeg.assert_matches(None)
|
||||
with pytest.raises(AssertionError):
|
||||
eeg.assert_matches(ValueError())
|
||||
eeg.assert_matches(ExceptionGroup("", (ValueError(),)))
|
||||
|
||||
|
||||
def test_message() -> None:
|
||||
with pytest.raises(
|
||||
Failed,
|
||||
match=re.escape(
|
||||
f"DID NOT RAISE ANY EXCEPTION, expected ExceptionGroup({repr(ValueError)})"
|
||||
),
|
||||
):
|
||||
with RaisesGroup(ValueError):
|
||||
...
|
||||
|
||||
with pytest.raises(
|
||||
Failed,
|
||||
match=re.escape(
|
||||
f"DID NOT RAISE ANY EXCEPTION, expected BaseExceptionGroup({repr(KeyboardInterrupt)})"
|
||||
),
|
||||
):
|
||||
with RaisesGroup(KeyboardInterrupt):
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def test_types_1() -> None:
|
||||
with RaisesGroup(ValueError) as e:
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
assert_type(e.value, BaseExceptionGroup[ValueError])
|
||||
|
||||
def test_types_2() -> None:
|
||||
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
|
||||
"", (ValueError(),)
|
||||
)
|
||||
if RaisesGroup(ValueError).assert_matches(exc):
|
||||
assert_type(exc, BaseExceptionGroup[ValueError])
|
||||
|
||||
def test_types_3() -> None:
|
||||
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
|
||||
"", (KeyboardInterrupt(),)
|
||||
)
|
||||
if RaisesGroup(ValueError).matches(e):
|
||||
assert_type(e, BaseExceptionGroup[ValueError])
|
||||
|
||||
def test_types_4() -> None:
|
||||
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
|
||||
"", (KeyboardInterrupt(),)
|
||||
)
|
||||
# not currently possible: https://github.com/python/typing/issues/930
|
||||
RaisesGroup(ValueError).assert_matches(e)
|
||||
assert_type(e, BaseExceptionGroup[ValueError]) # type: ignore[assert-type]
|
Loading…
Reference in New Issue