From 2023fa7de83b22997b55ede52f9e81b3f41f9ae0 Mon Sep 17 00:00:00 2001
From: jakkdl
Date: Fri, 1 Dec 2023 17:45:50 +0100
Subject: [PATCH] draft implementation of RaisesGroup
---
src/_pytest/python_api.py | 157 +++++++++++++
testing/python/expected_exception_group.py | 250 +++++++++++++++++++++
2 files changed, 407 insertions(+)
create mode 100644 testing/python/expected_exception_group.py
diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py
index 07db0f234..d7594cbf9 100644
--- a/src/_pytest/python_api.py
+++ b/src/_pytest/python_api.py
@@ -1,5 +1,7 @@
import math
import pprint
+import re
+import sys
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
@@ -10,6 +12,8 @@ from typing import Callable
from typing import cast
from typing import ContextManager
from typing import final
+from typing import Generic
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
@@ -28,6 +32,10 @@ from _pytest.outcomes import fail
if TYPE_CHECKING:
from numpy import ndarray
+ from typing_extensions import TypeGuard
+
+if sys.version_info < (3, 11):
+ from exceptiongroup import BaseExceptionGroup
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
@@ -987,6 +995,155 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore
+class Matcher(Generic[E]):
+ def __init__(
+ self,
+ exception_type: Optional[Type[E]] = None,
+ match: Optional[Union[str, Pattern[str]]] = None,
+ check: Optional[Callable[[E], bool]] = None,
+ ):
+ if exception_type is None and match is None and check is None:
+ raise ValueError("You must specify at least one parameter to match on.")
+ self.exception_type = exception_type
+ self.match = match
+ self.check = check
+
+ def matches(self, exception: E) -> "TypeGuard[E]":
+ if self.exception_type is not None and not isinstance(
+ exception, self.exception_type
+ ):
+ return False
+ if self.match is not None and not re.search(self.match, str(exception)):
+ return False
+ if self.check is not None and not self.check(exception):
+ return False
+ return True
+
+
+if TYPE_CHECKING:
+ SuperClass = BaseExceptionGroup
+else:
+ SuperClass = Generic
+
+
+@final
+class RaisesGroup(
+ ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]
+):
+ # My_T = TypeVar("My_T", bound=Union[Type[E], Matcher[E], "RaisesGroup[E]"])
+ def __init__(
+ self,
+ exceptions: Union[Type[E], Matcher[E], E],
+ *args: Union[Type[E], Matcher[E], E],
+ strict: bool = True,
+ match: Optional[Union[str, Pattern[str]]] = None,
+ ):
+ # could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None`
+ self.expected_exceptions = (exceptions, *args)
+ self.strict = strict
+ self.match_expr = match
+ self.message = f"DID NOT RAISE ExceptionGroup{repr(self.expected_exceptions)}" # type: ignore[misc]
+
+ for exc in self.expected_exceptions:
+ if not isinstance(exc, (Matcher, RaisesGroup)) and not (
+ isinstance(exc, type) and issubclass(exc, BaseException)
+ ):
+ raise ValueError(
+ "Invalid argument {exc} must be exception type, Matcher, or RaisesGroup."
+ )
+ if isinstance(exc, RaisesGroup) and not strict: # type: ignore[unreachable]
+ raise ValueError(
+ "You cannot specify a nested structure inside a RaisesGroup with strict=False"
+ )
+
+ def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]:
+ self.excinfo: _pytest._code.ExceptionInfo[
+ BaseExceptionGroup[E]
+ ] = _pytest._code.ExceptionInfo.for_later()
+ return self.excinfo
+
+ def _unroll_exceptions(
+ self, exceptions: Iterable[BaseException]
+ ) -> Iterable[BaseException]:
+ res: list[BaseException] = []
+ for exc in exceptions:
+ if isinstance(exc, BaseExceptionGroup):
+ res.extend(self._unroll_exceptions(exc.exceptions))
+
+ else:
+ res.append(exc)
+ return res
+
+ def matches(
+ self,
+ exc_val: Optional[BaseException],
+ ) -> "TypeGuard[BaseExceptionGroup[E]]":
+ if exc_val is None:
+ return False
+ if not isinstance(exc_val, BaseExceptionGroup):
+ return False
+ if not len(exc_val.exceptions) == len(self.expected_exceptions):
+ return False
+ remaining_exceptions = list(self.expected_exceptions)
+ actual_exceptions: Iterable[BaseException] = exc_val.exceptions
+ if not self.strict:
+ actual_exceptions = self._unroll_exceptions(actual_exceptions)
+
+ # it should be possible to get RaisesGroup.matches typed so as not to
+ # need these type: ignores, but I'm not sure that's possible while also having it
+ # transparent for the end user.
+ for e in actual_exceptions:
+ for rem_e in remaining_exceptions:
+ # TODO: how to print string diff on mismatch?
+ # Probably accumulate them, and then if fail, print them
+ # Further QoL would be to print how the exception structure differs on non-match
+ if (
+ (isinstance(rem_e, type) and isinstance(e, rem_e))
+ or (
+ isinstance(e, BaseExceptionGroup)
+ and isinstance(rem_e, RaisesGroup)
+ and rem_e.matches(e)
+ )
+ or (
+ isinstance(rem_e, Matcher)
+ and rem_e.matches(e) # type: ignore[arg-type]
+ )
+ ):
+ remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
+ break
+ else:
+ return False
+ return True
+
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType],
+ ) -> bool:
+ __tracebackhide__ = True
+ if exc_type is None:
+ fail(self.message)
+ assert self.excinfo is not None
+
+ if not self.matches(exc_val):
+ return False
+
+ # Cast to narrow the exception type now that it's verified.
+ exc_info = cast(
+ Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType],
+ (exc_type, exc_val, exc_tb),
+ )
+ self.excinfo.fill_unfilled(exc_info)
+ if self.match_expr is not None:
+ self.excinfo.match(self.match_expr)
+ return True
+
+ def __repr__(self) -> str:
+ # TODO: [Base]ExceptionGroup
+ return f"ExceptionGroup{self.expected_exceptions}"
+
+
@final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(
diff --git a/testing/python/expected_exception_group.py b/testing/python/expected_exception_group.py
new file mode 100644
index 000000000..bea04acfc
--- /dev/null
+++ b/testing/python/expected_exception_group.py
@@ -0,0 +1,250 @@
+import sys
+from typing import TYPE_CHECKING
+
+import pytest
+from _pytest.python_api import Matcher
+from _pytest.python_api import RaisesGroup
+
+# TODO: make a public export
+
+if TYPE_CHECKING:
+ from typing_extensions import assert_type
+
+if sys.version_info < (3, 11):
+ from exceptiongroup import ExceptionGroup
+
+
+class TestRaisesGroup:
+ def test_raises_group(self) -> None:
+ with pytest.raises(
+ ValueError,
+ match="^Invalid argument {exc} must be exception type, Matcher, or RaisesGroup.$",
+ ):
+ RaisesGroup(ValueError())
+
+ with RaisesGroup(ValueError):
+ raise ExceptionGroup("foo", (ValueError(),))
+
+ with RaisesGroup(SyntaxError):
+ with RaisesGroup(ValueError):
+ raise ExceptionGroup("foo", (SyntaxError(),))
+
+ # multiple exceptions
+ with RaisesGroup(ValueError, SyntaxError):
+ raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
+
+ # order doesn't matter
+ with RaisesGroup(SyntaxError, ValueError):
+ raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
+
+ # nested exceptions
+ with RaisesGroup(RaisesGroup(ValueError)):
+ raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
+
+ with RaisesGroup(
+ SyntaxError,
+ RaisesGroup(ValueError),
+ RaisesGroup(RuntimeError),
+ ):
+ raise ExceptionGroup(
+ "foo",
+ (
+ SyntaxError(),
+ ExceptionGroup("bar", (ValueError(),)),
+ ExceptionGroup("", (RuntimeError(),)),
+ ),
+ )
+
+ # will error if there's excess exceptions
+ with pytest.raises(ExceptionGroup):
+ with RaisesGroup(ValueError):
+ raise ExceptionGroup("", (ValueError(), ValueError()))
+
+ with pytest.raises(ExceptionGroup):
+ with RaisesGroup(ValueError):
+ raise ExceptionGroup("", (RuntimeError(), ValueError()))
+
+ # will error if there's missing exceptions
+ with pytest.raises(ExceptionGroup):
+ with RaisesGroup(ValueError, ValueError):
+ raise ExceptionGroup("", (ValueError(),))
+
+ with pytest.raises(ExceptionGroup):
+ with RaisesGroup(ValueError, SyntaxError):
+ raise ExceptionGroup("", (ValueError(),))
+
+ # loose semantics, as with expect*
+ with RaisesGroup(ValueError, strict=False):
+ raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
+
+ # mixed loose is possible if you want it to be at least N deep
+ with RaisesGroup(RaisesGroup(ValueError, strict=True)):
+ raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
+ with RaisesGroup(RaisesGroup(ValueError, strict=False)):
+ raise ExceptionGroup(
+ "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),)
+ )
+
+ # but not the other way around
+ with pytest.raises(
+ ValueError,
+ match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$",
+ ):
+ RaisesGroup(RaisesGroup(ValueError), strict=False)
+
+ # currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception
+ with pytest.raises(ValueError):
+ with RaisesGroup(ValueError, strict=False):
+ raise ValueError
+
+ def test_match(self) -> None:
+ # supports match string
+ with RaisesGroup(ValueError, match="bar"):
+ raise ExceptionGroup("bar", (ValueError(),))
+
+ try:
+ with RaisesGroup(ValueError, match="foo"):
+ raise ExceptionGroup("bar", (ValueError(),))
+ except AssertionError as e:
+ assert str(e).startswith("Regex pattern did not match.")
+ else:
+ raise AssertionError("Expected pytest.raises.Exception")
+
+ def test_RaisesGroup_matches(self) -> None:
+ eeg = RaisesGroup(ValueError)
+ assert not eeg.matches(None)
+ assert not eeg.matches(ValueError())
+ assert eeg.matches(ExceptionGroup("", (ValueError(),)))
+
+ def test_message(self) -> None:
+ try:
+ with RaisesGroup(ValueError):
+ ...
+ except pytest.fail.Exception as e:
+ assert e.msg == f"DID NOT RAISE ExceptionGroup({repr(ValueError)},)"
+ else:
+ assert False, "Expected pytest.raises.Exception"
+ try:
+ with RaisesGroup(RaisesGroup(ValueError)):
+ ...
+ except pytest.fail.Exception as e:
+ assert (
+ e.msg
+ == f"DID NOT RAISE ExceptionGroup(ExceptionGroup({repr(ValueError)},),)"
+ )
+ else:
+ assert False, "Expected pytest.raises.Exception"
+
+ def test_matcher(self) -> None:
+ with pytest.raises(
+ ValueError, match="^You must specify at least one parameter to match on.$"
+ ):
+ Matcher()
+
+ with RaisesGroup(Matcher(ValueError)):
+ raise ExceptionGroup("", (ValueError(),))
+ try:
+ with RaisesGroup(Matcher(TypeError)):
+ raise ExceptionGroup("", (ValueError(),))
+ except ExceptionGroup:
+ pass
+ else:
+ assert False, "Expected pytest.raises.Exception"
+
+ def test_matcher_match(self) -> None:
+ with RaisesGroup(Matcher(ValueError, "foo")):
+ raise ExceptionGroup("", (ValueError("foo"),))
+ try:
+ with RaisesGroup(Matcher(ValueError, "foo")):
+ raise ExceptionGroup("", (ValueError("bar"),))
+ except ExceptionGroup:
+ pass
+ else:
+ assert False, "Expected pytest.raises.Exception"
+
+ # Can be used without specifying the type
+ with RaisesGroup(Matcher(match="foo")):
+ raise ExceptionGroup("", (ValueError("foo"),))
+ try:
+ with RaisesGroup(Matcher(match="foo")):
+ raise ExceptionGroup("", (ValueError("bar"),))
+ except ExceptionGroup:
+ pass
+ else:
+ assert False, "Expected pytest.raises.Exception"
+
+ def test_Matcher_check(self) -> None:
+ def check_oserror_and_errno_is_5(e: BaseException) -> bool:
+ return isinstance(e, OSError) and e.errno == 5
+
+ with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
+ raise ExceptionGroup("", (OSError(5, ""),))
+
+ # specifying exception_type narrows the parameter type to the callable
+ def check_errno_is_5(e: OSError) -> bool:
+ return e.errno == 5
+
+ with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
+ raise ExceptionGroup("", (OSError(5, ""),))
+
+ try:
+ with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
+ raise ExceptionGroup("", (OSError(6, ""),))
+ except ExceptionGroup:
+ pass
+ else:
+ assert False, "Expected pytest.raises.Exception"
+
+ if TYPE_CHECKING:
+ # getting the typing working satisfactory is very tricky
+ # but with RaisesGroup being seen as a subclass of BaseExceptionGroup
+ # most end-user cases of checking excinfo.value.foobar should work fine now.
+ def test_types_0(self) -> None:
+ _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError)
+ _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore[arg-type]
+ a: BaseExceptionGroup[BaseExceptionGroup[ValueError]]
+ a = RaisesGroup(RaisesGroup(ValueError))
+ a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),))
+ assert a
+
+ def test_types_1(self) -> None:
+ with RaisesGroup(ValueError) as e:
+ raise ExceptionGroup("foo", (ValueError(),))
+ assert_type(e.value, BaseExceptionGroup[ValueError])
+ # assert_type(e.value, RaisesGroup[ValueError])
+
+ def test_types_2(self) -> None:
+ exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
+ "", (ValueError(),)
+ )
+ if RaisesGroup(ValueError).matches(exc):
+ assert_type(exc, BaseExceptionGroup[ValueError])
+
+ def test_types_3(self) -> None:
+ e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
+ "", (KeyboardInterrupt(),)
+ )
+ if RaisesGroup(ValueError).matches(e):
+ assert_type(e, BaseExceptionGroup[ValueError])
+
+ def test_types_4(self) -> None:
+ with RaisesGroup(Matcher(ValueError)) as e:
+ ...
+ _: BaseExceptionGroup[ValueError] = e.value
+ assert_type(e.value, BaseExceptionGroup[ValueError])
+
+ def test_types_5(self) -> None:
+ with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
+ raise ExceptionGroup("foo", (ValueError(),))
+ _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value
+ assert_type(
+ excinfo.value,
+ BaseExceptionGroup[RaisesGroup[ValueError]],
+ )
+ print(excinfo.value.exceptions[0].exceptions[0])
+
+ def test_types_6(self) -> None:
+ exc: ExceptionGroup[ExceptionGroup[ValueError]] = ... # type: ignore[assignment]
+ if RaisesGroup(RaisesGroup(ValueError)).matches(exc): # type: ignore[arg-type]
+ # ugly
+ assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]])