diff --git a/src/_pytest/tmpdir.py b/src/_pytest/tmpdir.py index 52355dff7..ebdb7fe5e 100644 --- a/src/_pytest/tmpdir.py +++ b/src/_pytest/tmpdir.py @@ -13,6 +13,8 @@ from typing import Union if TYPE_CHECKING: from typing_extensions import Literal + RetentionType = Literal["all", "failed", "none"] + import attr from _pytest.config.argparsing import Parser @@ -45,13 +47,13 @@ class TempPathFactory: _trace = attr.ib() _basetemp = attr.ib(type=Optional[Path]) _retention_count = attr.ib(type=int) - _retention_policy = attr.ib(type=RetentionPolicy) + _retention_policy = attr.ib(type="RetentionType") def __init__( self, given_basetemp: Optional[Path], retention_count: int, - retention_policy: RetentionPolicy, + retention_policy: "RetentionType", trace, basetemp: Optional[Path] = None, *,