From 8c4ca32f8d0b5248b193b8385d4a94d069638758 Mon Sep 17 00:00:00 2001 From: Manuel Jacob Date: Fri, 1 Jul 2022 13:02:02 +0200 Subject: [PATCH] WIP: Add a simple way to mark tests as xfail inside test TODO: discuss API in #9027 TODO: write documentation TODO: write changelog entry TODO: add me to AUTHORS TODO: fix linting errors TODO: write proper commit message --- src/_pytest/skipping.py | 28 +++++++++++++++++++++++ src/pytest/__init__.py | 2 ++ testing/test_skipping.py | 48 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+) diff --git a/src/_pytest/skipping.py b/src/_pytest/skipping.py index b20442350..6112dc2c1 100644 --- a/src/_pytest/skipping.py +++ b/src/_pytest/skipping.py @@ -4,10 +4,12 @@ import platform import sys import traceback from collections.abc import Mapping +from contextvars import ContextVar from typing import Generator from typing import Optional from typing import Tuple from typing import Type +from typing import Union import attr @@ -15,6 +17,7 @@ from _pytest.config import Config from _pytest.config import hookimpl from _pytest.config.argparsing import Parser from _pytest.mark.structures import Mark +from _pytest.mark.structures import MARK_GEN from _pytest.nodes import Item from _pytest.outcomes import fail from _pytest.outcomes import skip @@ -294,3 +297,28 @@ def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str elif report.passed: return "xpassed", "X", "XPASS" return None + + +current_item_var: ContextVar[Item] = ContextVar("current_item_var") + + +def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> None: + current_item_var.set(item) + + +_not_passed = object() + + +def expect_failure( + reason: str = _not_passed, + raises: Union[Type[BaseException], Tuple[Type[BaseException], ...]] = _not_passed, + strict: bool = _not_passed, +) -> None: + kwargs = {} + if reason is not _not_passed: + kwargs["reason"] = reason + if raises is not _not_passed: + kwargs["raises"] = raises + if strict is not _not_passed: + kwargs["strict"] = strict + current_item_var.get().add_marker(MARK_GEN.xfail(**kwargs)) diff --git a/src/pytest/__init__.py b/src/pytest/__init__.py index c1634e296..058f9f681 100644 --- a/src/pytest/__init__.py +++ b/src/pytest/__init__.py @@ -59,6 +59,7 @@ from _pytest.recwarn import warns from _pytest.reports import CollectReport from _pytest.reports import TestReport from _pytest.runner import CallInfo +from _pytest.skipping import expect_failure from _pytest.stash import Stash from _pytest.stash import StashKey from _pytest.tmpdir import TempPathFactory @@ -95,6 +96,7 @@ __all__ = [ "exit", "ExceptionInfo", "ExitCode", + "expect_failure", "fail", "File", "fixture", diff --git a/testing/test_skipping.py b/testing/test_skipping.py index 6415480ef..efec8de19 100644 --- a/testing/test_skipping.py +++ b/testing/test_skipping.py @@ -731,6 +731,54 @@ class TestXFail: res.stdout.fnmatch_lines(["*1 failed*"]) res.stdout.fnmatch_lines(["*1 xfailed*"]) + @pytest.mark.parametrize("strict", [True, False]) + def test_expect_failure_xfailed(self, pytester: Pytester, strict: bool) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(strict=%s) + assert 0 + """ + % strict + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.skipped + assert callreport.wasxfail == "" + + def test_expect_failure_xpassed(self, pytester: Pytester) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(reason="expect failure") + assert 1 + """ + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.passed + assert callreport.wasxfail == "expect failure" + + def test_expect_failure_xpassed_strict(self, pytester: Pytester) -> None: + reprec = pytester.inline_runsource( + """ + import pytest + def test_func(): + pytest.expect_failure(strict=True, reason="nope") + assert 1 + """ + ) + reports = reprec.getreports("pytest_runtest_logreport") + assert len(reports) == 3 + callreport = reports[1] + assert callreport.failed + assert str(callreport.longrepr) == "[XPASS(strict)] nope" + assert not hasattr(callreport, "wasxfail") + class TestXFailwithSetupTeardown: def test_failing_setup_issue9(self, pytester: Pytester) -> None: