Merge branch 'main' into warn-when-a-mark-is-applied-to-a-fixture

This commit is contained in:
Thomas Grainger
2023-06-25 16:08:11 +01:00
committed by GitHub
174 changed files with 10616 additions and 2941 deletions

View File

@@ -78,15 +78,15 @@ class FastFilesCompleter:
def __call__(self, prefix: str, **kwargs: Any) -> List[str]:
# Only called on non option completions.
if os.path.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.path.sep)
if os.sep in prefix[1:]:
prefix_dir = len(os.path.dirname(prefix) + os.sep)
else:
prefix_dir = 0
completion = []
globbed = []
if "*" not in prefix and "?" not in prefix:
# We are on unix, otherwise no bash.
if not prefix or prefix[-1] == os.path.sep:
if not prefix or prefix[-1] == os.sep:
globbed.extend(glob(prefix + ".*"))
prefix += "*"
globbed.extend(glob(prefix))

View File

@@ -1,4 +1,5 @@
import ast
import dataclasses
import inspect
import os
import re
@@ -30,9 +31,7 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from weakref import ref
import attr
import pluggy
import _pytest
@@ -50,9 +49,9 @@ from _pytest.pathlib import absolutepath
from _pytest.pathlib import bestrelpath
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
from typing_extensions import SupportsIndex
from weakref import ReferenceType
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
@@ -194,25 +193,25 @@ class Frame:
class TracebackEntry:
"""A single entry in a Traceback."""
__slots__ = ("_rawentry", "_excinfo", "_repr_style")
__slots__ = ("_rawentry", "_repr_style")
def __init__(
self,
rawentry: TracebackType,
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
repr_style: Optional['Literal["short", "long"]'] = None,
) -> None:
self._rawentry = rawentry
self._excinfo = excinfo
self._repr_style: Optional['Literal["short", "long"]'] = None
self._rawentry: "Final" = rawentry
self._repr_style: "Final" = repr_style
def with_repr_style(
self, repr_style: Optional['Literal["short", "long"]']
) -> "TracebackEntry":
return TracebackEntry(self._rawentry, repr_style)
@property
def lineno(self) -> int:
return self._rawentry.tb_lineno - 1
def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long")
self._repr_style = mode
@property
def frame(self) -> Frame:
return Frame(self._rawentry.tb_frame)
@@ -272,7 +271,7 @@ class TracebackEntry:
source = property(getsource)
def ishidden(self) -> bool:
def ishidden(self, excinfo: Optional["ExceptionInfo[BaseException]"]) -> bool:
"""Return True if the current frame has a var __tracebackhide__
resolving to True.
@@ -296,7 +295,7 @@ class TracebackEntry:
else:
break
if tbh and callable(tbh):
return tbh(None if self._excinfo is None else self._excinfo())
return tbh(excinfo)
return tbh
def __str__(self) -> str:
@@ -329,16 +328,14 @@ class Traceback(List[TracebackEntry]):
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo[BaseException]]"] = None,
) -> None:
"""Initialize from given python traceback object and ExceptionInfo."""
self._excinfo = excinfo
if isinstance(tb, TracebackType):
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_: Optional[TracebackType] = cur
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
yield TracebackEntry(cur_)
cur_ = cur_.tb_next
super().__init__(f(tb))
@@ -378,7 +375,7 @@ class Traceback(List[TracebackEntry]):
continue
if firstlineno is not None and x.frame.code.firstlineno != firstlineno:
continue
return Traceback(x._rawentry, self._excinfo)
return Traceback(x._rawentry)
return self
@overload
@@ -398,26 +395,27 @@ class Traceback(List[TracebackEntry]):
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
self,
# TODO(py38): change to positional only.
_excinfo_or_fn: Union[
"ExceptionInfo[BaseException]",
Callable[[TracebackEntry], bool],
],
) -> "Traceback":
"""Return a Traceback instance with certain items removed
"""Return a Traceback instance with certain items removed.
fn is a function that gets a single argument, a TracebackEntry
instance, and should return True when the item should be added
to the Traceback, False when not.
If the filter is an `ExceptionInfo`, removes all the ``TracebackEntry``s
which are hidden (see ishidden() above).
By default this removes all the TracebackEntries which are hidden
(see ishidden() above).
Otherwise, the filter is a function that gets a single argument, a
``TracebackEntry`` instance, and should return True when the item should
be added to the ``Traceback``, False when not.
"""
return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self) -> TracebackEntry:
"""Return last non-hidden traceback entry that lead to the exception of a traceback."""
for i in range(-1, -len(self) - 1, -1):
entry = self[i]
if not entry.ishidden():
return entry
return self[-1]
if isinstance(_excinfo_or_fn, ExceptionInfo):
fn = lambda x: not x.ishidden(_excinfo_or_fn) # noqa: E731
else:
fn = _excinfo_or_fn
return Traceback(filter(fn, self))
def recursionindex(self) -> Optional[int]:
"""Return the index of the frame/TracebackEntry where recursion originates if
@@ -445,7 +443,7 @@ E = TypeVar("E", bound=BaseException, covariant=True)
@final
@attr.s(repr=False, init=False, auto_attribs=True)
@dataclasses.dataclass
class ExceptionInfo(Generic[E]):
"""Wraps sys.exc_info() objects and offers help for navigating the traceback."""
@@ -469,22 +467,41 @@ class ExceptionInfo(Generic[E]):
self._traceback = traceback
@classmethod
def from_exc_info(
def from_exception(
cls,
exc_info: Tuple[Type[E], E, TracebackType],
# Ignoring error: "Cannot use a covariant type variable as a parameter".
# This is OK to ignore because this class is (conceptually) readonly.
# See https://github.com/python/mypy/issues/7049.
exception: E, # type: ignore[misc]
exprinfo: Optional[str] = None,
) -> "ExceptionInfo[E]":
"""Return an ExceptionInfo for an existing exc_info tuple.
"""Return an ExceptionInfo for an existing exception.
.. warning::
Experimental API
The exception must have a non-``None`` ``__traceback__`` attribute,
otherwise this function fails with an assertion error. This means that
the exception must have been raised, or added a traceback with the
:py:meth:`~BaseException.with_traceback()` method.
:param exprinfo:
A text string helping to determine if we should strip
``AssertionError`` from the output. Defaults to the exception
message/``__str__()``.
.. versionadded:: 7.4
"""
assert (
exception.__traceback__
), "Exceptions passed to ExcInfo.from_exception(...) must have a non-None __traceback__."
exc_info = (type(exception), exception, exception.__traceback__)
return cls.from_exc_info(exc_info, exprinfo)
@classmethod
def from_exc_info(
cls,
exc_info: Tuple[Type[E], E, TracebackType],
exprinfo: Optional[str] = None,
) -> "ExceptionInfo[E]":
"""Like :func:`from_exception`, but using old-style exc_info tuple."""
_striptext = ""
if exprinfo is None and isinstance(exc_info[1], AssertionError):
exprinfo = getattr(exc_info[1], "msg", None)
@@ -563,7 +580,7 @@ class ExceptionInfo(Generic[E]):
def traceback(self) -> Traceback:
"""The traceback."""
if self._traceback is None:
self._traceback = Traceback(self.tb, excinfo=ref(self))
self._traceback = Traceback(self.tb)
return self._traceback
@traceback.setter
@@ -602,18 +619,25 @@ class ExceptionInfo(Generic[E]):
"""
return isinstance(self.value, exc)
def _getreprcrash(self) -> "ReprFileLocation":
exconly = self.exconly(tryshort=True)
entry = self.traceback.getcrashentry()
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
return ReprFileLocation(path, lineno + 1, exconly)
def _getreprcrash(self) -> Optional["ReprFileLocation"]:
# Find last non-hidden traceback entry that led to the exception of the
# traceback, or None if all hidden.
for i in range(-1, -len(self.traceback) - 1, -1):
entry = self.traceback[i]
if not entry.ishidden(self):
path, lineno = entry.frame.code.raw.co_filename, entry.lineno
exconly = self.exconly(tryshort=True)
return ReprFileLocation(path, lineno + 1, exconly)
return None
def getrepr(
self,
showlocals: bool = False,
style: "_TracebackStyle" = "long",
abspath: bool = False,
tbfilter: bool = True,
tbfilter: Union[
bool, Callable[["ExceptionInfo[BaseException]"], Traceback]
] = True,
funcargs: bool = False,
truncate_locals: bool = True,
chain: bool = True,
@@ -625,14 +649,20 @@ class ExceptionInfo(Generic[E]):
Ignored if ``style=="native"``.
:param str style:
long|short|no|native|value traceback style.
long|short|line|no|native|value traceback style.
:param bool abspath:
If paths should be changed to absolute or left unchanged.
:param bool tbfilter:
Hide entries that contain a local variable ``__tracebackhide__==True``.
Ignored if ``style=="native"``.
:param tbfilter:
A filter for traceback entries.
* If false, don't hide any entries.
* If true, hide internal entries and entries that contain a local
variable ``__tracebackhide__ = True``.
* If a callable, delegates the filtering to the callable.
Ignored if ``style`` is ``"native"``.
:param bool funcargs:
Show fixtures ("funcargs" for legacy purposes) per traceback entry.
@@ -649,12 +679,14 @@ class ExceptionInfo(Generic[E]):
"""
if style == "native":
return ReprExceptionInfo(
ReprTracebackNative(
reprtraceback=ReprTracebackNative(
traceback.format_exception(
self.type, self.value, self.traceback[0]._rawentry
self.type,
self.value,
self.traceback[0]._rawentry if self.traceback else None,
)
),
self._getreprcrash(),
reprcrash=self._getreprcrash(),
)
fmt = FormattedExcinfo(
@@ -684,7 +716,7 @@ class ExceptionInfo(Generic[E]):
return True
@attr.s(auto_attribs=True)
@dataclasses.dataclass
class FormattedExcinfo:
"""Presenting information about failing Functions and Generators."""
@@ -695,12 +727,12 @@ class FormattedExcinfo:
showlocals: bool = False
style: "_TracebackStyle" = "long"
abspath: bool = True
tbfilter: bool = True
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]] = True
funcargs: bool = False
truncate_locals: bool = True
chain: bool = True
astcache: Dict[Union[str, Path], ast.AST] = attr.ib(
factory=dict, init=False, repr=False
astcache: Dict[Union[str, Path], ast.AST] = dataclasses.field(
default_factory=dict, init=False, repr=False
)
def _getindent(self, source: "Source") -> int:
@@ -741,11 +773,13 @@ class FormattedExcinfo:
) -> List[str]:
"""Return formatted and marked up source lines."""
lines = []
if source is None or line_index >= len(source.lines):
if source is not None and line_index < 0:
line_index += len(source)
if source is None or line_index >= len(source.lines) or line_index < 0:
# `line_index` could still be outside `range(len(source.lines))` if
# we're processing AST with pathological position attributes.
source = Source("???")
line_index = 0
if line_index < 0:
line_index += len(source)
space_prefix = " "
if short:
lines.append(space_prefix + source.lines[line_index].strip())
@@ -805,12 +839,16 @@ class FormattedExcinfo:
def repr_traceback_entry(
self,
entry: TracebackEntry,
entry: Optional[TracebackEntry],
excinfo: Optional[ExceptionInfo[BaseException]] = None,
) -> "ReprEntry":
lines: List[str] = []
style = entry._repr_style if entry._repr_style is not None else self.style
if style in ("short", "long"):
style = (
entry._repr_style
if entry is not None and entry._repr_style is not None
else self.style
)
if style in ("short", "long") and entry is not None:
source = self._getentrysource(entry)
if source is None:
source = Source("???")
@@ -851,25 +889,31 @@ class FormattedExcinfo:
def repr_traceback(self, excinfo: ExceptionInfo[BaseException]) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
if callable(self.tbfilter):
traceback = self.tbfilter(excinfo)
elif self.tbfilter:
traceback = traceback.filter(excinfo)
if isinstance(excinfo.value, RecursionError):
traceback, extraline = self._truncate_recursive_traceback(traceback)
else:
extraline = None
if not traceback:
if extraline is None:
extraline = "All traceback entries are hidden. Pass `--full-trace` to see hidden and internal frames."
entries = [self.repr_traceback_entry(None, excinfo)]
return ReprTraceback(entries, extraline, style=self.style)
last = traceback[-1]
entries = []
if self.style == "value":
reprentry = self.repr_traceback_entry(last, excinfo)
entries.append(reprentry)
entries = [self.repr_traceback_entry(last, excinfo)]
return ReprTraceback(entries, None, style=self.style)
for index, entry in enumerate(traceback):
einfo = (last == entry) and excinfo or None
reprentry = self.repr_traceback_entry(entry, einfo)
entries.append(reprentry)
entries = [
self.repr_traceback_entry(entry, excinfo if last == entry else None)
for entry in traceback
]
return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(
@@ -926,6 +970,7 @@ class FormattedExcinfo:
seen: Set[int] = set()
while e is not None and id(e) not in seen:
seen.add(id(e))
if excinfo_:
# Fall back to native traceback as a temporary workaround until
# full support for exception groups added to ExceptionInfo.
@@ -942,9 +987,7 @@ class FormattedExcinfo:
)
else:
reprtraceback = self.repr_traceback(excinfo_)
reprcrash: Optional[ReprFileLocation] = (
excinfo_._getreprcrash() if self.style != "value" else None
)
reprcrash = excinfo_._getreprcrash()
else:
# Fallback to native repr if the exception doesn't have a traceback:
# ExceptionInfo objects require a full traceback to work.
@@ -952,25 +995,17 @@ class FormattedExcinfo:
traceback.format_exception(type(e), e, None)
)
reprcrash = None
repr_chain += [(reprtraceback, reprcrash, descr)]
if e.__cause__ is not None and self.chain:
e = e.__cause__
excinfo_ = (
ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
if e.__traceback__
else None
)
excinfo_ = ExceptionInfo.from_exception(e) if e.__traceback__ else None
descr = "The above exception was the direct cause of the following exception:"
elif (
e.__context__ is not None and not e.__suppress_context__ and self.chain
):
e = e.__context__
excinfo_ = (
ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
if e.__traceback__
else None
)
excinfo_ = ExceptionInfo.from_exception(e) if e.__traceback__ else None
descr = "During handling of the above exception, another exception occurred:"
else:
e = None
@@ -978,7 +1013,7 @@ class FormattedExcinfo:
return ExceptionChainRepr(repr_chain)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class TerminalRepr:
def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception
@@ -996,14 +1031,14 @@ class TerminalRepr:
# This class is abstract -- only subclasses are instantiated.
@attr.s(eq=False)
@dataclasses.dataclass(eq=False)
class ExceptionRepr(TerminalRepr):
# Provided by subclasses.
reprcrash: Optional["ReprFileLocation"]
reprtraceback: "ReprTraceback"
def __attrs_post_init__(self) -> None:
self.sections: List[Tuple[str, str, str]] = []
reprcrash: Optional["ReprFileLocation"]
sections: List[Tuple[str, str, str]] = dataclasses.field(
init=False, default_factory=list
)
def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep))
@@ -1014,16 +1049,23 @@ class ExceptionRepr(TerminalRepr):
tw.line(content)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ExceptionChainRepr(ExceptionRepr):
chain: Sequence[Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]]
def __attrs_post_init__(self) -> None:
super().__attrs_post_init__()
def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
# reprcrash and reprtraceback of the outermost (the newest) exception
# in the chain.
self.reprtraceback = self.chain[-1][0]
self.reprcrash = self.chain[-1][1]
super().__init__(
reprtraceback=chain[-1][0],
reprcrash=chain[-1][1],
)
self.chain = chain
def toterminal(self, tw: TerminalWriter) -> None:
for element in self.chain:
@@ -1034,17 +1076,17 @@ class ExceptionChainRepr(ExceptionRepr):
super().toterminal(tw)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprExceptionInfo(ExceptionRepr):
reprtraceback: "ReprTraceback"
reprcrash: "ReprFileLocation"
reprcrash: Optional["ReprFileLocation"]
def toterminal(self, tw: TerminalWriter) -> None:
self.reprtraceback.toterminal(tw)
super().toterminal(tw)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprTraceback(TerminalRepr):
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]]
extraline: Optional[str]
@@ -1073,12 +1115,12 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native"
self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None
self.style = "native"
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprEntryNative(TerminalRepr):
lines: Sequence[str]
@@ -1088,7 +1130,7 @@ class ReprEntryNative(TerminalRepr):
tw.write("".join(self.lines))
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprEntry(TerminalRepr):
lines: Sequence[str]
reprfuncargs: Optional["ReprFuncArgs"]
@@ -1142,8 +1184,8 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw: TerminalWriter) -> None:
if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw)
if self.reprfileloc:
self.reprfileloc.toterminal(tw)
self._write_entry_lines(tw)
if self.reprlocals:
self.reprlocals.toterminal(tw, indent=" " * 8)
@@ -1168,12 +1210,15 @@ class ReprEntry(TerminalRepr):
)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprFileLocation(TerminalRepr):
path: str = attr.ib(converter=str)
path: str
lineno: int
message: str
def __post_init__(self) -> None:
self.path = str(self.path)
def toterminal(self, tw: TerminalWriter) -> None:
# Filename and lineno output for each entry, using an output format
# that most editors understand.
@@ -1185,7 +1230,7 @@ class ReprFileLocation(TerminalRepr):
tw.line(f":{self.lineno}: {msg}")
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprLocals(TerminalRepr):
lines: Sequence[str]
@@ -1194,7 +1239,7 @@ class ReprLocals(TerminalRepr):
tw.line(indent + line)
@attr.s(eq=False, auto_attribs=True)
@dataclasses.dataclass(eq=False)
class ReprFuncArgs(TerminalRepr):
args: Sequence[Tuple[str, object]]

View File

109
src/_pytest/_py/error.py Normal file
View File

@@ -0,0 +1,109 @@
"""create errno-specific classes for IO or os calls."""
from __future__ import annotations
import errno
import os
import sys
from typing import Callable
from typing import TYPE_CHECKING
from typing import TypeVar
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
R = TypeVar("R")
class Error(EnvironmentError):
def __repr__(self) -> str:
return "{}.{} {!r}: {} ".format(
self.__class__.__module__,
self.__class__.__name__,
self.__class__.__doc__,
" ".join(map(str, self.args)),
# repr(self.args)
)
def __str__(self) -> str:
s = "[{}]: {}".format(
self.__class__.__doc__,
" ".join(map(str, self.args)),
)
return s
_winerrnomap = {
2: errno.ENOENT,
3: errno.ENOENT,
17: errno.EEXIST,
18: errno.EXDEV,
13: errno.EBUSY, # empty cd drive, but ENOMEDIUM seems unavailiable
22: errno.ENOTDIR,
20: errno.ENOTDIR,
267: errno.ENOTDIR,
5: errno.EACCES, # anything better?
}
class ErrorMaker:
"""lazily provides Exception classes for each possible POSIX errno
(as defined per the 'errno' module). All such instances
subclass EnvironmentError.
"""
_errno2class: dict[int, type[Error]] = {}
def __getattr__(self, name: str) -> type[Error]:
if name[0] == "_":
raise AttributeError(name)
eno = getattr(errno, name)
cls = self._geterrnoclass(eno)
setattr(self, name, cls)
return cls
def _geterrnoclass(self, eno: int) -> type[Error]:
try:
return self._errno2class[eno]
except KeyError:
clsname = errno.errorcode.get(eno, "UnknownErrno%d" % (eno,))
errorcls = type(
clsname,
(Error,),
{"__module__": "py.error", "__doc__": os.strerror(eno)},
)
self._errno2class[eno] = errorcls
return errorcls
def checked_call(
self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs
) -> R:
"""Call a function and raise an errno-exception if applicable."""
__tracebackhide__ = True
try:
return func(*args, **kwargs)
except Error:
raise
except OSError as value:
if not hasattr(value, "errno"):
raise
errno = value.errno
if sys.platform == "win32":
try:
cls = self._geterrnoclass(_winerrnomap[errno])
except KeyError:
raise value
else:
# we are not on Windows, or we got a proper OSError
cls = self._geterrnoclass(errno)
raise cls(f"{func.__name__}{args!r}")
_error_maker = ErrorMaker()
checked_call = _error_maker.checked_call
def __getattr__(attr: str) -> type[Error]:
return getattr(_error_maker, attr) # type: ignore[no-any-return]

1475
src/_pytest/_py/path.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -44,10 +44,20 @@ from _pytest.stash import StashKey
if TYPE_CHECKING:
from _pytest.assertion import AssertionState
if sys.version_info >= (3, 8):
namedExpr = ast.NamedExpr
astNameConstant = ast.Constant
astStr = ast.Constant
astNum = ast.Constant
else:
namedExpr = ast.Expr
astNameConstant = ast.NameConstant
astStr = ast.Str
astNum = ast.Num
assertstate_key = StashKey["AssertionState"]()
# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + (__debug__ and "c" or "o")
@@ -180,7 +190,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
for initial_path in self.session._initialpaths:
# Make something as c:/projects/my_project/path.py ->
# ['c:', 'projects', 'my_project', 'path.py']
parts = str(initial_path).split(os.path.sep)
parts = str(initial_path).split(os.sep)
# add 'path' to basenames to be checked.
self._basenames_to_check_rewrite.add(os.path.splitext(parts[-1])[0])
@@ -274,8 +284,12 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return f.read()
if sys.version_info >= (3, 10):
if sys.version_info >= (3, 12):
from importlib.resources.abc import TraversableResources
else:
from importlib.abc import TraversableResources
def get_resource_reader(self, name: str) -> importlib.abc.TraversableResources: # type: ignore
def get_resource_reader(self, name: str) -> TraversableResources: # type: ignore
if sys.version_info < (3, 11):
from importlib.readers import FileReader
else:
@@ -631,8 +645,12 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.
This state is reset on every new assert statement visited and used
by the other visitors.
:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator
This state, except the variables_overwrite, is reset on every new assert
statement visited and used by the other visitors.
"""
def __init__(
@@ -648,6 +666,7 @@ class AssertionRewriter(ast.NodeVisitor):
else:
self.enable_assertion_pass_hook = False
self.source = source
self.variables_overwrite: Dict[str, str] = {}
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
@@ -662,14 +681,17 @@ class AssertionRewriter(ast.NodeVisitor):
if doc is not None and self.is_rewrite_disabled(doc):
return
pos = 0
lineno = 1
item = None
for item in mod.body:
if (
expect_docstring
and isinstance(item, ast.Expr)
and isinstance(item.value, ast.Str)
and isinstance(item.value, astStr)
):
doc = item.value.s
if sys.version_info >= (3, 8):
doc = item.value.value
else:
doc = item.value.s
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
@@ -801,7 +823,7 @@ class AssertionRewriter(ast.NodeVisitor):
current = self.stack.pop()
if self.stack:
self.explanation_specifiers = self.stack[-1]
keys = [ast.Str(key) for key in current.keys()]
keys = [astStr(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
@@ -855,16 +877,16 @@ class AssertionRewriter(ast.NodeVisitor):
negation = ast.UnaryOp(ast.Not(), top_condition)
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
msg = self.pop_format_context(ast.Str(explanation))
msg = self.pop_format_context(astStr(explanation))
# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
assertmsg = astStr("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
@@ -880,8 +902,8 @@ class AssertionRewriter(ast.NodeVisitor):
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
astNum(assert_.lineno),
astStr(orig),
fmt_pass,
)
)
@@ -900,7 +922,7 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, ast.NameConstant(None))
clear_format = ast.Assign(variables, astNameConstant(None))
self.statements.append(clear_format)
else: # Original assertion rewriting
@@ -911,9 +933,9 @@ class AssertionRewriter(ast.NodeVisitor):
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
assertmsg = astStr("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
@@ -925,7 +947,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, ast.NameConstant(None))
clear = ast.Assign(variables, astNameConstant(None))
self.statements.append(clear)
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
@@ -933,14 +955,26 @@ class AssertionRewriter(ast.NodeVisitor):
ast.copy_location(node, assert_)
return self.statements
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), astStr(target_id))
return name, self.explanation_param(expr)
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
expr = ast.IfExp(test, self.display(name), astStr(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
@@ -959,10 +993,26 @@ class AssertionRewriter(ast.NodeVisitor):
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
# Check if the left operand is a namedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
and isinstance(v.left, namedExpr)
and v.left.target.id
in [
ast_expr.id
for ast_expr in boolop.values[:i]
if hasattr(ast_expr, "id")
]
):
pytest_temp = self.variable()
self.variables_overwrite[
v.left.target.id
] = v.left # type:ignore[assignment]
v.left.target.id = pytest_temp
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
expl_format = self.pop_format_context(astStr(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
@@ -974,7 +1024,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = body = inner
self.statements = save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
@@ -998,10 +1048,19 @@ class AssertionRewriter(ast.NodeVisitor):
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
if (
isinstance(keyword.value, ast.Name)
and keyword.value.id in self.variables_overwrite
):
keyword.value = self.variables_overwrite[
keyword.value.id
] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
@@ -1034,6 +1093,15 @@ class AssertionRewriter(ast.NodeVisitor):
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left = self.variables_overwrite[
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, namedExpr):
self.variables_overwrite[
comp.left.target.id
] = comp.left # type:ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
@@ -1045,14 +1113,23 @@ class AssertionRewriter(ast.NodeVisitor):
syms = []
results = [left_res]
for i, op, next_operand in it:
if (
isinstance(next_operand, namedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = self.variable()
self.variables_overwrite[
left_res.id
] = next_operand # type:ignore[assignment]
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Str(sym))
syms.append(astStr(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(ast.Str(expl))
expls.append(astStr(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
@@ -1068,6 +1145,7 @@ class AssertionRewriter(ast.NodeVisitor):
res: ast.expr = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))

View File

@@ -38,9 +38,9 @@ def _truncate_explanation(
"""Truncate given list of strings that makes up the assertion explanation.
Truncates to either 8 lines, or 640 characters - whichever the input reaches
first. The remaining lines will be replaced by a usage message.
first, taking the truncation explanation into account. The remaining lines
will be replaced by a usage message.
"""
if max_lines is None:
max_lines = DEFAULT_MAX_LINES
if max_chars is None:
@@ -48,35 +48,56 @@ def _truncate_explanation(
# Check if truncation required
input_char_count = len("".join(input_lines))
if len(input_lines) <= max_lines and input_char_count <= max_chars:
# The length of the truncation explanation depends on the number of lines
# removed but is at least 68 characters:
# The real value is
# 64 (for the base message:
# '...\n...Full output truncated (1 line hidden), use '-vv' to show")'
# )
# + 1 (for plural)
# + int(math.log10(len(input_lines) - max_lines)) (number of hidden line, at least 1)
# + 3 for the '...' added to the truncated line
# But if there's more than 100 lines it's very likely that we're going to
# truncate, so we don't need the exact value using log10.
tolerable_max_chars = (
max_chars + 70 # 64 + 1 (for plural) + 2 (for '99') + 3 for '...'
)
# The truncation explanation add two lines to the output
tolerable_max_lines = max_lines + 2
if (
len(input_lines) <= tolerable_max_lines
and input_char_count <= tolerable_max_chars
):
return input_lines
# Truncate first to max_lines, and then truncate to max_chars if max_chars
# is exceeded.
# Truncate first to max_lines, and then truncate to max_chars if necessary
truncated_explanation = input_lines[:max_lines]
truncated_explanation = _truncate_by_char_count(truncated_explanation, max_chars)
# Add ellipsis to final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
# Append useful message to explanation
truncated_line_count = len(input_lines) - len(truncated_explanation)
truncated_line_count += 1 # Account for the part-truncated final line
msg = "...Full output truncated"
if truncated_line_count == 1:
msg += f" ({truncated_line_count} line hidden)"
truncated_char = True
# We reevaluate the need to truncate chars following removal of some lines
if len("".join(truncated_explanation)) > tolerable_max_chars:
truncated_explanation = _truncate_by_char_count(
truncated_explanation, max_chars
)
else:
msg += f" ({truncated_line_count} lines hidden)"
msg += f", {USAGE_MSG}"
truncated_explanation.extend(["", str(msg)])
return truncated_explanation
truncated_char = False
truncated_line_count = len(input_lines) - len(truncated_explanation)
if truncated_explanation[-1]:
# Add ellipsis and take into account part-truncated final line
truncated_explanation[-1] = truncated_explanation[-1] + "..."
if truncated_char:
# It's possible that we did not remove any char from this line
truncated_line_count += 1
else:
# Add proper ellipsis when we were able to fit a full line exactly
truncated_explanation[-1] = "..."
return truncated_explanation + [
"",
f"...Full output truncated ({truncated_line_count} line"
f"{'' if truncated_line_count == 1 else 's'} hidden), {USAGE_MSG}",
]
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines
# Find point at which input length exceeds total allowed length
iterated_char_count = 0
for iterated_index, input_line in enumerate(input_lines):

View File

@@ -1,6 +1,7 @@
"""Implementation of the cache provider."""
# This plugin was not named "cache" to avoid conflicts with the external
# pytest-cache version.
import dataclasses
import json
import os
from pathlib import Path
@@ -12,8 +13,6 @@ from typing import Optional
from typing import Set
from typing import Union
import attr
from .pathlib import resolve_from_str
from .pathlib import rm_rf
from .reports import CollectReport
@@ -28,11 +27,10 @@ from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.python import Module
from _pytest.nodes import File
from _pytest.python import Package
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
@@ -53,10 +51,12 @@ Signature: 8a477f597d28d172789f06886806bc55
@final
@attr.s(init=False, auto_attribs=True)
@dataclasses.dataclass
class Cache:
_cachedir: Path = attr.ib(repr=False)
_config: Config = attr.ib(repr=False)
"""Instance of the `cache` fixture."""
_cachedir: Path = dataclasses.field(repr=False)
_config: Config = dataclasses.field(repr=False)
# Sub-directory under cache-dir for directories created by `mkdir()`.
_CACHE_PREFIX_DIRS = "d"
@@ -179,16 +179,22 @@ class Cache:
else:
cache_dir_exists_already = self._cachedir.exists()
path.parent.mkdir(exist_ok=True, parents=True)
except OSError:
self.warn("could not create cache path {path}", path=path, _ispytest=True)
except OSError as exc:
self.warn(
f"could not create cache path {path}: {exc}",
_ispytest=True,
)
return
if not cache_dir_exists_already:
self._ensure_supporting_files()
data = json.dumps(value, ensure_ascii=False, indent=2)
try:
f = path.open("w", encoding="UTF-8")
except OSError:
self.warn("cache could not write path {path}", path=path, _ispytest=True)
except OSError as exc:
self.warn(
f"cache could not write path {path}: {exc}",
_ispytest=True,
)
else:
with f:
f.write(data)
@@ -213,22 +219,30 @@ class LFPluginCollWrapper:
@hookimpl(hookwrapper=True)
def pytest_make_collect_report(self, collector: nodes.Collector):
if isinstance(collector, Session):
if isinstance(collector, (Session, Package)):
out = yield
res: CollectReport = out.get_result()
# Sort any lf-paths to the beginning.
lf_paths = self.lfplugin._last_failed_paths
# Use stable sort to priorize last failed.
def sort_key(node: Union[nodes.Item, nodes.Collector]) -> bool:
# Package.path is the __init__.py file, we need the directory.
if isinstance(node, Package):
path = node.path.parent
else:
path = node.path
return path in lf_paths
res.result = sorted(
res.result,
# use stable sort to priorize last failed
key=lambda x: x.path in lf_paths,
key=sort_key,
reverse=True,
)
return
elif isinstance(collector, Module):
elif isinstance(collector, File):
if collector.path in self.lfplugin._last_failed_paths:
out = yield
res = out.get_result()
@@ -266,10 +280,9 @@ class LFPluginCollSkipfiles:
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Optional[CollectReport]:
# Packages are Modules, but _last_failed_paths only contains
# test-bearing paths and doesn't try to include the paths of their
# packages, so don't filter them.
if isinstance(collector, Module) and not isinstance(collector, Package):
# Packages are Files, but we only want to skip test-bearing Files,
# so don't filter Packages.
if isinstance(collector, File) and not isinstance(collector, Package):
if collector.path not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
@@ -299,9 +312,14 @@ class LFPlugin:
)
def get_last_failed_paths(self) -> Set[Path]:
"""Return a set with all Paths()s of the previously failed nodeids."""
"""Return a set with all Paths of the previously failed nodeids and
their parents."""
rootpath = self.config.rootpath
result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed}
result = set()
for nodeid in self.lastfailed:
path = rootpath / nodeid.split("::")[0]
result.add(path)
result.update(path.parents)
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self) -> Optional[str]:
@@ -492,7 +510,7 @@ def pytest_addoption(parser: Parser) -> None:
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.cacheshow:
if config.option.cacheshow and not config.option.help:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)

View File

@@ -1,19 +1,26 @@
"""Per-test stdout/stderr capturing mechanism."""
import abc
import collections
import contextlib
import functools
import io
import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from types import TracebackType
from typing import Any
from typing import AnyStr
from typing import BinaryIO
from typing import Generator
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import NamedTuple
from typing import Optional
from typing import TextIO
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
@@ -29,6 +36,7 @@ from _pytest.nodes import File
from _pytest.nodes import Item
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
@@ -185,19 +193,27 @@ class TeeCaptureIO(CaptureIO):
return self._other.write(s)
class DontReadFromInput:
encoding = None
class DontReadFromInput(TextIO):
@property
def encoding(self) -> str:
return sys.__stdin__.encoding
def read(self, *args):
def read(self, size: int = -1) -> str:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
readline = read
readlines = read
__next__ = read
def __iter__(self):
def __next__(self) -> str:
return self.readline()
def readlines(self, hint: Optional[int] = -1) -> List[str]:
raise OSError(
"pytest: reading from stdin while output is captured! Consider using `-s`."
)
def __iter__(self) -> Iterator[str]:
return self
def fileno(self) -> int:
@@ -215,7 +231,7 @@ class DontReadFromInput:
def readable(self) -> bool:
return False
def seek(self, offset: int) -> int:
def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no seek(int)")
def seekable(self) -> bool:
@@ -224,41 +240,104 @@ class DontReadFromInput:
def tell(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no tell()")
def truncate(self, size: int) -> None:
raise UnsupportedOperation("cannont truncate stdin")
def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("cannot truncate stdin")
def write(self, *args) -> None:
def write(self, data: str) -> int:
raise UnsupportedOperation("cannot write to stdin")
def writelines(self, *args) -> None:
def writelines(self, lines: Iterable[str]) -> None:
raise UnsupportedOperation("Cannot write to stdin")
def writable(self) -> bool:
return False
@property
def buffer(self):
def __enter__(self) -> "DontReadFromInput":
return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
pass
@property
def buffer(self) -> BinaryIO:
# The str/bytes doesn't actually matter in this type, so OK to fake.
return self # type: ignore[return-value]
# Capture classes.
class CaptureBase(abc.ABC, Generic[AnyStr]):
EMPTY_BUFFER: AnyStr
@abc.abstractmethod
def __init__(self, fd: int) -> None:
raise NotImplementedError()
@abc.abstractmethod
def start(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def done(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def suspend(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def resume(self) -> None:
raise NotImplementedError()
@abc.abstractmethod
def writeorg(self, data: AnyStr) -> None:
raise NotImplementedError()
@abc.abstractmethod
def snap(self) -> AnyStr:
raise NotImplementedError()
patchsysdict = {0: "stdin", 1: "stdout", 2: "stderr"}
class NoCapture:
EMPTY_BUFFER = None
__init__ = start = done = suspend = resume = lambda *args: None
class NoCapture(CaptureBase[str]):
EMPTY_BUFFER = ""
def __init__(self, fd: int) -> None:
pass
def start(self) -> None:
pass
def done(self) -> None:
pass
def suspend(self) -> None:
pass
def resume(self) -> None:
pass
def snap(self) -> str:
return ""
def writeorg(self, data: str) -> None:
pass
class SysCaptureBinary:
EMPTY_BUFFER = b""
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
class SysCaptureBase(CaptureBase[AnyStr]):
def __init__(
self, fd: int, tmpfile: Optional[TextIO] = None, *, tee: bool = False
) -> None:
name = patchsysdict[fd]
self._old = getattr(sys, name)
self._old: TextIO = getattr(sys, name)
self.name = name
if tmpfile is None:
if name == "stdin":
@@ -298,14 +377,6 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
@@ -327,36 +398,43 @@ class SysCaptureBinary:
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def writeorg(self, data) -> None:
class SysCaptureBinary(SysCaptureBase[bytes]):
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.flush()
self._old.buffer.write(data)
self._old.buffer.flush()
class SysCapture(SysCaptureBinary):
EMPTY_BUFFER = "" # type: ignore[assignment]
class SysCapture(SysCaptureBase[str]):
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
assert isinstance(self.tmpfile, CaptureIO)
res = self.tmpfile.getvalue()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.write(data)
self._old.flush()
class FDCaptureBinary:
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
class FDCaptureBase(CaptureBase[AnyStr]):
def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd
@@ -382,7 +460,7 @@ class FDCaptureBinary:
if targetfd == 0:
self.tmpfile = open(os.devnull, encoding="utf-8")
self.syscapture = SysCapture(targetfd)
self.syscapture: CaptureBase[str] = SysCapture(targetfd)
else:
self.tmpfile = EncodedFile(
TemporaryFile(buffering=0),
@@ -394,7 +472,7 @@ class FDCaptureBinary:
if targetfd in patchsysdict:
self.syscapture = SysCapture(targetfd, self.tmpfile)
else:
self.syscapture = NoCapture()
self.syscapture = NoCapture(targetfd)
self._state = "initialized"
@@ -421,14 +499,6 @@ class FDCaptureBinary:
self.syscapture.start()
self._state = "started"
def snap(self):
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def done(self) -> None:
"""Stop capturing, restore streams, return original capture file,
seeked to position zero."""
@@ -461,22 +531,38 @@ class FDCaptureBinary:
os.dup2(self.tmpfile.fileno(), self.targetfd)
self._state = "started"
def writeorg(self, data):
class FDCaptureBinary(FDCaptureBase[bytes]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces `bytes`.
"""
EMPTY_BUFFER = b""
def snap(self) -> bytes:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.buffer.read()
self.tmpfile.seek(0)
self.tmpfile.truncate()
return res
def writeorg(self, data: bytes) -> None:
"""Write to original file descriptor."""
self._assert_state("writeorg", ("started", "suspended"))
os.write(self.targetfd_save, data)
class FDCapture(FDCaptureBinary):
class FDCapture(FDCaptureBase[str]):
"""Capture IO to/from a given OS-level file descriptor.
snap() produces text.
"""
# Ignore type because it doesn't match the type in the superclass (bytes).
EMPTY_BUFFER = "" # type: ignore
EMPTY_BUFFER = ""
def snap(self):
def snap(self) -> str:
self._assert_state("snap", ("started", "suspended"))
self.tmpfile.seek(0)
res = self.tmpfile.read()
@@ -484,77 +570,49 @@ class FDCapture(FDCaptureBinary):
self.tmpfile.truncate()
return res
def writeorg(self, data):
def writeorg(self, data: str) -> None:
"""Write to original file descriptor."""
super().writeorg(data.encode("utf-8")) # XXX use encoding of original stream
self._assert_state("writeorg", ("started", "suspended"))
# XXX use encoding of original stream
os.write(self.targetfd_save, data.encode("utf-8"))
# MultiCapture
# This class was a namedtuple, but due to mypy limitation[0] it could not be
# made generic, so was replaced by a regular class which tries to emulate the
# pertinent parts of a namedtuple. If the mypy limitation is ever lifted, can
# make it a namedtuple again.
# [0]: https://github.com/python/mypy/issues/685
@final
@functools.total_ordering
class CaptureResult(Generic[AnyStr]):
"""The result of :method:`CaptureFixture.readouterr`."""
# Generic NamedTuple only supported since Python 3.11.
if sys.version_info >= (3, 11) or TYPE_CHECKING:
__slots__ = ("out", "err")
@final
class CaptureResult(NamedTuple, Generic[AnyStr]):
"""The result of :method:`CaptureFixture.readouterr`."""
def __init__(self, out: AnyStr, err: AnyStr) -> None:
self.out: AnyStr = out
self.err: AnyStr = err
out: AnyStr
err: AnyStr
def __len__(self) -> int:
return 2
else:
def __iter__(self) -> Iterator[AnyStr]:
return iter((self.out, self.err))
class CaptureResult(
collections.namedtuple("CaptureResult", ["out", "err"]), Generic[AnyStr]
):
"""The result of :method:`CaptureFixture.readouterr`."""
def __getitem__(self, item: int) -> AnyStr:
return tuple(self)[item]
def _replace(
self, *, out: Optional[AnyStr] = None, err: Optional[AnyStr] = None
) -> "CaptureResult[AnyStr]":
return CaptureResult(
out=self.out if out is None else out, err=self.err if err is None else err
)
def count(self, value: AnyStr) -> int:
return tuple(self).count(value)
def index(self, value) -> int:
return tuple(self).index(value)
def __eq__(self, other: object) -> bool:
if not isinstance(other, (CaptureResult, tuple)):
return NotImplemented
return tuple(self) == tuple(other)
def __hash__(self) -> int:
return hash(tuple(self))
def __lt__(self, other: object) -> bool:
if not isinstance(other, (CaptureResult, tuple)):
return NotImplemented
return tuple(self) < tuple(other)
def __repr__(self) -> str:
return f"CaptureResult(out={self.out!r}, err={self.err!r})"
__slots__ = ()
class MultiCapture(Generic[AnyStr]):
_state = None
_in_suspended = False
def __init__(self, in_, out, err) -> None:
self.in_ = in_
self.out = out
self.err = err
def __init__(
self,
in_: Optional[CaptureBase[AnyStr]],
out: Optional[CaptureBase[AnyStr]],
err: Optional[CaptureBase[AnyStr]],
) -> None:
self.in_: Optional[CaptureBase[AnyStr]] = in_
self.out: Optional[CaptureBase[AnyStr]] = out
self.err: Optional[CaptureBase[AnyStr]] = err
def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
@@ -578,8 +636,10 @@ class MultiCapture(Generic[AnyStr]):
"""Pop current snapshot out/err capture and flush to orig streams."""
out, err = self.readouterr()
if out:
assert self.out is not None
self.out.writeorg(out)
if err:
assert self.err is not None
self.err.writeorg(err)
return out, err
@@ -600,6 +660,7 @@ class MultiCapture(Generic[AnyStr]):
if self.err:
self.err.resume()
if self._in_suspended:
assert self.in_ is not None
self.in_.resume()
self._in_suspended = False
@@ -622,7 +683,8 @@ class MultiCapture(Generic[AnyStr]):
def readouterr(self) -> CaptureResult[AnyStr]:
out = self.out.snap() if self.out else ""
err = self.err.snap() if self.err else ""
return CaptureResult(out, err)
# TODO: This type error is real, need to fix.
return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
@@ -662,7 +724,7 @@ class CaptureManager:
"""
def __init__(self, method: "_CaptureMethod") -> None:
self._method = method
self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None
@@ -831,14 +893,18 @@ class CaptureFixture(Generic[AnyStr]):
:fixture:`capfd` and :fixture:`capfdbinary` fixtures."""
def __init__(
self, captureclass, request: SubRequest, *, _ispytest: bool = False
self,
captureclass: Type[CaptureBase[AnyStr]],
request: SubRequest,
*,
_ispytest: bool = False,
) -> None:
check_ispytest(_ispytest)
self.captureclass = captureclass
self.captureclass: Type[CaptureBase[AnyStr]] = captureclass
self.request = request
self._capture: Optional[MultiCapture[AnyStr]] = None
self._captured_out = self.captureclass.EMPTY_BUFFER
self._captured_err = self.captureclass.EMPTY_BUFFER
self._captured_out: AnyStr = self.captureclass.EMPTY_BUFFER
self._captured_err: AnyStr = self.captureclass.EMPTY_BUFFER
def _start(self) -> None:
if self._capture is None:
@@ -893,7 +959,9 @@ class CaptureFixture(Generic[AnyStr]):
@contextlib.contextmanager
def disabled(self) -> Generator[None, None, None]:
"""Temporarily disable capturing while inside the ``with`` block."""
capmanager = self.request.config.pluginmanager.getplugin("capturemanager")
capmanager: CaptureManager = self.request.config.pluginmanager.getplugin(
"capturemanager"
)
with capmanager.global_and_fixture_disabled():
yield
@@ -920,8 +988,8 @@ def capsys(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capsys.readouterr()
assert captured.out == "hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](SysCapture, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@@ -948,8 +1016,8 @@ def capsysbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None,
captured = capsysbinary.readouterr()
assert captured.out == b"hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](SysCaptureBinary, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(SysCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@@ -976,8 +1044,8 @@ def capfd(request: SubRequest) -> Generator[CaptureFixture[str], None, None]:
captured = capfd.readouterr()
assert captured.out == "hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[str](FDCapture, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCapture, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture
@@ -1005,8 +1073,8 @@ def capfdbinary(request: SubRequest) -> Generator[CaptureFixture[bytes], None, N
assert captured.out == b"hello\n"
"""
capman = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture[bytes](FDCaptureBinary, request, _ispytest=True)
capman: CaptureManager = request.config.pluginmanager.getplugin("capturemanager")
capture_fixture = CaptureFixture(FDCaptureBinary, request, _ispytest=True)
capman.set_fixture(capture_fixture)
capture_fixture._start()
yield capture_fixture

View File

@@ -1,4 +1,7 @@
"""Python version compatibility code."""
from __future__ import annotations
import dataclasses
import enum
import functools
import inspect
@@ -11,13 +14,9 @@ from typing import Any
from typing import Callable
from typing import Generic
from typing import NoReturn
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
import py
# fmt: off
@@ -46,7 +45,7 @@ LEGACY_PATH = py.path. local
# fmt: on
def legacy_path(path: Union[str, "os.PathLike[str]"]) -> LEGACY_PATH:
def legacy_path(path: str | os.PathLike[str]) -> LEGACY_PATH:
"""Internal wrapper to prepare lazy proxies for legacy_path instances"""
return LEGACY_PATH(path)
@@ -56,7 +55,7 @@ def legacy_path(path: Union[str, "os.PathLike[str]"]) -> LEGACY_PATH:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET: "Final" = NotSetType.token # noqa: E305
NOTSET: Final = NotSetType.token # noqa: E305
# fmt: on
if sys.version_info >= (3, 8):
@@ -94,7 +93,7 @@ def is_async_function(func: object) -> bool:
return iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
def getlocation(function, curdir: Optional[str] = None) -> str:
def getlocation(function, curdir: str | None = None) -> str:
function = get_real_func(function)
fn = Path(inspect.getfile(function))
lineno = function.__code__.co_firstlineno
@@ -132,8 +131,8 @@ def getfuncargnames(
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None,
) -> Tuple[str, ...]:
cls: type | None = None,
) -> tuple[str, ...]:
"""Return the names of a function's mandatory arguments.
Should return the names of all function arguments that:
@@ -197,7 +196,7 @@ def getfuncargnames(
return arg_names
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
def get_default_arg_names(function: Callable[..., Any]) -> tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of
# getfuncargnames, to get the arguments which were excluded from its result
# because they had default values.
@@ -228,7 +227,7 @@ def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace")
def ascii_escaped(val: Union[bytes, str]) -> str:
def ascii_escaped(val: bytes | str) -> str:
r"""If val is pure ASCII, return it as an str, otherwise, escape
bytes objects into a sequence of escaped bytes:
@@ -252,7 +251,7 @@ def ascii_escaped(val: Union[bytes, str]) -> str:
return _translate_non_printable(ret)
@attr.s
@dataclasses.dataclass
class _PytestWrapper:
"""Dummy wrapper around a function object for internal use only.
@@ -261,7 +260,7 @@ class _PytestWrapper:
decorator to issue warnings when the fixture function is called directly.
"""
obj = attr.ib()
obj: Any
def get_real_func(obj):
@@ -355,7 +354,6 @@ else:
if sys.version_info >= (3, 8):
from functools import cached_property as cached_property
else:
from typing import Type
class cached_property(Generic[_S, _T]):
__slots__ = ("func", "__doc__")
@@ -366,12 +364,12 @@ else:
@overload
def __get__(
self, instance: None, owner: Optional[Type[_S]] = ...
) -> "cached_property[_S, _T]":
self, instance: None, owner: type[_S] | None = ...
) -> cached_property[_S, _T]:
...
@overload
def __get__(self, instance: _S, owner: Optional[Type[_S]] = ...) -> _T:
def __get__(self, instance: _S, owner: type[_S] | None = ...) -> _T:
...
def __get__(self, instance, owner=None):
@@ -381,6 +379,18 @@ else:
return value
def get_user_id() -> int | None:
"""Return the current user id, or None if we cannot get it reliably on the current platform."""
# win32 does not have a getuid() function.
# On Emscripten, getuid() is a stub that always returns 0.
if sys.platform in ("win32", "emscripten"):
return None
# getuid shouldn't fail, but cpython defines such a case.
# Let's hope for the best.
uid = os.getuid()
return uid if uid != -1 else None
# Perform exhaustiveness checking.
#
# Consider this example:

View File

@@ -2,6 +2,7 @@
import argparse
import collections.abc
import copy
import dataclasses
import enum
import glob
import inspect
@@ -34,7 +35,6 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import attr
from pluggy import HookimplMarker
from pluggy import HookspecMarker
from pluggy import PluginManager
@@ -49,7 +49,7 @@ from _pytest._code import ExceptionInfo
from _pytest._code import filter_traceback
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.compat import importlib_metadata
from _pytest.compat import importlib_metadata # type: ignore[attr-defined]
from _pytest.outcomes import fail
from _pytest.outcomes import Skipped
from _pytest.pathlib import absolutepath
@@ -62,7 +62,6 @@ from _pytest.warning_types import PytestConfigWarning
from _pytest.warning_types import warn_explicit_for
if TYPE_CHECKING:
from _pytest._code.code import _TracebackStyle
from _pytest.terminal import TerminalReporter
from .argparsing import Argument
@@ -527,7 +526,13 @@ class PytestPluginManager(PluginManager):
# Internal API for local conftest plugin handling.
#
def _set_initial_conftests(
self, namespace: argparse.Namespace, rootpath: Path
self,
args: Sequence[Union[str, Path]],
pyargs: bool,
noconftest: bool,
rootpath: Path,
confcutdir: Optional[Path],
importmode: Union[ImportMode, str],
) -> None:
"""Load initial conftest files given a preparsed "namespace".
@@ -537,27 +542,29 @@ class PytestPluginManager(PluginManager):
common options will not confuse our logic here.
"""
current = Path.cwd()
self._confcutdir = (
absolutepath(current / namespace.confcutdir)
if namespace.confcutdir
else None
)
self._noconftest = namespace.noconftest
self._using_pyargs = namespace.pyargs
testpaths = namespace.file_or_dir
self._confcutdir = absolutepath(current / confcutdir) if confcutdir else None
self._noconftest = noconftest
self._using_pyargs = pyargs
foundanchor = False
for testpath in testpaths:
path = str(testpath)
for intitial_path in args:
path = str(intitial_path)
# remove node-id syntax
i = path.find("::")
if i != -1:
path = path[:i]
anchor = absolutepath(current / path)
if anchor.exists(): # we found some file object
self._try_load_conftest(anchor, namespace.importmode, rootpath)
# Ensure we do not break if what appears to be an anchor
# is in fact a very long option (#10169).
try:
anchor_exists = anchor.exists()
except OSError: # pragma: no cover
anchor_exists = False
if anchor_exists:
self._try_load_conftest(anchor, importmode, rootpath)
foundanchor = True
if not foundanchor:
self._try_load_conftest(current, namespace.importmode, rootpath)
self._try_load_conftest(current, importmode, rootpath)
def _is_in_confcutdir(self, path: Path) -> bool:
"""Whether a path is within the confcutdir.
@@ -697,6 +704,7 @@ class PytestPluginManager(PluginManager):
parg = opt[2:]
else:
continue
parg = parg.strip()
if exclude_only and not parg.startswith("no:"):
continue
self.consider_pluginarg(parg)
@@ -886,10 +894,6 @@ def _iter_rewritable_modules(package_files: Iterable[str]) -> Iterator[str]:
yield from _iter_rewritable_modules(new_package_files)
def _args_converter(args: Iterable[str]) -> Tuple[str, ...]:
return tuple(args)
@final
class Config:
"""Access to configuration values, pluginmanager and plugin hooks.
@@ -903,7 +907,7 @@ class Config:
"""
@final
@attr.s(frozen=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class InvocationParams:
"""Holds parameters passed during :func:`pytest.main`.
@@ -919,13 +923,24 @@ class Config:
Plugins accessing ``InvocationParams`` must be aware of that.
"""
args: Tuple[str, ...] = attr.ib(converter=_args_converter)
args: Tuple[str, ...]
"""The command-line arguments as passed to :func:`pytest.main`."""
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]]
"""Extra plugins, might be `None`."""
dir: Path
"""The directory from which :func:`pytest.main` was invoked."""
def __init__(
self,
*,
args: Iterable[str],
plugins: Optional[Sequence[Union[str, _PluggyPlugin]]],
dir: Path,
) -> None:
object.__setattr__(self, "args", tuple(args))
object.__setattr__(self, "plugins", plugins)
object.__setattr__(self, "dir", dir)
class ArgsSource(enum.Enum):
"""Indicates the source of the test arguments.
@@ -998,6 +1013,8 @@ class Config:
self.hook.pytest_addoption.call_historic(
kwargs=dict(parser=self._parser, pluginmanager=self.pluginmanager)
)
self.args_source = Config.ArgsSource.ARGS
self.args: List[str] = []
if TYPE_CHECKING:
from _pytest.cacheprovider import Cache
@@ -1057,7 +1074,6 @@ class Config:
try:
self.parse(args)
except UsageError:
# Handle --version and --help here in a minimal fashion.
# This gets done via helpconfig normally, but its
# pytest_cmdline_main is not called in case of errors.
@@ -1122,8 +1138,25 @@ class Config:
@hookimpl(trylast=True)
def pytest_load_initial_conftests(self, early_config: "Config") -> None:
# We haven't fully parsed the command line arguments yet, so
# early_config.args it not set yet. But we need it for
# discovering the initial conftests. So "pre-run" the logic here.
# It will be done for real in `parse()`.
args, args_source = early_config._decide_args(
args=early_config.known_args_namespace.file_or_dir,
pyargs=early_config.known_args_namespace.pyargs,
testpaths=early_config.getini("testpaths"),
invocation_dir=early_config.invocation_params.dir,
rootpath=early_config.rootpath,
warn=False,
)
self.pluginmanager._set_initial_conftests(
early_config.known_args_namespace, rootpath=early_config.rootpath
args=args,
pyargs=early_config.known_args_namespace.pyargs,
noconftest=early_config.known_args_namespace.noconftest,
rootpath=early_config.rootpath,
confcutdir=early_config.known_args_namespace.confcutdir,
importmode=early_config.known_args_namespace.importmode,
)
def _initini(self, args: Sequence[str]) -> None:
@@ -1203,6 +1236,49 @@ class Config:
return args
def _decide_args(
self,
*,
args: List[str],
pyargs: List[str],
testpaths: List[str],
invocation_dir: Path,
rootpath: Path,
warn: bool,
) -> Tuple[List[str], ArgsSource]:
"""Decide the args (initial paths/nodeids) to use given the relevant inputs.
:param warn: Whether can issue warnings.
"""
if args:
source = Config.ArgsSource.ARGS
result = args
else:
if invocation_dir == rootpath:
source = Config.ArgsSource.TESTPATHS
if pyargs:
result = testpaths
else:
result = []
for path in testpaths:
result.extend(sorted(glob.iglob(path, recursive=True)))
if testpaths and not result:
if warn:
warning_text = (
"No files were found in testpaths; "
"consider removing or adjusting your testpaths configuration. "
"Searching recursively from the current directory instead."
)
self.issue_config_time_warning(
PytestConfigWarning(warning_text), stacklevel=3
)
else:
result = []
if not result:
source = Config.ArgsSource.INCOVATION_DIR
result = [str(invocation_dir)]
return result, source
def _preparse(self, args: List[str], addopts: bool = True) -> None:
if addopts:
env_addopts = os.environ.get("PYTEST_ADDOPTS", "")
@@ -1241,8 +1317,11 @@ class Config:
_pytest.deprecated.STRICT_OPTION, stacklevel=2
)
if self.known_args_namespace.confcutdir is None and self.inipath is not None:
confcutdir = str(self.inipath.parent)
if self.known_args_namespace.confcutdir is None:
if self.inipath is not None:
confcutdir = str(self.inipath.parent)
else:
confcutdir = str(self.rootpath)
self.known_args_namespace.confcutdir = confcutdir
try:
self.hook.pytest_load_initial_conftests(
@@ -1337,8 +1416,8 @@ class Config:
def parse(self, args: List[str], addopts: bool = True) -> None:
# Parse given cmdline arguments into this config object.
assert not hasattr(
self, "args"
assert (
self.args == []
), "can only parse cmdline args at most once per Config object"
self.hook.pytest_addhooks.call_historic(
kwargs=dict(pluginmanager=self.pluginmanager)
@@ -1348,25 +1427,17 @@ class Config:
self.hook.pytest_cmdline_preparse(config=self, args=args)
self._parser.after_preparse = True # type: ignore
try:
source = Config.ArgsSource.ARGS
args = self._parser.parse_setoption(
args, self.option, namespace=self.option
)
if not args:
if self.invocation_params.dir == self.rootpath:
source = Config.ArgsSource.TESTPATHS
testpaths: List[str] = self.getini("testpaths")
if self.known_args_namespace.pyargs:
args = testpaths
else:
args = []
for path in testpaths:
args.extend(sorted(glob.iglob(path, recursive=True)))
if not args:
source = Config.ArgsSource.INCOVATION_DIR
args = [str(self.invocation_params.dir)]
self.args = args
self.args_source = source
self.args, self.args_source = self._decide_args(
args=args,
pyargs=self.known_args_namespace.pyargs,
testpaths=self.getini("testpaths"),
invocation_dir=self.invocation_params.dir,
rootpath=self.rootpath,
warn=True,
)
except PrintHelp:
pass

View File

@@ -43,7 +43,6 @@ class PathAwareHookProxy:
@_wraps(hook)
def fixed_hook(**kw):
path_value: Optional[Path] = kw.pop(path_var, None)
fspath_value: Optional[LEGACY_PATH] = kw.pop(fspath_var, None)
if fspath_value is not None:

View File

@@ -203,8 +203,7 @@ def determine_setup(
else:
cwd = Path.cwd()
rootdir = get_common_ancestor([cwd, ancestor])
is_fs_root = os.path.splitdrive(str(rootdir))[1] == "/"
if is_fs_root:
if is_fs_root(rootdir):
rootdir = ancestor
if rootdir_cmd_arg:
rootdir = absolutepath(os.path.expandvars(rootdir_cmd_arg))
@@ -216,3 +215,11 @@ def determine_setup(
)
assert rootdir is not None
return rootdir, inipath, inicfg or {}
def is_fs_root(p: Path) -> bool:
r"""
Return True if the given path is pointing to the root of the
file system ("/" on Unix and "C:\\" on Windows for example).
"""
return os.path.splitdrive(str(p))[1] == os.sep

View File

@@ -3,6 +3,7 @@ import argparse
import functools
import sys
import types
import unittest
from typing import Any
from typing import Callable
from typing import Generator
@@ -293,7 +294,9 @@ class PdbInvoke:
sys.stdout.write(out)
sys.stdout.write(err)
assert call.excinfo is not None
_enter_pdb(node, call.excinfo, report)
if not isinstance(call.excinfo.value, unittest.SkipTest):
_enter_pdb(node, call.excinfo, report)
def pytest_internalerror(self, excinfo: ExceptionInfo[BaseException]) -> None:
tb = _postmortem_traceback(excinfo)

View File

@@ -531,7 +531,6 @@ class DoctestModule(Module):
if _is_mocked(obj):
return
with _patch_unwrap_mock_aware():
# Type ignored because this is a private function.
super()._find( # type:ignore[misc]
tests, obj, name, module, source_lines, globs, seen

View File

@@ -2,7 +2,6 @@ import io
import os
import sys
from typing import Generator
from typing import TextIO
import pytest
from _pytest.config import Config
@@ -11,7 +10,7 @@ from _pytest.nodes import Item
from _pytest.stash import StashKey
fault_handler_stderr_key = StashKey[TextIO]()
fault_handler_stderr_fd_key = StashKey[int]()
fault_handler_originally_enabled_key = StashKey[bool]()
@@ -26,10 +25,9 @@ def pytest_addoption(parser: Parser) -> None:
def pytest_configure(config: Config) -> None:
import faulthandler
stderr_fd_copy = os.dup(get_stderr_fileno())
config.stash[fault_handler_stderr_key] = open(stderr_fd_copy, "w")
config.stash[fault_handler_stderr_fd_key] = os.dup(get_stderr_fileno())
config.stash[fault_handler_originally_enabled_key] = faulthandler.is_enabled()
faulthandler.enable(file=config.stash[fault_handler_stderr_key])
faulthandler.enable(file=config.stash[fault_handler_stderr_fd_key])
def pytest_unconfigure(config: Config) -> None:
@@ -37,9 +35,9 @@ def pytest_unconfigure(config: Config) -> None:
faulthandler.disable()
# Close the dup file installed during pytest_configure.
if fault_handler_stderr_key in config.stash:
config.stash[fault_handler_stderr_key].close()
del config.stash[fault_handler_stderr_key]
if fault_handler_stderr_fd_key in config.stash:
os.close(config.stash[fault_handler_stderr_fd_key])
del config.stash[fault_handler_stderr_fd_key]
if config.stash.get(fault_handler_originally_enabled_key, False):
# Re-enable the faulthandler if it was originally enabled.
faulthandler.enable(file=get_stderr_fileno())
@@ -67,10 +65,10 @@ def get_timeout_config_value(config: Config) -> float:
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
timeout = get_timeout_config_value(item.config)
stderr = item.config.stash[fault_handler_stderr_key]
if timeout > 0 and stderr is not None:
if timeout > 0:
import faulthandler
stderr = item.config.stash[fault_handler_stderr_fd_key]
faulthandler.dump_traceback_later(timeout, file=stderr)
try:
yield

View File

@@ -1,3 +1,4 @@
import dataclasses
import functools
import inspect
import os
@@ -28,8 +29,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
import _pytest
from _pytest import nodes
from _pytest._code import getfslineno
@@ -47,6 +46,7 @@ from _pytest.compat import getimfunc
from _pytest.compat import getlocation
from _pytest.compat import is_generator
from _pytest.compat import NOTSET
from _pytest.compat import NotSetType
from _pytest.compat import overload
from _pytest.compat import safe_getattr
from _pytest.config import _PluggyPlugin
@@ -59,6 +59,7 @@ from _pytest.mark import Mark
from _pytest.mark import ParameterSet
from _pytest.mark.structures import MarkDecorator
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import TEST_OUTCOME
from _pytest.pathlib import absolutepath
from _pytest.pathlib import bestrelpath
@@ -103,7 +104,7 @@ _FixtureCachedResult = Union[
]
@attr.s(frozen=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class PseudoFixtureDef(Generic[FixtureValue]):
cached_result: "_FixtureCachedResult[FixtureValue]"
_scope: Scope
@@ -113,16 +114,18 @@ def pytest_sessionstart(session: "Session") -> None:
session._fixturemanager = FixtureManager(session)
def get_scope_package(node, fixturedef: "FixtureDef[object]"):
import pytest
def get_scope_package(
node: nodes.Item,
fixturedef: "FixtureDef[object]",
) -> Optional[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package
cls = pytest.Package
current = node
current: Optional[Union[nodes.Item, nodes.Collector]] = node
fixture_package_name = "{}/{}".format(fixturedef.baseid, "__init__.py")
while current and (
type(current) is not cls or fixture_package_name != current.nodeid
not isinstance(current, Package) or fixture_package_name != current.nodeid
):
current = current.parent
current = current.parent # type: ignore[assignment]
if current is None:
return node.session
return current
@@ -350,8 +353,10 @@ def get_direct_param_fixture_func(request: "FixtureRequest") -> Any:
return request.param
@attr.s(slots=True, auto_attribs=True)
@dataclasses.dataclass
class FuncFixtureInfo:
__slots__ = ("argnames", "initialnames", "names_closure", "name2fixturedefs")
# Original function argument names.
argnames: Tuple[str, ...]
# Argnames that function immediately requires. These include argnames +
@@ -433,7 +438,23 @@ class FixtureRequest:
@property
def node(self):
"""Underlying collection node (depends on current request scope)."""
return self._getscopeitem(self._scope)
scope = self._scope
if scope is Scope.Function:
# This might also be a non-function Item despite its attribute name.
node: Optional[Union[nodes.Item, nodes.Collector]] = self._pyfuncitem
elif scope is Scope.Package:
# FIXME: _fixturedef is not defined on FixtureRequest (this class),
# but on FixtureRequest (a subclass).
node = get_scope_package(self._pyfuncitem, self._fixturedef) # type: ignore[attr-defined]
else:
node = get_scope_node(self._pyfuncitem, scope)
if node is None and scope is Scope.Class:
# Fallback to function item itself.
node = self._pyfuncitem
assert node, 'Could not obtain a node for scope "{}" for function {!r}'.format(
scope, self._pyfuncitem
)
return node
def _getnextfixturedef(self, argname: str) -> "FixtureDef[Any]":
fixturedefs = self._arg2fixturedefs.get(argname, None)
@@ -517,11 +538,7 @@ class FixtureRequest:
"""Add finalizer/teardown function to be called without arguments after
the last test within the requesting test context finished execution."""
# XXX usually this method is shadowed by fixturedef specific ones.
self._addfinalizer(finalizer, scope=self.scope)
def _addfinalizer(self, finalizer: Callable[[], object], scope) -> None:
node = self._getscopeitem(scope)
node.addfinalizer(finalizer)
self.node.addfinalizer(finalizer)
def applymarker(self, marker: Union[str, MarkDecorator]) -> None:
"""Apply a marker to a single test function invocation.
@@ -716,28 +733,6 @@ class FixtureRequest:
lines.append("%s:%d: def %s%s" % (p, lineno + 1, factory.__name__, args))
return lines
def _getscopeitem(
self, scope: Union[Scope, "_ScopeName"]
) -> Union[nodes.Item, nodes.Collector]:
if isinstance(scope, str):
scope = Scope(scope)
if scope is Scope.Function:
# This might also be a non-function Item despite its attribute name.
node: Optional[Union[nodes.Item, nodes.Collector]] = self._pyfuncitem
elif scope is Scope.Package:
# FIXME: _fixturedef is not defined on FixtureRequest (this class),
# but on FixtureRequest (a subclass).
node = get_scope_package(self._pyfuncitem, self._fixturedef) # type: ignore[attr-defined]
else:
node = get_scope_node(self._pyfuncitem, scope)
if node is None and scope is Scope.Class:
# Fallback to function item itself.
node = self._pyfuncitem
assert node, 'Could not obtain a node for scope "{}" for function {!r}'.format(
scope, self._pyfuncitem
)
return node
def __repr__(self) -> str:
return "<FixtureRequest for %r>" % (self.node)
@@ -1130,6 +1125,10 @@ def pytest_fixture_setup(
except TEST_OUTCOME:
exc_info = sys.exc_info()
assert exc_info[0] is not None
if isinstance(
exc_info[1], skip.Exception
) and not fixturefunc.__name__.startswith("xunit_setup"):
exc_info[1]._use_item_location = True # type: ignore[attr-defined]
fixturedef.cached_result = (None, my_cache_key, exc_info)
raise
fixturedef.cached_result = (result, my_cache_key, None)
@@ -1177,19 +1176,21 @@ def wrap_function_to_error_out_if_called_directly(
@final
@attr.s(frozen=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class FixtureFunctionMarker:
scope: "Union[_ScopeName, Callable[[str, Config], _ScopeName]]"
params: Optional[Tuple[object, ...]] = attr.ib(converter=_params_converter)
params: Optional[Tuple[object, ...]]
autouse: bool = False
ids: Optional[
Union[Tuple[Optional[object], ...], Callable[[Any], Optional[object]]]
] = attr.ib(
default=None,
converter=_ensure_immutable_ids,
)
] = None
name: Optional[str] = None
_ispytest: dataclasses.InitVar[bool] = False
def __post_init__(self, _ispytest: bool) -> None:
check_ispytest(_ispytest)
def __call__(self, function: FixtureFunction) -> FixtureFunction:
if inspect.isclass(function):
raise ValueError("class fixtures not supported (maybe in the future)")
@@ -1312,10 +1313,11 @@ def fixture( # noqa: F811
"""
fixture_marker = FixtureFunctionMarker(
scope=scope,
params=params,
params=tuple(params) if params is not None else None,
autouse=autouse,
ids=ids,
ids=None if ids is None else ids if callable(ids) else tuple(ids),
name=name,
_ispytest=True,
)
# Direct decoration.
@@ -1588,13 +1590,52 @@ class FixtureManager:
# Separate parametrized setups.
items[:] = reorder_items(items)
@overload
def parsefactories(
self, node_or_obj, nodeid=NOTSET, unittest: bool = False
self,
node_or_obj: nodes.Node,
*,
unittest: bool = ...,
) -> None:
raise NotImplementedError()
@overload
def parsefactories( # noqa: F811
self,
node_or_obj: object,
nodeid: Optional[str],
*,
unittest: bool = ...,
) -> None:
raise NotImplementedError()
def parsefactories( # noqa: F811
self,
node_or_obj: Union[nodes.Node, object],
nodeid: Union[str, NotSetType, None] = NOTSET,
*,
unittest: bool = False,
) -> None:
"""Collect fixtures from a collection node or object.
Found fixtures are parsed into `FixtureDef`s and saved.
If `node_or_object` is a collection node (with an underlying Python
object), the node's object is traversed and the node's nodeid is used to
determine the fixtures' visibilty. `nodeid` must not be specified in
this case.
If `node_or_object` is an object (e.g. a plugin), the object is
traversed and the given `nodeid` is used to determine the fixtures'
visibility. `nodeid` must be specified in this case; None and "" mean
total visibility.
"""
if nodeid is not NOTSET:
holderobj = node_or_obj
else:
holderobj = node_or_obj.obj
assert isinstance(node_or_obj, nodes.Node)
holderobj = cast(object, node_or_obj.obj) # type: ignore[attr-defined]
assert isinstance(node_or_obj.nodeid, str)
nodeid = node_or_obj.nodeid
if holderobj in self._holderobjseen:
return

View File

@@ -105,7 +105,7 @@ def pytest_cmdline_parse():
if config.option.debug:
# --debug | --debug <file.log> was provided.
path = config.option.debug
debugfile = open(path, "w")
debugfile = open(path, "w", encoding="utf-8")
debugfile.write(
"versions pytest-%s, "
"python-%s\ncwd=%s\nargs=%s\n\n"
@@ -164,7 +164,8 @@ def showhelp(config: Config) -> None:
tw.write(config._parser.optparser.format_help())
tw.line()
tw.line(
"[pytest] ini-options in the first pytest.ini|tox.ini|setup.cfg file found:"
"[pytest] ini-options in the first "
"pytest.ini|tox.ini|setup.cfg|pyproject.toml file found:"
)
tw.line()

View File

@@ -21,7 +21,7 @@ if TYPE_CHECKING:
from typing_extensions import Literal
from _pytest._code.code import ExceptionRepr
from _pytest.code import ExceptionInfo
from _pytest._code.code import ExceptionInfo
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PytestPluginManager
@@ -41,6 +41,7 @@ if TYPE_CHECKING:
from _pytest.reports import TestReport
from _pytest.runner import CallInfo
from _pytest.terminal import TerminalReporter
from _pytest.terminal import TestShortLogReport
from _pytest.compat import LEGACY_PATH
@@ -505,7 +506,9 @@ def pytest_runtest_logstart(
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
:param nodeid: Full node ID of the item.
:param location: A tuple of ``(filename, lineno, testname)``.
:param location: A tuple of ``(filename, lineno, testname)``
where ``filename`` is a file path relative to ``config.rootpath``
and ``lineno`` is 0-based.
"""
@@ -517,7 +520,9 @@ def pytest_runtest_logfinish(
See :hook:`pytest_runtest_protocol` for a description of the runtest protocol.
:param nodeid: Full node ID of the item.
:param location: A tuple of ``(filename, lineno, testname)``.
:param location: A tuple of ``(filename, lineno, testname)``
where ``filename`` is a file path relative to ``config.rootpath``
and ``lineno`` is 0-based.
"""
@@ -738,7 +743,7 @@ def pytest_assertion_pass(item: "Item", lineno: int, orig: str, expl: str) -> No
# -------------------------------------------------------------------------
def pytest_report_header(
def pytest_report_header( # type:ignore[empty-body]
config: "Config", start_path: Path, startdir: "LEGACY_PATH"
) -> Union[str, List[str]]:
"""Return a string or list of strings to be displayed as header info for terminal reporting.
@@ -767,7 +772,7 @@ def pytest_report_header(
"""
def pytest_report_collectionfinish(
def pytest_report_collectionfinish( # type:ignore[empty-body]
config: "Config",
start_path: Path,
startdir: "LEGACY_PATH",
@@ -800,9 +805,9 @@ def pytest_report_collectionfinish(
@hookspec(firstresult=True)
def pytest_report_teststatus(
def pytest_report_teststatus( # type:ignore[empty-body]
report: Union["CollectReport", "TestReport"], config: "Config"
) -> Tuple[str, str, Union[str, Mapping[str, bool]]]:
) -> "TestShortLogReport | Tuple[str, str, Union[str, Tuple[str, Mapping[str, bool]]]]":
"""Return result-category, shortletter and verbose word for status
reporting.
@@ -880,7 +885,9 @@ def pytest_warning_recorded(
# -------------------------------------------------------------------------
def pytest_markeval_namespace(config: "Config") -> Dict[str, Any]:
def pytest_markeval_namespace( # type:ignore[empty-body]
config: "Config",
) -> Dict[str, Any]:
"""Called when constructing the globals dictionary used for
evaluating string conditions in xfail/skipif markers.

View File

@@ -645,8 +645,8 @@ class LogXML:
def pytest_sessionfinish(self) -> None:
dirname = os.path.dirname(os.path.abspath(self.logfile))
if not os.path.isdir(dirname):
os.makedirs(dirname)
# exist_ok avoids filesystem race conditions between checking path existence and requesting creation
os.makedirs(dirname, exist_ok=True)
with open(self.logfile, "w", encoding="utf-8") as logfile:
suite_stop_time = timing.time()

View File

@@ -1,4 +1,5 @@
"""Add backward compatibility support for the legacy py path type."""
import dataclasses
import shlex
import subprocess
from pathlib import Path
@@ -7,7 +8,6 @@ from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
import attr
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache
@@ -268,7 +268,7 @@ class LegacyTestdirPlugin:
@final
@attr.s(init=False, auto_attribs=True)
@dataclasses.dataclass
class TempdirFactory:
"""Backward compatibility wrapper that implements :class:`py.path.local`
for :class:`TempPathFactory`.

View File

@@ -5,7 +5,11 @@ import os
import re
from contextlib import contextmanager
from contextlib import nullcontext
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import StringIO
from logging import LogRecord
from pathlib import Path
from typing import AbstractSet
from typing import Dict
@@ -53,7 +57,25 @@ def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
class ColoredLevelFormatter(logging.Formatter):
class DatetimeFormatter(logging.Formatter):
"""A logging formatter which formats record with
:func:`datetime.datetime.strftime` formatter instead of
:func:`time.strftime` in case of microseconds in format string.
"""
def formatTime(self, record: LogRecord, datefmt=None) -> str:
if datefmt and "%f" in datefmt:
ct = self.converter(record.created)
tz = timezone(timedelta(seconds=ct.tm_gmtoff), ct.tm_zone)
# Construct `datetime.datetime` object from `struct_time`
# and msecs information from `record`
dt = datetime(*ct[0:6], microsecond=round(record.msecs * 1000), tzinfo=tz)
return dt.strftime(datefmt)
# Use `logging.Formatter` for non-microsecond formats
return super().formatTime(record, datefmt)
class ColoredLevelFormatter(DatetimeFormatter):
"""A logging formatter which colorizes the %(levelname)..s part of the
log format passed to __init__."""
@@ -297,6 +319,13 @@ def pytest_addoption(parser: Parser) -> None:
default=None,
help="Auto-indent multiline messages passed to the logging module. Accepts true|on, false|off or an integer.",
)
group.addoption(
"--log-disable",
action="append",
default=[],
dest="logger_disable",
help="Disable a logger by name. Can be passed multiple times.",
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
@@ -369,11 +398,12 @@ class LogCaptureFixture:
self._initial_handler_level: Optional[int] = None
# Dict of log name -> log level.
self._initial_logger_levels: Dict[Optional[str], int] = {}
self._initial_disabled_logging_level: Optional[int] = None
def _finalize(self) -> None:
"""Finalize the fixture.
This restores the log levels changed by :meth:`set_level`.
This restores the log levels and the disabled logging levels changed by :meth:`set_level`.
"""
# Restore log levels.
if self._initial_handler_level is not None:
@@ -381,6 +411,10 @@ class LogCaptureFixture:
for logger_name, level in self._initial_logger_levels.items():
logger = logging.getLogger(logger_name)
logger.setLevel(level)
# Disable logging at the original disabled logging level.
if self._initial_disabled_logging_level is not None:
logging.disable(self._initial_disabled_logging_level)
self._initial_disabled_logging_level = None
@property
def handler(self) -> LogCaptureHandler:
@@ -446,13 +480,51 @@ class LogCaptureFixture:
"""Reset the list of log records and the captured log text."""
self.handler.clear()
def _force_enable_logging(
self, level: Union[int, str], logger_obj: logging.Logger
) -> int:
"""Enable the desired logging level if the global level was disabled via ``logging.disabled``.
Only enables logging levels greater than or equal to the requested ``level``.
Does nothing if the desired ``level`` wasn't disabled.
:param level:
The logger level caplog should capture.
All logging is enabled if a non-standard logging level string is supplied.
Valid level strings are in :data:`logging._nameToLevel`.
:param logger_obj: The logger object to check.
:return: The original disabled logging level.
"""
original_disable_level: int = logger_obj.manager.disable # type: ignore[attr-defined]
if isinstance(level, str):
# Try to translate the level string to an int for `logging.disable()`
level = logging.getLevelName(level)
if not isinstance(level, int):
# The level provided was not valid, so just un-disable all logging.
logging.disable(logging.NOTSET)
elif not logger_obj.isEnabledFor(level):
# Each level is `10` away from other levels.
# https://docs.python.org/3/library/logging.html#logging-levels
disable_level = max(level - 10, logging.NOTSET)
logging.disable(disable_level)
return original_disable_level
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
"""Set the level of a logger for the duration of a test.
"""Set the threshold level of a logger for the duration of a test.
Logging messages which are less severe than this level will not be captured.
.. versionchanged:: 3.4
The levels of the loggers changed by this function will be
restored to their initial values at the end of the test.
Will enable the requested logging level if it was disabled via :meth:`logging.disable`.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
@@ -463,6 +535,9 @@ class LogCaptureFixture:
if self._initial_handler_level is None:
self._initial_handler_level = self.handler.level
self.handler.setLevel(level)
initial_disabled_logging_level = self._force_enable_logging(level, logger_obj)
if self._initial_disabled_logging_level is None:
self._initial_disabled_logging_level = initial_disabled_logging_level
@contextmanager
def at_level(
@@ -472,6 +547,8 @@ class LogCaptureFixture:
the end of the 'with' statement the level is restored to its original
value.
Will enable the requested logging level if it was disabled via :meth:`logging.disable`.
:param level: The level.
:param logger: The logger to update. If not given, the root logger.
"""
@@ -480,11 +557,13 @@ class LogCaptureFixture:
logger_obj.setLevel(level)
handler_orig_level = self.handler.level
self.handler.setLevel(level)
original_disable_level = self._force_enable_logging(level, logger_obj)
try:
yield
finally:
logger_obj.setLevel(orig_level)
self.handler.setLevel(handler_orig_level)
logging.disable(original_disable_level)
@fixture
@@ -570,7 +649,7 @@ class LoggingPlugin:
config, "log_file_date_format", "log_date_format"
)
log_file_formatter = logging.Formatter(
log_file_formatter = DatetimeFormatter(
log_file_format, datefmt=log_file_date_format
)
self.log_file_handler.setFormatter(log_file_formatter)
@@ -594,6 +673,15 @@ class LoggingPlugin:
get_option_ini(config, "log_auto_indent"),
)
self.log_cli_handler.setFormatter(log_cli_formatter)
self._disable_loggers(loggers_to_disable=config.option.logger_disable)
def _disable_loggers(self, loggers_to_disable: List[str]) -> None:
if not loggers_to_disable:
return
for name in loggers_to_disable:
logger = logging.getLogger(name)
logger.disabled = True
def _create_formatter(self, log_format, log_date_format, auto_indent):
# Color option doesn't exist if terminal plugin is disabled.
@@ -605,7 +693,7 @@ class LoggingPlugin:
create_terminal_writer(self._config), log_format, log_date_format
)
else:
formatter = logging.Formatter(log_format, log_date_format)
formatter = DatetimeFormatter(log_format, log_date_format)
formatter._style = PercentStyleMultiline(
formatter._style._fmt, auto_indent=auto_indent

View File

@@ -1,5 +1,6 @@
"""Core implementation of the testing process: init, session, runtest loop."""
import argparse
import dataclasses
import fnmatch
import functools
import importlib
@@ -19,8 +20,6 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
import attr
import _pytest._code
from _pytest import nodes
from _pytest.compat import final
@@ -123,11 +122,12 @@ def pytest_addoption(parser: Parser) -> None:
)
group._addoption(
"-c",
metavar="file",
"--config-file",
metavar="FILE",
type=str,
dest="inifilename",
help="Load configuration from `file` instead of trying to locate one of the "
"implicit configuration files",
help="Load configuration from `FILE` instead of trying to locate one of the "
"implicit configuration files.",
)
group._addoption(
"--continue-on-collection-errors",
@@ -400,6 +400,12 @@ def pytest_ignore_collect(collection_path: Path, config: Config) -> Optional[boo
allow_in_venv = config.getoption("collect_in_virtualenv")
if not allow_in_venv and _in_venv(collection_path):
return True
if collection_path.is_dir():
norecursepatterns = config.getini("norecursedirs")
if any(fnmatch_ex(pat, collection_path) for pat in norecursepatterns):
return True
return None
@@ -442,8 +448,10 @@ class Failed(Exception):
"""Signals a stop as failed test run."""
@attr.s(slots=True, auto_attribs=True)
@dataclasses.dataclass
class _bestrelpath_cache(Dict[Path, str]):
__slots__ = ("path",)
path: Path
def __missing__(self, path: Path) -> str:
@@ -561,9 +569,6 @@ class Session(nodes.FSCollector):
ihook = self.gethookproxy(fspath.parent)
if ihook.pytest_ignore_collect(collection_path=fspath, config=self.config):
return False
norecursepatterns = self.config.getini("norecursedirs")
if any(fnmatch_ex(pat, fspath) for pat in norecursepatterns):
return False
return True
def _collectfile(
@@ -684,8 +689,8 @@ class Session(nodes.FSCollector):
# are not collected more than once.
matchnodes_cache: Dict[Tuple[Type[nodes.Collector], str], CollectReport] = {}
# Dirnames of pkgs with dunder-init files.
pkg_roots: Dict[str, Package] = {}
# Directories of pkgs with dunder-init files.
pkg_roots: Dict[Path, Package] = {}
for argpath, names in self._initial_parts:
self.trace("processing argument", (argpath, names))
@@ -706,7 +711,7 @@ class Session(nodes.FSCollector):
col = self._collectfile(pkginit, handle_dupes=False)
if col:
if isinstance(col[0], Package):
pkg_roots[str(parent)] = col[0]
pkg_roots[parent] = col[0]
node_cache1[col[0].path] = [col[0]]
# If it's a directory argument, recurse and look for any Subpackages.
@@ -715,7 +720,7 @@ class Session(nodes.FSCollector):
assert not names, f"invalid arg {(argpath, names)!r}"
seen_dirs: Set[Path] = set()
for direntry in visit(str(argpath), self._recurse):
for direntry in visit(argpath, self._recurse):
if not direntry.is_file():
continue
@@ -730,8 +735,8 @@ class Session(nodes.FSCollector):
for x in self._collectfile(pkginit):
yield x
if isinstance(x, Package):
pkg_roots[str(dirpath)] = x
if str(dirpath) in pkg_roots:
pkg_roots[dirpath] = x
if dirpath in pkg_roots:
# Do not collect packages here.
continue
@@ -748,7 +753,7 @@ class Session(nodes.FSCollector):
if argpath in node_cache1:
col = node_cache1[argpath]
else:
collect_root = pkg_roots.get(str(argpath.parent), self)
collect_root = pkg_roots.get(argpath.parent, self)
col = collect_root._collectfile(argpath, handle_dupes=False)
if col:
node_cache1[argpath] = col

View File

@@ -1,4 +1,5 @@
"""Generic mechanism for marking and selecting python functions."""
import dataclasses
from typing import AbstractSet
from typing import Collection
from typing import List
@@ -6,8 +7,6 @@ from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
import attr
from .expression import Expression
from .expression import ParseError
from .structures import EMPTY_PARAMETERSET_OPTION
@@ -130,7 +129,7 @@ def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
return None
@attr.s(slots=True, auto_attribs=True)
@dataclasses.dataclass
class KeywordMatcher:
"""A matcher for keywords.
@@ -145,6 +144,8 @@ class KeywordMatcher:
any item, as well as names directly assigned to test functions.
"""
__slots__ = ("_names",)
_names: AbstractSet[str]
@classmethod
@@ -201,13 +202,15 @@ def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
items[:] = remaining
@attr.s(slots=True, auto_attribs=True)
@dataclasses.dataclass
class MarkMatcher:
"""A matcher for markers which are present.
Tries to match on any marker names, attached to the given colitem.
"""
__slots__ = ("own_mark_names",)
own_mark_names: AbstractSet[str]
@classmethod

View File

@@ -15,8 +15,10 @@ The semantics are:
- or/and/not evaluate according to the usual boolean semantics.
"""
import ast
import dataclasses
import enum
import re
import sys
import types
from typing import Callable
from typing import Iterator
@@ -25,7 +27,10 @@ from typing import NoReturn
from typing import Optional
from typing import Sequence
import attr
if sys.version_info >= (3, 8):
astNameConstant = ast.Constant
else:
astNameConstant = ast.NameConstant
__all__ = [
@@ -44,8 +49,9 @@ class TokenType(enum.Enum):
EOF = "end of input"
@attr.s(frozen=True, slots=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class Token:
__slots__ = ("type", "value", "pos")
type: TokenType
value: str
pos: int
@@ -132,7 +138,7 @@ IDENT_PREFIX = "$"
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
ret: ast.expr = ast.NameConstant(False)
ret: ast.expr = astNameConstant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)

View File

@@ -1,4 +1,5 @@
import collections.abc
import dataclasses
import inspect
import warnings
from typing import Any
@@ -20,8 +21,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import final
@@ -192,8 +191,10 @@ class ParameterSet(NamedTuple):
@final
@attr.s(frozen=True, init=False, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class Mark:
"""A pytest mark."""
#: Name of the mark.
name: str
#: Positional arguments of the mark decorator.
@@ -202,9 +203,11 @@ class Mark:
kwargs: Mapping[str, Any]
#: Source Mark for ids with parametrize Marks.
_param_ids_from: Optional["Mark"] = attr.ib(default=None, repr=False)
_param_ids_from: Optional["Mark"] = dataclasses.field(default=None, repr=False)
#: Resolved/generated ids with parametrize Marks.
_param_ids_generated: Optional[Sequence[str]] = attr.ib(default=None, repr=False)
_param_ids_generated: Optional[Sequence[str]] = dataclasses.field(
default=None, repr=False
)
def __init__(
self,
@@ -262,7 +265,7 @@ class Mark:
Markable = TypeVar("Markable", bound=Union[Callable[..., object], type])
@attr.s(init=False, auto_attribs=True)
@dataclasses.dataclass
class MarkDecorator:
"""A decorator for applying a mark on test functions and classes.
@@ -356,12 +359,35 @@ class MarkDecorator:
return self.with_args(*args, **kwargs)
def get_unpacked_marks(obj: object) -> Iterable[Mark]:
"""Obtain the unpacked marks that are stored on an object."""
mark_list = getattr(obj, "pytestmark", [])
if not isinstance(mark_list, list):
mark_list = [mark_list]
return normalize_mark_list(mark_list)
def get_unpacked_marks(
obj: Union[object, type],
*,
consider_mro: bool = True,
) -> List[Mark]:
"""Obtain the unpacked marks that are stored on an object.
If obj is a class and consider_mro is true, return marks applied to
this class and all of its super-classes in MRO order. If consider_mro
is false, only return marks applied directly to this class.
"""
if isinstance(obj, type):
if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])]
else:
mark_lists = [x.__dict__.get("pytestmark", []) for x in obj.__mro__]
mark_list = []
for item in mark_lists:
if isinstance(item, list):
mark_list.extend(item)
else:
mark_list.append(item)
else:
mark_attribute = getattr(obj, "pytestmark", [])
if isinstance(mark_attribute, list):
mark_list = mark_attribute
else:
mark_list = [mark_attribute]
return list(normalize_mark_list(mark_list))
def normalize_mark_list(
@@ -395,7 +421,7 @@ def store_mark(obj, mark: Mark) -> None:
# Always reassign name to avoid updating pytestmark in a reference that
# was only borrowed.
obj.pytestmark = [*get_unpacked_marks(obj), mark]
obj.pytestmark = [*get_unpacked_marks(obj, consider_mro=False), mark]
# Typing for builtin pytest marks. This is cheating; it gives builtin marks

View File

@@ -7,6 +7,7 @@ from contextlib import contextmanager
from typing import Any
from typing import Generator
from typing import List
from typing import Mapping
from typing import MutableMapping
from typing import Optional
from typing import overload
@@ -129,7 +130,7 @@ class MonkeyPatch:
def __init__(self) -> None:
self._setattr: List[Tuple[object, str, object]] = []
self._setitem: List[Tuple[MutableMapping[Any, Any], object, object]] = []
self._setitem: List[Tuple[Mapping[Any, Any], object, object]] = []
self._cwd: Optional[str] = None
self._savesyspath: Optional[List[str]] = None
@@ -290,12 +291,13 @@ class MonkeyPatch:
self._setattr.append((target, name, oldval))
delattr(target, name)
def setitem(self, dic: MutableMapping[K, V], name: K, value: V) -> None:
def setitem(self, dic: Mapping[K, V], name: K, value: V) -> None:
"""Set dictionary entry ``name`` to value."""
self._setitem.append((dic, name, dic.get(name, notset)))
dic[name] = value
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
dic[name] = value # type: ignore[index]
def delitem(self, dic: MutableMapping[K, V], name: K, raising: bool = True) -> None:
def delitem(self, dic: Mapping[K, V], name: K, raising: bool = True) -> None:
"""Delete ``name`` from dict.
Raises ``KeyError`` if it doesn't exist, unless ``raising`` is set to
@@ -306,7 +308,8 @@ class MonkeyPatch:
raise KeyError(name)
else:
self._setitem.append((dic, name, dic.get(name, notset)))
del dic[name]
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dic[name] # type: ignore[attr-defined]
def setenv(self, name: str, value: str, prepend: Optional[str] = None) -> None:
"""Set environment variable ``name`` to ``value``.
@@ -401,11 +404,13 @@ class MonkeyPatch:
for dictionary, key, value in reversed(self._setitem):
if value is notset:
try:
del dictionary[key]
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
del dictionary[key] # type: ignore[attr-defined]
except KeyError:
pass # Was already deleted, so we have the desired state.
else:
dictionary[key] = value
# Not all Mapping types support indexing, but MutableMapping doesn't support TypedDict
dictionary[key] = value # type: ignore[index]
self._setitem[:] = []
if self._savesyspath is not None:
sys.path[:] = self._savesyspath

View File

@@ -22,6 +22,7 @@ import _pytest._code
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest._code.code import Traceback
from _pytest.compat import cached_property
from _pytest.compat import LEGACY_PATH
from _pytest.config import Config
@@ -432,8 +433,8 @@ class Node(metaclass=NodeMeta):
assert current is None or isinstance(current, cls)
return current
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
pass
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
return excinfo.traceback
def _repr_failure_py(
self,
@@ -449,13 +450,13 @@ class Node(metaclass=NodeMeta):
style = "value"
if isinstance(excinfo.value, FixtureLookupError):
return excinfo.value.formatrepr()
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]]
if self.config.getoption("fulltrace", False):
style = "long"
tbfilter = False
else:
tb = _pytest._code.Traceback([excinfo.traceback[-1]])
self._prunetraceback(excinfo)
if len(excinfo.traceback) == 0:
excinfo.traceback = tb
tbfilter = self._traceback_filter
if style == "auto":
style = "long"
# XXX should excinfo.getrepr record all data and toterminal() process it?
@@ -486,7 +487,7 @@ class Node(metaclass=NodeMeta):
abspath=abspath,
showlocals=self.config.getoption("showlocals", False),
style=style,
tbfilter=False, # pruned already, or in --fulltrace mode.
tbfilter=tbfilter,
truncate_locals=truncate_locals,
)
@@ -511,7 +512,7 @@ def get_fslocation_from_item(node: "Node") -> Tuple[Union[str, Path], Optional[i
* "obj": a Python object that the node wraps.
* "fspath": just a path
:rtype: A tuple of (str|Path, int) with filename and line number.
:rtype: A tuple of (str|Path, int) with filename and 0-based line number.
"""
# See Item.location.
location: Optional[Tuple[str, Optional[int], str]] = getattr(node, "location", None)
@@ -557,13 +558,14 @@ class Collector(Node):
return self._repr_failure_py(excinfo, style=tbstyle)
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
if hasattr(self, "path"):
traceback = excinfo.traceback
ntraceback = traceback.cut(path=self.path)
if ntraceback == traceback:
ntraceback = ntraceback.cut(excludepath=tracebackcutdir)
excinfo.traceback = ntraceback.filter()
return excinfo.traceback.filter(excinfo)
return excinfo.traceback
def _check_initialpaths_for_relpath(session: "Session", path: Path) -> Optional[str]:
@@ -755,7 +757,7 @@ class Item(Node):
Returns a tuple with three elements:
- The path of the test (default ``self.path``)
- The line number of the test (default ``None``)
- The 0-based line number of the test (default ``None``)
- A name of the test to be shown (default ``""``)
.. seealso:: :ref:`non-python tests`
@@ -764,6 +766,11 @@ class Item(Node):
@cached_property
def location(self) -> Tuple[str, Optional[int], str]:
"""
Returns a tuple of ``(relfspath, lineno, testname)`` for this item
where ``relfspath`` is file path relative to ``config.rootpath``
and lineno is a 0-based line number.
"""
location = self.reportinfo()
path = absolutepath(os.fspath(location[0]))
relfspath = self.session._node_location_to_relpath(path)

View File

@@ -157,8 +157,12 @@ def skip(
The message to show the user as reason for the skip.
:param allow_module_level:
Allows this function to be called at module level, skipping the rest
of the module. Defaults to False.
Allows this function to be called at module level.
Raising the skip exception at module level will stop
the execution of the module and prevent the collection of all tests in the module,
even those defined before the `skip` call.
Defaults to False.
:param msg:
Same as ``reason``, but deprecated. Will be removed in a future version, use ``reason`` instead.
@@ -219,7 +223,6 @@ def _resolve_msg_to_reason(
"""
__tracebackhide__ = True
if msg is not None:
if reason:
from pytest import UsageError

View File

@@ -6,6 +6,7 @@ import itertools
import os
import shutil
import sys
import types
import uuid
import warnings
from enum import Enum
@@ -26,8 +27,11 @@ from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
@@ -63,21 +67,33 @@ def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
return path.joinpath(".lock")
def on_rm_rf_error(func, path: str, exc, *, start_path: Path) -> bool:
def on_rm_rf_error(
func,
path: str,
excinfo: Union[
BaseException,
Tuple[Type[BaseException], BaseException, Optional[types.TracebackType]],
],
*,
start_path: Path,
) -> bool:
"""Handle known read-only errors during rmtree.
The returned value is used only by our own tests.
"""
exctype, excvalue = exc[:2]
if isinstance(excinfo, BaseException):
exc = excinfo
else:
exc = excinfo[1]
# Another process removed the file in the middle of the "rm_rf" (xdist for example).
# More context: https://github.com/pytest-dev/pytest/issues/5974#issuecomment-543799018
if isinstance(excvalue, FileNotFoundError):
if isinstance(exc, FileNotFoundError):
return False
if not isinstance(excvalue, PermissionError):
if not isinstance(exc, PermissionError):
warnings.warn(
PytestWarning(f"(rm_rf) error removing {path}\n{exctype}: {excvalue}")
PytestWarning(f"(rm_rf) error removing {path}\n{type(exc)}: {exc}")
)
return False
@@ -86,7 +102,7 @@ def on_rm_rf_error(func, path: str, exc, *, start_path: Path) -> bool:
warnings.warn(
PytestWarning(
"(rm_rf) unknown function {} when removing {}:\n{}: {}".format(
func, path, exctype, excvalue
func, path, type(exc), exc
)
)
)
@@ -149,7 +165,10 @@ def rm_rf(path: Path) -> None:
are read-only."""
path = ensure_extended_length_path(path)
onerror = partial(on_rm_rf_error, start_path=path)
shutil.rmtree(str(path), onerror=onerror)
if sys.version_info >= (3, 12):
shutil.rmtree(str(path), onexc=onerror)
else:
shutil.rmtree(str(path), onerror=onerror)
def find_prefixed(root: Path, prefix: str) -> Iterator[Path]:
@@ -335,15 +354,26 @@ def cleanup_candidates(root: Path, prefix: str, keep: int) -> Iterator[Path]:
yield path
def cleanup_dead_symlinks(root: Path):
for left_dir in root.iterdir():
if left_dir.is_symlink():
if not left_dir.resolve().exists():
left_dir.unlink()
def cleanup_numbered_dir(
root: Path, prefix: str, keep: int, consider_lock_dead_if_created_before: float
) -> None:
"""Cleanup for lock driven numbered directories."""
if not root.exists():
return
for path in cleanup_candidates(root, prefix, keep):
try_cleanup(path, consider_lock_dead_if_created_before)
for path in root.glob("garbage-*"):
try_cleanup(path, consider_lock_dead_if_created_before)
cleanup_dead_symlinks(root)
def make_numbered_dir_with_cleanup(
root: Path,
@@ -357,8 +387,10 @@ def make_numbered_dir_with_cleanup(
for i in range(10):
try:
p = make_numbered_dir(root, prefix, mode)
lock_path = create_cleanup_lock(p)
register_cleanup_lock_removal(lock_path)
# Only lock the current dir when keep is not 0
if keep != 0:
lock_path = create_cleanup_lock(p)
register_cleanup_lock_removal(lock_path)
except Exception as exc:
e = exc
else:
@@ -464,14 +496,14 @@ def import_path(
* `mode == ImportMode.prepend`: the directory containing the module (or package, taking
`__init__.py` files into account) will be put at the *start* of `sys.path` before
being imported with `__import__.
being imported with `importlib.import_module`.
* `mode == ImportMode.append`: same as `prepend`, but the directory will be appended
to the end of `sys.path`, if not already in `sys.path`.
* `mode == ImportMode.importlib`: uses more fine control mechanisms provided by `importlib`
to import the module, which avoids having to use `__import__` and muck with `sys.path`
at all. It effectively allows having same-named test modules in different places.
to import the module, which avoids having to muck with `sys.path` at all. It effectively
allows having same-named test modules in different places.
:param root:
Used as an anchor when mode == ImportMode.importlib to obtain
@@ -544,8 +576,8 @@ def import_path(
if module_file.endswith((".pyc", ".pyo")):
module_file = module_file[:-1]
if module_file.endswith(os.path.sep + "__init__.py"):
module_file = module_file[: -(len(os.path.sep + "__init__.py"))]
if module_file.endswith(os.sep + "__init__.py"):
module_file = module_file[: -(len(os.sep + "__init__.py"))]
try:
is_same = _is_same(str(path), module_file)
@@ -638,30 +670,38 @@ def resolve_package_path(path: Path) -> Optional[Path]:
return result
def scandir(path: Union[str, "os.PathLike[str]"]) -> List["os.DirEntry[str]"]:
"""Scan a directory recursively, in breadth-first order.
The returned entries are sorted.
"""
entries = []
with os.scandir(path) as s:
# Skip entries with symlink loops and other brokenness, so the caller
# doesn't have to deal with it.
for entry in s:
try:
entry.is_file()
except OSError as err:
if _ignore_error(err):
continue
raise
entries.append(entry)
entries.sort(key=lambda entry: entry.name)
return entries
def visit(
path: Union[str, "os.PathLike[str]"], recurse: Callable[["os.DirEntry[str]"], bool]
) -> Iterator["os.DirEntry[str]"]:
"""Walk a directory recursively, in breadth-first order.
The `recurse` predicate determines whether a directory is recursed.
Entries at each directory level are sorted.
"""
# Skip entries with symlink loops and other brokenness, so the caller doesn't
# have to deal with it.
entries = []
for entry in os.scandir(path):
try:
entry.is_file()
except OSError as err:
if _ignore_error(err):
continue
raise
entries.append(entry)
entries.sort(key=lambda entry: entry.name)
entries = scandir(path)
yield from entries
for entry in entries:
if entry.is_dir() and recurse(entry):
yield from visit(entry.path, recurse)

View File

@@ -6,6 +6,7 @@ import collections.abc
import contextlib
import gc
import importlib
import locale
import os
import platform
import re
@@ -129,6 +130,7 @@ class LsofFdLeakChecker:
stderr=subprocess.DEVNULL,
check=True,
text=True,
encoding=locale.getpreferredencoding(False),
).stdout
def isopen(line: str) -> bool:

View File

@@ -1,4 +1,5 @@
"""Python test discovery, setup and run of test functions."""
import dataclasses
import enum
import fnmatch
import inspect
@@ -27,8 +28,6 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import attr
import _pytest
from _pytest import fixtures
from _pytest import nodes
@@ -36,6 +35,7 @@ from _pytest._code import filter_traceback
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest._code.code import Traceback
from _pytest._io import TerminalWriter
from _pytest._io.saferepr import saferepr
from _pytest.compat import ascii_escaped
@@ -57,7 +57,6 @@ from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.deprecated import FSCOLLECTOR_GETHOOKPROXY_ISINITPATH
from _pytest.deprecated import INSTANCE_COLLECTOR
from _pytest.deprecated import NOSE_SUPPORT_METHOD
from _pytest.fixtures import FuncFixtureInfo
@@ -403,8 +402,8 @@ class PyCollector(PyobjMixin, nodes.Collector):
def istestfunction(self, obj: object, name: str) -> bool:
if self.funcnamefilter(name) or self.isnosetest(obj):
if isinstance(obj, staticmethod):
# staticmethods need to be unwrapped.
if isinstance(obj, (staticmethod, classmethod)):
# staticmethods and classmethods need to be unwrapped.
obj = safe_getattr(obj, "__func__", False)
return callable(obj) and fixtures.getfixturemarker(obj) is None
else:
@@ -668,7 +667,7 @@ class Package(Module):
config=None,
session=None,
nodeid=None,
path=Optional[Path],
path: Optional[Path] = None,
) -> None:
# NOTE: Could be just the following, but kept as-is for compat.
# nodes.FSCollector.__init__(self, fspath, parent=parent)
@@ -700,14 +699,6 @@ class Package(Module):
func = partial(_call_with_optional_argument, teardown_module, self.obj)
self.addfinalizer(func)
def gethookproxy(self, fspath: "os.PathLike[str]"):
warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
return self.session.gethookproxy(fspath)
def isinitpath(self, path: Union[str, "os.PathLike[str]"]) -> bool:
warnings.warn(FSCOLLECTOR_GETHOOKPROXY_ISINITPATH, stacklevel=2)
return self.session.isinitpath(path)
def _recurse(self, direntry: "os.DirEntry[str]") -> bool:
if direntry.name == "__pycache__":
return False
@@ -715,9 +706,6 @@ class Package(Module):
ihook = self.session.gethookproxy(fspath.parent)
if ihook.pytest_ignore_collect(collection_path=fspath, config=self.config):
return False
norecursepatterns = self.config.getini("norecursedirs")
if any(fnmatch_ex(pat, fspath) for pat in norecursepatterns):
return False
return True
def _collectfile(
@@ -746,11 +734,13 @@ class Package(Module):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
this_path = self.path.parent
init_module = this_path / "__init__.py"
if init_module.is_file() and path_matches_patterns(
init_module, self.config.getini("python_files")
# Always collect the __init__ first.
if self.session.isinitpath(self.path) or path_matches_patterns(
self.path, self.config.getini("python_files")
):
yield Module.from_parent(self, path=init_module)
yield Module.from_parent(self, path=self.path)
pkg_prefixes: Set[Path] = set()
for direntry in visit(str(this_path), recurse=self._recurse):
path = Path(direntry.path)
@@ -790,7 +780,8 @@ def _call_with_optional_argument(func, arg) -> None:
def _get_first_non_fixture_func(obj: object, names: Iterable[str]) -> Optional[object]:
"""Return the attribute from the given object to be used as a setup/teardown
xunit-style function, but only if not marked as a fixture to avoid calling it twice."""
xunit-style function, but only if not marked as a fixture to avoid calling it twice.
"""
for name in names:
meth: Optional[object] = getattr(obj, name, None)
if meth is not None and fixtures.getfixturemarker(meth) is None:
@@ -848,7 +839,7 @@ class Class(PyCollector):
other fixtures (#517).
"""
setup_class = _get_first_non_fixture_func(self.obj, ("setup_class",))
teardown_class = getattr(self.obj, "teardown_class", None)
teardown_class = _get_first_non_fixture_func(self.obj, ("teardown_class",))
if setup_class is None and teardown_class is None:
return
@@ -885,12 +876,12 @@ class Class(PyCollector):
emit_nose_setup_warning = True
setup_method = _get_first_non_fixture_func(self.obj, (setup_name,))
teardown_name = "teardown_method"
teardown_method = getattr(self.obj, teardown_name, None)
teardown_method = _get_first_non_fixture_func(self.obj, (teardown_name,))
emit_nose_teardown_warning = False
if teardown_method is None and has_nose:
teardown_name = "teardown"
emit_nose_teardown_warning = True
teardown_method = getattr(self.obj, teardown_name, None)
teardown_method = _get_first_non_fixture_func(self.obj, (teardown_name,))
if setup_method is None and teardown_method is None:
return
@@ -956,10 +947,20 @@ def hasnew(obj: object) -> bool:
@final
@attr.s(frozen=True, auto_attribs=True, slots=True)
@dataclasses.dataclass(frozen=True)
class IdMaker:
"""Make IDs for a parametrization."""
__slots__ = (
"argnames",
"parametersets",
"idfn",
"ids",
"config",
"nodeid",
"func_name",
)
# The argnames of the parametrization.
argnames: Sequence[str]
# The ParameterSets of the parametrization.
@@ -1109,7 +1110,7 @@ class IdMaker:
@final
@attr.s(frozen=True, slots=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class CallSpec2:
"""A planned parameterized invocation of a test function.
@@ -1120,18 +1121,18 @@ class CallSpec2:
# arg name -> arg value which will be passed to the parametrized test
# function (direct parameterization).
funcargs: Dict[str, object] = attr.Factory(dict)
funcargs: Dict[str, object] = dataclasses.field(default_factory=dict)
# arg name -> arg value which will be passed to a fixture of the same name
# (indirect parametrization).
params: Dict[str, object] = attr.Factory(dict)
params: Dict[str, object] = dataclasses.field(default_factory=dict)
# arg name -> arg index.
indices: Dict[str, int] = attr.Factory(dict)
indices: Dict[str, int] = dataclasses.field(default_factory=dict)
# Used for sorting parametrized resources.
_arg2scope: Dict[str, Scope] = attr.Factory(dict)
_arg2scope: Dict[str, Scope] = dataclasses.field(default_factory=dict)
# Parts which will be added to the item's name in `[..]` separated by "-".
_idlist: List[str] = attr.Factory(list)
_idlist: List[str] = dataclasses.field(default_factory=list)
# Marks which will be applied to the item.
marks: List[Mark] = attr.Factory(list)
marks: List[Mark] = dataclasses.field(default_factory=list)
def setmulti(
self,
@@ -1163,9 +1164,9 @@ class CallSpec2:
return CallSpec2(
funcargs=funcargs,
params=params,
arg2scope=arg2scope,
indices=indices,
idlist=[*self._idlist, id],
_arg2scope=arg2scope,
_idlist=[*self._idlist, id],
marks=[*self.marks, *normalize_mark_list(marks)],
)
@@ -1791,7 +1792,7 @@ class Function(PyobjMixin, nodes.Item):
def setup(self) -> None:
self._request._fillfixtures()
def _prunetraceback(self, excinfo: ExceptionInfo[BaseException]) -> None:
def _traceback_filter(self, excinfo: ExceptionInfo[BaseException]) -> Traceback:
if hasattr(self, "_obj") and not self.config.getoption("fulltrace", False):
code = _pytest._code.Code.from_function(get_real_func(self.obj))
path, firstlineno = code.path, code.firstlineno
@@ -1803,14 +1804,21 @@ class Function(PyobjMixin, nodes.Item):
ntraceback = ntraceback.filter(filter_traceback)
if not ntraceback:
ntraceback = traceback
ntraceback = ntraceback.filter(excinfo)
excinfo.traceback = ntraceback.filter()
# issue364: mark all but first and last frames to
# only show a single-line message for each frame.
if self.config.getoption("tbstyle", "auto") == "auto":
if len(excinfo.traceback) > 2:
for entry in excinfo.traceback[1:-1]:
entry.set_repr_style("short")
if len(ntraceback) > 2:
ntraceback = Traceback(
entry
if i == 0 or i == len(ntraceback) - 1
else entry.with_repr_style("short")
for i, entry in enumerate(ntraceback)
)
return ntraceback
return excinfo.traceback
# TODO: Type ignored -- breaks Liskov Substitution.
def repr_failure( # type: ignore[override]

View File

@@ -8,7 +8,7 @@ from types import TracebackType
from typing import Any
from typing import Callable
from typing import cast
from typing import Generic
from typing import ContextManager
from typing import List
from typing import Mapping
from typing import Optional
@@ -269,10 +269,16 @@ class ApproxMapping(ApproxBase):
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
max_rel_diff = max(
max_rel_diff,
abs((approx_value.expected - other_value) / approx_value.expected),
)
if approx_value.expected == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
)
different_ids.append(approx_key)
message_data = [
@@ -801,8 +807,8 @@ def raises( # noqa: F811
r"""Assert that a code block/function call raises an exception.
:param typing.Type[E] | typing.Tuple[typing.Type[E], ...] expected_exception:
The excpected exception type, or a tuple if one of multiple possible
exception types are excepted.
The expected exception type, or a tuple if one of multiple possible
exception types are expected.
:kwparam str | typing.Pattern[str] | None match:
If specified, a string containing a regular expression,
or a regular expression object, that is tested against the string
@@ -918,10 +924,10 @@ def raises( # noqa: F811
f"any special code to say 'this should never raise an exception'."
)
if isinstance(expected_exception, type):
excepted_exceptions: Tuple[Type[E], ...] = (expected_exception,)
expected_exceptions: Tuple[Type[E], ...] = (expected_exception,)
else:
excepted_exceptions = expected_exception
for exc in excepted_exceptions:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
@@ -944,11 +950,7 @@ def raises( # noqa: F811
try:
func(*args[1:], **kwargs)
except expected_exception as e:
# We just caught the exception - there is a traceback.
assert e.__traceback__ is not None
return _pytest._code.ExceptionInfo.from_exc_info(
(type(e), e, e.__traceback__)
)
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)
@@ -957,7 +959,7 @@ raises.Exception = fail.Exception # type: ignore
@final
class RaisesContext(Generic[E]):
class RaisesContext(ContextManager[_pytest._code.ExceptionInfo[E]]):
def __init__(
self,
expected_exception: Union[Type[E], Tuple[Type[E], ...]],

View File

@@ -1,3 +1,4 @@
import dataclasses
import os
from io import StringIO
from pprint import pprint
@@ -16,8 +17,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ExceptionRepr
@@ -263,6 +262,8 @@ class TestReport(BaseReport):
when: "Literal['setup', 'call', 'teardown']",
sections: Iterable[Tuple[str, str]] = (),
duration: float = 0,
start: float = 0,
stop: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None,
**extra,
) -> None:
@@ -272,6 +273,8 @@ class TestReport(BaseReport):
#: A (filesystempath, lineno, domaininfo) tuple indicating the
#: actual location of a test item - it might be different from the
#: collected one e.g. if a method is inherited from a different module.
#: The filesystempath may be relative to ``config.rootdir``.
#: The line number is 0-based.
self.location: Tuple[str, Optional[int], str] = location
#: A name -> value dictionary containing all keywords and
@@ -300,6 +303,11 @@ class TestReport(BaseReport):
#: Time it took to run just the test.
self.duration: float = duration
#: The system time when the call started, in seconds since the epoch.
self.start: float = start
#: The system time when the call ended, in seconds since the epoch.
self.stop: float = stop
self.__dict__.update(extra)
def __repr__(self) -> str:
@@ -318,6 +326,8 @@ class TestReport(BaseReport):
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
duration = call.duration
start = call.start
stop = call.stop
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
@@ -337,6 +347,9 @@ class TestReport(BaseReport):
elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped"
r = excinfo._getreprcrash()
assert (
r is not None
), "There should always be a traceback entry for skipping a test."
if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2]
assert line is not None
@@ -362,6 +375,8 @@ class TestReport(BaseReport):
when,
sections,
duration,
start,
stop,
user_properties=item.user_properties,
)
@@ -407,7 +422,9 @@ class CollectReport(BaseReport):
self.__dict__.update(extra)
@property
def location(self):
def location( # type:ignore[override]
self,
) -> Optional[Tuple[str, Optional[int], str]]:
return (self.fspath, None, self.fspath)
def __repr__(self) -> str:
@@ -459,15 +476,15 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
def serialize_repr_entry(
entry: Union[ReprEntry, ReprEntryNative]
) -> Dict[str, Any]:
data = attr.asdict(entry)
data = dataclasses.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
data[key] = attr.asdict(value)
data[key] = dataclasses.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback) -> Dict[str, Any]:
result = attr.asdict(reprtraceback)
result = dataclasses.asdict(reprtraceback)
result["reprentries"] = [
serialize_repr_entry(x) for x in reprtraceback.reprentries
]
@@ -477,7 +494,7 @@ def _report_to_json(report: BaseReport) -> Dict[str, Any]:
reprcrash: Optional[ReprFileLocation],
) -> Optional[Dict[str, Any]]:
if reprcrash is not None:
return attr.asdict(reprcrash)
return dataclasses.asdict(reprcrash)
else:
return None
@@ -573,7 +590,6 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
and "reprcrash" in reportdict["longrepr"]
and "reprtraceback" in reportdict["longrepr"]
):
reprtraceback = deserialize_repr_traceback(
reportdict["longrepr"]["reprtraceback"]
)
@@ -594,7 +610,10 @@ def _report_kwargs_from_json(reportdict: Dict[str, Any]) -> Dict[str, Any]:
ExceptionChainRepr, ReprExceptionInfo
] = ExceptionChainRepr(chain)
else:
exception_info = ReprExceptionInfo(reprtraceback, reprcrash)
exception_info = ReprExceptionInfo(
reprtraceback=reprtraceback,
reprcrash=reprcrash,
)
for section in reportdict["longrepr"]["sections"]:
exception_info.addsection(*section)

View File

@@ -1,5 +1,6 @@
"""Basic collect and runtest protocol implementations."""
import bdb
import dataclasses
import os
import sys
from typing import Callable
@@ -14,8 +15,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import attr
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
@@ -35,6 +34,9 @@ from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME
if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup
if TYPE_CHECKING:
from typing_extensions import Literal
@@ -265,7 +267,7 @@ TResult = TypeVar("TResult", covariant=True)
@final
@attr.s(repr=False, init=False, auto_attribs=True)
@dataclasses.dataclass
class CallInfo(Generic[TResult]):
"""Result/Exception info of a function invocation."""
@@ -512,22 +514,29 @@ class SetupState:
stack is torn down.
"""
needed_collectors = nextitem and nextitem.listchain() or []
exc = None
exceptions: List[BaseException] = []
while self.stack:
if list(self.stack.keys()) == needed_collectors[: len(self.stack)]:
break
node, (finalizers, _) = self.stack.popitem()
these_exceptions = []
while finalizers:
fin = finalizers.pop()
try:
fin()
except TEST_OUTCOME as e:
# XXX Only first exception will be seen by user,
# ideally all should be reported.
if exc is None:
exc = e
if exc:
raise exc
these_exceptions.append(e)
if len(these_exceptions) == 1:
exceptions.extend(these_exceptions)
elif these_exceptions:
msg = f"errors while tearing down {node!r}"
exceptions.append(BaseExceptionGroup(msg, these_exceptions[::-1]))
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise BaseExceptionGroup("errors during test teardown", exceptions[::-1])
if nextitem is None:
assert not self.stack

View File

@@ -1,4 +1,5 @@
"""Support for skip/xfail functions and markers."""
import dataclasses
import os
import platform
import sys
@@ -9,8 +10,6 @@ from typing import Optional
from typing import Tuple
from typing import Type
import attr
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
@@ -157,7 +156,7 @@ def evaluate_condition(item: Item, mark: Mark, condition: object) -> Tuple[bool,
return result, reason
@attr.s(slots=True, frozen=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class Skip:
"""The result of evaluate_skip_marks()."""
@@ -192,10 +191,12 @@ def evaluate_skip_marks(item: Item) -> Optional[Skip]:
return None
@attr.s(slots=True, frozen=True, auto_attribs=True)
@dataclasses.dataclass(frozen=True)
class Xfail:
"""The result of evaluate_xfail_marks()."""
__slots__ = ("reason", "run", "strict", "raises")
reason: str
run: bool
strict: bool

View File

@@ -48,6 +48,10 @@ def pytest_configure(config: Config) -> None:
def pytest_sessionfinish(session: Session) -> None:
if not session.config.getoption("stepwise"):
assert session.config.cache is not None
if hasattr(session.config, "workerinput"):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return
# Clear the list of failing tests if the plugin is not active.
session.config.cache.set(STEPWISE_CACHE_DIR, [])
@@ -119,4 +123,8 @@ class StepwisePlugin:
return None
def pytest_sessionfinish(self) -> None:
if hasattr(self.config, "workerinput"):
# Do not update cache if this process is a xdist worker to prevent
# race conditions (#10641).
return
self.cache.set(STEPWISE_CACHE_DIR, self.lastfailed)

View File

@@ -3,10 +3,12 @@
This is a good source for looking at the various reporting hooks.
"""
import argparse
import dataclasses
import datetime
import inspect
import platform
import sys
import textwrap
import warnings
from collections import Counter
from functools import partial
@@ -19,6 +21,7 @@ from typing import Dict
from typing import Generator
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
@@ -27,7 +30,6 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import attr
import pluggy
import _pytest._version
@@ -111,6 +113,26 @@ class MoreQuietAction(argparse.Action):
namespace.quiet = getattr(namespace, "quiet", 0) + 1
class TestShortLogReport(NamedTuple):
"""Used to store the test status result category, shortletter and verbose word.
For example ``"rerun", "R", ("RERUN", {"yellow": True})``.
:ivar category:
The class of result, for example ``“passed”``, ``“skipped”``, ``“error”``, or the empty string.
:ivar letter:
The short letter shown as testing progresses, for example ``"."``, ``"s"``, ``"E"``, or the empty string.
:ivar word:
Verbose word is shown as testing progresses in verbose mode, for example ``"PASSED"``, ``"SKIPPED"``,
``"ERROR"``, or the empty string.
"""
category: str
letter: str
word: Union[str, Tuple[str, Mapping[str, bool]]]
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "Reporting", after="general")
group._addoption(
@@ -178,6 +200,12 @@ def pytest_addoption(parser: Parser) -> None:
default=False,
help="Show locals in tracebacks (disabled by default)",
)
group._addoption(
"--no-showlocals",
action="store_false",
dest="showlocals",
help="Hide locals in tracebacks (negate --showlocals passed through addopts)",
)
group._addoption(
"--tb",
metavar="style",
@@ -223,7 +251,8 @@ def pytest_addoption(parser: Parser) -> None:
parser.addini(
"console_output_style",
help='Console output: "classic", or with additional progress information '
'("progress" (percentage) | "count")',
'("progress" (percentage) | "count" | "progress-even-when-capture-no" (forces '
"progress even when capture=no)",
default="progress",
)
@@ -281,7 +310,7 @@ def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]:
return outcome, letter, outcome.upper()
@attr.s(auto_attribs=True)
@dataclasses.dataclass
class WarningReport:
"""Simple structure to hold warnings information captured by ``pytest_warning_recorded``.
@@ -340,14 +369,19 @@ class TerminalReporter:
def _determine_show_progress_info(self) -> "Literal['progress', 'count', False]":
"""Return whether we should display progress information based on the current config."""
# do not show progress if we are not capturing output (#3038)
if self.config.getoption("capture", "no") == "no":
# do not show progress if we are not capturing output (#3038) unless explicitly
# overridden by progress-even-when-capture-no
if (
self.config.getoption("capture", "no") == "no"
and self.config.getini("console_output_style")
!= "progress-even-when-capture-no"
):
return False
# do not show progress if we are showing fixture setup/teardown
if self.config.getoption("setupshow", False):
return False
cfg: str = self.config.getini("console_output_style")
if cfg == "progress":
if cfg == "progress" or cfg == "progress-even-when-capture-no":
return "progress"
elif cfg == "count":
return "count"
@@ -414,6 +448,28 @@ class TerminalReporter:
self._tw.line()
self.currentfspath = None
def wrap_write(
self,
content: str,
*,
flush: bool = False,
margin: int = 8,
line_sep: str = "\n",
**markup: bool,
) -> None:
"""Wrap message with margin for progress info."""
width_of_current_line = self._tw.width_of_current_line
wrapped = line_sep.join(
textwrap.wrap(
" " * width_of_current_line + content,
width=self._screen_width - margin,
drop_whitespace=True,
replace_whitespace=False,
),
)
wrapped = wrapped[width_of_current_line:]
self._tw.write(wrapped, flush=flush, **markup)
def write(self, content: str, *, flush: bool = False, **markup: bool) -> None:
self._tw.write(content, flush=flush, **markup)
@@ -513,10 +569,11 @@ class TerminalReporter:
def pytest_runtest_logreport(self, report: TestReport) -> None:
self._tests_ran = True
rep = report
res: Tuple[
str, str, Union[str, Tuple[str, Mapping[str, bool]]]
] = self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
category, letter, word = res
res = TestShortLogReport(
*self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
)
category, letter, word = res.category, res.letter, res.word
if not isinstance(word, tuple):
markup = None
else:
@@ -560,7 +617,7 @@ class TerminalReporter:
formatted_reason = f" ({reason})"
if reason and formatted_reason is not None:
self._tw.write(formatted_reason)
self.wrap_write(formatted_reason)
if self._show_progress_info:
self._write_progress_information_filling_space()
else:
@@ -727,16 +784,14 @@ class TerminalReporter:
self.write_line(line)
def pytest_report_header(self, config: Config) -> List[str]:
line = "rootdir: %s" % config.rootpath
result = [f"rootdir: {config.rootpath}"]
if config.inipath:
line += ", configfile: " + bestrelpath(config.rootpath, config.inipath)
result.append("configfile: " + bestrelpath(config.rootpath, config.inipath))
if config.args_source == Config.ArgsSource.TESTPATHS:
testpaths: List[str] = config.getini("testpaths")
line += ", testpaths: {}".format(", ".join(testpaths))
result = [line]
result.append("testpaths: {}".format(", ".join(testpaths)))
plugininfo = config.pluginmanager.list_plugin_distinfo()
if plugininfo:

View File

@@ -1,40 +1,66 @@
"""Support for providing temporary directories to test functions."""
import dataclasses
import os
import re
import sys
import tempfile
from pathlib import Path
from shutil import rmtree
from typing import Any
from typing import Dict
from typing import Generator
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
import attr
from _pytest.nodes import Item
from _pytest.reports import CollectReport
from _pytest.stash import StashKey
if TYPE_CHECKING:
from typing_extensions import Literal
RetentionType = Literal["all", "failed", "none"]
from _pytest.config.argparsing import Parser
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from _pytest.compat import final
from .pathlib import cleanup_dead_symlinks
from _pytest.compat import final, get_user_id
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
tmppath_result_key = StashKey[Dict[str, bool]]()
@final
@attr.s(init=False)
@dataclasses.dataclass
class TempPathFactory:
"""Factory for temporary directories under the common base temp directory.
The base directory can be configured using the ``--basetemp`` option.
"""
_given_basetemp = attr.ib(type=Optional[Path])
_trace = attr.ib()
_basetemp = attr.ib(type=Optional[Path])
_given_basetemp: Optional[Path]
# pluggy TagTracerSub, not currently exposed, so Any.
_trace: Any
_basetemp: Optional[Path]
_retention_count: int
_retention_policy: "RetentionType"
def __init__(
self,
given_basetemp: Optional[Path],
retention_count: int,
retention_policy: "RetentionType",
trace,
basetemp: Optional[Path] = None,
*,
@@ -49,6 +75,8 @@ class TempPathFactory:
# Path.absolute() exists, but it is not public (see https://bugs.python.org/issue25012).
self._given_basetemp = Path(os.path.abspath(str(given_basetemp)))
self._trace = trace
self._retention_count = retention_count
self._retention_policy = retention_policy
self._basetemp = basetemp
@classmethod
@@ -63,9 +91,23 @@ class TempPathFactory:
:meta private:
"""
check_ispytest(_ispytest)
count = int(config.getini("tmp_path_retention_count"))
if count < 0:
raise ValueError(
f"tmp_path_retention_count must be >= 0. Current input: {count}."
)
policy = config.getini("tmp_path_retention_policy")
if policy not in ("all", "failed", "none"):
raise ValueError(
f"tmp_path_retention_policy must be either all, failed, none. Current input: {policy}."
)
return cls(
given_basetemp=config.option.basetemp,
trace=config.trace.get("tmpdir"),
retention_count=count,
retention_policy=policy,
_ispytest=True,
)
@@ -133,23 +175,23 @@ class TempPathFactory:
# Also, to keep things private, fixup any world-readable temp
# rootdir's permissions. Historically 0o755 was used, so we can't
# just error out on this, at least for a while.
if sys.platform != "win32":
uid = os.getuid()
uid = get_user_id()
if uid is not None:
rootdir_stat = rootdir.stat()
# getuid shouldn't fail, but cpython defines such a case.
# Let's hope for the best.
if uid != -1:
if rootdir_stat.st_uid != uid:
raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. "
"Fix this and try again."
)
if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
if rootdir_stat.st_uid != uid:
raise OSError(
f"The temporary directory {rootdir} is not owned by the current user. "
"Fix this and try again."
)
if (rootdir_stat.st_mode & 0o077) != 0:
os.chmod(rootdir, rootdir_stat.st_mode & ~0o077)
keep = self._retention_count
if self._retention_policy == "none":
keep = 0
basetemp = make_numbered_dir_with_cleanup(
prefix="pytest-",
root=rootdir,
keep=3,
keep=keep,
lock_timeout=LOCK_TIMEOUT,
mode=0o700,
)
@@ -184,6 +226,21 @@ def pytest_configure(config: Config) -> None:
mp.setattr(config, "_tmp_path_factory", _tmp_path_factory, raising=False)
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"tmp_path_retention_count",
help="How many sessions should we keep the `tmp_path` directories, according to `tmp_path_retention_policy`.",
default=3,
)
parser.addini(
"tmp_path_retention_policy",
help="Controls which directories created by the `tmp_path` fixture are kept around, based on test outcome. "
"(all/failed/none)",
default="all",
)
@fixture(scope="session")
def tmp_path_factory(request: FixtureRequest) -> TempPathFactory:
"""Return a :class:`pytest.TempPathFactory` instance for the test session."""
@@ -200,17 +257,68 @@ def _mk_tmp(request: FixtureRequest, factory: TempPathFactory) -> Path:
@fixture
def tmp_path(request: FixtureRequest, tmp_path_factory: TempPathFactory) -> Path:
def tmp_path(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Generator[Path, None, None]:
"""Return a temporary directory path object which is unique to each test
function invocation, created as a sub directory of the base temporary
directory.
By default, a new base temporary directory is created each test session,
and old bases are removed after 3 sessions, to aid in debugging. If
``--basetemp`` is used then it is cleared each session. See :ref:`base
and old bases are removed after 3 sessions, to aid in debugging.
This behavior can be configured with :confval:`tmp_path_retention_count` and
:confval:`tmp_path_retention_policy`.
If ``--basetemp`` is used then it is cleared each session. See :ref:`base
temporary directory`.
The returned object is a :class:`pathlib.Path` object.
"""
return _mk_tmp(request, tmp_path_factory)
path = _mk_tmp(request, tmp_path_factory)
yield path
# Remove the tmpdir if the policy is "failed" and the test passed.
tmp_path_factory: TempPathFactory = request.session.config._tmp_path_factory # type: ignore
policy = tmp_path_factory._retention_policy
result_dict = request.node.stash[tmppath_result_key]
if policy == "failed" and result_dict.get("call", True):
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it.
rmtree(path, ignore_errors=True)
del request.node.stash[tmppath_result_key]
def pytest_sessionfinish(session, exitstatus: Union[int, ExitCode]):
"""After each session, remove base directory if all the tests passed,
the policy is "failed", and the basetemp is not specified by a user.
"""
tmp_path_factory: TempPathFactory = session.config._tmp_path_factory
basetemp = tmp_path_factory._basetemp
if basetemp is None:
return
policy = tmp_path_factory._retention_policy
if (
exitstatus == 0
and policy == "failed"
and tmp_path_factory._given_basetemp is None
):
if basetemp.is_dir():
# We do a "best effort" to remove files, but it might not be possible due to some leaked resource,
# permissions, etc, in which case we ignore it.
rmtree(basetemp, ignore_errors=True)
# Remove dead symlinks.
if basetemp.is_dir():
cleanup_dead_symlinks(basetemp)
@hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_makereport(item: Item, call):
outcome = yield
result: CollectReport = outcome.get_result()
empty: Dict[str, bool] = {}
item.stash.setdefault(tmppath_result_key, empty)[result.when] = result.passed

View File

@@ -298,6 +298,9 @@ class TestCaseFunction(Function):
def stopTest(self, testcase: "unittest.TestCase") -> None:
pass
def addDuration(self, testcase: "unittest.TestCase", elapsed: float) -> None:
pass
def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
@@ -331,15 +334,16 @@ class TestCaseFunction(Function):
finally:
delattr(self._testcase, self.name)
def _prunetraceback(
def _traceback_filter(
self, excinfo: _pytest._code.ExceptionInfo[BaseException]
) -> None:
super()._prunetraceback(excinfo)
traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest")
) -> _pytest._code.Traceback:
traceback = super()._traceback_filter(excinfo)
ntraceback = traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest"),
)
if traceback:
excinfo.traceback = traceback
if not ntraceback:
ntraceback = traceback
return ntraceback
@hookimpl(tryfirst=True)

View File

@@ -1,3 +1,4 @@
import dataclasses
import inspect
import warnings
from types import FunctionType
@@ -6,8 +7,6 @@ from typing import Generic
from typing import Type
from typing import TypeVar
import attr
from _pytest.compat import final
@@ -57,6 +56,12 @@ class PytestRemovedIn8Warning(PytestDeprecationWarning):
__module__ = "pytest"
class PytestRemovedIn9Warning(PytestDeprecationWarning):
"""Warning class for features that will be removed in pytest 9."""
__module__ = "pytest"
class PytestReturnNotNoneWarning(PytestRemovedIn8Warning):
"""Warning emitted when a test function is returning value other than None."""
@@ -130,7 +135,7 @@ _W = TypeVar("_W", bound=PytestWarning)
@final
@attr.s(auto_attribs=True)
@dataclasses.dataclass
class UnformattedWarning(Generic[_W]):
"""A warning meant to be formatted during runtime.
@@ -150,7 +155,7 @@ def warn_explicit_for(method: FunctionType, message: PytestWarning) -> None:
"""
Issue the warning :param:`message` for the definition of the given :param:`method`
this helps to log warnigns for functions defined prior to finding an issue with them
this helps to log warnings for functions defined prior to finding an issue with them
(like hook wrappers being marked in a legacy mechanism)
"""
lineno = method.__code__.co_firstlineno

View File

@@ -49,6 +49,8 @@ def catch_warnings_for_item(
warnings.filterwarnings("always", category=DeprecationWarning)
warnings.filterwarnings("always", category=PendingDeprecationWarning)
warnings.filterwarnings("error", category=pytest.PytestRemovedIn8Warning)
apply_warning_filters(config_filters, cmdline_filters)
# apply filters from "filterwarnings" marks

10
src/py.py Normal file
View File

@@ -0,0 +1,10 @@
# shim for pylib going away
# if pylib is installed this file will get skipped
# (`py/__init__.py` has higher precedence)
import sys
import _pytest._py.error as error
import _pytest._py.path as path
sys.modules["py.error"] = error
sys.modules["py.path"] = path

View File

@@ -62,6 +62,7 @@ from _pytest.reports import TestReport
from _pytest.runner import CallInfo
from _pytest.stash import Stash
from _pytest.stash import StashKey
from _pytest.terminal import TestShortLogReport
from _pytest.tmpdir import TempPathFactory
from _pytest.warning_types import PytestAssertRewriteWarning
from _pytest.warning_types import PytestCacheWarning
@@ -70,6 +71,7 @@ from _pytest.warning_types import PytestConfigWarning
from _pytest.warning_types import PytestDeprecationWarning
from _pytest.warning_types import PytestExperimentalApiWarning
from _pytest.warning_types import PytestRemovedIn8Warning
from _pytest.warning_types import PytestRemovedIn9Warning
from _pytest.warning_types import PytestReturnNotNoneWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning
from _pytest.warning_types import PytestUnhandledThreadExceptionWarning
@@ -130,6 +132,7 @@ __all__ = [
"PytestDeprecationWarning",
"PytestExperimentalApiWarning",
"PytestRemovedIn8Warning",
"PytestRemovedIn9Warning",
"PytestReturnNotNoneWarning",
"Pytester",
"PytestPluginManager",
@@ -152,6 +155,7 @@ __all__ = [
"TempPathFactory",
"Testdir",
"TestReport",
"TestShortLogReport",
"UsageError",
"WarningsRecorder",
"warns",