diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index f398902ff..a33c38700 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import math import pprint +import sys from collections.abc import Collection from collections.abc import Sized from decimal import Decimal @@ -29,6 +30,10 @@ from _pytest.outcomes import fail if TYPE_CHECKING: from numpy import ndarray + from typing_extensions import TypeGuard + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup def _compare_approx( @@ -975,6 +980,108 @@ def raises( # noqa: F811 raises.Exception = fail.Exception # type: ignore +@final +class RaisesGroup(ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]]): + """Helper for catching exceptions wrapped in an ExceptionGroup. + + Similar to pytest.raises, except: + * It requires that the exception is inside an exceptiongroup + * It is only able to be used as a contextmanager + * Due to the above, is not split into a caller function and a cm class + Similar to trio.RaisesGroup, except: + * does not handle multiple levels of nested groups. + * does not have trio.Matcher, to add matching on the sub-exception + * does not handle multiple exceptions in the exceptiongroup. + + TODO: copy over docstring example usage from trio.RaisesGroup + """ + + def __init__( + self, + exception: Type[E], + check: Optional[Callable[[BaseExceptionGroup[E]], bool]] = None, + ): + # copied from raises() above + if not isinstance(exception, type) or not issubclass(exception, BaseException): + msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable] + not_a = ( + exception.__name__ + if isinstance(exception, type) + else type(exception).__name__ + ) + raise TypeError(msg.format(not_a)) + + self.exception = exception + self.check = check + + def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]: + self.excinfo: _pytest._code.ExceptionInfo[ + BaseExceptionGroup[E] + ] = _pytest._code.ExceptionInfo.for_later() + return self.excinfo + + def matches( + self, + exc_val: Optional[BaseException], + ) -> "TypeGuard[BaseExceptionGroup[E]]": + return ( + exc_val is not None + and isinstance(exc_val, BaseExceptionGroup) + and len(exc_val.exceptions) == 1 + and isinstance(exc_val.exceptions[0], self.exception) + and (self.check is None or self.check(exc_val)) + ) + + def assert_matches( + self, + exc_val: Optional[BaseException], + ) -> "TypeGuard[BaseExceptionGroup[E]]": + assert ( + exc_val is not None + ), "Internal Error: exc_type is not None but exc_val is" + assert isinstance( + exc_val, BaseExceptionGroup + ), f"Expected an ExceptionGroup, not {type(exc_val)}" + assert ( + len(exc_val.exceptions) == 1 + ), f"Wrong number of exceptions: got {len(exc_val.exceptions)}, expected 1." + assert isinstance( + exc_val.exceptions[0], self.exception + ), f"Wrong type in group: got {type(exc_val.exceptions[0])}, expected {self.exception}" + if self.check is not None: + assert self.check(exc_val), f"Check failed on {repr(exc_val)}." + + return True + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + __tracebackhide__ = True + if exc_type is None: + fail("DID NOT RAISE ANY EXCEPTION, expected " + self.expected_type()) + assert self.excinfo is not None, "__exit__ without __enter__" + + self.assert_matches(exc_val) + + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType], + (exc_type, exc_val, exc_tb), + ) + self.excinfo.fill_unfilled(exc_info) + return True + + def expected_type(self) -> str: + if not issubclass(self.exception, Exception): + base = "Base" + else: + base = "" + return f"{base}ExceptionGroup({self.exception})" + + @final class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): def __init__( diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index 449cb39b8..9e84a11dc 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -57,6 +57,7 @@ from _pytest.python import Module from _pytest.python import Package from _pytest.python_api import approx from _pytest.python_api import raises +from _pytest.python_api import RaisesGroup from _pytest.recwarn import deprecated_call from _pytest.recwarn import WarningsRecorder from _pytest.recwarn import warns @@ -146,6 +147,7 @@ __all__ = [ "PytestUnraisableExceptionWarning", "PytestWarning", "raises", + "RaisesGroup", "RecordedHookCall", "register_assert_rewrite", "RunResult", diff --git a/testing/python/raises_group.py b/testing/python/raises_group.py new file mode 100644 index 000000000..b6edf3d82 --- /dev/null +++ b/testing/python/raises_group.py @@ -0,0 +1,149 @@ +import re +import sys +from typing import TYPE_CHECKING + +import pytest +from _pytest.outcomes import Failed +from pytest import RaisesGroup + +if TYPE_CHECKING: + from typing_extensions import assert_type + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +def test_raises_group() -> None: + # wrong type to constructor + with pytest.raises( + TypeError, + match="^expected exception must be a BaseException type, not ValueError$", + ): + RaisesGroup(ValueError()) # type: ignore[arg-type] + + # working example + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (ValueError(),)) + + with RaisesGroup(ValueError, check=lambda x: True): + raise ExceptionGroup("foo", (ValueError(),)) + + # wrong subexception + with pytest.raises( + AssertionError, + match="Wrong type in group: got , expected ", + ): + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (SyntaxError(),)) + + # will error if there's excess exceptions + with pytest.raises( + AssertionError, match="Wrong number of exceptions: got 2, expected 1" + ): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ValueError(), ValueError())) + + # double nested exceptions is not (currently) supported (contrary to expect*) + with pytest.raises( + AssertionError, + match="Wrong type in group: got , expected ", + ): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # you'd need to write + with RaisesGroup(ExceptionGroup) as excinfo: + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + RaisesGroup(ValueError).assert_matches(excinfo.value.exceptions[0]) + + # unwrapped exceptions are not accepted (contrary to expect*) + with pytest.raises( + AssertionError, match="Expected an ExceptionGroup, not None: + eeg = RaisesGroup(ValueError) + # exc_val is None + assert not eeg.matches(None) + # exc_val is not an exceptiongroup + assert not eeg.matches(ValueError()) + # wrong length + assert not eeg.matches(ExceptionGroup("", (ValueError(), ValueError()))) + # wrong type + assert not eeg.matches(ExceptionGroup("", (TypeError(),))) + # check fails + assert not RaisesGroup(ValueError, check=lambda _: False).matches( + ExceptionGroup("", (ValueError(),)) + ) + # success + assert eeg.matches(ExceptionGroup("", (ValueError(),))) + + +def test_RaisesGroup_assert_matches() -> None: + """Check direct use of RaisesGroup.assert_matches, without a context manager""" + eeg = RaisesGroup(ValueError) + with pytest.raises(AssertionError): + eeg.assert_matches(None) + with pytest.raises(AssertionError): + eeg.assert_matches(ValueError()) + eeg.assert_matches(ExceptionGroup("", (ValueError(),))) + + +def test_message() -> None: + with pytest.raises( + Failed, + match=re.escape( + f"DID NOT RAISE ANY EXCEPTION, expected ExceptionGroup({repr(ValueError)})" + ), + ): + with RaisesGroup(ValueError): + ... + + with pytest.raises( + Failed, + match=re.escape( + f"DID NOT RAISE ANY EXCEPTION, expected BaseExceptionGroup({repr(KeyboardInterrupt)})" + ), + ): + with RaisesGroup(KeyboardInterrupt): + ... + + +if TYPE_CHECKING: + + def test_types_1() -> None: + with RaisesGroup(ValueError) as e: + raise ExceptionGroup("foo", (ValueError(),)) + assert_type(e.value, BaseExceptionGroup[ValueError]) + + def test_types_2() -> None: + exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup( + "", (ValueError(),) + ) + if RaisesGroup(ValueError).assert_matches(exc): + assert_type(exc, BaseExceptionGroup[ValueError]) + + def test_types_3() -> None: + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + if RaisesGroup(ValueError).matches(e): + assert_type(e, BaseExceptionGroup[ValueError]) + + def test_types_4() -> None: + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + # not currently possible: https://github.com/python/typing/issues/930 + RaisesGroup(ValueError).assert_matches(e) + assert_type(e, BaseExceptionGroup[ValueError]) # type: ignore[assert-type]