Merge branch 'main' into patch-1

This commit is contained in:
Zac Hatfield-Dodds
2023-07-04 10:17:06 -07:00
committed by GitHub
65 changed files with 659 additions and 570 deletions

View File

@@ -17,18 +17,21 @@ from typing import Any
from typing import Callable
from typing import ClassVar
from typing import Dict
from typing import Final
from typing import final
from typing import Generic
from typing import Iterable
from typing import List
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import SupportsIndex
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
@@ -42,22 +45,16 @@ from _pytest._code.source import Source
from _pytest._io import TerminalWriter
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
from _pytest.compat import final
from _pytest.compat import get_real_func
from _pytest.deprecated import check_ispytest
from _pytest.pathlib import absolutepath
from _pytest.pathlib import bestrelpath
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
from typing_extensions import SupportsIndex
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
if sys.version_info[:2] < (3, 11):
from exceptiongroup import BaseExceptionGroup
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
class Code:
"""Wrapper around Python code objects."""
@@ -633,7 +630,7 @@ class ExceptionInfo(Generic[E]):
def getrepr(
self,
showlocals: bool = False,
style: "_TracebackStyle" = "long",
style: _TracebackStyle = "long",
abspath: bool = False,
tbfilter: Union[
bool, Callable[["ExceptionInfo[BaseException]"], Traceback]
@@ -725,7 +722,7 @@ class FormattedExcinfo:
fail_marker: ClassVar = "E"
showlocals: bool = False
style: "_TracebackStyle" = "long"
style: _TracebackStyle = "long"
abspath: bool = True
tbfilter: Union[bool, Callable[[ExceptionInfo[BaseException]], Traceback]] = True
funcargs: bool = False
@@ -1090,7 +1087,7 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr):
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]]
extraline: Optional[str]
style: "_TracebackStyle"
style: _TracebackStyle
entrysep: ClassVar = "_ "
@@ -1124,7 +1121,7 @@ class ReprTracebackNative(ReprTraceback):
class ReprEntryNative(TerminalRepr):
lines: Sequence[str]
style: ClassVar["_TracebackStyle"] = "native"
style: ClassVar[_TracebackStyle] = "native"
def toterminal(self, tw: TerminalWriter) -> None:
tw.write("".join(self.lines))
@@ -1136,7 +1133,7 @@ class ReprEntry(TerminalRepr):
reprfuncargs: Optional["ReprFuncArgs"]
reprlocals: Optional["ReprLocals"]
reprfileloc: Optional["ReprFileLocation"]
style: "_TracebackStyle"
style: _TracebackStyle
def _write_entry_lines(self, tw: TerminalWriter) -> None:
"""Write the source code portions of a list of traceback entries with syntax highlighting.

View File

@@ -149,8 +149,7 @@ def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[i
values: List[int] = []
for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
# Before Python 3.8, the lineno of a decorated class or function pointed at the decorator.
# Since Python 3.8, the lineno points to the class/def, so need to include the decorators.
# The lineno points to the class/def, so need to include the decorators.
if isinstance(x, (ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef)):
for d in x.decorator_list:
values.append(d.lineno - 1)

View File

@@ -2,12 +2,12 @@
import os
import shutil
import sys
from typing import final
from typing import Optional
from typing import Sequence
from typing import TextIO
from .wcwidth import wcswidth
from _pytest.compat import final
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py.

View File

@@ -44,17 +44,6 @@ from _pytest.stash import StashKey
if TYPE_CHECKING:
from _pytest.assertion import AssertionState
if sys.version_info >= (3, 8):
namedExpr = ast.NamedExpr
astNameConstant = ast.Constant
astStr = ast.Constant
astNum = ast.Constant
else:
namedExpr = ast.Expr
astNameConstant = ast.NameConstant
astStr = ast.Str
astNum = ast.Num
assertstate_key = StashKey["AssertionState"]()
@@ -686,14 +675,10 @@ class AssertionRewriter(ast.NodeVisitor):
if (
expect_docstring
and isinstance(item, ast.Expr)
and isinstance(item.value, astStr)
and isinstance(item.value, ast.Constant)
and isinstance(item.value.value, str)
):
if sys.version_info >= (3, 8):
doc = item.value.value
else:
doc = item.value.s
if not isinstance(doc, str):
return
doc = item.value.value
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
@@ -825,7 +810,7 @@ class AssertionRewriter(ast.NodeVisitor):
current = self.stack.pop()
if self.stack:
self.explanation_specifiers = self.stack[-1]
keys = [astStr(key) for key in current.keys()]
keys = [ast.Constant(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
@@ -879,16 +864,16 @@ class AssertionRewriter(ast.NodeVisitor):
negation = ast.UnaryOp(ast.Not(), top_condition)
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
msg = self.pop_format_context(astStr(explanation))
msg = self.pop_format_context(ast.Constant(explanation))
# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = astStr("")
assertmsg = ast.Constant("")
gluestr = "assert "
err_explanation = ast.BinOp(astStr(gluestr), ast.Add(), msg)
err_explanation = ast.BinOp(ast.Constant(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
@@ -904,8 +889,8 @@ class AssertionRewriter(ast.NodeVisitor):
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
astNum(assert_.lineno),
astStr(orig),
ast.Constant(assert_.lineno),
ast.Constant(orig),
fmt_pass,
)
)
@@ -924,7 +909,7 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, astNameConstant(None))
clear_format = ast.Assign(variables, ast.Constant(None))
self.statements.append(clear_format)
else: # Original assertion rewriting
@@ -935,9 +920,9 @@ class AssertionRewriter(ast.NodeVisitor):
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = astStr("")
assertmsg = ast.Constant("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), astStr(explanation))
template = ast.BinOp(assertmsg, ast.Add(), ast.Constant(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
@@ -949,7 +934,7 @@ class AssertionRewriter(ast.NodeVisitor):
# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, astNameConstant(None))
clear = ast.Assign(variables, ast.Constant(None))
self.statements.append(clear)
# Fix locations (line numbers/column offsets).
for stmt in self.statements:
@@ -957,26 +942,26 @@ class AssertionRewriter(ast.NodeVisitor):
ast.copy_location(node, assert_)
return self.statements
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
inlocs = ast.Compare(astStr(target_id), [ast.In()], [locs])
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), astStr(target_id))
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
return name, self.explanation_param(expr)
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(astStr(name.id), [ast.In()], [locs])
inlocs = ast.Compare(ast.Constant(name.id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), astStr(name.id))
expr = ast.IfExp(test, self.display(name), ast.Constant(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
@@ -995,10 +980,10 @@ class AssertionRewriter(ast.NodeVisitor):
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
# Check if the left operand is a namedExpr and the value has already been visited
# Check if the left operand is a ast.NamedExpr and the value has already been visited
if (
isinstance(v, ast.Compare)
and isinstance(v.left, namedExpr)
and isinstance(v.left, ast.NamedExpr)
and v.left.target.id
in [
ast_expr.id
@@ -1014,7 +999,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(astStr(expl))
expl_format = self.pop_format_context(ast.Constant(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
@@ -1026,7 +1011,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.statements = body = inner
self.statements = save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, astNum(is_or))
expl_template = self.helper("_format_boolop", expl_list, ast.Constant(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
@@ -1100,7 +1085,7 @@ class AssertionRewriter(ast.NodeVisitor):
comp.left = self.variables_overwrite[
comp.left.id
] # type:ignore[assignment]
if isinstance(comp.left, namedExpr):
if isinstance(comp.left, ast.NamedExpr):
self.variables_overwrite[
comp.left.target.id
] = comp.left # type:ignore[assignment]
@@ -1116,7 +1101,7 @@ class AssertionRewriter(ast.NodeVisitor):
results = [left_res]
for i, op, next_operand in it:
if (
isinstance(next_operand, namedExpr)
isinstance(next_operand, ast.NamedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
@@ -1129,9 +1114,9 @@ class AssertionRewriter(ast.NodeVisitor):
next_expl = f"({next_expl})"
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(astStr(sym))
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
expls.append(astStr(expl))
expls.append(ast.Constant(expl))
res_expr = ast.Compare(left_res, [op], [next_res])
self.statements.append(ast.Assign([store_names[i]], res_expr))
left_res, left_expl = next_res, next_expl
@@ -1175,7 +1160,7 @@ def try_makedirs(cache_dir: Path) -> bool:
def get_cache_dir(file_path: Path) -> Path:
"""Return the cache directory to write .pyc files for the given .py file path."""
if sys.version_info >= (3, 8) and sys.pycache_prefix:
if sys.pycache_prefix:
# given:
# prefix = '/tmp/pycs'
# path = '/home/user/proj/test_app.py'

View File

@@ -6,6 +6,7 @@ import json
import os
from pathlib import Path
from typing import Dict
from typing import final
from typing import Generator
from typing import Iterable
from typing import List
@@ -18,7 +19,6 @@ from .pathlib import rm_rf
from .reports import CollectReport
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl

View File

@@ -11,11 +11,14 @@ from types import TracebackType
from typing import Any
from typing import AnyStr
from typing import BinaryIO
from typing import Final
from typing import final
from typing import Generator
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import NamedTuple
from typing import Optional
from typing import TextIO
@@ -24,7 +27,6 @@ from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from _pytest.compat import final
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
@@ -35,11 +37,7 @@ from _pytest.nodes import Collector
from _pytest.nodes import File
from _pytest.nodes import Item
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
def pytest_addoption(parser: Parser) -> None:
@@ -687,7 +685,7 @@ class MultiCapture(Generic[AnyStr]):
return CaptureResult(out, err) # type: ignore[arg-type]
def _get_multicapture(method: "_CaptureMethod") -> MultiCapture[str]:
def _get_multicapture(method: _CaptureMethod) -> MultiCapture[str]:
if method == "fd":
return MultiCapture(in_=FDCapture(0), out=FDCapture(1), err=FDCapture(2))
elif method == "sys":
@@ -723,7 +721,7 @@ class CaptureManager:
needed to ensure the fixtures take precedence over the global capture.
"""
def __init__(self, method: "_CaptureMethod") -> None:
def __init__(self, method: _CaptureMethod) -> None:
self._method: Final = method
self._global_capturing: Optional[MultiCapture[str]] = None
self._capture_fixture: Optional[CaptureFixture[Any]] = None

View File

@@ -12,26 +12,12 @@ from inspect import signature
from pathlib import Path
from typing import Any
from typing import Callable
from typing import Generic
from typing import Final
from typing import NoReturn
from typing import TYPE_CHECKING
from typing import TypeVar
import py
# fmt: off
# Workaround for https://github.com/sphinx-doc/sphinx/issues/10351.
# If `overload` is imported from `compat` instead of from `typing`,
# Sphinx doesn't recognize it as `overload` and the API docs for
# overloaded functions look good again. But type checkers handle
# it fine.
# fmt: on
if True:
from typing import overload as overload
if TYPE_CHECKING:
from typing_extensions import Final
_T = TypeVar("_T")
_S = TypeVar("_S")
@@ -58,17 +44,6 @@ class NotSetType(enum.Enum):
NOTSET: Final = NotSetType.token # noqa: E305
# fmt: on
if sys.version_info >= (3, 8):
import importlib.metadata
importlib_metadata = importlib.metadata
else:
import importlib_metadata as importlib_metadata # noqa: F401
def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func))
def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func)
@@ -338,47 +313,6 @@ def safe_isclass(obj: object) -> bool:
return False
if TYPE_CHECKING:
if sys.version_info >= (3, 8):
from typing import final as final
else:
from typing_extensions import final as final
elif sys.version_info >= (3, 8):
from typing import final as final
else:
def final(f):
return f
if sys.version_info >= (3, 8):
from functools import cached_property as cached_property
else:
class cached_property(Generic[_S, _T]):
__slots__ = ("func", "__doc__")
def __init__(self, func: Callable[[_S], _T]) -> None:
self.func = func
self.__doc__ = func.__doc__
@overload
def __get__(
self, instance: None, owner: type[_S] | None = ...
) -> cached_property[_S, _T]:
...
@overload
def __get__(self, instance: _S, owner: type[_S] | None = ...) -> _T:
...
def __get__(self, instance, owner=None):
if instance is None:
return self
value = instance.__dict__[self.func.__name__] = self.func(instance)
return value
def get_user_id() -> int | None:
"""Return the current user id, or None if we cannot get it reliably on the current platform."""
# win32 does not have a getuid() function.

View File

@@ -5,6 +5,7 @@ import copy
import dataclasses
import enum
import glob
import importlib.metadata
import inspect
import os
import re
@@ -21,6 +22,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import final
from typing import Generator
from typing import IO
from typing import Iterable
@@ -48,8 +50,6 @@ from .findpaths import determine_setup
from _pytest._code import ExceptionInfo
from _pytest._code import filter_traceback
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.compat import importlib_metadata # type: ignore[attr-defined]
from _pytest.outcomes import fail
from _pytest.outcomes import Skipped
from _pytest.pathlib import absolutepath
@@ -257,7 +257,8 @@ default_plugins = essential_plugins + (
"logging",
"reports",
"python_path",
*(["unraisableexception", "threadexception"] if sys.version_info >= (3, 8) else []),
"unraisableexception",
"threadexception",
"faulthandler",
)
@@ -1216,7 +1217,7 @@ class Config:
package_files = (
str(file)
for dist in importlib_metadata.distributions()
for dist in importlib.metadata.distributions()
if any(ep.group == "pytest11" for ep in dist.entry_points)
for file in dist.files or []
)

View File

@@ -7,6 +7,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import final
from typing import List
from typing import Mapping
from typing import NoReturn
@@ -17,7 +18,6 @@ from typing import TYPE_CHECKING
from typing import Union
import _pytest._io
from _pytest.compat import final
from _pytest.config.exceptions import UsageError
from _pytest.deprecated import ARGUMENT_PERCENT_DEFAULT
from _pytest.deprecated import ARGUMENT_TYPE_STR

View File

@@ -1,4 +1,4 @@
from _pytest.compat import final
from typing import final
@final

View File

@@ -13,6 +13,7 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import final
from typing import Generator
from typing import Generic
from typing import Iterable
@@ -21,6 +22,7 @@ from typing import List
from typing import MutableMapping
from typing import NoReturn
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -35,10 +37,8 @@ from _pytest._code import getfslineno
from _pytest._code.code import FormattedExcinfo
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import _format_args
from _pytest.compat import _PytestWrapper
from _pytest.compat import assert_never
from _pytest.compat import final
from _pytest.compat import get_real_func
from _pytest.compat import get_real_method
from _pytest.compat import getfuncargnames
@@ -47,7 +47,6 @@ from _pytest.compat import getlocation
from _pytest.compat import is_generator
from _pytest.compat import NOTSET
from _pytest.compat import NotSetType
from _pytest.compat import overload
from _pytest.compat import safe_getattr
from _pytest.config import _PluggyPlugin
from _pytest.config import Config
@@ -729,8 +728,10 @@ class FixtureRequest:
p = bestrelpath(session.path, fs)
else:
p = fs
args = _format_args(factory)
lines.append("%s:%d: def %s%s" % (p, lineno + 1, factory.__name__, args))
lines.append(
"%s:%d: def %s%s"
% (p, lineno + 1, factory.__name__, inspect.signature(factory))
)
return lines
def __repr__(self) -> str:

View File

@@ -3,6 +3,8 @@ import dataclasses
import shlex
import subprocess
from pathlib import Path
from typing import Final
from typing import final
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
@@ -11,7 +13,6 @@ from typing import Union
from iniconfig import SectionWrapper
from _pytest.cacheprovider import Cache
from _pytest.compat import final
from _pytest.compat import LEGACY_PATH
from _pytest.compat import legacy_path
from _pytest.config import Config
@@ -32,8 +33,6 @@ from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory
if TYPE_CHECKING:
from typing_extensions import Final
import pexpect

View File

@@ -13,6 +13,7 @@ from logging import LogRecord
from pathlib import Path
from typing import AbstractSet
from typing import Dict
from typing import final
from typing import Generator
from typing import List
from typing import Mapping
@@ -25,7 +26,6 @@ from typing import Union
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.compat import final
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer

View File

@@ -9,10 +9,12 @@ import sys
from pathlib import Path
from typing import Callable
from typing import Dict
from typing import final
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Sequence
from typing import Set
from typing import Tuple
@@ -22,8 +24,6 @@ from typing import Union
import _pytest._code
from _pytest import nodes
from _pytest.compat import final
from _pytest.compat import overload
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode

View File

@@ -18,7 +18,6 @@ import ast
import dataclasses
import enum
import re
import sys
import types
from typing import Callable
from typing import Iterator
@@ -27,12 +26,6 @@ from typing import NoReturn
from typing import Optional
from typing import Sequence
if sys.version_info >= (3, 8):
astNameConstant = ast.Constant
else:
astNameConstant = ast.NameConstant
__all__ = [
"Expression",
"ParseError",
@@ -138,7 +131,7 @@ IDENT_PREFIX = "$"
def expression(s: Scanner) -> ast.Expression:
if s.accept(TokenType.EOF):
ret: ast.expr = astNameConstant(False)
ret: ast.expr = ast.Constant(False)
else:
ret = expr(s)
s.accept(TokenType.EOF, reject=True)

View File

@@ -5,6 +5,7 @@ import warnings
from typing import Any
from typing import Callable
from typing import Collection
from typing import final
from typing import Iterable
from typing import Iterator
from typing import List
@@ -23,7 +24,6 @@ from typing import Union
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import final
from ..compat import NOTSET
from ..compat import NotSetType
from _pytest.config import Config
@@ -374,7 +374,9 @@ def get_unpacked_marks(
if not consider_mro:
mark_lists = [obj.__dict__.get("pytestmark", [])]
else:
mark_lists = [x.__dict__.get("pytestmark", []) for x in obj.__mro__]
mark_lists = [
x.__dict__.get("pytestmark", []) for x in reversed(obj.__mro__)
]
mark_list = []
for item in mark_lists:
if isinstance(item, list):

View File

@@ -5,6 +5,7 @@ import sys
import warnings
from contextlib import contextmanager
from typing import Any
from typing import final
from typing import Generator
from typing import List
from typing import Mapping
@@ -15,7 +16,6 @@ from typing import Tuple
from typing import TypeVar
from typing import Union
from _pytest.compat import final
from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning

View File

@@ -1,5 +1,6 @@
import os
import warnings
from functools import cached_property
from inspect import signature
from pathlib import Path
from typing import Any
@@ -23,7 +24,6 @@ from _pytest._code import getfslineno
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest._code.code import Traceback
from _pytest.compat import cached_property
from _pytest.compat import LEGACY_PATH
from _pytest.config import Config
from _pytest.config import ConftestImportFailure

View File

@@ -7,23 +7,12 @@ from typing import Callable
from typing import cast
from typing import NoReturn
from typing import Optional
from typing import Protocol
from typing import Type
from typing import TypeVar
from _pytest.deprecated import KEYWORD_MSG_ARG
TYPE_CHECKING = False # Avoid circular import through compat.
if TYPE_CHECKING:
from typing_extensions import Protocol
else:
# typing.Protocol is only available starting from Python 3.8. It is also
# available from typing_extensions, but we don't want a runtime dependency
# on that. So use a dummy runtime implementation.
from typing import Generic
Protocol = Generic
class OutcomeException(BaseException):
"""OutcomeException and its subclass instances indicate and contain info

View File

@@ -523,6 +523,8 @@ def import_path(
if mode is ImportMode.importlib:
module_name = module_name_from_path(path, root)
with contextlib.suppress(KeyError):
return sys.modules[module_name]
for meta_importer in sys.meta_path:
spec = meta_importer.find_spec(module_name, [str(path.parent)])
@@ -633,6 +635,9 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
otherwise "src.tests.test_foo" is not importable by ``__import__``.
"""
module_parts = module_name.split(".")
child_module: Union[ModuleType, None] = None
module: Union[ModuleType, None] = None
child_name: str = ""
while module_name:
if module_name not in modules:
try:
@@ -642,13 +647,22 @@ def insert_missing_modules(modules: Dict[str, ModuleType], module_name: str) ->
# ourselves to fall back to creating a dummy module.
if not sys.meta_path:
raise ModuleNotFoundError
importlib.import_module(module_name)
module = importlib.import_module(module_name)
except ModuleNotFoundError:
module = ModuleType(
module_name,
doc="Empty module created by pytest's importmode=importlib.",
)
else:
module = modules[module_name]
if child_module:
# Add child attribute to the parent that can reference the child
# modules.
if not hasattr(module, child_name):
setattr(module, child_name, child_module)
modules[module_name] = module
# Keep track of the child module while moving up the tree.
child_module, child_name = module, module_name.rpartition(".")[-1]
module_parts.pop(-1)
module_name = ".".join(module_parts)

View File

@@ -20,10 +20,13 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Final
from typing import final
from typing import Generator
from typing import IO
from typing import Iterable
from typing import List
from typing import Literal
from typing import Optional
from typing import overload
from typing import Sequence
@@ -40,7 +43,6 @@ from iniconfig import SectionWrapper
from _pytest import timing
from _pytest._code import Source
from _pytest.capture import _get_multicapture
from _pytest.compat import final
from _pytest.compat import NOTSET
from _pytest.compat import NotSetType
from _pytest.config import _PluggyPlugin
@@ -68,11 +70,7 @@ from _pytest.reports import TestReport
from _pytest.tmpdir import TempPathFactory
from _pytest.warning_types import PytestWarning
if TYPE_CHECKING:
from typing_extensions import Final
from typing_extensions import Literal
import pexpect

View File

@@ -15,6 +15,7 @@ from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import final
from typing import Generator
from typing import Iterable
from typing import Iterator
@@ -40,7 +41,6 @@ from _pytest._io import TerminalWriter
from _pytest._io.saferepr import saferepr
from _pytest.compat import ascii_escaped
from _pytest.compat import assert_never
from _pytest.compat import final
from _pytest.compat import get_default_arg_names
from _pytest.compat import get_real_func
from _pytest.compat import getimfunc

View File

@@ -9,9 +9,11 @@ from typing import Any
from typing import Callable
from typing import cast
from typing import ContextManager
from typing import final
from typing import List
from typing import Mapping
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Sequence
from typing import Tuple
@@ -20,17 +22,14 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
import _pytest._code
from _pytest.compat import STRING_TYPES
from _pytest.outcomes import fail
if TYPE_CHECKING:
from numpy import ndarray
import _pytest._code
from _pytest.compat import final
from _pytest.compat import STRING_TYPES
from _pytest.compat import overload
from _pytest.outcomes import fail
def _non_numeric_type_error(value, at: Optional[str]) -> TypeError:
at_str = f" at {at}" if at else ""
return TypeError(

View File

@@ -5,18 +5,18 @@ from pprint import pformat
from types import TracebackType
from typing import Any
from typing import Callable
from typing import final
from typing import Generator
from typing import Iterator
from typing import List
from typing import Optional
from typing import overload
from typing import Pattern
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from _pytest.compat import final
from _pytest.compat import overload
from _pytest.deprecated import check_ispytest
from _pytest.deprecated import WARNS_NONE_ARG
from _pytest.fixtures import fixture
@@ -117,10 +117,10 @@ def warns( # noqa: F811
warning of that class or classes.
This helper produces a list of :class:`warnings.WarningMessage` objects, one for
each warning raised (regardless of whether it is an ``expected_warning`` or not).
each warning emitted (regardless of whether it is an ``expected_warning`` or not).
Since pytest 8.0, unmatched warnings are also re-emitted when the context closes.
This function can be used as a context manager, which will capture all the raised
warnings inside it::
This function can be used as a context manager::
>>> import pytest
>>> with pytest.warns(RuntimeWarning):
@@ -135,8 +135,9 @@ def warns( # noqa: F811
>>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("value must be 42", UserWarning)
>>> with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("this is not here", UserWarning)
>>> with pytest.warns(UserWarning): # catch re-emitted warning
... with pytest.warns(UserWarning, match=r'must be \d+$'):
... warnings.warn("this is not here", UserWarning)
Traceback (most recent call last):
...
Failed: DID NOT WARN. No warnings of type ...UserWarning... were emitted...
@@ -277,6 +278,12 @@ class WarningsChecker(WarningsRecorder):
self.expected_warning = expected_warning_tup
self.match_expr = match_expr
def matches(self, warning: warnings.WarningMessage) -> bool:
assert self.expected_warning is not None
return issubclass(warning.category, self.expected_warning) and bool(
self.match_expr is None or re.search(self.match_expr, str(warning.message))
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
@@ -287,27 +294,34 @@ class WarningsChecker(WarningsRecorder):
__tracebackhide__ = True
if self.expected_warning is None:
# nothing to do in this deprecated case, see WARNS_NONE_ARG above
return
def found_str():
return pformat([record.message for record in self], indent=2)
# only check if we're not currently handling an exception
if exc_type is None and exc_val is None and exc_tb is None:
if self.expected_warning is not None:
if not any(issubclass(r.category, self.expected_warning) for r in self):
__tracebackhide__ = True
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f"The list of emitted warnings is: {found_str()}."
try:
if not any(issubclass(w.category, self.expected_warning) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} were emitted.\n"
f" Emitted warnings: {found_str()}."
)
elif not any(self.matches(w) for w in self):
fail(
f"DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.\n"
f" Regex: {self.match_expr}\n"
f" Emitted warnings: {found_str()}."
)
finally:
# Whether or not any warnings matched, we want to re-emit all unmatched warnings.
for w in self:
if not self.matches(w):
warnings.warn_explicit(
str(w.message),
w.message.__class__, # type: ignore[arg-type]
w.filename,
w.lineno,
module=w.__module__,
source=w.source,
)
elif self.match_expr is not None:
for r in self:
if issubclass(r.category, self.expected_warning):
if re.compile(self.match_expr).search(str(r.message)):
break
else:
fail(
f"""\
DID NOT WARN. No warnings of type {self.expected_warning} matching the regex were emitted.
Regex: {self.match_expr}
Emitted warnings: {found_str()}"""
)

View File

@@ -5,6 +5,7 @@ from pprint import pprint
from typing import Any
from typing import cast
from typing import Dict
from typing import final
from typing import Iterable
from typing import Iterator
from typing import List
@@ -29,7 +30,6 @@ from _pytest._code.code import ReprLocals
from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import final
from _pytest.config import Config
from _pytest.nodes import Collector
from _pytest.nodes import Item

View File

@@ -6,6 +6,7 @@ import sys
from typing import Callable
from typing import cast
from typing import Dict
from typing import final
from typing import Generic
from typing import List
from typing import Optional
@@ -23,7 +24,6 @@ from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest.compat import final
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.nodes import Collector

View File

@@ -18,6 +18,7 @@ from typing import Callable
from typing import cast
from typing import ClassVar
from typing import Dict
from typing import final
from typing import Generator
from typing import List
from typing import Mapping
@@ -40,7 +41,6 @@ from _pytest._code.code import ExceptionRepr
from _pytest._io import TerminalWriter
from _pytest._io.wcwidth import wcswidth
from _pytest.assertion.util import running_on_ci
from _pytest.compat import final
from _pytest.config import _PluggyPlugin
from _pytest.config import Config
from _pytest.config import ExitCode

View File

@@ -7,38 +7,32 @@ from pathlib import Path
from shutil import rmtree
from typing import Any
from typing import Dict
from typing import final
from typing import Generator
from typing import Literal
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from _pytest.nodes import Item
from _pytest.reports import CollectReport
from _pytest.stash import StashKey
if TYPE_CHECKING:
from typing_extensions import Literal
RetentionType = Literal["all", "failed", "none"]
from _pytest.config.argparsing import Parser
from .pathlib import cleanup_dead_symlinks
from .pathlib import LOCK_TIMEOUT
from .pathlib import make_numbered_dir
from .pathlib import make_numbered_dir_with_cleanup
from .pathlib import rm_rf
from .pathlib import cleanup_dead_symlinks
from _pytest.compat import final, get_user_id
from _pytest.compat import get_user_id
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import check_ispytest
from _pytest.fixtures import fixture
from _pytest.fixtures import FixtureRequest
from _pytest.monkeypatch import MonkeyPatch
from _pytest.nodes import Item
from _pytest.reports import CollectReport
from _pytest.stash import StashKey
tmppath_result_key = StashKey[Dict[str, bool]]()
RetentionType = Literal["all", "failed", "none"]
@final

View File

@@ -3,12 +3,11 @@ import inspect
import warnings
from types import FunctionType
from typing import Any
from typing import final
from typing import Generic
from typing import Type
from typing import TypeVar
from _pytest.compat import final
class PytestWarning(UserWarning):
"""Base class for all warnings emitted by pytest."""