diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index d7594cbf9..7231b28af 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -1,6 +1,5 @@ import math import pprint -import re import sys from collections.abc import Collection from collections.abc import Sized @@ -12,7 +11,6 @@ from typing import Callable from typing import cast from typing import ContextManager from typing import final -from typing import Generic from typing import Iterable from typing import List from typing import Mapping @@ -995,66 +993,39 @@ def raises( # noqa: F811 raises.Exception = fail.Exception # type: ignore -class Matcher(Generic[E]): - def __init__( - self, - exception_type: Optional[Type[E]] = None, - match: Optional[Union[str, Pattern[str]]] = None, - check: Optional[Callable[[E], bool]] = None, - ): - if exception_type is None and match is None and check is None: - raise ValueError("You must specify at least one parameter to match on.") - self.exception_type = exception_type - self.match = match - self.check = check - - def matches(self, exception: E) -> "TypeGuard[E]": - if self.exception_type is not None and not isinstance( - exception, self.exception_type - ): - return False - if self.match is not None and not re.search(self.match, str(exception)): - return False - if self.check is not None and not self.check(exception): - return False - return True - - -if TYPE_CHECKING: - SuperClass = BaseExceptionGroup -else: - SuperClass = Generic - - @final -class RaisesGroup( - ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E] -): - # My_T = TypeVar("My_T", bound=Union[Type[E], Matcher[E], "RaisesGroup[E]"]) +class RaisesGroup(ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]]): + """Helper for catching exceptions wrapped in an ExceptionGroup. + + Similar to pytest.raises, except: + * It requires that the exception is inside an exceptiongroup + * It is only able to be used as a contextmanager + * Due to the above, is not split into a caller function and a cm class + Similar to trio.RaisesGroup, except: + * does not handle multiple levels of nested groups. + * does not have trio.Matcher, to add matching on the sub-exception + * does not handle multiple exceptions in the exceptiongroup. + + TODO: copy over docstring example usage from trio.RaisesGroup + """ + def __init__( self, - exceptions: Union[Type[E], Matcher[E], E], - *args: Union[Type[E], Matcher[E], E], - strict: bool = True, - match: Optional[Union[str, Pattern[str]]] = None, + exception: Type[E], + check: Optional[Callable[[BaseExceptionGroup[E]], bool]] = None, ): - # could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None` - self.expected_exceptions = (exceptions, *args) - self.strict = strict - self.match_expr = match - self.message = f"DID NOT RAISE ExceptionGroup{repr(self.expected_exceptions)}" # type: ignore[misc] + # copied from raises() above + if not isinstance(exception, type) or not issubclass(exception, BaseException): + msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable] + not_a = ( + exception.__name__ + if isinstance(exception, type) + else type(exception).__name__ + ) + raise TypeError(msg.format(not_a)) - for exc in self.expected_exceptions: - if not isinstance(exc, (Matcher, RaisesGroup)) and not ( - isinstance(exc, type) and issubclass(exc, BaseException) - ): - raise ValueError( - "Invalid argument {exc} must be exception type, Matcher, or RaisesGroup." - ) - if isinstance(exc, RaisesGroup) and not strict: # type: ignore[unreachable] - raise ValueError( - "You cannot specify a nested structure inside a RaisesGroup with strict=False" - ) + self.exception = exception + self.check = check def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]: self.excinfo: _pytest._code.ExceptionInfo[ @@ -1078,41 +1049,33 @@ class RaisesGroup( self, exc_val: Optional[BaseException], ) -> "TypeGuard[BaseExceptionGroup[E]]": - if exc_val is None: - return False - if not isinstance(exc_val, BaseExceptionGroup): - return False - if not len(exc_val.exceptions) == len(self.expected_exceptions): - return False - remaining_exceptions = list(self.expected_exceptions) - actual_exceptions: Iterable[BaseException] = exc_val.exceptions - if not self.strict: - actual_exceptions = self._unroll_exceptions(actual_exceptions) + return ( + exc_val is not None + and isinstance(exc_val, BaseExceptionGroup) + and len(exc_val.exceptions) == 1 + and isinstance(exc_val.exceptions[0], self.exception) + and (self.check is None or self.check(exc_val)) + ) + + def assert_matches( + self, + exc_val: Optional[BaseException], + ) -> "TypeGuard[BaseExceptionGroup[E]]": + assert ( + exc_val is not None + ), "Internal Error: exc_type is not None but exc_val is" + assert isinstance( + exc_val, BaseExceptionGroup + ), f"Expected an ExceptionGroup, not {type(exc_val)}" + assert ( + len(exc_val.exceptions) == 1 + ), f"Wrong number of exceptions: got {len(exc_val.exceptions)}, expected 1." + assert isinstance( + exc_val.exceptions[0], self.exception + ), f"Wrong type in group: got {type(exc_val.exceptions[0])}, expected {self.exception}" + if self.check is not None: + assert self.check(exc_val), f"Check failed on {repr(exc_val)}." - # it should be possible to get RaisesGroup.matches typed so as not to - # need these type: ignores, but I'm not sure that's possible while also having it - # transparent for the end user. - for e in actual_exceptions: - for rem_e in remaining_exceptions: - # TODO: how to print string diff on mismatch? - # Probably accumulate them, and then if fail, print them - # Further QoL would be to print how the exception structure differs on non-match - if ( - (isinstance(rem_e, type) and isinstance(e, rem_e)) - or ( - isinstance(e, BaseExceptionGroup) - and isinstance(rem_e, RaisesGroup) - and rem_e.matches(e) - ) - or ( - isinstance(rem_e, Matcher) - and rem_e.matches(e) # type: ignore[arg-type] - ) - ): - remaining_exceptions.remove(rem_e) # type: ignore[arg-type] - break - else: - return False return True def __exit__( @@ -1123,11 +1086,10 @@ class RaisesGroup( ) -> bool: __tracebackhide__ = True if exc_type is None: - fail(self.message) - assert self.excinfo is not None + fail("DID NOT RAISE ANY EXCEPTION, expected " + self.expected_type()) + assert self.excinfo is not None, "__exit__ without __enter__" - if not self.matches(exc_val): - return False + self.assert_matches(exc_val) # Cast to narrow the exception type now that it's verified. exc_info = cast( @@ -1135,13 +1097,14 @@ class RaisesGroup( (exc_type, exc_val, exc_tb), ) self.excinfo.fill_unfilled(exc_info) - if self.match_expr is not None: - self.excinfo.match(self.match_expr) return True - def __repr__(self) -> str: - # TODO: [Base]ExceptionGroup - return f"ExceptionGroup{self.expected_exceptions}" + def expected_type(self) -> str: + if not issubclass(self.exception, Exception): + base = "Base" + else: + base = "" + return f"{base}ExceptionGroup({self.exception})" @final diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index 0aa496a2f..238292992 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -57,6 +57,7 @@ from _pytest.python import Module from _pytest.python import Package from _pytest.python_api import approx from _pytest.python_api import raises +from _pytest.python_api import RaisesGroup from _pytest.recwarn import deprecated_call from _pytest.recwarn import WarningsRecorder from _pytest.recwarn import warns @@ -146,6 +147,7 @@ __all__ = [ "PytestUnraisableExceptionWarning", "PytestWarning", "raises", + "RaisesGroup", "RecordedHookCall", "register_assert_rewrite", "RunResult", diff --git a/testing/python/expected_exception_group.py b/testing/python/expected_exception_group.py index bea04acfc..b6edf3d82 100644 --- a/testing/python/expected_exception_group.py +++ b/testing/python/expected_exception_group.py @@ -1,11 +1,10 @@ +import re import sys from typing import TYPE_CHECKING import pytest -from _pytest.python_api import Matcher -from _pytest.python_api import RaisesGroup - -# TODO: make a public export +from _pytest.outcomes import Failed +from pytest import RaisesGroup if TYPE_CHECKING: from typing_extensions import assert_type @@ -14,237 +13,137 @@ if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup -class TestRaisesGroup: - def test_raises_group(self) -> None: - with pytest.raises( - ValueError, - match="^Invalid argument {exc} must be exception type, Matcher, or RaisesGroup.$", - ): - RaisesGroup(ValueError()) +def test_raises_group() -> None: + # wrong type to constructor + with pytest.raises( + TypeError, + match="^expected exception must be a BaseException type, not ValueError$", + ): + RaisesGroup(ValueError()) # type: ignore[arg-type] + # working example + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (ValueError(),)) + + with RaisesGroup(ValueError, check=lambda x: True): + raise ExceptionGroup("foo", (ValueError(),)) + + # wrong subexception + with pytest.raises( + AssertionError, + match="Wrong type in group: got , expected ", + ): with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (SyntaxError(),)) + + # will error if there's excess exceptions + with pytest.raises( + AssertionError, match="Wrong number of exceptions: got 2, expected 1" + ): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ValueError(), ValueError())) + + # double nested exceptions is not (currently) supported (contrary to expect*) + with pytest.raises( + AssertionError, + match="Wrong type in group: got , expected ", + ): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # you'd need to write + with RaisesGroup(ExceptionGroup) as excinfo: + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + RaisesGroup(ValueError).assert_matches(excinfo.value.exceptions[0]) + + # unwrapped exceptions are not accepted (contrary to expect*) + with pytest.raises( + AssertionError, match="Expected an ExceptionGroup, not None: + eeg = RaisesGroup(ValueError) + # exc_val is None + assert not eeg.matches(None) + # exc_val is not an exceptiongroup + assert not eeg.matches(ValueError()) + # wrong length + assert not eeg.matches(ExceptionGroup("", (ValueError(), ValueError()))) + # wrong type + assert not eeg.matches(ExceptionGroup("", (TypeError(),))) + # check fails + assert not RaisesGroup(ValueError, check=lambda _: False).matches( + ExceptionGroup("", (ValueError(),)) + ) + # success + assert eeg.matches(ExceptionGroup("", (ValueError(),))) - # order doesn't matter - with RaisesGroup(SyntaxError, ValueError): - raise ExceptionGroup("foo", (ValueError(), SyntaxError())) - # nested exceptions - with RaisesGroup(RaisesGroup(ValueError)): - raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),)) +def test_RaisesGroup_assert_matches() -> None: + """Check direct use of RaisesGroup.assert_matches, without a context manager""" + eeg = RaisesGroup(ValueError) + with pytest.raises(AssertionError): + eeg.assert_matches(None) + with pytest.raises(AssertionError): + eeg.assert_matches(ValueError()) + eeg.assert_matches(ExceptionGroup("", (ValueError(),))) - with RaisesGroup( - SyntaxError, - RaisesGroup(ValueError), - RaisesGroup(RuntimeError), - ): - raise ExceptionGroup( - "foo", - ( - SyntaxError(), - ExceptionGroup("bar", (ValueError(),)), - ExceptionGroup("", (RuntimeError(),)), - ), - ) - # will error if there's excess exceptions - with pytest.raises(ExceptionGroup): - with RaisesGroup(ValueError): - raise ExceptionGroup("", (ValueError(), ValueError())) +def test_message() -> None: + with pytest.raises( + Failed, + match=re.escape( + f"DID NOT RAISE ANY EXCEPTION, expected ExceptionGroup({repr(ValueError)})" + ), + ): + with RaisesGroup(ValueError): + ... - with pytest.raises(ExceptionGroup): - with RaisesGroup(ValueError): - raise ExceptionGroup("", (RuntimeError(), ValueError())) + with pytest.raises( + Failed, + match=re.escape( + f"DID NOT RAISE ANY EXCEPTION, expected BaseExceptionGroup({repr(KeyboardInterrupt)})" + ), + ): + with RaisesGroup(KeyboardInterrupt): + ... - # will error if there's missing exceptions - with pytest.raises(ExceptionGroup): - with RaisesGroup(ValueError, ValueError): - raise ExceptionGroup("", (ValueError(),)) - with pytest.raises(ExceptionGroup): - with RaisesGroup(ValueError, SyntaxError): - raise ExceptionGroup("", (ValueError(),)) +if TYPE_CHECKING: - # loose semantics, as with expect* - with RaisesGroup(ValueError, strict=False): - raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + def test_types_1() -> None: + with RaisesGroup(ValueError) as e: + raise ExceptionGroup("foo", (ValueError(),)) + assert_type(e.value, BaseExceptionGroup[ValueError]) - # mixed loose is possible if you want it to be at least N deep - with RaisesGroup(RaisesGroup(ValueError, strict=True)): - raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) - with RaisesGroup(RaisesGroup(ValueError, strict=False)): - raise ExceptionGroup( - "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) - ) + def test_types_2() -> None: + exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup( + "", (ValueError(),) + ) + if RaisesGroup(ValueError).assert_matches(exc): + assert_type(exc, BaseExceptionGroup[ValueError]) - # but not the other way around - with pytest.raises( - ValueError, - match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$", - ): - RaisesGroup(RaisesGroup(ValueError), strict=False) + def test_types_3() -> None: + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + if RaisesGroup(ValueError).matches(e): + assert_type(e, BaseExceptionGroup[ValueError]) - # currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception - with pytest.raises(ValueError): - with RaisesGroup(ValueError, strict=False): - raise ValueError - - def test_match(self) -> None: - # supports match string - with RaisesGroup(ValueError, match="bar"): - raise ExceptionGroup("bar", (ValueError(),)) - - try: - with RaisesGroup(ValueError, match="foo"): - raise ExceptionGroup("bar", (ValueError(),)) - except AssertionError as e: - assert str(e).startswith("Regex pattern did not match.") - else: - raise AssertionError("Expected pytest.raises.Exception") - - def test_RaisesGroup_matches(self) -> None: - eeg = RaisesGroup(ValueError) - assert not eeg.matches(None) - assert not eeg.matches(ValueError()) - assert eeg.matches(ExceptionGroup("", (ValueError(),))) - - def test_message(self) -> None: - try: - with RaisesGroup(ValueError): - ... - except pytest.fail.Exception as e: - assert e.msg == f"DID NOT RAISE ExceptionGroup({repr(ValueError)},)" - else: - assert False, "Expected pytest.raises.Exception" - try: - with RaisesGroup(RaisesGroup(ValueError)): - ... - except pytest.fail.Exception as e: - assert ( - e.msg - == f"DID NOT RAISE ExceptionGroup(ExceptionGroup({repr(ValueError)},),)" - ) - else: - assert False, "Expected pytest.raises.Exception" - - def test_matcher(self) -> None: - with pytest.raises( - ValueError, match="^You must specify at least one parameter to match on.$" - ): - Matcher() - - with RaisesGroup(Matcher(ValueError)): - raise ExceptionGroup("", (ValueError(),)) - try: - with RaisesGroup(Matcher(TypeError)): - raise ExceptionGroup("", (ValueError(),)) - except ExceptionGroup: - pass - else: - assert False, "Expected pytest.raises.Exception" - - def test_matcher_match(self) -> None: - with RaisesGroup(Matcher(ValueError, "foo")): - raise ExceptionGroup("", (ValueError("foo"),)) - try: - with RaisesGroup(Matcher(ValueError, "foo")): - raise ExceptionGroup("", (ValueError("bar"),)) - except ExceptionGroup: - pass - else: - assert False, "Expected pytest.raises.Exception" - - # Can be used without specifying the type - with RaisesGroup(Matcher(match="foo")): - raise ExceptionGroup("", (ValueError("foo"),)) - try: - with RaisesGroup(Matcher(match="foo")): - raise ExceptionGroup("", (ValueError("bar"),)) - except ExceptionGroup: - pass - else: - assert False, "Expected pytest.raises.Exception" - - def test_Matcher_check(self) -> None: - def check_oserror_and_errno_is_5(e: BaseException) -> bool: - return isinstance(e, OSError) and e.errno == 5 - - with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)): - raise ExceptionGroup("", (OSError(5, ""),)) - - # specifying exception_type narrows the parameter type to the callable - def check_errno_is_5(e: OSError) -> bool: - return e.errno == 5 - - with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): - raise ExceptionGroup("", (OSError(5, ""),)) - - try: - with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): - raise ExceptionGroup("", (OSError(6, ""),)) - except ExceptionGroup: - pass - else: - assert False, "Expected pytest.raises.Exception" - - if TYPE_CHECKING: - # getting the typing working satisfactory is very tricky - # but with RaisesGroup being seen as a subclass of BaseExceptionGroup - # most end-user cases of checking excinfo.value.foobar should work fine now. - def test_types_0(self) -> None: - _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) - _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore[arg-type] - a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] - a = RaisesGroup(RaisesGroup(ValueError)) - a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),)) - assert a - - def test_types_1(self) -> None: - with RaisesGroup(ValueError) as e: - raise ExceptionGroup("foo", (ValueError(),)) - assert_type(e.value, BaseExceptionGroup[ValueError]) - # assert_type(e.value, RaisesGroup[ValueError]) - - def test_types_2(self) -> None: - exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup( - "", (ValueError(),) - ) - if RaisesGroup(ValueError).matches(exc): - assert_type(exc, BaseExceptionGroup[ValueError]) - - def test_types_3(self) -> None: - e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( - "", (KeyboardInterrupt(),) - ) - if RaisesGroup(ValueError).matches(e): - assert_type(e, BaseExceptionGroup[ValueError]) - - def test_types_4(self) -> None: - with RaisesGroup(Matcher(ValueError)) as e: - ... - _: BaseExceptionGroup[ValueError] = e.value - assert_type(e.value, BaseExceptionGroup[ValueError]) - - def test_types_5(self) -> None: - with RaisesGroup(RaisesGroup(ValueError)) as excinfo: - raise ExceptionGroup("foo", (ValueError(),)) - _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value - assert_type( - excinfo.value, - BaseExceptionGroup[RaisesGroup[ValueError]], - ) - print(excinfo.value.exceptions[0].exceptions[0]) - - def test_types_6(self) -> None: - exc: ExceptionGroup[ExceptionGroup[ValueError]] = ... # type: ignore[assignment] - if RaisesGroup(RaisesGroup(ValueError)).matches(exc): # type: ignore[arg-type] - # ugly - assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) + def test_types_4() -> None: + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + # not currently possible: https://github.com/python/typing/issues/930 + RaisesGroup(ValueError).assert_matches(e) + assert_type(e, BaseExceptionGroup[ValueError]) # type: ignore[assert-type]