From 2023fa7de83b22997b55ede52f9e81b3f41f9ae0 Mon Sep 17 00:00:00 2001 From: jakkdl Date: Fri, 1 Dec 2023 17:45:50 +0100 Subject: [PATCH] draft implementation of RaisesGroup --- src/_pytest/python_api.py | 157 +++++++++++++ testing/python/expected_exception_group.py | 250 +++++++++++++++++++++ 2 files changed, 407 insertions(+) create mode 100644 testing/python/expected_exception_group.py diff --git a/src/_pytest/python_api.py b/src/_pytest/python_api.py index 07db0f234..d7594cbf9 100644 --- a/src/_pytest/python_api.py +++ b/src/_pytest/python_api.py @@ -1,5 +1,7 @@ import math import pprint +import re +import sys from collections.abc import Collection from collections.abc import Sized from decimal import Decimal @@ -10,6 +12,8 @@ from typing import Callable from typing import cast from typing import ContextManager from typing import final +from typing import Generic +from typing import Iterable from typing import List from typing import Mapping from typing import Optional @@ -28,6 +32,10 @@ from _pytest.outcomes import fail if TYPE_CHECKING: from numpy import ndarray + from typing_extensions import TypeGuard + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup def _non_numeric_type_error(value, at: Optional[str]) -> TypeError: @@ -987,6 +995,155 @@ def raises( # noqa: F811 raises.Exception = fail.Exception # type: ignore +class Matcher(Generic[E]): + def __init__( + self, + exception_type: Optional[Type[E]] = None, + match: Optional[Union[str, Pattern[str]]] = None, + check: Optional[Callable[[E], bool]] = None, + ): + if exception_type is None and match is None and check is None: + raise ValueError("You must specify at least one parameter to match on.") + self.exception_type = exception_type + self.match = match + self.check = check + + def matches(self, exception: E) -> "TypeGuard[E]": + if self.exception_type is not None and not isinstance( + exception, self.exception_type + ): + return False + if self.match is not None and not re.search(self.match, str(exception)): + return False + if self.check is not None and not self.check(exception): + return False + return True + + +if TYPE_CHECKING: + SuperClass = BaseExceptionGroup +else: + SuperClass = Generic + + +@final +class RaisesGroup( + ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E] +): + # My_T = TypeVar("My_T", bound=Union[Type[E], Matcher[E], "RaisesGroup[E]"]) + def __init__( + self, + exceptions: Union[Type[E], Matcher[E], E], + *args: Union[Type[E], Matcher[E], E], + strict: bool = True, + match: Optional[Union[str, Pattern[str]]] = None, + ): + # could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None` + self.expected_exceptions = (exceptions, *args) + self.strict = strict + self.match_expr = match + self.message = f"DID NOT RAISE ExceptionGroup{repr(self.expected_exceptions)}" # type: ignore[misc] + + for exc in self.expected_exceptions: + if not isinstance(exc, (Matcher, RaisesGroup)) and not ( + isinstance(exc, type) and issubclass(exc, BaseException) + ): + raise ValueError( + "Invalid argument {exc} must be exception type, Matcher, or RaisesGroup." + ) + if isinstance(exc, RaisesGroup) and not strict: # type: ignore[unreachable] + raise ValueError( + "You cannot specify a nested structure inside a RaisesGroup with strict=False" + ) + + def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]: + self.excinfo: _pytest._code.ExceptionInfo[ + BaseExceptionGroup[E] + ] = _pytest._code.ExceptionInfo.for_later() + return self.excinfo + + def _unroll_exceptions( + self, exceptions: Iterable[BaseException] + ) -> Iterable[BaseException]: + res: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, BaseExceptionGroup): + res.extend(self._unroll_exceptions(exc.exceptions)) + + else: + res.append(exc) + return res + + def matches( + self, + exc_val: Optional[BaseException], + ) -> "TypeGuard[BaseExceptionGroup[E]]": + if exc_val is None: + return False + if not isinstance(exc_val, BaseExceptionGroup): + return False + if not len(exc_val.exceptions) == len(self.expected_exceptions): + return False + remaining_exceptions = list(self.expected_exceptions) + actual_exceptions: Iterable[BaseException] = exc_val.exceptions + if not self.strict: + actual_exceptions = self._unroll_exceptions(actual_exceptions) + + # it should be possible to get RaisesGroup.matches typed so as not to + # need these type: ignores, but I'm not sure that's possible while also having it + # transparent for the end user. + for e in actual_exceptions: + for rem_e in remaining_exceptions: + # TODO: how to print string diff on mismatch? + # Probably accumulate them, and then if fail, print them + # Further QoL would be to print how the exception structure differs on non-match + if ( + (isinstance(rem_e, type) and isinstance(e, rem_e)) + or ( + isinstance(e, BaseExceptionGroup) + and isinstance(rem_e, RaisesGroup) + and rem_e.matches(e) + ) + or ( + isinstance(rem_e, Matcher) + and rem_e.matches(e) # type: ignore[arg-type] + ) + ): + remaining_exceptions.remove(rem_e) # type: ignore[arg-type] + break + else: + return False + return True + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + __tracebackhide__ = True + if exc_type is None: + fail(self.message) + assert self.excinfo is not None + + if not self.matches(exc_val): + return False + + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType], + (exc_type, exc_val, exc_tb), + ) + self.excinfo.fill_unfilled(exc_info) + if self.match_expr is not None: + self.excinfo.match(self.match_expr) + return True + + def __repr__(self) -> str: + # TODO: [Base]ExceptionGroup + return f"ExceptionGroup{self.expected_exceptions}" + + @final class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]): def __init__( diff --git a/testing/python/expected_exception_group.py b/testing/python/expected_exception_group.py new file mode 100644 index 000000000..bea04acfc --- /dev/null +++ b/testing/python/expected_exception_group.py @@ -0,0 +1,250 @@ +import sys +from typing import TYPE_CHECKING + +import pytest +from _pytest.python_api import Matcher +from _pytest.python_api import RaisesGroup + +# TODO: make a public export + +if TYPE_CHECKING: + from typing_extensions import assert_type + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +class TestRaisesGroup: + def test_raises_group(self) -> None: + with pytest.raises( + ValueError, + match="^Invalid argument {exc} must be exception type, Matcher, or RaisesGroup.$", + ): + RaisesGroup(ValueError()) + + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (ValueError(),)) + + with RaisesGroup(SyntaxError): + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (SyntaxError(),)) + + # multiple exceptions + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # order doesn't matter + with RaisesGroup(SyntaxError, ValueError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # nested exceptions + with RaisesGroup(RaisesGroup(ValueError)): + raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),)) + + with RaisesGroup( + SyntaxError, + RaisesGroup(ValueError), + RaisesGroup(RuntimeError), + ): + raise ExceptionGroup( + "foo", + ( + SyntaxError(), + ExceptionGroup("bar", (ValueError(),)), + ExceptionGroup("", (RuntimeError(),)), + ), + ) + + # will error if there's excess exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ValueError(), ValueError())) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (RuntimeError(), ValueError())) + + # will error if there's missing exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, ValueError): + raise ExceptionGroup("", (ValueError(),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("", (ValueError(),)) + + # loose semantics, as with expect* + with RaisesGroup(ValueError, strict=False): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # mixed loose is possible if you want it to be at least N deep + with RaisesGroup(RaisesGroup(ValueError, strict=True)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + with RaisesGroup(RaisesGroup(ValueError, strict=False)): + raise ExceptionGroup( + "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) + ) + + # but not the other way around + with pytest.raises( + ValueError, + match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$", + ): + RaisesGroup(RaisesGroup(ValueError), strict=False) + + # currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception + with pytest.raises(ValueError): + with RaisesGroup(ValueError, strict=False): + raise ValueError + + def test_match(self) -> None: + # supports match string + with RaisesGroup(ValueError, match="bar"): + raise ExceptionGroup("bar", (ValueError(),)) + + try: + with RaisesGroup(ValueError, match="foo"): + raise ExceptionGroup("bar", (ValueError(),)) + except AssertionError as e: + assert str(e).startswith("Regex pattern did not match.") + else: + raise AssertionError("Expected pytest.raises.Exception") + + def test_RaisesGroup_matches(self) -> None: + eeg = RaisesGroup(ValueError) + assert not eeg.matches(None) + assert not eeg.matches(ValueError()) + assert eeg.matches(ExceptionGroup("", (ValueError(),))) + + def test_message(self) -> None: + try: + with RaisesGroup(ValueError): + ... + except pytest.fail.Exception as e: + assert e.msg == f"DID NOT RAISE ExceptionGroup({repr(ValueError)},)" + else: + assert False, "Expected pytest.raises.Exception" + try: + with RaisesGroup(RaisesGroup(ValueError)): + ... + except pytest.fail.Exception as e: + assert ( + e.msg + == f"DID NOT RAISE ExceptionGroup(ExceptionGroup({repr(ValueError)},),)" + ) + else: + assert False, "Expected pytest.raises.Exception" + + def test_matcher(self) -> None: + with pytest.raises( + ValueError, match="^You must specify at least one parameter to match on.$" + ): + Matcher() + + with RaisesGroup(Matcher(ValueError)): + raise ExceptionGroup("", (ValueError(),)) + try: + with RaisesGroup(Matcher(TypeError)): + raise ExceptionGroup("", (ValueError(),)) + except ExceptionGroup: + pass + else: + assert False, "Expected pytest.raises.Exception" + + def test_matcher_match(self) -> None: + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + try: + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + except ExceptionGroup: + pass + else: + assert False, "Expected pytest.raises.Exception" + + # Can be used without specifying the type + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + try: + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + except ExceptionGroup: + pass + else: + assert False, "Expected pytest.raises.Exception" + + def test_Matcher_check(self) -> None: + def check_oserror_and_errno_is_5(e: BaseException) -> bool: + return isinstance(e, OSError) and e.errno == 5 + + with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + # specifying exception_type narrows the parameter type to the callable + def check_errno_is_5(e: OSError) -> bool: + return e.errno == 5 + + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + try: + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(6, ""),)) + except ExceptionGroup: + pass + else: + assert False, "Expected pytest.raises.Exception" + + if TYPE_CHECKING: + # getting the typing working satisfactory is very tricky + # but with RaisesGroup being seen as a subclass of BaseExceptionGroup + # most end-user cases of checking excinfo.value.foobar should work fine now. + def test_types_0(self) -> None: + _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) + _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore[arg-type] + a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] + a = RaisesGroup(RaisesGroup(ValueError)) + a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),)) + assert a + + def test_types_1(self) -> None: + with RaisesGroup(ValueError) as e: + raise ExceptionGroup("foo", (ValueError(),)) + assert_type(e.value, BaseExceptionGroup[ValueError]) + # assert_type(e.value, RaisesGroup[ValueError]) + + def test_types_2(self) -> None: + exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup( + "", (ValueError(),) + ) + if RaisesGroup(ValueError).matches(exc): + assert_type(exc, BaseExceptionGroup[ValueError]) + + def test_types_3(self) -> None: + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + if RaisesGroup(ValueError).matches(e): + assert_type(e, BaseExceptionGroup[ValueError]) + + def test_types_4(self) -> None: + with RaisesGroup(Matcher(ValueError)) as e: + ... + _: BaseExceptionGroup[ValueError] = e.value + assert_type(e.value, BaseExceptionGroup[ValueError]) + + def test_types_5(self) -> None: + with RaisesGroup(RaisesGroup(ValueError)) as excinfo: + raise ExceptionGroup("foo", (ValueError(),)) + _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value + assert_type( + excinfo.value, + BaseExceptionGroup[RaisesGroup[ValueError]], + ) + print(excinfo.value.exceptions[0].exceptions[0]) + + def test_types_6(self) -> None: + exc: ExceptionGroup[ExceptionGroup[ValueError]] = ... # type: ignore[assignment] + if RaisesGroup(RaisesGroup(ValueError)).matches(exc): # type: ignore[arg-type] + # ugly + assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]])