This commit is contained in:
John Litborn 2024-01-31 11:53:36 -08:00 committed by GitHub
commit bcf1791740
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 258 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import math import math
import pprint import pprint
import sys
from collections.abc import Collection from collections.abc import Collection
from collections.abc import Sized from collections.abc import Sized
from decimal import Decimal from decimal import Decimal
@ -29,6 +30,10 @@ from _pytest.outcomes import fail
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy import ndarray from numpy import ndarray
from typing_extensions import TypeGuard
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def _compare_approx( def _compare_approx(
@ -975,6 +980,108 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore raises.Exception = fail.Exception # type: ignore
@final
class RaisesGroup(ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]]):
"""Helper for catching exceptions wrapped in an ExceptionGroup.
Similar to pytest.raises, except:
* It requires that the exception is inside an exceptiongroup
* It is only able to be used as a contextmanager
* Due to the above, is not split into a caller function and a cm class
Similar to trio.RaisesGroup, except:
* does not handle multiple levels of nested groups.
* does not have trio.Matcher, to add matching on the sub-exception
* does not handle multiple exceptions in the exceptiongroup.
TODO: copy over docstring example usage from trio.RaisesGroup
"""
def __init__(
self,
exception: Type[E],
check: Optional[Callable[[BaseExceptionGroup[E]], bool]] = None,
):
# copied from raises() above
if not isinstance(exception, type) or not issubclass(exception, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = (
exception.__name__
if isinstance(exception, type)
else type(exception).__name__
)
raise TypeError(msg.format(not_a))
self.exception = exception
self.check = check
def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]:
self.excinfo: _pytest._code.ExceptionInfo[
BaseExceptionGroup[E]
] = _pytest._code.ExceptionInfo.for_later()
return self.excinfo
def matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
return (
exc_val is not None
and isinstance(exc_val, BaseExceptionGroup)
and len(exc_val.exceptions) == 1
and isinstance(exc_val.exceptions[0], self.exception)
and (self.check is None or self.check(exc_val))
)
def assert_matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
assert (
exc_val is not None
), "Internal Error: exc_type is not None but exc_val is"
assert isinstance(
exc_val, BaseExceptionGroup
), f"Expected an ExceptionGroup, not {type(exc_val)}"
assert (
len(exc_val.exceptions) == 1
), f"Wrong number of exceptions: got {len(exc_val.exceptions)}, expected 1."
assert isinstance(
exc_val.exceptions[0], self.exception
), f"Wrong type in group: got {type(exc_val.exceptions[0])}, expected {self.exception}"
if self.check is not None:
assert self.check(exc_val), f"Check failed on {repr(exc_val)}."
return True
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True
if exc_type is None:
fail("DID NOT RAISE ANY EXCEPTION, expected " + self.expected_type())
assert self.excinfo is not None, "__exit__ without __enter__"
self.assert_matches(exc_val)
# Cast to narrow the exception type now that it's verified.
exc_info = cast(
Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType],
(exc_type, exc_val, exc_tb),
)
self.excinfo.fill_unfilled(exc_info)
return True
def expected_type(self) -> str:
if not issubclass(self.exception, Exception):
base = "Base"
else:
base = ""
return f"{base}ExceptionGroup({self.exception})"
@final @final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__( def __init__(

View File

@ -57,6 +57,7 @@ from _pytest.python import Module
from _pytest.python import Package from _pytest.python import Package
from _pytest.python_api import approx from _pytest.python_api import approx
from _pytest.python_api import raises from _pytest.python_api import raises
from _pytest.python_api import RaisesGroup
from _pytest.recwarn import deprecated_call from _pytest.recwarn import deprecated_call
from _pytest.recwarn import WarningsRecorder from _pytest.recwarn import WarningsRecorder
from _pytest.recwarn import warns from _pytest.recwarn import warns
@ -146,6 +147,7 @@ __all__ = [
"PytestUnraisableExceptionWarning", "PytestUnraisableExceptionWarning",
"PytestWarning", "PytestWarning",
"raises", "raises",
"RaisesGroup",
"RecordedHookCall", "RecordedHookCall",
"register_assert_rewrite", "register_assert_rewrite",
"RunResult", "RunResult",

View File

@ -0,0 +1,149 @@
import re
import sys
from typing import TYPE_CHECKING
import pytest
from _pytest.outcomes import Failed
from pytest import RaisesGroup
if TYPE_CHECKING:
from typing_extensions import assert_type
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def test_raises_group() -> None:
# wrong type to constructor
with pytest.raises(
TypeError,
match="^expected exception must be a BaseException type, not ValueError$",
):
RaisesGroup(ValueError()) # type: ignore[arg-type]
# working example
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (ValueError(),))
with RaisesGroup(ValueError, check=lambda x: True):
raise ExceptionGroup("foo", (ValueError(),))
# wrong subexception
with pytest.raises(
AssertionError,
match="Wrong type in group: got <class 'SyntaxError'>, expected <class 'ValueError'>",
):
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (SyntaxError(),))
# will error if there's excess exceptions
with pytest.raises(
AssertionError, match="Wrong number of exceptions: got 2, expected 1"
):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ValueError(), ValueError()))
# double nested exceptions is not (currently) supported (contrary to expect*)
with pytest.raises(
AssertionError,
match="Wrong type in group: got <class '(exceptiongroup.)?ExceptionGroup'>, expected <class 'ValueError'>",
):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# you'd need to write
with RaisesGroup(ExceptionGroup) as excinfo:
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
RaisesGroup(ValueError).assert_matches(excinfo.value.exceptions[0])
# unwrapped exceptions are not accepted (contrary to expect*)
with pytest.raises(
AssertionError, match="Expected an ExceptionGroup, not <class 'ValueError'."
):
with RaisesGroup(ValueError):
raise ValueError
with pytest.raises(
AssertionError,
match=re.escape("Check failed on ExceptionGroup('foo', (ValueError(),))."),
):
with RaisesGroup(ValueError, check=lambda x: False):
raise ExceptionGroup("foo", (ValueError(),))
def test_RaisesGroup_matches() -> None:
eeg = RaisesGroup(ValueError)
# exc_val is None
assert not eeg.matches(None)
# exc_val is not an exceptiongroup
assert not eeg.matches(ValueError())
# wrong length
assert not eeg.matches(ExceptionGroup("", (ValueError(), ValueError())))
# wrong type
assert not eeg.matches(ExceptionGroup("", (TypeError(),)))
# check fails
assert not RaisesGroup(ValueError, check=lambda _: False).matches(
ExceptionGroup("", (ValueError(),))
)
# success
assert eeg.matches(ExceptionGroup("", (ValueError(),)))
def test_RaisesGroup_assert_matches() -> None:
"""Check direct use of RaisesGroup.assert_matches, without a context manager"""
eeg = RaisesGroup(ValueError)
with pytest.raises(AssertionError):
eeg.assert_matches(None)
with pytest.raises(AssertionError):
eeg.assert_matches(ValueError())
eeg.assert_matches(ExceptionGroup("", (ValueError(),)))
def test_message() -> None:
with pytest.raises(
Failed,
match=re.escape(
f"DID NOT RAISE ANY EXCEPTION, expected ExceptionGroup({repr(ValueError)})"
),
):
with RaisesGroup(ValueError):
...
with pytest.raises(
Failed,
match=re.escape(
f"DID NOT RAISE ANY EXCEPTION, expected BaseExceptionGroup({repr(KeyboardInterrupt)})"
),
):
with RaisesGroup(KeyboardInterrupt):
...
if TYPE_CHECKING:
def test_types_1() -> None:
with RaisesGroup(ValueError) as e:
raise ExceptionGroup("foo", (ValueError(),))
assert_type(e.value, BaseExceptionGroup[ValueError])
def test_types_2() -> None:
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
"", (ValueError(),)
)
if RaisesGroup(ValueError).assert_matches(exc):
assert_type(exc, BaseExceptionGroup[ValueError])
def test_types_3() -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
if RaisesGroup(ValueError).matches(e):
assert_type(e, BaseExceptionGroup[ValueError])
def test_types_4() -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
# not currently possible: https://github.com/python/typing/issues/930
RaisesGroup(ValueError).assert_matches(e)
assert_type(e, BaseExceptionGroup[ValueError]) # type: ignore[assert-type]