This commit is contained in:
John Litborn 2024-01-31 11:53:36 -08:00 committed by GitHub
commit bcf1791740
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 258 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import math
import pprint
import sys
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
@ -29,6 +30,10 @@ from _pytest.outcomes import fail
if TYPE_CHECKING:
from numpy import ndarray
from typing_extensions import TypeGuard
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def _compare_approx(
@ -975,6 +980,108 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore
@final
class RaisesGroup(ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]]):
"""Helper for catching exceptions wrapped in an ExceptionGroup.
Similar to pytest.raises, except:
* It requires that the exception is inside an exceptiongroup
* It is only able to be used as a contextmanager
* Due to the above, is not split into a caller function and a cm class
Similar to trio.RaisesGroup, except:
* does not handle multiple levels of nested groups.
* does not have trio.Matcher, to add matching on the sub-exception
* does not handle multiple exceptions in the exceptiongroup.
TODO: copy over docstring example usage from trio.RaisesGroup
"""
def __init__(
self,
exception: Type[E],
check: Optional[Callable[[BaseExceptionGroup[E]], bool]] = None,
):
# copied from raises() above
if not isinstance(exception, type) or not issubclass(exception, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = (
exception.__name__
if isinstance(exception, type)
else type(exception).__name__
)
raise TypeError(msg.format(not_a))
self.exception = exception
self.check = check
def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]:
self.excinfo: _pytest._code.ExceptionInfo[
BaseExceptionGroup[E]
] = _pytest._code.ExceptionInfo.for_later()
return self.excinfo
def matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
return (
exc_val is not None
and isinstance(exc_val, BaseExceptionGroup)
and len(exc_val.exceptions) == 1
and isinstance(exc_val.exceptions[0], self.exception)
and (self.check is None or self.check(exc_val))
)
def assert_matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
assert (
exc_val is not None
), "Internal Error: exc_type is not None but exc_val is"
assert isinstance(
exc_val, BaseExceptionGroup
), f"Expected an ExceptionGroup, not {type(exc_val)}"
assert (
len(exc_val.exceptions) == 1
), f"Wrong number of exceptions: got {len(exc_val.exceptions)}, expected 1."
assert isinstance(
exc_val.exceptions[0], self.exception
), f"Wrong type in group: got {type(exc_val.exceptions[0])}, expected {self.exception}"
if self.check is not None:
assert self.check(exc_val), f"Check failed on {repr(exc_val)}."
return True
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True
if exc_type is None:
fail("DID NOT RAISE ANY EXCEPTION, expected " + self.expected_type())
assert self.excinfo is not None, "__exit__ without __enter__"
self.assert_matches(exc_val)
# Cast to narrow the exception type now that it's verified.
exc_info = cast(
Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType],
(exc_type, exc_val, exc_tb),
)
self.excinfo.fill_unfilled(exc_info)
return True
def expected_type(self) -> str:
if not issubclass(self.exception, Exception):
base = "Base"
else:
base = ""
return f"{base}ExceptionGroup({self.exception})"
@final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(

View File

@ -57,6 +57,7 @@ from _pytest.python import Module
from _pytest.python import Package
from _pytest.python_api import approx
from _pytest.python_api import raises
from _pytest.python_api import RaisesGroup
from _pytest.recwarn import deprecated_call
from _pytest.recwarn import WarningsRecorder
from _pytest.recwarn import warns
@ -146,6 +147,7 @@ __all__ = [
"PytestUnraisableExceptionWarning",
"PytestWarning",
"raises",
"RaisesGroup",
"RecordedHookCall",
"register_assert_rewrite",
"RunResult",

View File

@ -0,0 +1,149 @@
import re
import sys
from typing import TYPE_CHECKING
import pytest
from _pytest.outcomes import Failed
from pytest import RaisesGroup
if TYPE_CHECKING:
from typing_extensions import assert_type
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def test_raises_group() -> None:
# wrong type to constructor
with pytest.raises(
TypeError,
match="^expected exception must be a BaseException type, not ValueError$",
):
RaisesGroup(ValueError()) # type: ignore[arg-type]
# working example
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (ValueError(),))
with RaisesGroup(ValueError, check=lambda x: True):
raise ExceptionGroup("foo", (ValueError(),))
# wrong subexception
with pytest.raises(
AssertionError,
match="Wrong type in group: got <class 'SyntaxError'>, expected <class 'ValueError'>",
):
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (SyntaxError(),))
# will error if there's excess exceptions
with pytest.raises(
AssertionError, match="Wrong number of exceptions: got 2, expected 1"
):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ValueError(), ValueError()))
# double nested exceptions is not (currently) supported (contrary to expect*)
with pytest.raises(
AssertionError,
match="Wrong type in group: got <class '(exceptiongroup.)?ExceptionGroup'>, expected <class 'ValueError'>",
):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# you'd need to write
with RaisesGroup(ExceptionGroup) as excinfo:
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
RaisesGroup(ValueError).assert_matches(excinfo.value.exceptions[0])
# unwrapped exceptions are not accepted (contrary to expect*)
with pytest.raises(
AssertionError, match="Expected an ExceptionGroup, not <class 'ValueError'."
):
with RaisesGroup(ValueError):
raise ValueError
with pytest.raises(
AssertionError,
match=re.escape("Check failed on ExceptionGroup('foo', (ValueError(),))."),
):
with RaisesGroup(ValueError, check=lambda x: False):
raise ExceptionGroup("foo", (ValueError(),))
def test_RaisesGroup_matches() -> None:
eeg = RaisesGroup(ValueError)
# exc_val is None
assert not eeg.matches(None)
# exc_val is not an exceptiongroup
assert not eeg.matches(ValueError())
# wrong length
assert not eeg.matches(ExceptionGroup("", (ValueError(), ValueError())))
# wrong type
assert not eeg.matches(ExceptionGroup("", (TypeError(),)))
# check fails
assert not RaisesGroup(ValueError, check=lambda _: False).matches(
ExceptionGroup("", (ValueError(),))
)
# success
assert eeg.matches(ExceptionGroup("", (ValueError(),)))
def test_RaisesGroup_assert_matches() -> None:
"""Check direct use of RaisesGroup.assert_matches, without a context manager"""
eeg = RaisesGroup(ValueError)
with pytest.raises(AssertionError):
eeg.assert_matches(None)
with pytest.raises(AssertionError):
eeg.assert_matches(ValueError())
eeg.assert_matches(ExceptionGroup("", (ValueError(),)))
def test_message() -> None:
with pytest.raises(
Failed,
match=re.escape(
f"DID NOT RAISE ANY EXCEPTION, expected ExceptionGroup({repr(ValueError)})"
),
):
with RaisesGroup(ValueError):
...
with pytest.raises(
Failed,
match=re.escape(
f"DID NOT RAISE ANY EXCEPTION, expected BaseExceptionGroup({repr(KeyboardInterrupt)})"
),
):
with RaisesGroup(KeyboardInterrupt):
...
if TYPE_CHECKING:
def test_types_1() -> None:
with RaisesGroup(ValueError) as e:
raise ExceptionGroup("foo", (ValueError(),))
assert_type(e.value, BaseExceptionGroup[ValueError])
def test_types_2() -> None:
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
"", (ValueError(),)
)
if RaisesGroup(ValueError).assert_matches(exc):
assert_type(exc, BaseExceptionGroup[ValueError])
def test_types_3() -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
if RaisesGroup(ValueError).matches(e):
assert_type(e, BaseExceptionGroup[ValueError])
def test_types_4() -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
# not currently possible: https://github.com/python/typing/issues/930
RaisesGroup(ValueError).assert_matches(e)
assert_type(e, BaseExceptionGroup[ValueError]) # type: ignore[assert-type]