draft implementation of ExpectedExceptionGroup

This commit is contained in:
jakkdl 2023-12-01 17:45:50 +01:00
parent 5689d806cf
commit 9a48173a6a
3 changed files with 496 additions and 16 deletions

View File

@ -1,5 +1,6 @@
import math
import pprint
import re
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
@ -10,6 +11,8 @@ from typing import Callable
from typing import cast
from typing import ContextManager
from typing import final
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
@ -22,6 +25,10 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
if TYPE_CHECKING:
from typing_extensions import TypeAlias, TypeGuard
import _pytest._code
from _pytest.compat import STRING_TYPES
from _pytest.outcomes import fail
@ -780,6 +787,149 @@ def _as_numpy_array(obj: object) -> Optional["ndarray"]:
# builtin pytest.raises helper
E = TypeVar("E", bound=BaseException)
E2 = TypeVar("E2", bound=BaseException)
class Matcher(Generic[E]):
def __init__(
self,
exception_type: Optional[type[E]] = None,
match: Optional[Union[str, Pattern[str]]] = None,
check: Optional[Callable[[E], bool]] = None,
):
if exception_type is None and match is None and check is None:
raise ValueError("You must specify at least one parameter to match on.")
self.exception_type = exception_type
self.match = match
self.check = check
def matches(self, exception: E) -> "TypeGuard[E]":
if self.exception_type is not None and not isinstance(
exception, self.exception_type
):
return False
if self.match is not None and not re.search(self.match, str(exception)):
return False
if self.check is not None and not self.check(exception):
return False
return True
# TODO: rename if kept, EEE[E] looks like gibberish
EEE: "TypeAlias" = Union[Matcher[E], Type[E], "ExpectedExceptionGroup[E]"]
if TYPE_CHECKING:
SuperClass = BaseExceptionGroup
else:
SuperClass = Generic
# it's unclear if
# `ExpectedExceptionGroup(ValueError, strict=False).matches(ValueError())`
# should return True. It matches the behaviour of expect*, but is maybe better handled
# by the end user doing pytest.raises((ValueError, ExpectedExceptionGroup(ValueError)))
@final
class ExpectedExceptionGroup(SuperClass[E]):
# TODO: overload to disallow nested exceptiongroup with strict=False
# @overload
# def __init__(self, exceptions: Union[Matcher[E], Type[E]], *args: Union[Matcher[E],
# Type[E]], strict: Literal[False]): ...
# @overload
# def __init__(self, exceptions: EEE[E], *args: EEE[E], strict: bool = True): ...
def __init__(
self,
exceptions: Union[Type[E], E, Matcher[E]],
*args: Union[Type[E], E, Matcher[E]],
strict: bool = True,
):
# could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None`
self.expected_exceptions = (exceptions, *args)
self.strict = strict
for exc in self.expected_exceptions:
if not isinstance(exc, (Matcher, ExpectedExceptionGroup)) and not (
isinstance(exc, type) and issubclass(exc, BaseException)
):
raise ValueError(
"Invalid argument {exc} must be exception type, Matcher, or ExpectedExceptionGroup."
)
if isinstance(exc, ExpectedExceptionGroup) and not strict:
raise ValueError(
"You cannot specify a nested structure inside an ExpectedExceptionGroup with strict=False"
)
def _unroll_exceptions(
self, exceptions: Iterable[BaseException]
) -> Iterable[BaseException]:
res: list[BaseException] = []
for exc in exceptions:
if isinstance(exc, BaseExceptionGroup):
res.extend(self._unroll_exceptions(exc.exceptions))
else:
res.append(exc)
return res
def matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
if exc_val is None:
return False
if not isinstance(exc_val, BaseExceptionGroup):
return False
if not len(exc_val.exceptions) == len(self.expected_exceptions):
return False
remaining_exceptions = list(self.expected_exceptions)
actual_exceptions: Iterable[BaseException] = exc_val.exceptions
if not self.strict:
actual_exceptions = self._unroll_exceptions(actual_exceptions)
# it should be possible to get ExpectedExceptionGroup.matches typed so as not to
# need these type: ignores, but I'm not sure that's possible while also having it
# transparent for the end user.
for e in actual_exceptions:
for rem_e in remaining_exceptions:
# TODO: how to print string diff on mismatch?
# Probably accumulate them, and then if fail, print them
# Further QoL would be to print how the exception structure differs on non-match
if (
(isinstance(rem_e, type) and isinstance(e, rem_e))
or (
isinstance(e, BaseExceptionGroup)
and isinstance(rem_e, ExpectedExceptionGroup)
and rem_e.matches(e)
)
or (
isinstance(rem_e, Matcher)
and rem_e.matches(e) # type: ignore[arg-type]
)
):
remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
break
else:
return False
return True
# def __str__(self) -> str:
# return f"ExceptionGroup{self.expected_exceptions}"
# str(tuple(...)) seems to call repr
def __repr__(self) -> str:
# TODO: [Base]ExceptionGroup
return f"ExceptionGroup{self.expected_exceptions}"
@overload
def raises(
expected_exception: Union[
ExpectedExceptionGroup[E], Tuple[ExpectedExceptionGroup[E], ...]
],
*,
match: Optional[Union[str, Pattern[str]]] = ...,
) -> "RaisesContext[ExpectedExceptionGroup[E]]":
...
@overload
@ -791,6 +941,17 @@ def raises(
...
#
#
# @overload
# def raises(
# expected_exception: Tuple[Union[Type[E], ExpectedExceptionGroup[E2]], ...],
# *,
# match: Optional[Union[str, Pattern[str]]] = ...,
# ) -> "RaisesContext[Union[E, BaseExceptionGroup[E2]]]":
# ...
@overload
def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
@ -801,9 +962,20 @@ def raises( # noqa: F811
...
def raises( # noqa: F811
expected_exception: Union[Type[E], Tuple[Type[E], ...]], *args: Any, **kwargs: Any
) -> Union["RaisesContext[E]", _pytest._code.ExceptionInfo[E]]:
def raises(
expected_exception: Union[
Type[E],
ExpectedExceptionGroup[E2],
Tuple[Union[Type[E], ExpectedExceptionGroup[E2]], ...],
],
*args: Any,
**kwargs: Any,
) -> Union[
"RaisesContext[E]",
"RaisesContext[BaseExceptionGroup[E2]]",
"RaisesContext[Union[E, BaseExceptionGroup[E2]]]",
_pytest._code.ExceptionInfo[E],
]:
r"""Assert that a code block/function call raises an exception type, or one of its subclasses.
:param typing.Type[E] | typing.Tuple[typing.Type[E], ...] expected_exception:
@ -952,13 +1124,20 @@ def raises( # noqa: F811
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)
if isinstance(expected_exception, type):
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,)
if isinstance(expected_exception, (type, ExpectedExceptionGroup)):
expected_exception_tuple: Tuple[
Union[Type[E], ExpectedExceptionGroup[E2]], ...
] = (expected_exception,)
else:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
expected_exception_tuple = expected_exception
for exc in expected_exception_tuple:
if (
not isinstance(exc, type) or not issubclass(exc, BaseException)
) and not isinstance(exc, ExpectedExceptionGroup):
msg = ( # type: ignore[unreachable]
"expected exception must be a BaseException "
"type or ExpectedExceptionGroup instance, not {}"
)
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
@ -971,14 +1150,23 @@ def raises( # noqa: F811
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
# the ExpectedExceptionGroup -> BaseExceptionGroup swap necessitates an ignore
return RaisesContext(expected_exception, message, match) # type: ignore[misc]
else:
func = args[0]
for exc in expected_exception_tuple:
if isinstance(exc, ExpectedExceptionGroup):
raise TypeError(
"Only contextmanager form is supported for ExpectedExceptionGroup"
)
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
except expected_exception as e: # type: ignore[misc] # TypeError raised for any ExpectedExceptionGroup
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)
@ -987,11 +1175,14 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore
EE: "TypeAlias" = Union[Type[E], "ExpectedExceptionGroup[E]"]
@final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(
self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]],
expected_exception: Union[EE[E], Tuple[EE[E], ...]],
message: str,
match_expr: Optional[Union[str, Pattern[str]]] = None,
) -> None:
@ -1014,8 +1205,26 @@ class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
if exc_type is None:
fail(self.message)
assert self.excinfo is not None
if not issubclass(exc_type, self.expected_exception):
if isinstance(self.expected_exception, ExpectedExceptionGroup):
if not self.expected_exception.matches(exc_val):
return False
elif isinstance(self.expected_exception, tuple):
for expected_exc in self.expected_exception:
if (
isinstance(expected_exc, ExpectedExceptionGroup)
and expected_exc.matches(exc_val)
) or (
isinstance(expected_exc, type)
and issubclass(exc_type, expected_exc)
):
break
else: # pragma: no cover
# this would've been caught on initialization of pytest.raises()
return False
elif not issubclass(exc_type, self.expected_exception):
return False
# Cast to narrow the exception type now that it's verified.
exc_info = cast(Tuple[Type[E], E, TracebackType], (exc_type, exc_val, exc_tb))
self.excinfo.fill_unfilled(exc_info)

View File

@ -0,0 +1,262 @@
from typing import TYPE_CHECKING
import pytest
from _pytest.python_api import ExpectedExceptionGroup
from _pytest.python_api import Matcher
# TODO: make a public export
if TYPE_CHECKING:
from typing_extensions import assert_type
class TestExpectedExceptionGroup:
def test_expected_exception_group(self) -> None:
with pytest.raises(
ValueError,
match="^Invalid argument {exc} must be exception type, Matcher, or ExpectedExceptionGroup.$",
):
ExpectedExceptionGroup(ValueError())
with pytest.raises(ExpectedExceptionGroup(ValueError)):
raise ExceptionGroup("foo", (ValueError(),))
with pytest.raises(ExpectedExceptionGroup(SyntaxError)):
with pytest.raises(ExpectedExceptionGroup(ValueError)):
raise ExceptionGroup("foo", (SyntaxError(),))
# multiple exceptions
with pytest.raises(ExpectedExceptionGroup(ValueError, SyntaxError)):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# order doesn't matter
with pytest.raises(ExpectedExceptionGroup(SyntaxError, ValueError)):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# nested exceptions
with pytest.raises(ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError))):
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
with pytest.raises(
ExpectedExceptionGroup(
SyntaxError,
ExpectedExceptionGroup(ValueError),
ExpectedExceptionGroup(RuntimeError),
)
):
raise ExceptionGroup(
"foo",
(
SyntaxError(),
ExceptionGroup("bar", (ValueError(),)),
ExceptionGroup("", (RuntimeError(),)),
),
)
# will error if there's excess exceptions
with pytest.raises(ExceptionGroup):
with pytest.raises(ExpectedExceptionGroup(ValueError)):
raise ExceptionGroup("", (ValueError(), ValueError()))
with pytest.raises(ExceptionGroup):
with pytest.raises(ExpectedExceptionGroup(ValueError)):
raise ExceptionGroup("", (RuntimeError(), ValueError()))
# will error if there's missing exceptions
with pytest.raises(ExceptionGroup):
with pytest.raises(ExpectedExceptionGroup(ValueError, ValueError)):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with pytest.raises(ExpectedExceptionGroup(ValueError, SyntaxError)):
raise ExceptionGroup("", (ValueError(),))
# loose semantics, as with expect*
with pytest.raises(ExpectedExceptionGroup(ValueError, strict=False)):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# mixed loose is possible if you want it to be at least N deep
with pytest.raises(
ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError, strict=True))
):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with pytest.raises(
ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError, strict=False))
):
raise ExceptionGroup(
"", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),)
)
# but not the other way around
with pytest.raises(
ValueError,
match="^You cannot specify a nested structure inside an ExpectedExceptionGroup with strict=False$",
):
ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError), strict=False)
# currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception
with pytest.raises(ValueError):
with pytest.raises(ExpectedExceptionGroup(ValueError, strict=False)):
raise ValueError
def test_match(self) -> None:
# supports match string
with pytest.raises(ExpectedExceptionGroup(ValueError), match="bar"):
raise ExceptionGroup("bar", (ValueError(),))
try:
with pytest.raises(ExpectedExceptionGroup(ValueError), match="foo"):
raise ExceptionGroup("bar", (ValueError(),))
except AssertionError as e:
assert str(e).startswith("Regex pattern did not match.")
else:
raise AssertionError("Expected pytest.raises.Exception")
def test_ExpectedExceptionGroup_matches(self) -> None:
eeg = ExpectedExceptionGroup(ValueError)
assert not eeg.matches(None)
assert not eeg.matches(ValueError())
assert eeg.matches(ExceptionGroup("", (ValueError(),)))
def test_message(self) -> None:
try:
with pytest.raises(ExpectedExceptionGroup(ValueError)):
...
except pytest.fail.Exception as e:
assert e.msg == f"DID NOT RAISE ExceptionGroup({repr(ValueError)},)"
else:
assert False, "Expected pytest.raises.Exception"
try:
with pytest.raises(
ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError))
):
...
except pytest.fail.Exception as e:
assert (
e.msg
== f"DID NOT RAISE ExceptionGroup(ExceptionGroup({repr(ValueError)},),)"
)
else:
assert False, "Expected pytest.raises.Exception"
def test_matcher(self) -> None:
with pytest.raises(
ValueError, match="^You must specify at least one parameter to match on.$"
):
Matcher()
with pytest.raises(ExpectedExceptionGroup(Matcher(ValueError))):
raise ExceptionGroup("", (ValueError(),))
try:
with pytest.raises(ExpectedExceptionGroup(Matcher(TypeError))):
raise ExceptionGroup("", (ValueError(),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
def test_matcher_match(self) -> None:
with pytest.raises(ExpectedExceptionGroup(Matcher(ValueError, "foo"))):
raise ExceptionGroup("", (ValueError("foo"),))
try:
with pytest.raises(ExpectedExceptionGroup(Matcher(ValueError, "foo"))):
raise ExceptionGroup("", (ValueError("bar"),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
# Can be used without specifying the type
with pytest.raises(ExpectedExceptionGroup(Matcher(match="foo"))):
raise ExceptionGroup("", (ValueError("foo"),))
try:
with pytest.raises(ExpectedExceptionGroup(Matcher(match="foo"))):
raise ExceptionGroup("", (ValueError("bar"),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
def test_Matcher_check(self) -> None:
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
return isinstance(e, OSError) and e.errno == 5
with pytest.raises(
ExpectedExceptionGroup(Matcher(check=check_oserror_and_errno_is_5))
):
raise ExceptionGroup("", (OSError(5, ""),))
# specifying exception_type narrows the parameter type to the callable
def check_errno_is_5(e: OSError) -> bool:
return e.errno == 5
with pytest.raises(
ExpectedExceptionGroup(Matcher(OSError, check=check_errno_is_5))
):
raise ExceptionGroup("", (OSError(5, ""),))
try:
with pytest.raises(
ExpectedExceptionGroup(Matcher(OSError, check=check_errno_is_5))
):
raise ExceptionGroup("", (OSError(6, ""),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
if TYPE_CHECKING:
# getting the typing working satisfactory is very tricky
# but with ExpectedExceptionGroup being seen as a subclass of BaseExceptionGroup
# most end-user cases of checking excinfo.value.foobar should work fine now.
def test_types_0(self) -> None:
_: BaseExceptionGroup[ValueError] = ExpectedExceptionGroup(ValueError)
_ = ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError)) # type: ignore[arg-type]
a: BaseExceptionGroup[BaseExceptionGroup[ValueError]]
a = ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError))
a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),))
assert a
def test_types_1(self) -> None:
with pytest.raises(ExpectedExceptionGroup(ValueError)) as e:
raise ExceptionGroup("foo", (ValueError(),))
assert_type(e.value, ExpectedExceptionGroup[ValueError])
def test_types_2(self) -> None:
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
"", (ValueError(),)
)
if ExpectedExceptionGroup(ValueError).matches(exc):
assert_type(exc, BaseExceptionGroup[ValueError])
def test_types_3(self) -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
if ExpectedExceptionGroup(ValueError).matches(e):
assert_type(e, BaseExceptionGroup[ValueError])
def test_types_4(self) -> None:
with pytest.raises(ExpectedExceptionGroup(Matcher(ValueError))) as e:
...
_: BaseExceptionGroup[ValueError] = e.value
assert_type(e.value, ExpectedExceptionGroup[ValueError])
def test_types_5(self) -> None:
with pytest.raises(
ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError))
) as excinfo:
raise ExceptionGroup("foo", (ValueError(),))
_: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value
assert_type(
excinfo.value,
ExpectedExceptionGroup[ExpectedExceptionGroup[ValueError]],
)
print(excinfo.value.exceptions[0].exceptions[0])
def test_types_6(self) -> None:
exc: ExceptionGroup[ExceptionGroup[ValueError]] = ... # type: ignore[assignment]
if ExpectedExceptionGroup(ExpectedExceptionGroup(ValueError)).matches(exc): # type: ignore[arg-type]
# ugly
assert_type(exc, BaseExceptionGroup[ExpectedExceptionGroup[ValueError]])

View File

@ -287,7 +287,10 @@ class TestRaises:
with pytest.raises(TypeError) as excinfo:
with pytest.raises("hello"): # type: ignore[call-overload]
pass # pragma: no cover
assert "must be a BaseException type, not str" in str(excinfo.value)
assert (
"must be a BaseException type or ExpectedExceptionGroup instance, not str"
in str(excinfo.value)
)
class NotAnException:
pass
@ -295,9 +298,15 @@ class TestRaises:
with pytest.raises(TypeError) as excinfo:
with pytest.raises(NotAnException): # type: ignore[type-var]
pass # pragma: no cover
assert "must be a BaseException type, not NotAnException" in str(excinfo.value)
assert (
"must be a BaseException type or ExpectedExceptionGroup instance, not NotAnException"
in str(excinfo.value)
)
with pytest.raises(TypeError) as excinfo:
with pytest.raises(("hello", NotAnException)): # type: ignore[arg-type]
pass # pragma: no cover
assert "must be a BaseException type, not str" in str(excinfo.value)
assert (
"must be a BaseException type or ExpectedExceptionGroup instance, not str"
in str(excinfo.value)
)