draft implementation of RaisesGroup

This commit is contained in:
jakkdl 2023-12-01 17:45:50 +01:00
parent 5689d806cf
commit 2023fa7de8
2 changed files with 407 additions and 0 deletions

View File

@ -1,5 +1,7 @@
import math
import pprint
import re
import sys
from collections.abc import Collection
from collections.abc import Sized
from decimal import Decimal
@ -10,6 +12,8 @@ from typing import Callable
from typing import cast
from typing import ContextManager
from typing import final
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
@ -28,6 +32,10 @@ from _pytest.outcomes import fail
if TYPE_CHECKING:
from numpy import ndarray
from typing_extensions import TypeGuard
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
@ -987,6 +995,155 @@ def raises( # noqa: F811
raises.Exception = fail.Exception # type: ignore
class Matcher(Generic[E]):
def __init__(
self,
exception_type: Optional[Type[E]] = None,
match: Optional[Union[str, Pattern[str]]] = None,
check: Optional[Callable[[E], bool]] = None,
):
if exception_type is None and match is None and check is None:
raise ValueError("You must specify at least one parameter to match on.")
self.exception_type = exception_type
self.match = match
self.check = check
def matches(self, exception: E) -> "TypeGuard[E]":
if self.exception_type is not None and not isinstance(
exception, self.exception_type
):
return False
if self.match is not None and not re.search(self.match, str(exception)):
return False
if self.check is not None and not self.check(exception):
return False
return True
if TYPE_CHECKING:
SuperClass = BaseExceptionGroup
else:
SuperClass = Generic
@final
class RaisesGroup(
ContextManager[_pytest._code.ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]
):
# My_T = TypeVar("My_T", bound=Union[Type[E], Matcher[E], "RaisesGroup[E]"])
def __init__(
self,
exceptions: Union[Type[E], Matcher[E], E],
*args: Union[Type[E], Matcher[E], E],
strict: bool = True,
match: Optional[Union[str, Pattern[str]]] = None,
):
# could add parameter `notes: Optional[Tuple[str, Pattern[str]]] = None`
self.expected_exceptions = (exceptions, *args)
self.strict = strict
self.match_expr = match
self.message = f"DID NOT RAISE ExceptionGroup{repr(self.expected_exceptions)}" # type: ignore[misc]
for exc in self.expected_exceptions:
if not isinstance(exc, (Matcher, RaisesGroup)) and not (
isinstance(exc, type) and issubclass(exc, BaseException)
):
raise ValueError(
"Invalid argument {exc} must be exception type, Matcher, or RaisesGroup."
)
if isinstance(exc, RaisesGroup) and not strict: # type: ignore[unreachable]
raise ValueError(
"You cannot specify a nested structure inside a RaisesGroup with strict=False"
)
def __enter__(self) -> _pytest._code.ExceptionInfo[BaseExceptionGroup[E]]:
self.excinfo: _pytest._code.ExceptionInfo[
BaseExceptionGroup[E]
] = _pytest._code.ExceptionInfo.for_later()
return self.excinfo
def _unroll_exceptions(
self, exceptions: Iterable[BaseException]
) -> Iterable[BaseException]:
res: list[BaseException] = []
for exc in exceptions:
if isinstance(exc, BaseExceptionGroup):
res.extend(self._unroll_exceptions(exc.exceptions))
else:
res.append(exc)
return res
def matches(
self,
exc_val: Optional[BaseException],
) -> "TypeGuard[BaseExceptionGroup[E]]":
if exc_val is None:
return False
if not isinstance(exc_val, BaseExceptionGroup):
return False
if not len(exc_val.exceptions) == len(self.expected_exceptions):
return False
remaining_exceptions = list(self.expected_exceptions)
actual_exceptions: Iterable[BaseException] = exc_val.exceptions
if not self.strict:
actual_exceptions = self._unroll_exceptions(actual_exceptions)
# it should be possible to get RaisesGroup.matches typed so as not to
# need these type: ignores, but I'm not sure that's possible while also having it
# transparent for the end user.
for e in actual_exceptions:
for rem_e in remaining_exceptions:
# TODO: how to print string diff on mismatch?
# Probably accumulate them, and then if fail, print them
# Further QoL would be to print how the exception structure differs on non-match
if (
(isinstance(rem_e, type) and isinstance(e, rem_e))
or (
isinstance(e, BaseExceptionGroup)
and isinstance(rem_e, RaisesGroup)
and rem_e.matches(e)
)
or (
isinstance(rem_e, Matcher)
and rem_e.matches(e) # type: ignore[arg-type]
)
):
remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
break
else:
return False
return True
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
__tracebackhide__ = True
if exc_type is None:
fail(self.message)
assert self.excinfo is not None
if not self.matches(exc_val):
return False
# Cast to narrow the exception type now that it's verified.
exc_info = cast(
Tuple[Type[BaseExceptionGroup[E]], BaseExceptionGroup[E], TracebackType],
(exc_type, exc_val, exc_tb),
)
self.excinfo.fill_unfilled(exc_info)
if self.match_expr is not None:
self.excinfo.match(self.match_expr)
return True
def __repr__(self) -> str:
# TODO: [Base]ExceptionGroup
return f"ExceptionGroup{self.expected_exceptions}"
@final
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(

View File

@ -0,0 +1,250 @@
import sys
from typing import TYPE_CHECKING
import pytest
from _pytest.python_api import Matcher
from _pytest.python_api import RaisesGroup
# TODO: make a public export
if TYPE_CHECKING:
from typing_extensions import assert_type
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
class TestRaisesGroup:
def test_raises_group(self) -> None:
with pytest.raises(
ValueError,
match="^Invalid argument {exc} must be exception type, Matcher, or RaisesGroup.$",
):
RaisesGroup(ValueError())
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (ValueError(),))
with RaisesGroup(SyntaxError):
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (SyntaxError(),))
# multiple exceptions
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# order doesn't matter
with RaisesGroup(SyntaxError, ValueError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# nested exceptions
with RaisesGroup(RaisesGroup(ValueError)):
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
with RaisesGroup(
SyntaxError,
RaisesGroup(ValueError),
RaisesGroup(RuntimeError),
):
raise ExceptionGroup(
"foo",
(
SyntaxError(),
ExceptionGroup("bar", (ValueError(),)),
ExceptionGroup("", (RuntimeError(),)),
),
)
# will error if there's excess exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ValueError(), ValueError()))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (RuntimeError(), ValueError()))
# will error if there's missing exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, ValueError):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("", (ValueError(),))
# loose semantics, as with expect*
with RaisesGroup(ValueError, strict=False):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# mixed loose is possible if you want it to be at least N deep
with RaisesGroup(RaisesGroup(ValueError, strict=True)):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with RaisesGroup(RaisesGroup(ValueError, strict=False)):
raise ExceptionGroup(
"", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),)
)
# but not the other way around
with pytest.raises(
ValueError,
match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$",
):
RaisesGroup(RaisesGroup(ValueError), strict=False)
# currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception
with pytest.raises(ValueError):
with RaisesGroup(ValueError, strict=False):
raise ValueError
def test_match(self) -> None:
# supports match string
with RaisesGroup(ValueError, match="bar"):
raise ExceptionGroup("bar", (ValueError(),))
try:
with RaisesGroup(ValueError, match="foo"):
raise ExceptionGroup("bar", (ValueError(),))
except AssertionError as e:
assert str(e).startswith("Regex pattern did not match.")
else:
raise AssertionError("Expected pytest.raises.Exception")
def test_RaisesGroup_matches(self) -> None:
eeg = RaisesGroup(ValueError)
assert not eeg.matches(None)
assert not eeg.matches(ValueError())
assert eeg.matches(ExceptionGroup("", (ValueError(),)))
def test_message(self) -> None:
try:
with RaisesGroup(ValueError):
...
except pytest.fail.Exception as e:
assert e.msg == f"DID NOT RAISE ExceptionGroup({repr(ValueError)},)"
else:
assert False, "Expected pytest.raises.Exception"
try:
with RaisesGroup(RaisesGroup(ValueError)):
...
except pytest.fail.Exception as e:
assert (
e.msg
== f"DID NOT RAISE ExceptionGroup(ExceptionGroup({repr(ValueError)},),)"
)
else:
assert False, "Expected pytest.raises.Exception"
def test_matcher(self) -> None:
with pytest.raises(
ValueError, match="^You must specify at least one parameter to match on.$"
):
Matcher()
with RaisesGroup(Matcher(ValueError)):
raise ExceptionGroup("", (ValueError(),))
try:
with RaisesGroup(Matcher(TypeError)):
raise ExceptionGroup("", (ValueError(),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
def test_matcher_match(self) -> None:
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("foo"),))
try:
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("bar"),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
# Can be used without specifying the type
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("foo"),))
try:
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("bar"),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
def test_Matcher_check(self) -> None:
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
return isinstance(e, OSError) and e.errno == 5
with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
# specifying exception_type narrows the parameter type to the callable
def check_errno_is_5(e: OSError) -> bool:
return e.errno == 5
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
try:
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(6, ""),))
except ExceptionGroup:
pass
else:
assert False, "Expected pytest.raises.Exception"
if TYPE_CHECKING:
# getting the typing working satisfactory is very tricky
# but with RaisesGroup being seen as a subclass of BaseExceptionGroup
# most end-user cases of checking excinfo.value.foobar should work fine now.
def test_types_0(self) -> None:
_: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError)
_ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore[arg-type]
a: BaseExceptionGroup[BaseExceptionGroup[ValueError]]
a = RaisesGroup(RaisesGroup(ValueError))
a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),))
assert a
def test_types_1(self) -> None:
with RaisesGroup(ValueError) as e:
raise ExceptionGroup("foo", (ValueError(),))
assert_type(e.value, BaseExceptionGroup[ValueError])
# assert_type(e.value, RaisesGroup[ValueError])
def test_types_2(self) -> None:
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup(
"", (ValueError(),)
)
if RaisesGroup(ValueError).matches(exc):
assert_type(exc, BaseExceptionGroup[ValueError])
def test_types_3(self) -> None:
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"", (KeyboardInterrupt(),)
)
if RaisesGroup(ValueError).matches(e):
assert_type(e, BaseExceptionGroup[ValueError])
def test_types_4(self) -> None:
with RaisesGroup(Matcher(ValueError)) as e:
...
_: BaseExceptionGroup[ValueError] = e.value
assert_type(e.value, BaseExceptionGroup[ValueError])
def test_types_5(self) -> None:
with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
raise ExceptionGroup("foo", (ValueError(),))
_: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value
assert_type(
excinfo.value,
BaseExceptionGroup[RaisesGroup[ValueError]],
)
print(excinfo.value.exceptions[0].exceptions[0])
def test_types_6(self) -> None:
exc: ExceptionGroup[ExceptionGroup[ValueError]] = ... # type: ignore[assignment]
if RaisesGroup(RaisesGroup(ValueError)).matches(exc): # type: ignore[arg-type]
# ugly
assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]])