diff --git a/src/_pytest/mark/structures.py b/src/_pytest/mark/structures.py index d6e426567..37ae1fd9c 100644 --- a/src/_pytest/mark/structures.py +++ b/src/_pytest/mark/structures.py @@ -461,7 +461,9 @@ if TYPE_CHECKING: *conditions: Union[str, bool], reason: str = ..., run: bool = ..., - raises: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = ..., + raises: Optional[ + Union[Type[BaseException], Tuple[Type[BaseException], ...]] + ] = ..., strict: bool = ..., ) -> MarkDecorator: ... diff --git a/src/_pytest/skipping.py b/src/_pytest/skipping.py index 0c5c38f5f..868bf011a 100644 --- a/src/_pytest/skipping.py +++ b/src/_pytest/skipping.py @@ -5,15 +5,19 @@ import platform import sys import traceback from collections.abc import Mapping +from contextvars import ContextVar from typing import Generator +from typing import MutableMapping from typing import Optional from typing import Tuple from typing import Type +from typing import Union from _pytest.config import Config from _pytest.config import hookimpl from _pytest.config.argparsing import Parser from _pytest.mark.structures import Mark +from _pytest.mark.structures import MARK_GEN from _pytest.nodes import Item from _pytest.outcomes import fail from _pytest.outcomes import skip @@ -299,3 +303,32 @@ def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str elif report.passed: return "xpassed", "X", "XPASS" return None + + +current_item_var: ContextVar[Item] = ContextVar("current_item_var") + + +def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> None: + current_item_var.set(item) + + +class _NotPassed: + pass + + +_not_passed = _NotPassed() + + +def expect_failure( + reason: str = "", + raises: Optional[ + Union[Type[BaseException], Tuple[Type[BaseException], ...]] + ] = None, + strict: Union[bool, _NotPassed] = _not_passed, +) -> None: + kwargs: MutableMapping[str, bool] = {} + if not isinstance(strict, _NotPassed): + kwargs["strict"] = strict + current_item_var.get().add_marker( + MARK_GEN.xfail(reason=reason, raises=raises, **kwargs) + ) diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index 831ede1fa..d030569d3 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -61,6 +61,7 @@ from _pytest.recwarn import warns from _pytest.reports import CollectReport from _pytest.reports import TestReport from _pytest.runner import CallInfo +from _pytest.skipping import expect_failure from _pytest.stash import Stash from _pytest.stash import StashKey from _pytest.terminal import TestShortLogReport @@ -100,6 +101,7 @@ __all__ = [ "exit", "ExceptionInfo", "ExitCode", + "expect_failure", "fail", "File", "fixture", diff --git a/testing/test_skipping.py b/testing/test_skipping.py index b7e448df3..96469202e 100644 --- a/testing/test_skipping.py +++ b/testing/test_skipping.py @@ -738,6 +738,54 @@ class TestXFail: res.stdout.fnmatch_lines(["*1 failed*"]) res.stdout.fnmatch_lines(["*1 xfailed*"]) + @pytest.mark.parametrize("strict", [True, False]) + def test_expect_failure_xfailed(self, pytester: Pytester, strict: bool) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(strict=%s) + assert 0 + """ + % strict + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.skipped + assert callreport.wasxfail == "" + + def test_expect_failure_xpassed(self, pytester: Pytester) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(reason="expect failure") + assert 1 + """ + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.passed + assert callreport.wasxfail == "expect failure" + + def test_expect_failure_xpassed_strict(self, pytester: Pytester) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(strict=True, reason="nope") + assert 1 + """ + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.failed + assert str(callreport.longrepr) == "[XPASS(strict)] nope" + assert not hasattr(callreport, "wasxfail") + class TestXFailwithSetupTeardown: def test_failing_setup_issue9(self, pytester: Pytester) -> None: