Add new `ExceptionInfo.group_contains` assertion helper method

Tests if a captured exception group contains an expected exception.
Will raise `AssertionError` if the wrapped exception is not an exception group.
Supports recursive search into nested exception groups.
This commit is contained in:
Mihail Milushev 2023-09-10 12:24:18 +01:00
parent 6c2feb75d2
commit ab8f5ce7f4
4 changed files with 121 additions and 6 deletions

View File

@ -266,6 +266,7 @@ Michal Wajszczuk
Michał Zięba Michał Zięba
Mickey Pashov Mickey Pashov
Mihai Capotă Mihai Capotă
Mihail Milushev
Mike Hoyle (hoylemd) Mike Hoyle (hoylemd)
Mike Lundy Mike Lundy
Milan Lesnek Milan Lesnek

View File

@ -0,0 +1,2 @@
Added :func:`ExceptionInfo.group_contains() <pytest.ExceptionInfo.group_contains>`, an assertion
helper that tests if an `ExceptionGroup` contains a matching exception.

View File

@ -697,6 +697,14 @@ class ExceptionInfo(Generic[E]):
) )
return fmt.repr_excinfo(self) return fmt.repr_excinfo(self)
def _stringify_exception(self, exc: BaseException) -> str:
return "\n".join(
[
str(exc),
*getattr(exc, "__notes__", []),
]
)
def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]": def match(self, regexp: Union[str, Pattern[str]]) -> "Literal[True]":
"""Check whether the regular expression `regexp` matches the string """Check whether the regular expression `regexp` matches the string
representation of the exception using :func:`python:re.search`. representation of the exception using :func:`python:re.search`.
@ -704,12 +712,7 @@ class ExceptionInfo(Generic[E]):
If it matches `True` is returned, otherwise an `AssertionError` is raised. If it matches `True` is returned, otherwise an `AssertionError` is raised.
""" """
__tracebackhide__ = True __tracebackhide__ = True
value = "\n".join( value = self._stringify_exception(self.value)
[
str(self.value),
*getattr(self.value, "__notes__", []),
]
)
msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}" msg = f"Regex pattern did not match.\n Regex: {regexp!r}\n Input: {value!r}"
if regexp == value: if regexp == value:
msg += "\n Did you mean to `re.escape()` the regex?" msg += "\n Did you mean to `re.escape()` the regex?"
@ -717,6 +720,56 @@ class ExceptionInfo(Generic[E]):
# Return True to allow for "assert excinfo.match()". # Return True to allow for "assert excinfo.match()".
return True return True
def _group_contains(
self,
exc_group: BaseExceptionGroup[BaseException],
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
match: Union[str, Pattern[str], None],
recursive: bool = False,
) -> bool:
"""Return `True` if a `BaseExceptionGroup` contains a matching exception."""
for exc in exc_group.exceptions:
if recursive and isinstance(exc, BaseExceptionGroup):
if self._group_contains(exc, expected_exception, match, recursive):
return True
if not isinstance(exc, expected_exception):
continue
if match is not None:
value = self._stringify_exception(exc)
if not re.search(match, value):
continue
return True
return False
def group_contains(
self,
expected_exception: Union[Type[BaseException], Tuple[Type[BaseException], ...]],
match: Union[str, Pattern[str], None] = None,
recursive: bool = False,
) -> bool:
"""Check whether a captured exception group contains a matching exception.
:param Type[BaseException] | Tuple[Type[BaseException]] expected_exception:
The expected exception type, or a tuple if one of multiple possible
exception types are expected.
:param str | Pattern[str] | None match:
If specified, a string containing a regular expression,
or a regular expression object, that is tested against the string
representation of the exception and its `PEP-678 <https://peps.python.org/pep-0678/>` `__notes__`
using :func:`re.search`.
To match a literal string that may contain :ref:`special characters
<re-syntax>`, the pattern can first be escaped with :func:`re.escape`.
:param bool recursive:
If `True`, search will descend recursively into any nested exception groups.
If `False`, only the top exception group will be searched.
"""
msg = "Captured exception is not an instance of `BaseExceptionGroup`"
assert isinstance(self.value, BaseExceptionGroup), msg
return self._group_contains(self.value, expected_exception, match, recursive)
@dataclasses.dataclass @dataclasses.dataclass
class FormattedExcinfo: class FormattedExcinfo:

View File

@ -27,6 +27,9 @@ from _pytest.pytester import Pytester
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest._code.code import _TracebackStyle from _pytest._code.code import _TracebackStyle
if sys.version_info[:2] < (3, 11):
from exceptiongroup import ExceptionGroup
@pytest.fixture @pytest.fixture
def limited_recursion_depth(): def limited_recursion_depth():
@ -444,6 +447,62 @@ def test_match_raises_error(pytester: Pytester) -> None:
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match]) result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])
class TestGroupContains:
def test_contains_exception_type(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError)
def test_doesnt_contain_exception_type(self) -> None:
exc_group = ExceptionGroup("", [ValueError()])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError)
def test_contains_exception_match(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError("exception message")])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError, match=r"^exception message$")
def test_doesnt_contain_exception_match(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError("message that will not match")])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
def test_contains_exception_type_recursive(self) -> None:
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(RuntimeError, recursive=True)
def test_doesnt_contain_exception_type_nonrecursive(self) -> None:
exc_group = ExceptionGroup("", [ExceptionGroup("", [RuntimeError()])])
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError)
def test_contains_exception_match_recursive(self) -> None:
exc_group = ExceptionGroup(
"", [ExceptionGroup("", [RuntimeError("exception message")])]
)
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert exc_info.group_contains(
RuntimeError, match=r"^exception message$", recursive=True
)
def test_doesnt_contain_exception_match_nonrecursive(self) -> None:
exc_group = ExceptionGroup(
"", [ExceptionGroup("", [RuntimeError("message that will not match")])]
)
with pytest.raises(ExceptionGroup) as exc_info:
raise exc_group
assert not exc_info.group_contains(RuntimeError, match=r"^exception message$")
class TestFormattedExcinfo: class TestFormattedExcinfo:
@pytest.fixture @pytest.fixture
def importasmod(self, tmp_path: Path, _sys_snapshot): def importasmod(self, tmp_path: Path, _sys_snapshot):