Merge pull request #7142 from bluetech/typing

Add more type annotations
This commit is contained in:
Ran Benita 2020-06-05 11:55:28 +03:00 committed by GitHub
commit cc283cfe79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 2697 additions and 1528 deletions

View File

@ -18,7 +18,7 @@ repos:
args: [--remove]
- id: check-yaml
- id: debug-statements
exclude: _pytest/debugging.py
exclude: _pytest/(debugging|hookspec).py
language_version: python3
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.2

View File

@ -91,6 +91,7 @@ formats = sdist.tgz,bdist_wheel
[mypy]
mypy_path = src
check_untyped_defs = True
ignore_missing_imports = True
no_implicit_optional = True
show_error_codes = True

View File

@ -15,6 +15,7 @@ from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Pattern
from typing import Sequence
@ -46,7 +47,7 @@ if TYPE_CHECKING:
from typing_extensions import Literal
from weakref import ReferenceType
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value"]
_TracebackStyle = Literal["long", "short", "line", "no", "native", "value", "auto"]
class Code:
@ -728,7 +729,7 @@ class FormattedExcinfo:
failindent = indentstr
return lines
def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
def repr_locals(self, locals: Mapping[str, object]) -> Optional["ReprLocals"]:
if self.showlocals:
lines = []
keys = [loc for loc in locals if loc[0] != "@"]

View File

@ -1,9 +1,12 @@
import pprint
import reprlib
from typing import Any
from typing import Dict
from typing import IO
from typing import Optional
def _try_repr_or_str(obj):
def _try_repr_or_str(obj: object) -> str:
try:
return repr(obj)
except (KeyboardInterrupt, SystemExit):
@ -12,7 +15,7 @@ def _try_repr_or_str(obj):
return '{}("{}")'.format(type(obj).__name__, obj)
def _format_repr_exception(exc: BaseException, obj: Any) -> str:
def _format_repr_exception(exc: BaseException, obj: object) -> str:
try:
exc_info = _try_repr_or_str(exc)
except (KeyboardInterrupt, SystemExit):
@ -42,7 +45,7 @@ class SafeRepr(reprlib.Repr):
self.maxstring = maxsize
self.maxsize = maxsize
def repr(self, x: Any) -> str:
def repr(self, x: object) -> str:
try:
s = super().repr(x)
except (KeyboardInterrupt, SystemExit):
@ -51,7 +54,7 @@ class SafeRepr(reprlib.Repr):
s = _format_repr_exception(exc, x)
return _ellipsize(s, self.maxsize)
def repr_instance(self, x: Any, level: int) -> str:
def repr_instance(self, x: object, level: int) -> str:
try:
s = repr(x)
except (KeyboardInterrupt, SystemExit):
@ -61,7 +64,7 @@ class SafeRepr(reprlib.Repr):
return _ellipsize(s, self.maxsize)
def safeformat(obj: Any) -> str:
def safeformat(obj: object) -> str:
"""return a pretty printed string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info.
@ -72,7 +75,7 @@ def safeformat(obj: Any) -> str:
return _format_repr_exception(exc, obj)
def saferepr(obj: Any, maxsize: int = 240) -> str:
def saferepr(obj: object, maxsize: int = 240) -> str:
"""return a size-limited safe repr-string for the given object.
Failing __repr__ functions of user instances will be represented
with a short exception info and 'saferepr' generally takes
@ -85,19 +88,39 @@ def saferepr(obj: Any, maxsize: int = 240) -> str:
class AlwaysDispatchingPrettyPrinter(pprint.PrettyPrinter):
"""PrettyPrinter that always dispatches (regardless of width)."""
def _format(self, object, stream, indent, allowance, context, level):
p = self._dispatch.get(type(object).__repr__, None)
def _format(
self,
object: object,
stream: IO[str],
indent: int,
allowance: int,
context: Dict[int, Any],
level: int,
) -> None:
# Type ignored because _dispatch is private.
p = self._dispatch.get(type(object).__repr__, None) # type: ignore[attr-defined] # noqa: F821
objid = id(object)
if objid in context or p is None:
return super()._format(object, stream, indent, allowance, context, level)
# Type ignored because _format is private.
super()._format( # type: ignore[misc] # noqa: F821
object, stream, indent, allowance, context, level,
)
return
context[objid] = 1
p(self, object, stream, indent, allowance, context, level + 1)
del context[objid]
def _pformat_dispatch(object, indent=1, width=80, depth=None, *, compact=False):
def _pformat_dispatch(
object: object,
indent: int = 1,
width: int = 80,
depth: Optional[int] = None,
*,
compact: bool = False
) -> str:
return AlwaysDispatchingPrettyPrinter(
indent=indent, width=width, depth=depth, compact=compact
).pformat(object)

View File

@ -3,6 +3,7 @@ support for presenting detailed information in failing assertions.
"""
import sys
from typing import Any
from typing import Generator
from typing import List
from typing import Optional
@ -13,12 +14,14 @@ from _pytest.assertion.rewrite import assertstate_key
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
if TYPE_CHECKING:
from _pytest.main import Session
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--assert",
@ -43,7 +46,7 @@ def pytest_addoption(parser):
)
def register_assert_rewrite(*names) -> None:
def register_assert_rewrite(*names: str) -> None:
"""Register one or more module names to be rewritten on import.
This function will make sure that this module or all modules inside
@ -72,27 +75,27 @@ def register_assert_rewrite(*names) -> None:
class DummyRewriteHook:
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names):
def mark_rewrite(self, *names: str) -> None:
pass
class AssertionState:
"""State for the assertion plugin."""
def __init__(self, config, mode):
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
def install_importhook(config):
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:
"""Try to install the rewrite hook, raise SystemError if it fails."""
config._store[assertstate_key] = AssertionState(config, "rewrite")
config._store[assertstate_key].hook = hook = rewrite.AssertionRewritingHook(config)
sys.meta_path.insert(0, hook)
config._store[assertstate_key].trace("installed rewrite import hook")
def undo():
def undo() -> None:
hook = config._store[assertstate_key].hook
if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook)
@ -112,7 +115,7 @@ def pytest_collection(session: "Session") -> None:
@hookimpl(tryfirst=True, hookwrapper=True)
def pytest_runtest_protocol(item):
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The rewrite module will use util._reprcompare if
@ -121,8 +124,7 @@ def pytest_runtest_protocol(item):
comparison for the test.
"""
def callbinrepr(op, left, right):
# type: (str, object, object) -> Optional[str]
def callbinrepr(op, left: object, right: object) -> Optional[str]:
"""Call the pytest_assertrepr_compare hook and prepare the result
This uses the first result from the hook and then ensures the
@ -155,7 +157,7 @@ def pytest_runtest_protocol(item):
if item.ihook.pytest_assertion_pass.get_hookimpls():
def call_assertion_pass_hook(lineno, orig, expl):
def call_assertion_pass_hook(lineno: int, orig: str, expl: str) -> None:
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)
@ -167,7 +169,7 @@ def pytest_runtest_protocol(item):
util._reprcompare, util._assertion_pass = saved_assert_hooks
def pytest_sessionfinish(session):
def pytest_sessionfinish(session: "Session") -> None:
assertstate = session.config._store.get(assertstate_key, None)
if assertstate:
if assertstate.hook is not None:

View File

@ -13,11 +13,15 @@ import struct
import sys
import tokenize
import types
from typing import Callable
from typing import Dict
from typing import IO
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
from _pytest._io.saferepr import saferepr
from _pytest._version import version
@ -27,6 +31,8 @@ from _pytest.assertion.util import ( # noqa: F401
)
from _pytest.compat import fspath
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.main import Session
from _pytest.pathlib import fnmatch_ex
from _pytest.pathlib import Path
from _pytest.pathlib import PurePath
@ -48,13 +54,13 @@ PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
def __init__(self, config):
def __init__(self, config: Config) -> None:
self.config = config
try:
self.fnpats = config.getini("python_files")
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session = None
self.session = None # type: Optional[Session]
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
@ -64,14 +70,19 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False
def set_session(self, session):
def set_session(self, session: Optional[Session]) -> None:
self.session = session
self._session_paths_checked = False
# Indirection so we can mock calls to find_spec originated from the hook during testing
_find_spec = importlib.machinery.PathFinder.find_spec
def find_spec(self, name, path=None, target=None):
def find_spec(
self,
name: str,
path: Optional[Sequence[Union[str, bytes]]] = None,
target: Optional[types.ModuleType] = None,
) -> Optional[importlib.machinery.ModuleSpec]:
if self._writing_pyc:
return None
state = self.config._store[assertstate_key]
@ -79,7 +90,8 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return None
state.trace("find_module called for: %s" % name)
spec = self._find_spec(name, path)
# Type ignored because mypy is confused about the `self` binding here.
spec = self._find_spec(name, path) # type: ignore
if (
# the import machinery could not find a file to import
spec is None
@ -108,10 +120,14 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
submodule_search_locations=spec.submodule_search_locations,
)
def create_module(self, spec):
def create_module(
self, spec: importlib.machinery.ModuleSpec
) -> Optional[types.ModuleType]:
return None # default behaviour is fine
def exec_module(self, module):
def exec_module(self, module: types.ModuleType) -> None:
assert module.__spec__ is not None
assert module.__spec__.origin is not None
fn = Path(module.__spec__.origin)
state = self.config._store[assertstate_key]
@ -151,7 +167,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("found cached rewritten pyc for {}".format(fn))
exec(co, module.__dict__)
def _early_rewrite_bailout(self, name, state):
def _early_rewrite_bailout(self, name: str, state: "AssertionState") -> bool:
"""This is a fast way to get out of rewriting modules.
Profiling has shown that the call to PathFinder.find_spec (inside of
@ -190,7 +206,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
state.trace("early skip of rewriting module: {}".format(name))
return True
def _should_rewrite(self, name, fn, state):
def _should_rewrite(self, name: str, fn: str, state: "AssertionState") -> bool:
# always rewrite conftest files
if os.path.basename(fn) == "conftest.py":
state.trace("rewriting conftest file: {!r}".format(fn))
@ -213,7 +229,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
return self._is_marked_for_rewrite(name, state)
def _is_marked_for_rewrite(self, name: str, state):
def _is_marked_for_rewrite(self, name: str, state: "AssertionState") -> bool:
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
@ -246,7 +262,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
self._must_rewrite.update(names)
self._marked_for_rewrite_cache.clear()
def _warn_already_imported(self, name):
def _warn_already_imported(self, name: str) -> None:
from _pytest.warning_types import PytestAssertRewriteWarning
from _pytest.warnings import _issue_warning_captured
@ -258,13 +274,15 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
stacklevel=5,
)
def get_data(self, pathname):
def get_data(self, pathname: Union[str, bytes]) -> bytes:
"""Optional PEP302 get_data API."""
with open(pathname, "rb") as f:
return f.read()
def _write_pyc_fp(fp, source_stat, co):
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason deviate.
@ -280,7 +298,12 @@ def _write_pyc_fp(fp, source_stat, co):
if sys.platform == "win32":
from atomicwrites import atomic_write
def _write_pyc(state, co, source_stat, pyc):
def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
try:
with atomic_write(fspath(pyc), mode="wb", overwrite=True) as fp:
_write_pyc_fp(fp, source_stat, co)
@ -295,7 +318,12 @@ if sys.platform == "win32":
else:
def _write_pyc(state, co, source_stat, pyc):
def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
pyc: Path,
) -> bool:
proc_pyc = "{}.{}".format(pyc, os.getpid())
try:
fp = open(proc_pyc, "wb")
@ -319,19 +347,21 @@ else:
return True
def _rewrite_test(fn, config):
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
"""read and rewrite *fn* and return the code object."""
fn = fspath(fn)
stat = os.stat(fn)
with open(fn, "rb") as f:
fn_ = fspath(fn)
stat = os.stat(fn_)
with open(fn_, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, source, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
tree = ast.parse(source, filename=fn_)
rewrite_asserts(tree, source, fn_, config)
co = compile(tree, fn_, "exec", dont_inherit=True)
return stat, co
def _read_pyc(source, pyc, trace=lambda x: None):
def _read_pyc(
source: Path, pyc: Path, trace: Callable[[str], None] = lambda x: None
) -> Optional[types.CodeType]:
"""Possibly read a pytest pyc containing rewritten code.
Return rewritten code if successful or None if not.
@ -368,12 +398,17 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co
def rewrite_asserts(mod, source, module_path=None, config=None):
def rewrite_asserts(
mod: ast.Module,
source: bytes,
module_path: Optional[str] = None,
config: Optional[Config] = None,
) -> None:
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config, source).run(mod)
def _saferepr(obj):
def _saferepr(obj: object) -> str:
"""Get a safe repr of an object for assertion error messages.
The assertion formatting (util.format_explanation()) requires
@ -387,7 +422,7 @@ def _saferepr(obj):
return saferepr(obj).replace("\n", "\\n")
def _format_assertmsg(obj):
def _format_assertmsg(obj: object) -> str:
"""Format the custom assertion message given.
For strings this simply replaces newlines with '\n~' so that
@ -410,7 +445,7 @@ def _format_assertmsg(obj):
return obj
def _should_repr_global_name(obj):
def _should_repr_global_name(obj: object) -> bool:
if callable(obj):
return False
@ -420,7 +455,7 @@ def _should_repr_global_name(obj):
return True
def _format_boolop(explanations, is_or):
def _format_boolop(explanations, is_or: bool):
explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
if isinstance(explanation, str):
return explanation.replace("%", "%%")
@ -428,8 +463,12 @@ def _format_boolop(explanations, is_or):
return explanation.replace(b"%", b"%%")
def _call_reprcompare(ops, results, expls, each_obj):
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
def _call_reprcompare(
ops: Sequence[str],
results: Sequence[bool],
expls: Sequence[str],
each_obj: Sequence[object],
) -> str:
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
@ -444,14 +483,12 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl
def _call_assertion_pass(lineno, orig, expl):
# type: (int, str, str) -> None
def _call_assertion_pass(lineno: int, orig: str, expl: str) -> None:
if util._assertion_pass is not None:
util._assertion_pass(lineno, orig, expl)
def _check_if_assertion_pass_impl():
# type: () -> bool
def _check_if_assertion_pass_impl() -> bool:
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
@ -609,7 +646,9 @@ class AssertionRewriter(ast.NodeVisitor):
"""
def __init__(self, module_path, config, source):
def __init__(
self, module_path: Optional[str], config: Optional[Config], source: bytes
) -> None:
super().__init__()
self.module_path = module_path
self.config = config
@ -622,7 +661,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.source = source
@functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self):
def _assert_expr_to_lineno(self) -> Dict[int, str]:
return _get_assertion_exprs(self.source)
def run(self, mod: ast.Module) -> None:
@ -691,38 +730,38 @@ class AssertionRewriter(ast.NodeVisitor):
nodes.append(field)
@staticmethod
def is_rewrite_disabled(docstring):
def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring
def variable(self):
def variable(self) -> str:
"""Get a new variable."""
# Use a character invalid in python identifiers to avoid clashing.
name = "@py_assert" + str(next(self.variable_counter))
self.variables.append(name)
return name
def assign(self, expr):
def assign(self, expr: ast.expr) -> ast.Name:
"""Give *expr* a name."""
name = self.variable()
self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
return ast.Name(name, ast.Load())
def display(self, expr):
def display(self, expr: ast.expr) -> ast.expr:
"""Call saferepr on the expression."""
return self.helper("_saferepr", expr)
def helper(self, name, *args):
def helper(self, name: str, *args: ast.expr) -> ast.expr:
"""Call a helper in this module."""
py_name = ast.Name("@pytest_ar", ast.Load())
attr = ast.Attribute(py_name, name, ast.Load())
return ast.Call(attr, list(args), [])
def builtin(self, name):
def builtin(self, name: str) -> ast.Attribute:
"""Return the builtin called *name*."""
builtin_name = ast.Name("@py_builtins", ast.Load())
return ast.Attribute(builtin_name, name, ast.Load())
def explanation_param(self, expr):
def explanation_param(self, expr: ast.expr) -> str:
"""Return a new named %-formatting placeholder for expr.
This creates a %-formatting placeholder for expr in the
@ -735,7 +774,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers[specifier] = expr
return "%(" + specifier + ")s"
def push_format_context(self):
def push_format_context(self) -> None:
"""Create a new formatting context.
The format context is used for when an explanation wants to
@ -749,10 +788,10 @@ class AssertionRewriter(ast.NodeVisitor):
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers)
def pop_format_context(self, expl_expr):
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
"""Format the %-formatted string with current format context.
The expl_expr should be an ast.Str instance constructed from
The expl_expr should be an str ast.expr instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
@ -770,13 +809,13 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())
def generic_visit(self, node):
def generic_visit(self, node: ast.AST) -> Tuple[ast.Name, str]:
"""Handle expressions we don't have custom code for."""
assert isinstance(node, ast.expr)
res = self.assign(node)
return res, self.explanation_param(self.display(res))
def visit_Assert(self, assert_):
def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
"""Return the AST statements to replace the ast.Assert instance.
This rewrites the test of an assertion to provide
@ -789,6 +828,8 @@ class AssertionRewriter(ast.NodeVisitor):
from _pytest.warning_types import PytestAssertRewriteWarning
import warnings
# TODO: This assert should not be needed.
assert self.module_path is not None
warnings.warn_explicit(
PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?"
@ -891,7 +932,7 @@ class AssertionRewriter(ast.NodeVisitor):
set_location(stmt, assert_.lineno, assert_.col_offset)
return self.statements
def visit_Name(self, name):
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
@ -901,7 +942,7 @@ class AssertionRewriter(ast.NodeVisitor):
expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
return name, self.explanation_param(expr)
def visit_BoolOp(self, boolop):
def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
res_var = self.variable()
expl_list = self.assign(ast.List([], ast.Load()))
app = ast.Attribute(expl_list, "append", ast.Load())
@ -936,13 +977,13 @@ class AssertionRewriter(ast.NodeVisitor):
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
def visit_UnaryOp(self, unary):
def visit_UnaryOp(self, unary: ast.UnaryOp) -> Tuple[ast.Name, str]:
pattern = UNARY_MAP[unary.op.__class__]
operand_res, operand_expl = self.visit(unary.operand)
res = self.assign(ast.UnaryOp(unary.op, operand_res))
return res, pattern % (operand_expl,)
def visit_BinOp(self, binop):
def visit_BinOp(self, binop: ast.BinOp) -> Tuple[ast.Name, str]:
symbol = BINOP_MAP[binop.op.__class__]
left_expr, left_expl = self.visit(binop.left)
right_expr, right_expl = self.visit(binop.right)
@ -950,7 +991,7 @@ class AssertionRewriter(ast.NodeVisitor):
res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
return res, explanation
def visit_Call(self, call):
def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
"""
visit `ast.Call` nodes
"""
@ -977,13 +1018,13 @@ class AssertionRewriter(ast.NodeVisitor):
outer_expl = "{}\n{{{} = {}\n}}".format(res_expl, res_expl, expl)
return res, outer_expl
def visit_Starred(self, starred):
def visit_Starred(self, starred: ast.Starred) -> Tuple[ast.Starred, str]:
# From Python 3.5, a Starred node can appear in a function call
res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl
def visit_Attribute(self, attr):
def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
value, value_expl = self.visit(attr.value)
@ -993,7 +1034,7 @@ class AssertionRewriter(ast.NodeVisitor):
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl
def visit_Compare(self, comp: ast.Compare):
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
@ -1032,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
return res, self.explanation_param(self.pop_format_context(expl_call))
def try_makedirs(cache_dir) -> bool:
def try_makedirs(cache_dir: Path) -> bool:
"""Attempts to create the given directory and sub-directories exist, returns True if
successful or it already exists"""
try:

View File

@ -5,13 +5,20 @@ Current default behaviour is to truncate assertion explanations at
~8 terminal lines, unless running in "-vv" mode or running on CI.
"""
import os
from typing import List
from typing import Optional
from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = 8 * 80
USAGE_MSG = "use '-vv' to show"
def truncate_if_required(explanation, item, max_length=None):
def truncate_if_required(
explanation: List[str], item: Item, max_length: Optional[int] = None
) -> List[str]:
"""
Truncate this assertion explanation if the given test item is eligible.
"""
@ -20,7 +27,7 @@ def truncate_if_required(explanation, item, max_length=None):
return explanation
def _should_truncate_item(item):
def _should_truncate_item(item: Item) -> bool:
"""
Whether or not this test item is eligible for truncation.
"""
@ -28,13 +35,17 @@ def _should_truncate_item(item):
return verbose < 2 and not _running_on_ci()
def _running_on_ci():
def _running_on_ci() -> bool:
"""Check if we're currently running on a CI system."""
env_vars = ["CI", "BUILD_NUMBER"]
return any(var in os.environ for var in env_vars)
def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
def _truncate_explanation(
input_lines: List[str],
max_lines: Optional[int] = None,
max_chars: Optional[int] = None,
) -> List[str]:
"""
Truncate given list of strings that makes up the assertion explanation.
@ -73,7 +84,7 @@ def _truncate_explanation(input_lines, max_lines=None, max_chars=None):
return truncated_explanation
def _truncate_by_char_count(input_lines, max_chars):
def _truncate_by_char_count(input_lines: List[str], max_chars: int) -> List[str]:
# Check if truncation required
if len("".join(input_lines)) <= max_chars:
return input_lines

View File

@ -8,9 +8,11 @@ import json
import os
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Set
from typing import Union
import attr
import py
@ -24,8 +26,13 @@ from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.compat import order_preserving_dict
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.python import Module
from _pytest.reports import TestReport
README_CONTENT = """\
# pytest cache directory #
@ -48,8 +55,8 @@ Signature: 8a477f597d28d172789f06886806bc55
@attr.s
class Cache:
_cachedir = attr.ib(repr=False)
_config = attr.ib(repr=False)
_cachedir = attr.ib(type=Path, repr=False)
_config = attr.ib(type=Config, repr=False)
# sub-directory under cache-dir for directories created by "makedir"
_CACHE_PREFIX_DIRS = "d"
@ -58,14 +65,14 @@ class Cache:
_CACHE_PREFIX_VALUES = "v"
@classmethod
def for_config(cls, config):
def for_config(cls, config: Config) -> "Cache":
cachedir = cls.cache_dir_from_config(config)
if config.getoption("cacheclear") and cachedir.is_dir():
cls.clear_cache(cachedir)
return cls(cachedir, config)
@classmethod
def clear_cache(cls, cachedir: Path):
def clear_cache(cls, cachedir: Path) -> None:
"""Clears the sub-directories used to hold cached directories and values."""
for prefix in (cls._CACHE_PREFIX_DIRS, cls._CACHE_PREFIX_VALUES):
d = cachedir / prefix
@ -73,10 +80,10 @@ class Cache:
rm_rf(d)
@staticmethod
def cache_dir_from_config(config):
def cache_dir_from_config(config: Config):
return resolve_from_str(config.getini("cache_dir"), config.rootdir)
def warn(self, fmt, **args):
def warn(self, fmt: str, **args: object) -> None:
import warnings
from _pytest.warning_types import PytestCacheWarning
@ -86,7 +93,7 @@ class Cache:
stacklevel=3,
)
def makedir(self, name):
def makedir(self, name: str) -> py.path.local:
""" return a directory path object with the given name. If the
directory does not yet exist, it will be created. You can use it
to manage files likes e. g. store/retrieve database
@ -96,14 +103,14 @@ class Cache:
Make sure the name contains your plugin or application
identifiers to prevent clashes with other cache users.
"""
name = Path(name)
if len(name.parts) > 1:
path = Path(name)
if len(path.parts) > 1:
raise ValueError("name is not allowed to contain path separators")
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, name)
res = self._cachedir.joinpath(self._CACHE_PREFIX_DIRS, path)
res.mkdir(exist_ok=True, parents=True)
return py.path.local(res)
def _getvaluepath(self, key):
def _getvaluepath(self, key: str) -> Path:
return self._cachedir.joinpath(self._CACHE_PREFIX_VALUES, Path(key))
def get(self, key, default):
@ -124,7 +131,7 @@ class Cache:
except (ValueError, OSError):
return default
def set(self, key, value):
def set(self, key, value) -> None:
""" save value for the given key.
:param key: must be a ``/`` separated value. Usually the first
@ -154,7 +161,7 @@ class Cache:
with f:
f.write(data)
def _ensure_supporting_files(self):
def _ensure_supporting_files(self) -> None:
"""Create supporting files in the cache dir that are not really part of the cache."""
readme_path = self._cachedir / "README.md"
readme_path.write_text(README_CONTENT)
@ -168,12 +175,12 @@ class Cache:
class LFPluginCollWrapper:
def __init__(self, lfplugin: "LFPlugin"):
def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
self._collected_at_least_one_failure = False
@pytest.hookimpl(hookwrapper=True)
def pytest_make_collect_report(self, collector) -> Generator:
def pytest_make_collect_report(self, collector: nodes.Collector) -> Generator:
if isinstance(collector, Session):
out = yield
res = out.get_result() # type: CollectReport
@ -216,11 +223,13 @@ class LFPluginCollWrapper:
class LFPluginCollSkipfiles:
def __init__(self, lfplugin: "LFPlugin"):
def __init__(self, lfplugin: "LFPlugin") -> None:
self.lfplugin = lfplugin
@pytest.hookimpl
def pytest_make_collect_report(self, collector) -> Optional[CollectReport]:
def pytest_make_collect_report(
self, collector: nodes.Collector
) -> Optional[CollectReport]:
if isinstance(collector, Module):
if Path(str(collector.fspath)) not in self.lfplugin._last_failed_paths:
self.lfplugin._skipped_files += 1
@ -258,17 +267,18 @@ class LFPlugin:
result = {rootpath / nodeid.split("::")[0] for nodeid in self.lastfailed}
return {x for x in result if x.exists()}
def pytest_report_collectionfinish(self):
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0:
return "run-last-failure: %s" % self._report_status
return None
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: TestReport) -> None:
if (report.when == "call" and report.passed) or report.skipped:
self.lastfailed.pop(report.nodeid, None)
elif report.failed:
self.lastfailed[report.nodeid] = True
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: CollectReport) -> None:
passed = report.outcome in ("passed", "skipped")
if passed:
if report.nodeid in self.lastfailed:
@ -329,11 +339,12 @@ class LFPlugin:
else:
self._report_status += "not deselecting items."
def pytest_sessionfinish(self, session):
def pytest_sessionfinish(self, session: Session) -> None:
config = self.config
if config.getoption("cacheshow") or hasattr(config, "slaveinput"):
return
assert config.cache is not None
saved_lastfailed = config.cache.get("cache/lastfailed", {})
if saved_lastfailed != self.lastfailed:
config.cache.set("cache/lastfailed", self.lastfailed)
@ -342,9 +353,10 @@ class LFPlugin:
class NFPlugin:
""" Plugin which implements the --nf (run new-first) option """
def __init__(self, config):
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.option.newfirst
assert config.cache is not None
self.cached_nodeids = set(config.cache.get("cache/nodeids", []))
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
@ -369,7 +381,7 @@ class NFPlugin:
else:
self.cached_nodeids.update(item.nodeid for item in items)
def _get_increasing_order(self, items):
def _get_increasing_order(self, items: Iterable[nodes.Item]) -> List[nodes.Item]:
return sorted(items, key=lambda item: item.fspath.mtime(), reverse=True)
def pytest_sessionfinish(self) -> None:
@ -379,10 +391,12 @@ class NFPlugin:
if config.getoption("collectonly"):
return
assert config.cache is not None
config.cache.set("cache/nodeids", sorted(self.cached_nodeids))
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--lf",
@ -440,22 +454,24 @@ def pytest_addoption(parser):
)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.cacheshow:
from _pytest.main import wrap_session
return wrap_session(config, cacheshow)
return None
@pytest.hookimpl(tryfirst=True)
def pytest_configure(config: Config) -> None:
config.cache = Cache.for_config(config)
# Type ignored: pending mechanism to store typed objects scoped to config.
config.cache = Cache.for_config(config) # type: ignore # noqa: F821
config.pluginmanager.register(LFPlugin(config), "lfplugin")
config.pluginmanager.register(NFPlugin(config), "nfplugin")
@pytest.fixture
def cache(request):
def cache(request: FixtureRequest) -> Cache:
"""
Return a cache object that can persist state between testing sessions.
@ -467,12 +483,14 @@ def cache(request):
Values can be any object handled by the json stdlib module.
"""
assert request.config.cache is not None
return request.config.cache
def pytest_report_header(config):
def pytest_report_header(config: Config) -> Optional[str]:
"""Display cachedir with --cache-show and if non-default."""
if config.option.verbose > 0 or config.getini("cache_dir") != ".pytest_cache":
assert config.cache is not None
cachedir = config.cache._cachedir
# TODO: evaluate generating upward relative paths
# starting with .., ../.. if sensible
@ -482,11 +500,14 @@ def pytest_report_header(config):
except ValueError:
displaypath = cachedir
return "cachedir: {}".format(displaypath)
return None
def cacheshow(config, session):
def cacheshow(config: Config, session: Session) -> int:
from pprint import pformat
assert config.cache is not None
tw = TerminalWriter()
tw.line("cachedir: " + str(config.cache._cachedir))
if not config.cache._cachedir.is_dir():

View File

@ -9,13 +9,19 @@ import os
import sys
from io import UnsupportedOperation
from tempfile import TemporaryFile
from typing import Generator
from typing import Optional
from typing import TextIO
from typing import Tuple
from typing import Union
import pytest
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import SubRequest
from _pytest.nodes import Collector
from _pytest.nodes import Item
if TYPE_CHECKING:
from typing_extensions import Literal
@ -23,7 +29,7 @@ if TYPE_CHECKING:
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"--capture",
@ -42,7 +48,7 @@ def pytest_addoption(parser):
)
def _colorama_workaround():
def _colorama_workaround() -> None:
"""
Ensure colorama is imported so that it attaches to the correct stdio
handles on Windows.
@ -58,7 +64,7 @@ def _colorama_workaround():
pass
def _readline_workaround():
def _readline_workaround() -> None:
"""
Ensure readline is imported so that it attaches to the correct stdio
handles on Windows.
@ -83,7 +89,7 @@ def _readline_workaround():
pass
def _py36_windowsconsoleio_workaround(stream):
def _py36_windowsconsoleio_workaround(stream: TextIO) -> None:
"""
Python 3.6 implemented unicode console handling for Windows. This works
by reading/writing to the raw console handle using
@ -198,7 +204,7 @@ class TeeCaptureIO(CaptureIO):
self._other = other
super().__init__()
def write(self, s) -> int:
def write(self, s: str) -> int:
super().write(s)
return self._other.write(s)
@ -218,13 +224,13 @@ class DontReadFromInput:
def __iter__(self):
return self
def fileno(self):
def fileno(self) -> int:
raise UnsupportedOperation("redirected stdin is pseudofile, has no fileno()")
def isatty(self):
def isatty(self) -> bool:
return False
def close(self):
def close(self) -> None:
pass
@property
@ -247,7 +253,7 @@ class SysCaptureBinary:
EMPTY_BUFFER = b""
def __init__(self, fd, tmpfile=None, *, tee=False):
def __init__(self, fd: int, tmpfile=None, *, tee: bool = False) -> None:
name = patchsysdict[fd]
self._old = getattr(sys, name)
self.name = name
@ -284,7 +290,7 @@ class SysCaptureBinary:
op, self._state, ", ".join(states)
)
def start(self):
def start(self) -> None:
self._assert_state("start", ("initialized",))
setattr(sys, self.name, self.tmpfile)
self._state = "started"
@ -297,7 +303,7 @@ class SysCaptureBinary:
self.tmpfile.truncate()
return res
def done(self):
def done(self) -> None:
self._assert_state("done", ("initialized", "started", "suspended", "done"))
if self._state == "done":
return
@ -306,19 +312,19 @@ class SysCaptureBinary:
self.tmpfile.close()
self._state = "done"
def suspend(self):
def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended"))
setattr(sys, self.name, self._old)
self._state = "suspended"
def resume(self):
def resume(self) -> None:
self._assert_state("resume", ("started", "suspended"))
if self._state == "started":
return
setattr(sys, self.name, self.tmpfile)
self._state = "started"
def writeorg(self, data):
def writeorg(self, data) -> None:
self._assert_state("writeorg", ("started", "suspended"))
self._old.flush()
self._old.buffer.write(data)
@ -348,7 +354,7 @@ class FDCaptureBinary:
EMPTY_BUFFER = b""
def __init__(self, targetfd):
def __init__(self, targetfd: int) -> None:
self.targetfd = targetfd
try:
@ -365,7 +371,9 @@ class FDCaptureBinary:
# Further complications are the need to support suspend() and the
# possibility of FD reuse (e.g. the tmpfile getting the very same
# target FD). The following approach is robust, I believe.
self.targetfd_invalid = os.open(os.devnull, os.O_RDWR)
self.targetfd_invalid = os.open(
os.devnull, os.O_RDWR
) # type: Optional[int]
os.dup2(self.targetfd_invalid, targetfd)
else:
self.targetfd_invalid = None
@ -376,7 +384,8 @@ class FDCaptureBinary:
self.syscapture = SysCapture(targetfd)
else:
self.tmpfile = EncodedFile(
TemporaryFile(buffering=0),
# TODO: Remove type ignore, fixed in next mypy release.
TemporaryFile(buffering=0), # type: ignore[arg-type]
encoding="utf-8",
errors="replace",
write_through=True,
@ -388,7 +397,7 @@ class FDCaptureBinary:
self._state = "initialized"
def __repr__(self):
def __repr__(self) -> str:
return "<{} {} oldfd={} _state={!r} tmpfile={!r}>".format(
self.__class__.__name__,
self.targetfd,
@ -404,7 +413,7 @@ class FDCaptureBinary:
op, self._state, ", ".join(states)
)
def start(self):
def start(self) -> None:
""" Start capturing on targetfd using memorized tmpfile. """
self._assert_state("start", ("initialized",))
os.dup2(self.tmpfile.fileno(), self.targetfd)
@ -419,7 +428,7 @@ class FDCaptureBinary:
self.tmpfile.truncate()
return res
def done(self):
def done(self) -> None:
""" stop capturing, restore streams, return original capture file,
seeked to position zero. """
self._assert_state("done", ("initialized", "started", "suspended", "done"))
@ -435,7 +444,7 @@ class FDCaptureBinary:
self.tmpfile.close()
self._state = "done"
def suspend(self):
def suspend(self) -> None:
self._assert_state("suspend", ("started", "suspended"))
if self._state == "suspended":
return
@ -443,7 +452,7 @@ class FDCaptureBinary:
os.dup2(self.targetfd_save, self.targetfd)
self._state = "suspended"
def resume(self):
def resume(self) -> None:
self._assert_state("resume", ("started", "suspended"))
if self._state == "started":
return
@ -493,12 +502,12 @@ class MultiCapture:
self.out = out
self.err = err
def __repr__(self):
def __repr__(self) -> str:
return "<MultiCapture out={!r} err={!r} in_={!r} _state={!r} _in_suspended={!r}>".format(
self.out, self.err, self.in_, self._state, self._in_suspended,
)
def start_capturing(self):
def start_capturing(self) -> None:
self._state = "started"
if self.in_:
self.in_.start()
@ -510,13 +519,14 @@ class MultiCapture:
def pop_outerr_to_orig(self):
""" pop current snapshot out/err capture and flush to orig streams. """
out, err = self.readouterr()
# TODO: Fix type ignores.
if out:
self.out.writeorg(out)
self.out.writeorg(out) # type: ignore[union-attr] # noqa: F821
if err:
self.err.writeorg(err)
self.err.writeorg(err) # type: ignore[union-attr] # noqa: F821
return out, err
def suspend_capturing(self, in_=False):
def suspend_capturing(self, in_: bool = False) -> None:
self._state = "suspended"
if self.out:
self.out.suspend()
@ -526,17 +536,18 @@ class MultiCapture:
self.in_.suspend()
self._in_suspended = True
def resume_capturing(self):
def resume_capturing(self) -> None:
self._state = "resumed"
if self.out:
self.out.resume()
if self.err:
self.err.resume()
if self._in_suspended:
self.in_.resume()
# TODO: Fix type ignore.
self.in_.resume() # type: ignore[union-attr] # noqa: F821
self._in_suspended = False
def stop_capturing(self):
def stop_capturing(self) -> None:
""" stop capturing and reset capturing streams """
if self._state == "stopped":
raise ValueError("was already stopped")
@ -592,15 +603,15 @@ class CaptureManager:
def __init__(self, method: "_CaptureMethod") -> None:
self._method = method
self._global_capturing = None
self._global_capturing = None # type: Optional[MultiCapture]
self._capture_fixture = None # type: Optional[CaptureFixture]
def __repr__(self):
def __repr__(self) -> str:
return "<CaptureManager _method={!r} _global_capturing={!r} _capture_fixture={!r}>".format(
self._method, self._global_capturing, self._capture_fixture
)
def is_capturing(self):
def is_capturing(self) -> Union[str, bool]:
if self.is_globally_capturing():
return "global"
if self._capture_fixture:
@ -609,40 +620,41 @@ class CaptureManager:
# Global capturing control
def is_globally_capturing(self):
def is_globally_capturing(self) -> bool:
return self._method != "no"
def start_global_capturing(self):
def start_global_capturing(self) -> None:
assert self._global_capturing is None
self._global_capturing = _get_multicapture(self._method)
self._global_capturing.start_capturing()
def stop_global_capturing(self):
def stop_global_capturing(self) -> None:
if self._global_capturing is not None:
self._global_capturing.pop_outerr_to_orig()
self._global_capturing.stop_capturing()
self._global_capturing = None
def resume_global_capture(self):
def resume_global_capture(self) -> None:
# During teardown of the python process, and on rare occasions, capture
# attributes can be `None` while trying to resume global capture.
if self._global_capturing is not None:
self._global_capturing.resume_capturing()
def suspend_global_capture(self, in_=False):
def suspend_global_capture(self, in_: bool = False) -> None:
if self._global_capturing is not None:
self._global_capturing.suspend_capturing(in_=in_)
def suspend(self, in_=False):
def suspend(self, in_: bool = False) -> None:
# Need to undo local capsys-et-al if it exists before disabling global capture.
self.suspend_fixture()
self.suspend_global_capture(in_)
def resume(self):
def resume(self) -> None:
self.resume_global_capture()
self.resume_fixture()
def read_global_capture(self):
assert self._global_capturing is not None
return self._global_capturing.readouterr()
# Fixture Control
@ -661,30 +673,30 @@ class CaptureManager:
def unset_fixture(self) -> None:
self._capture_fixture = None
def activate_fixture(self):
def activate_fixture(self) -> None:
"""If the current item is using ``capsys`` or ``capfd``, activate them so they take precedence over
the global capture.
"""
if self._capture_fixture:
self._capture_fixture._start()
def deactivate_fixture(self):
def deactivate_fixture(self) -> None:
"""Deactivates the ``capsys`` or ``capfd`` fixture of this item, if any."""
if self._capture_fixture:
self._capture_fixture.close()
def suspend_fixture(self):
def suspend_fixture(self) -> None:
if self._capture_fixture:
self._capture_fixture._suspend()
def resume_fixture(self):
def resume_fixture(self) -> None:
if self._capture_fixture:
self._capture_fixture._resume()
# Helper context managers
@contextlib.contextmanager
def global_and_fixture_disabled(self):
def global_and_fixture_disabled(self) -> Generator[None, None, None]:
"""Context manager to temporarily disable global and current fixture capturing."""
self.suspend()
try:
@ -693,7 +705,7 @@ class CaptureManager:
self.resume()
@contextlib.contextmanager
def item_capture(self, when, item):
def item_capture(self, when: str, item: Item) -> Generator[None, None, None]:
self.resume_global_capture()
self.activate_fixture()
try:
@ -709,7 +721,7 @@ class CaptureManager:
# Hooks
@pytest.hookimpl(hookwrapper=True)
def pytest_make_collect_report(self, collector):
def pytest_make_collect_report(self, collector: Collector):
if isinstance(collector, pytest.File):
self.resume_global_capture()
outcome = yield
@ -724,17 +736,17 @@ class CaptureManager:
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_setup(self, item):
def pytest_runtest_setup(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("setup", item):
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):
def pytest_runtest_call(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("call", item):
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(self, item):
def pytest_runtest_teardown(self, item: Item) -> Generator[None, None, None]:
with self.item_capture("teardown", item):
yield
@ -753,21 +765,21 @@ class CaptureFixture:
fixtures.
"""
def __init__(self, captureclass, request):
def __init__(self, captureclass, request: SubRequest) -> None:
self.captureclass = captureclass
self.request = request
self._capture = None
self._capture = None # type: Optional[MultiCapture]
self._captured_out = self.captureclass.EMPTY_BUFFER
self._captured_err = self.captureclass.EMPTY_BUFFER
def _start(self):
def _start(self) -> None:
if self._capture is None:
self._capture = MultiCapture(
in_=None, out=self.captureclass(1), err=self.captureclass(2),
)
self._capture.start_capturing()
def close(self):
def close(self) -> None:
if self._capture is not None:
out, err = self._capture.pop_outerr_to_orig()
self._captured_out += out
@ -789,18 +801,18 @@ class CaptureFixture:
self._captured_err = self.captureclass.EMPTY_BUFFER
return CaptureResult(captured_out, captured_err)
def _suspend(self):
def _suspend(self) -> None:
"""Suspends this fixture's own capturing temporarily."""
if self._capture is not None:
self._capture.suspend_capturing()
def _resume(self):
def _resume(self) -> None:
"""Resumes this fixture's own capturing temporarily."""
if self._capture is not None:
self._capture.resume_capturing()
@contextlib.contextmanager
def disabled(self):
def disabled(self) -> Generator[None, None, None]:
"""Temporarily disables capture while inside the 'with' block."""
capmanager = self.request.config.pluginmanager.getplugin("capturemanager")
with capmanager.global_and_fixture_disabled():
@ -811,7 +823,7 @@ class CaptureFixture:
@pytest.fixture
def capsys(request):
def capsys(request: SubRequest):
"""Enable text capturing of writes to ``sys.stdout`` and ``sys.stderr``.
The captured output is made available via ``capsys.readouterr()`` method
@ -828,7 +840,7 @@ def capsys(request):
@pytest.fixture
def capsysbinary(request):
def capsysbinary(request: SubRequest):
"""Enable bytes capturing of writes to ``sys.stdout`` and ``sys.stderr``.
The captured output is made available via ``capsysbinary.readouterr()``
@ -845,7 +857,7 @@ def capsysbinary(request):
@pytest.fixture
def capfd(request):
def capfd(request: SubRequest):
"""Enable text capturing of writes to file descriptors ``1`` and ``2``.
The captured output is made available via ``capfd.readouterr()`` method
@ -862,7 +874,7 @@ def capfd(request):
@pytest.fixture
def capfdbinary(request):
def capfdbinary(request: SubRequest):
"""Enable bytes capturing of writes to file descriptors ``1`` and ``2``.
The captured output is made available via ``capfd.readouterr()`` method

View File

@ -1,6 +1,7 @@
"""
python version compatibility code
"""
import enum
import functools
import inspect
import os
@ -33,13 +34,20 @@ else:
if TYPE_CHECKING:
from typing import Type
from typing_extensions import Final
_T = TypeVar("_T")
_S = TypeVar("_S")
NOTSET = object()
# fmt: off
# Singleton type for NOTSET, as described in:
# https://www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions
class NotSetType(enum.Enum):
token = 0
NOTSET = NotSetType.token # type: Final # noqa: E305
# fmt: on
MODULE_NOT_FOUND_ERROR = (
"ModuleNotFoundError" if sys.version_info[:2] >= (3, 6) else "ImportError"

View File

@ -14,6 +14,7 @@ from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import IO
from typing import List
from typing import Optional
from typing import Sequence
@ -295,7 +296,7 @@ class PytestPluginManager(PluginManager):
* ``conftest.py`` loading during start-up;
"""
def __init__(self):
def __init__(self) -> None:
import _pytest.assertion
super().__init__("pytest")
@ -315,7 +316,7 @@ class PytestPluginManager(PluginManager):
self.add_hookspecs(_pytest.hookspec)
self.register(self)
if os.environ.get("PYTEST_DEBUG"):
err = sys.stderr
err = sys.stderr # type: IO[str]
encoding = getattr(err, "encoding", "utf8")
try:
err = open(
@ -377,7 +378,7 @@ class PytestPluginManager(PluginManager):
}
return opts
def register(self, plugin, name=None):
def register(self, plugin: _PluggyPlugin, name: Optional[str] = None):
if name in _pytest.deprecated.DEPRECATED_EXTERNAL_PLUGINS:
warnings.warn(
PytestConfigWarning(
@ -406,7 +407,7 @@ class PytestPluginManager(PluginManager):
"""Return True if the plugin with the given name is registered."""
return bool(self.get_plugin(name))
def pytest_configure(self, config):
def pytest_configure(self, config: "Config") -> None:
# XXX now that the pluginmanager exposes hookimpl(tryfirst...)
# we should remove tryfirst/trylast as markers
config.addinivalue_line(
@ -552,7 +553,7 @@ class PytestPluginManager(PluginManager):
#
#
def consider_preparse(self, args, *, exclude_only=False):
def consider_preparse(self, args, *, exclude_only: bool = False) -> None:
i = 0
n = len(args)
while i < n:
@ -573,7 +574,7 @@ class PytestPluginManager(PluginManager):
continue
self.consider_pluginarg(parg)
def consider_pluginarg(self, arg):
def consider_pluginarg(self, arg) -> None:
if arg.startswith("no:"):
name = arg[3:]
if name in essential_plugins:
@ -598,13 +599,13 @@ class PytestPluginManager(PluginManager):
del self._name2plugin["pytest_" + name]
self.import_plugin(arg, consider_entry_points=True)
def consider_conftest(self, conftestmodule):
def consider_conftest(self, conftestmodule) -> None:
self.register(conftestmodule, name=conftestmodule.__file__)
def consider_env(self):
def consider_env(self) -> None:
self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
def consider_module(self, mod):
def consider_module(self, mod: types.ModuleType) -> None:
self._import_plugin_specs(getattr(mod, "pytest_plugins", []))
def _import_plugin_specs(self, spec):
@ -612,7 +613,7 @@ class PytestPluginManager(PluginManager):
for import_spec in plugins:
self.import_plugin(import_spec)
def import_plugin(self, modname, consider_entry_points=False):
def import_plugin(self, modname: str, consider_entry_points: bool = False) -> None:
"""
Imports a plugin with ``modname``. If ``consider_entry_points`` is True, entry point
names are also considered to find a plugin.
@ -839,23 +840,23 @@ class Config:
self.cache = None # type: Optional[Cache]
@property
def invocation_dir(self):
def invocation_dir(self) -> py.path.local:
"""Backward compatibility"""
return py.path.local(str(self.invocation_params.dir))
def add_cleanup(self, func):
def add_cleanup(self, func) -> None:
""" Add a function to be called when the config object gets out of
use (usually coninciding with pytest_unconfigure)."""
self._cleanup.append(func)
def _do_configure(self):
def _do_configure(self) -> None:
assert not self._configured
self._configured = True
with warnings.catch_warnings():
warnings.simplefilter("default")
self.hook.pytest_configure.call_historic(kwargs=dict(config=self))
def _ensure_unconfigure(self):
def _ensure_unconfigure(self) -> None:
if self._configured:
self._configured = False
self.hook.pytest_unconfigure(config=self)
@ -867,7 +868,9 @@ class Config:
def get_terminal_writer(self):
return self.pluginmanager.get_plugin("terminalreporter")._tw
def pytest_cmdline_parse(self, pluginmanager, args):
def pytest_cmdline_parse(
self, pluginmanager: PytestPluginManager, args: List[str]
) -> object:
try:
self.parse(args)
except UsageError:
@ -971,7 +974,7 @@ class Config:
self._mark_plugins_for_rewrite(hook)
_warn_about_missing_assertion(mode)
def _mark_plugins_for_rewrite(self, hook):
def _mark_plugins_for_rewrite(self, hook) -> None:
"""
Given an importhook, mark for rewrite any top-level
modules or packages in the distribution package for
@ -986,7 +989,9 @@ class Config:
package_files = (
str(file)
for dist in importlib_metadata.distributions()
if any(ep.group == "pytest11" for ep in dist.entry_points)
# Type ignored due to missing stub:
# https://github.com/python/typeshed/pull/3795
if any(ep.group == "pytest11" for ep in dist.entry_points) # type: ignore
for file in dist.files or []
)

View File

@ -2,14 +2,27 @@
import argparse
import functools
import sys
from typing import Generator
from typing import Tuple
from typing import Union
from _pytest import outcomes
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
from _pytest.config import hookimpl
from _pytest.config import PytestPluginManager
from _pytest.config.argparsing import Parser
from _pytest.config.exceptions import UsageError
from _pytest.nodes import Node
from _pytest.reports import BaseReport
if TYPE_CHECKING:
from _pytest.capture import CaptureManager
from _pytest.runner import CallInfo
def _validate_usepdb_cls(value):
def _validate_usepdb_cls(value: str) -> Tuple[str, str]:
"""Validate syntax of --pdbcls option."""
try:
modname, classname = value.split(":")
@ -20,7 +33,7 @@ def _validate_usepdb_cls(value):
return (modname, classname)
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"--pdb",
@ -44,7 +57,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
import pdb
if config.getvalue("trace"):
@ -61,7 +74,7 @@ def pytest_configure(config):
# NOTE: not using pytest_unconfigure, since it might get called although
# pytest_configure was not (if another plugin raises UsageError).
def fin():
def fin() -> None:
(
pdb.set_trace,
pytestPDB._pluginmanager,
@ -74,20 +87,20 @@ def pytest_configure(config):
class pytestPDB:
""" Pseudo PDB that defers to the real pdb. """
_pluginmanager = None
_config = None
_pluginmanager = None # type: PytestPluginManager
_config = None # type: Config
_saved = [] # type: list
_recursive_debug = 0
_wrapped_pdb_cls = None
@classmethod
def _is_capturing(cls, capman):
def _is_capturing(cls, capman: "CaptureManager") -> Union[str, bool]:
if capman:
return capman.is_capturing()
return False
@classmethod
def _import_pdb_cls(cls, capman):
def _import_pdb_cls(cls, capman: "CaptureManager"):
if not cls._config:
import pdb
@ -126,10 +139,12 @@ class pytestPDB:
return wrapped_cls
@classmethod
def _get_pdb_wrapper_class(cls, pdb_cls, capman):
def _get_pdb_wrapper_class(cls, pdb_cls, capman: "CaptureManager"):
import _pytest.config
class PytestPdbWrapper(pdb_cls):
# Type ignored because mypy doesn't support "dynamic"
# inheritance like this.
class PytestPdbWrapper(pdb_cls): # type: ignore[valid-type,misc] # noqa: F821
_pytest_capman = capman
_continued = False
@ -248,7 +263,7 @@ class pytestPDB:
return _pdb
@classmethod
def set_trace(cls, *args, **kwargs):
def set_trace(cls, *args, **kwargs) -> None:
"""Invoke debugging via ``Pdb.set_trace``, dropping any IO capturing."""
frame = sys._getframe().f_back
_pdb = cls._init_pdb("set_trace", *args, **kwargs)
@ -256,7 +271,9 @@ class pytestPDB:
class PdbInvoke:
def pytest_exception_interact(self, node, call, report):
def pytest_exception_interact(
self, node: Node, call: "CallInfo", report: BaseReport
) -> None:
capman = node.config.pluginmanager.getplugin("capturemanager")
if capman:
capman.suspend_global_capture(in_=True)
@ -265,14 +282,14 @@ class PdbInvoke:
sys.stdout.write(err)
_enter_pdb(node, call.excinfo, report)
def pytest_internalerror(self, excrepr, excinfo):
def pytest_internalerror(self, excrepr, excinfo) -> None:
tb = _postmortem_traceback(excinfo)
post_mortem(tb)
class PdbTrace:
@hookimpl(hookwrapper=True)
def pytest_pyfunc_call(self, pyfuncitem):
def pytest_pyfunc_call(self, pyfuncitem) -> Generator[None, None, None]:
wrap_pytest_function_for_tracing(pyfuncitem)
yield
@ -303,7 +320,7 @@ def maybe_wrap_pytest_function_for_tracing(pyfuncitem):
wrap_pytest_function_for_tracing(pyfuncitem)
def _enter_pdb(node, excinfo, rep):
def _enter_pdb(node: Node, excinfo, rep: BaseReport) -> BaseReport:
# XXX we re-use the TerminalReporter's terminalwriter
# because this seems to avoid some encoding related troubles
# for not completely clear reasons.
@ -327,7 +344,7 @@ def _enter_pdb(node, excinfo, rep):
rep.toterminal(tw)
tw.sep(">", "entering PDB")
tb = _postmortem_traceback(excinfo)
rep._pdbshown = True
rep._pdbshown = True # type: ignore[attr-defined] # noqa: F821
post_mortem(tb)
return rep
@ -347,7 +364,7 @@ def _postmortem_traceback(excinfo):
return excinfo._excinfo[2]
def post_mortem(t):
def post_mortem(t) -> None:
p = pytestPDB._init_pdb("post_mortem")
p.reset()
p.interaction(None, t)

View File

@ -4,11 +4,17 @@ import inspect
import platform
import sys
import traceback
import types
import warnings
from contextlib import contextmanager
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Tuple
from typing import Union
@ -23,6 +29,8 @@ from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import safe_getattr
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.outcomes import OutcomeException
from _pytest.python_api import approx
@ -52,7 +60,7 @@ RUNNER_CLASS = None
CHECKER_CLASS = None # type: Optional[Type[doctest.OutputChecker]]
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"doctest_optionflags",
"option flags for doctests",
@ -102,19 +110,24 @@ def pytest_addoption(parser):
)
def pytest_unconfigure():
def pytest_unconfigure() -> None:
global RUNNER_CLASS
RUNNER_CLASS = None
def pytest_collect_file(path: py.path.local, parent):
def pytest_collect_file(
path: py.path.local, parent
) -> Optional[Union["DoctestModule", "DoctestTextfile"]]:
config = parent.config
if path.ext == ".py":
if config.option.doctestmodules and not _is_setup_py(path):
return DoctestModule.from_parent(parent, fspath=path)
mod = DoctestModule.from_parent(parent, fspath=path) # type: DoctestModule
return mod
elif _is_doctest(config, path, parent):
return DoctestTextfile.from_parent(parent, fspath=path)
txt = DoctestTextfile.from_parent(parent, fspath=path) # type: DoctestTextfile
return txt
return None
def _is_setup_py(path: py.path.local) -> bool:
@ -124,7 +137,7 @@ def _is_setup_py(path: py.path.local) -> bool:
return b"setuptools" in contents or b"distutils" in contents
def _is_doctest(config, path, parent):
def _is_doctest(config: Config, path: py.path.local, parent) -> bool:
if path.ext in (".txt", ".rst") and parent.session.isinitpath(path):
return True
globs = config.getoption("doctestglob") or ["test*.txt"]
@ -137,7 +150,7 @@ def _is_doctest(config, path, parent):
class ReprFailDoctest(TerminalRepr):
def __init__(
self, reprlocation_lines: Sequence[Tuple[ReprFileLocation, Sequence[str]]]
):
) -> None:
self.reprlocation_lines = reprlocation_lines
def toterminal(self, tw: TerminalWriter) -> None:
@ -148,7 +161,7 @@ class ReprFailDoctest(TerminalRepr):
class MultipleDoctestFailures(Exception):
def __init__(self, failures):
def __init__(self, failures: "Sequence[doctest.DocTestFailure]") -> None:
super().__init__()
self.failures = failures
@ -163,21 +176,33 @@ def _init_runner_class() -> "Type[doctest.DocTestRunner]":
"""
def __init__(
self, checker=None, verbose=None, optionflags=0, continue_on_failure=True
):
self,
checker: Optional[doctest.OutputChecker] = None,
verbose: Optional[bool] = None,
optionflags: int = 0,
continue_on_failure: bool = True,
) -> None:
doctest.DebugRunner.__init__(
self, checker=checker, verbose=verbose, optionflags=optionflags
)
self.continue_on_failure = continue_on_failure
def report_failure(self, out, test, example, got):
def report_failure(
self, out, test: "doctest.DocTest", example: "doctest.Example", got: str,
) -> None:
failure = doctest.DocTestFailure(test, example, got)
if self.continue_on_failure:
out.append(failure)
else:
raise failure
def report_unexpected_exception(self, out, test, example, exc_info):
def report_unexpected_exception(
self,
out,
test: "doctest.DocTest",
example: "doctest.Example",
exc_info: "Tuple[Type[BaseException], BaseException, types.TracebackType]",
) -> None:
if isinstance(exc_info[1], OutcomeException):
raise exc_info[1]
if isinstance(exc_info[1], bdb.BdbQuit):
@ -212,16 +237,27 @@ def _get_runner(
class DoctestItem(pytest.Item):
def __init__(self, name, parent, runner=None, dtest=None):
def __init__(
self,
name: str,
parent: "Union[DoctestTextfile, DoctestModule]",
runner: Optional["doctest.DocTestRunner"] = None,
dtest: Optional["doctest.DocTest"] = None,
) -> None:
super().__init__(name, parent)
self.runner = runner
self.dtest = dtest
self.obj = None
self.fixture_request = None
self.fixture_request = None # type: Optional[FixtureRequest]
@classmethod
def from_parent( # type: ignore
cls, parent: "Union[DoctestTextfile, DoctestModule]", *, name, runner, dtest
cls,
parent: "Union[DoctestTextfile, DoctestModule]",
*,
name: str,
runner: "doctest.DocTestRunner",
dtest: "doctest.DocTest"
):
# incompatible signature due to to imposed limits on sublcass
"""
@ -229,7 +265,7 @@ class DoctestItem(pytest.Item):
"""
return super().from_parent(name=name, parent=parent, runner=runner, dtest=dtest)
def setup(self):
def setup(self) -> None:
if self.dtest is not None:
self.fixture_request = _setup_fixtures(self)
globs = dict(getfixture=self.fixture_request.getfixturevalue)
@ -240,14 +276,18 @@ class DoctestItem(pytest.Item):
self.dtest.globs.update(globs)
def runtest(self) -> None:
assert self.dtest is not None
assert self.runner is not None
_check_all_skipped(self.dtest)
self._disable_output_capturing_for_darwin()
failures = [] # type: List[doctest.DocTestFailure]
self.runner.run(self.dtest, out=failures)
# Type ignored because we change the type of `out` from what
# doctest expects.
self.runner.run(self.dtest, out=failures) # type: ignore[arg-type] # noqa: F821
if failures:
raise MultipleDoctestFailures(failures)
def _disable_output_capturing_for_darwin(self):
def _disable_output_capturing_for_darwin(self) -> None:
"""
Disable output capturing. Otherwise, stdout is lost to doctest (#985)
"""
@ -260,15 +300,20 @@ class DoctestItem(pytest.Item):
sys.stdout.write(out)
sys.stderr.write(err)
def repr_failure(self, excinfo):
# TODO: Type ignored -- breaks Liskov Substitution.
def repr_failure( # type: ignore[override] # noqa: F821
self, excinfo: ExceptionInfo[BaseException],
) -> Union[str, TerminalRepr]:
import doctest
failures = (
None
) # type: Optional[List[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
if excinfo.errisinstance((doctest.DocTestFailure, doctest.UnexpectedException)):
) # type: Optional[Sequence[Union[doctest.DocTestFailure, doctest.UnexpectedException]]]
if isinstance(
excinfo.value, (doctest.DocTestFailure, doctest.UnexpectedException)
):
failures = [excinfo.value]
elif excinfo.errisinstance(MultipleDoctestFailures):
elif isinstance(excinfo.value, MultipleDoctestFailures):
failures = excinfo.value.failures
if failures is not None:
@ -282,7 +327,8 @@ class DoctestItem(pytest.Item):
else:
lineno = test.lineno + example.lineno + 1
message = type(failure).__name__
reprlocation = ReprFileLocation(filename, lineno, message)
# TODO: ReprFileLocation doesn't expect a None lineno.
reprlocation = ReprFileLocation(filename, lineno, message) # type: ignore[arg-type] # noqa: F821
checker = _get_checker()
report_choice = _get_report_choice(
self.config.getoption("doctestreport")
@ -322,7 +368,8 @@ class DoctestItem(pytest.Item):
else:
return super().repr_failure(excinfo)
def reportinfo(self) -> Tuple[py.path.local, int, str]:
def reportinfo(self):
assert self.dtest is not None
return self.fspath, self.dtest.lineno, "[doctest] %s" % self.name
@ -364,7 +411,7 @@ def _get_continue_on_failure(config):
class DoctestTextfile(pytest.Module):
obj = None
def collect(self):
def collect(self) -> Iterable[DoctestItem]:
import doctest
# inspired by doctest.testfile; ideally we would use it directly,
@ -392,7 +439,7 @@ class DoctestTextfile(pytest.Module):
)
def _check_all_skipped(test):
def _check_all_skipped(test: "doctest.DocTest") -> None:
"""raises pytest.skip() if all examples in the given DocTest have the SKIP
option set.
"""
@ -403,7 +450,7 @@ def _check_all_skipped(test):
pytest.skip("all tests skipped by +SKIP option")
def _is_mocked(obj):
def _is_mocked(obj: object) -> bool:
"""
returns if a object is possibly a mock object by checking the existence of a highly improbable attribute
"""
@ -414,23 +461,26 @@ def _is_mocked(obj):
@contextmanager
def _patch_unwrap_mock_aware():
def _patch_unwrap_mock_aware() -> Generator[None, None, None]:
"""
contextmanager which replaces ``inspect.unwrap`` with a version
that's aware of mock objects and doesn't recurse on them
"""
real_unwrap = inspect.unwrap
def _mock_aware_unwrap(obj, stop=None):
def _mock_aware_unwrap(
func: Callable[..., Any], *, stop: Optional[Callable[[Any], Any]] = None
) -> Any:
try:
if stop is None or stop is _is_mocked:
return real_unwrap(obj, stop=_is_mocked)
return real_unwrap(obj, stop=lambda obj: _is_mocked(obj) or stop(obj))
return real_unwrap(func, stop=_is_mocked)
_stop = stop
return real_unwrap(func, stop=lambda obj: _is_mocked(obj) or _stop(func))
except Exception as e:
warnings.warn(
"Got %r when unwrapping %r. This is usually caused "
"by a violation of Python's object protocol; see e.g. "
"https://github.com/pytest-dev/pytest/issues/5080" % (e, obj),
"https://github.com/pytest-dev/pytest/issues/5080" % (e, func),
PytestWarning,
)
raise
@ -443,7 +493,7 @@ def _patch_unwrap_mock_aware():
class DoctestModule(pytest.Module):
def collect(self):
def collect(self) -> Iterable[DoctestItem]:
import doctest
class MockAwareDocTestFinder(doctest.DocTestFinder):
@ -462,7 +512,10 @@ class DoctestModule(pytest.Module):
"""
if isinstance(obj, property):
obj = getattr(obj, "fget", obj)
return doctest.DocTestFinder._find_lineno(self, obj, source_lines)
# Type ignored because this is a private function.
return doctest.DocTestFinder._find_lineno( # type: ignore
self, obj, source_lines,
)
def _find(
self, tests, obj, name, module, source_lines, globs, seen
@ -503,17 +556,17 @@ class DoctestModule(pytest.Module):
)
def _setup_fixtures(doctest_item):
def _setup_fixtures(doctest_item: DoctestItem) -> FixtureRequest:
"""
Used by DoctestTextfile and DoctestItem to setup fixture information.
"""
def func():
def func() -> None:
pass
doctest_item.funcargs = {}
doctest_item.funcargs = {} # type: ignore[attr-defined] # noqa: F821
fm = doctest_item.session._fixturemanager
doctest_item._fixtureinfo = fm.getfixtureinfo(
doctest_item._fixtureinfo = fm.getfixtureinfo( # type: ignore[attr-defined] # noqa: F821
node=doctest_item, func=func, cls=None, funcargs=False
)
fixture_request = FixtureRequest(doctest_item)
@ -557,7 +610,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
re.VERBOSE,
)
def check_output(self, want, got, optionflags):
def check_output(self, want: str, got: str, optionflags: int) -> bool:
if doctest.OutputChecker.check_output(self, want, got, optionflags):
return True
@ -568,7 +621,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
if not allow_unicode and not allow_bytes and not allow_number:
return False
def remove_prefixes(regex, txt):
def remove_prefixes(regex: Pattern[str], txt: str) -> str:
return re.sub(regex, r"\1\2", txt)
if allow_unicode:
@ -584,7 +637,7 @@ def _init_checker_class() -> "Type[doctest.OutputChecker]":
return doctest.OutputChecker.check_output(self, want, got, optionflags)
def _remove_unwanted_precision(self, want, got):
def _remove_unwanted_precision(self, want: str, got: str) -> str:
wants = list(self._number_re.finditer(want))
gots = list(self._number_re.finditer(got))
if len(wants) != len(gots):
@ -679,7 +732,7 @@ def _get_report_choice(key: str) -> int:
@pytest.fixture(scope="session")
def doctest_namespace():
def doctest_namespace() -> Dict[str, Any]:
"""
Fixture that returns a :py:class:`dict` that will be injected into the namespace of doctests.
"""

View File

@ -1,16 +1,20 @@
import io
import os
import sys
from typing import Generator
from typing import TextIO
import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.nodes import Item
from _pytest.store import StoreKey
fault_handler_stderr_key = StoreKey[TextIO]()
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
help = (
"Dump the traceback of all threads if a test takes "
"more than TIMEOUT seconds to finish."
@ -18,7 +22,7 @@ def pytest_addoption(parser):
parser.addini("faulthandler_timeout", help, default=0.0)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
import faulthandler
if not faulthandler.is_enabled():
@ -46,14 +50,14 @@ class FaultHandlerHooks:
"""Implements hooks that will actually install fault handler before tests execute,
as well as correctly handle pdb and internal errors."""
def pytest_configure(self, config):
def pytest_configure(self, config: Config) -> None:
import faulthandler
stderr_fd_copy = os.dup(self._get_stderr_fileno())
config._store[fault_handler_stderr_key] = open(stderr_fd_copy, "w")
faulthandler.enable(file=config._store[fault_handler_stderr_key])
def pytest_unconfigure(self, config):
def pytest_unconfigure(self, config: Config) -> None:
import faulthandler
faulthandler.disable()
@ -80,7 +84,7 @@ class FaultHandlerHooks:
return float(config.getini("faulthandler_timeout") or 0.0)
@pytest.hookimpl(hookwrapper=True, trylast=True)
def pytest_runtest_protocol(self, item):
def pytest_runtest_protocol(self, item: Item) -> Generator[None, None, None]:
timeout = self.get_timeout_config_value(item.config)
stderr = item.config._store[fault_handler_stderr_key]
if timeout > 0 and stderr is not None:
@ -95,7 +99,7 @@ class FaultHandlerHooks:
yield
@pytest.hookimpl(tryfirst=True)
def pytest_enter_pdb(self):
def pytest_enter_pdb(self) -> None:
"""Cancel any traceback dumping due to timeout before entering pdb.
"""
import faulthandler
@ -103,7 +107,7 @@ class FaultHandlerHooks:
faulthandler.cancel_dump_traceback_later()
@pytest.hookimpl(tryfirst=True)
def pytest_exception_interact(self):
def pytest_exception_interact(self) -> None:
"""Cancel any traceback dumping due to an interactive exception being
raised.
"""

File diff suppressed because it is too large Load Diff

View File

@ -2,11 +2,17 @@
import os
import sys
from argparse import Action
from typing import List
from typing import Optional
from typing import Union
import py
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PrintHelp
from _pytest.config.argparsing import Parser
class HelpAction(Action):
@ -36,7 +42,7 @@ class HelpAction(Action):
raise PrintHelp
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--version",
@ -109,7 +115,7 @@ def pytest_cmdline_parse():
undo_tracing = config.pluginmanager.enable_tracing()
sys.stderr.write("writing pytestdebug information to %s\n" % path)
def unset_tracing():
def unset_tracing() -> None:
debugfile.close()
sys.stderr.write("wrote pytestdebug information to %s\n" % debugfile.name)
config.trace.root.setwriter(None)
@ -133,7 +139,7 @@ def showversion(config):
sys.stderr.write("pytest {}\n".format(pytest.__version__))
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.version > 0:
showversion(config)
return 0
@ -142,9 +148,10 @@ def pytest_cmdline_main(config):
showhelp(config)
config._ensure_unconfigure()
return 0
return None
def showhelp(config):
def showhelp(config: Config) -> None:
import textwrap
reporter = config.pluginmanager.get_plugin("terminalreporter")
@ -229,7 +236,7 @@ def getpluginversioninfo(config):
return lines
def pytest_report_header(config):
def pytest_report_header(config: Config) -> List[str]:
lines = []
if config.option.debug or config.option.traceconfig:
lines.append(

View File

@ -1,10 +1,13 @@
""" hook specifications for pytest plugins, invoked from main.py and builtin plugins. """
from typing import Any
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import py.path
from pluggy import HookspecMarker
from .deprecated import COLLECT_DIRECTORY_HOOK
@ -12,10 +15,30 @@ from .deprecated import WARNING_CAPTURED_HOOK
from _pytest.compat import TYPE_CHECKING
if TYPE_CHECKING:
import pdb
import warnings
from typing_extensions import Literal
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import PytestPluginManager
from _pytest.config import _PluggyPlugin
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
from _pytest.main import Session
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.python import Function
from _pytest.python import Metafunc
from _pytest.python import Module
from _pytest.python import PyCollector
from _pytest.reports import BaseReport
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.runner import CallInfo
from _pytest.terminal import TerminalReporter
hookspec = HookspecMarker("pytest")
@ -26,7 +49,7 @@ hookspec = HookspecMarker("pytest")
@hookspec(historic=True)
def pytest_addhooks(pluginmanager):
def pytest_addhooks(pluginmanager: "PytestPluginManager") -> None:
"""called at plugin registration time to allow adding new hooks via a call to
``pluginmanager.add_hookspecs(module_or_class, prefix)``.
@ -39,7 +62,9 @@ def pytest_addhooks(pluginmanager):
@hookspec(historic=True)
def pytest_plugin_registered(plugin, manager):
def pytest_plugin_registered(
plugin: "_PluggyPlugin", manager: "PytestPluginManager"
) -> None:
""" a new pytest plugin got registered.
:param plugin: the plugin module or instance
@ -51,7 +76,7 @@ def pytest_plugin_registered(plugin, manager):
@hookspec(historic=True)
def pytest_addoption(parser, pluginmanager):
def pytest_addoption(parser: "Parser", pluginmanager: "PytestPluginManager") -> None:
"""register argparse-style options and ini-style config values,
called once at the beginning of a test run.
@ -89,7 +114,7 @@ def pytest_addoption(parser, pluginmanager):
@hookspec(historic=True)
def pytest_configure(config):
def pytest_configure(config: "Config") -> None:
"""
Allows plugins and conftest files to perform initial configuration.
@ -113,7 +138,9 @@ def pytest_configure(config):
@hookspec(firstresult=True)
def pytest_cmdline_parse(pluginmanager, args):
def pytest_cmdline_parse(
pluginmanager: "PytestPluginManager", args: List[str]
) -> Optional[object]:
"""return initialized config object, parsing the specified args.
Stops at first non-None result, see :ref:`firstresult`
@ -127,7 +154,7 @@ def pytest_cmdline_parse(pluginmanager, args):
"""
def pytest_cmdline_preparse(config, args):
def pytest_cmdline_preparse(config: "Config", args: List[str]) -> None:
"""(**Deprecated**) modify command line arguments before option parsing.
This hook is considered deprecated and will be removed in a future pytest version. Consider
@ -142,7 +169,7 @@ def pytest_cmdline_preparse(config, args):
@hookspec(firstresult=True)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: "Config") -> "Optional[Union[ExitCode, int]]":
""" called for performing the main command line action. The default
implementation will invoke the configure hooks and runtest_mainloop.
@ -155,7 +182,9 @@ def pytest_cmdline_main(config):
"""
def pytest_load_initial_conftests(early_config, parser, args):
def pytest_load_initial_conftests(
early_config: "Config", parser: "Parser", args: List[str]
) -> None:
""" implements the loading of initial conftest files ahead
of command line option parsing.
@ -198,7 +227,9 @@ def pytest_collection(session: "Session") -> Optional[Any]:
"""
def pytest_collection_modifyitems(session, config, items):
def pytest_collection_modifyitems(
session: "Session", config: "Config", items: List["Item"]
) -> None:
""" called after collection has been performed, may filter or re-order
the items in-place.
@ -208,7 +239,7 @@ def pytest_collection_modifyitems(session, config, items):
"""
def pytest_collection_finish(session):
def pytest_collection_finish(session: "Session"):
""" called after collection has been performed and modified.
:param _pytest.main.Session session: the pytest session object
@ -216,7 +247,7 @@ def pytest_collection_finish(session):
@hookspec(firstresult=True)
def pytest_ignore_collect(path, config):
def pytest_ignore_collect(path, config: "Config"):
""" return True to prevent considering this path for collection.
This hook is consulted for all files and directories prior to calling
more specific hooks.
@ -238,7 +269,7 @@ def pytest_collect_directory(path, parent):
"""
def pytest_collect_file(path, parent):
def pytest_collect_file(path: py.path.local, parent) -> "Optional[Collector]":
""" return collection Node or None for the given path. Any new node
needs to have the specified ``parent`` as a parent.
@ -249,7 +280,7 @@ def pytest_collect_file(path, parent):
# logging hooks for collection
def pytest_collectstart(collector):
def pytest_collectstart(collector: "Collector") -> None:
""" collector starts collecting. """
@ -257,7 +288,7 @@ def pytest_itemcollected(item):
""" we just collected a test item. """
def pytest_collectreport(report):
def pytest_collectreport(report: "CollectReport") -> None:
""" collector finished collecting. """
@ -266,7 +297,7 @@ def pytest_deselected(items):
@hookspec(firstresult=True)
def pytest_make_collect_report(collector):
def pytest_make_collect_report(collector: "Collector") -> "Optional[CollectReport]":
""" perform ``collector.collect()`` and return a CollectReport.
Stops at first non-None result, see :ref:`firstresult` """
@ -278,7 +309,7 @@ def pytest_make_collect_report(collector):
@hookspec(firstresult=True)
def pytest_pycollect_makemodule(path, parent):
def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Optional[Module]":
""" return a Module collector or None for the given path.
This hook will be called for each matching test module path.
The pytest_collect_file hook needs to be used if you want to
@ -291,25 +322,29 @@ def pytest_pycollect_makemodule(path, parent):
@hookspec(firstresult=True)
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(
collector: "PyCollector", name: str, obj
) -> "Union[None, Item, Collector, List[Union[Item, Collector]]]":
""" return custom item/collector for a python object in a module, or None.
Stops at first non-None result, see :ref:`firstresult` """
@hookspec(firstresult=True)
def pytest_pyfunc_call(pyfuncitem):
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
""" call underlying test function.
Stops at first non-None result, see :ref:`firstresult` """
def pytest_generate_tests(metafunc):
def pytest_generate_tests(metafunc: "Metafunc") -> None:
""" generate (multiple) parametrized calls to a test function."""
@hookspec(firstresult=True)
def pytest_make_parametrize_id(config, val, argname):
def pytest_make_parametrize_id(
config: "Config", val: object, argname: str
) -> Optional[str]:
"""Return a user-friendly string representation of the given ``val`` that will be used
by @pytest.mark.parametrize calls. Return None if the hook doesn't know about ``val``.
The parameter name is available as ``argname``, if required.
@ -328,7 +363,7 @@ def pytest_make_parametrize_id(config, val, argname):
@hookspec(firstresult=True)
def pytest_runtestloop(session):
def pytest_runtestloop(session: "Session") -> Optional[object]:
""" called for performing the main runtest loop
(after collection finished).
@ -339,7 +374,9 @@ def pytest_runtestloop(session):
@hookspec(firstresult=True)
def pytest_runtest_protocol(item, nextitem):
def pytest_runtest_protocol(
item: "Item", nextitem: "Optional[Item]"
) -> Optional[object]:
""" implements the runtest_setup/call/teardown protocol for
the given test item, including capturing exceptions and calling
reporting hooks.
@ -378,15 +415,15 @@ def pytest_runtest_logfinish(nodeid, location):
"""
def pytest_runtest_setup(item):
def pytest_runtest_setup(item: "Item") -> None:
""" called before ``pytest_runtest_call(item)``. """
def pytest_runtest_call(item):
def pytest_runtest_call(item: "Item") -> None:
""" called to execute the test ``item``. """
def pytest_runtest_teardown(item, nextitem):
def pytest_runtest_teardown(item: "Item", nextitem: "Optional[Item]") -> None:
""" called after ``pytest_runtest_call``.
:arg nextitem: the scheduled-to-be-next test item (None if no further
@ -397,7 +434,7 @@ def pytest_runtest_teardown(item, nextitem):
@hookspec(firstresult=True)
def pytest_runtest_makereport(item, call):
def pytest_runtest_makereport(item: "Item", call: "CallInfo[None]") -> Optional[object]:
""" return a :py:class:`_pytest.runner.TestReport` object
for the given :py:class:`pytest.Item <_pytest.main.Item>` and
:py:class:`_pytest.runner.CallInfo`.
@ -405,13 +442,13 @@ def pytest_runtest_makereport(item, call):
Stops at first non-None result, see :ref:`firstresult` """
def pytest_runtest_logreport(report):
def pytest_runtest_logreport(report: "TestReport") -> None:
""" process a test setup/call/teardown report relating to
the respective phase of executing a test. """
@hookspec(firstresult=True)
def pytest_report_to_serializable(config, report):
def pytest_report_to_serializable(config: "Config", report: "BaseReport"):
"""
Serializes the given report object into a data structure suitable for sending
over the wire, e.g. converted to JSON.
@ -419,7 +456,7 @@ def pytest_report_to_serializable(config, report):
@hookspec(firstresult=True)
def pytest_report_from_serializable(config, data):
def pytest_report_from_serializable(config: "Config", data):
"""
Restores a report object previously serialized with pytest_report_to_serializable().
"""
@ -431,7 +468,9 @@ def pytest_report_from_serializable(config, data):
@hookspec(firstresult=True)
def pytest_fixture_setup(fixturedef, request):
def pytest_fixture_setup(
fixturedef: "FixtureDef", request: "SubRequest"
) -> Optional[object]:
""" performs fixture setup execution.
:return: The return value of the call to the fixture function
@ -445,7 +484,9 @@ def pytest_fixture_setup(fixturedef, request):
"""
def pytest_fixture_post_finalizer(fixturedef, request):
def pytest_fixture_post_finalizer(
fixturedef: "FixtureDef", request: "SubRequest"
) -> None:
"""Called after fixture teardown, but before the cache is cleared, so
the fixture result ``fixturedef.cached_result`` is still available (not
``None``)."""
@ -456,7 +497,7 @@ def pytest_fixture_post_finalizer(fixturedef, request):
# -------------------------------------------------------------------------
def pytest_sessionstart(session):
def pytest_sessionstart(session: "Session") -> None:
""" called after the ``Session`` object has been created and before performing collection
and entering the run test loop.
@ -464,7 +505,9 @@ def pytest_sessionstart(session):
"""
def pytest_sessionfinish(session, exitstatus):
def pytest_sessionfinish(
session: "Session", exitstatus: "Union[int, ExitCode]"
) -> None:
""" called after whole test run finished, right before returning the exit status to the system.
:param _pytest.main.Session session: the pytest session object
@ -472,7 +515,7 @@ def pytest_sessionfinish(session, exitstatus):
"""
def pytest_unconfigure(config):
def pytest_unconfigure(config: "Config") -> None:
""" called before test process is exited.
:param _pytest.config.Config config: pytest config object
@ -484,7 +527,9 @@ def pytest_unconfigure(config):
# -------------------------------------------------------------------------
def pytest_assertrepr_compare(config, op, left, right):
def pytest_assertrepr_compare(
config: "Config", op: str, left: object, right: object
) -> Optional[List[str]]:
"""return explanation for comparisons in failing assert expressions.
Return None for no custom explanation, otherwise return a list
@ -496,7 +541,7 @@ def pytest_assertrepr_compare(config, op, left, right):
"""
def pytest_assertion_pass(item, lineno, orig, expl):
def pytest_assertion_pass(item, lineno: int, orig: str, expl: str) -> None:
"""
**(Experimental)**
@ -539,7 +584,9 @@ def pytest_assertion_pass(item, lineno, orig, expl):
# -------------------------------------------------------------------------
def pytest_report_header(config, startdir):
def pytest_report_header(
config: "Config", startdir: py.path.local
) -> Union[str, List[str]]:
""" return a string or list of strings to be displayed as header info for terminal reporting.
:param _pytest.config.Config config: pytest config object
@ -560,7 +607,9 @@ def pytest_report_header(config, startdir):
"""
def pytest_report_collectionfinish(config, startdir, items):
def pytest_report_collectionfinish(
config: "Config", startdir: py.path.local, items: "Sequence[Item]"
) -> Union[str, List[str]]:
"""
.. versionadded:: 3.2
@ -610,7 +659,9 @@ def pytest_report_teststatus(
"""
def pytest_terminal_summary(terminalreporter, exitstatus, config):
def pytest_terminal_summary(
terminalreporter: "TerminalReporter", exitstatus: "ExitCode", config: "Config",
) -> None:
"""Add a section to terminal summary reporting.
:param _pytest.terminal.TerminalReporter terminalreporter: the internal terminal reporter object
@ -625,8 +676,8 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
@hookspec(historic=True, warn_on_impl=WARNING_CAPTURED_HOOK)
def pytest_warning_captured(
warning_message: "warnings.WarningMessage",
when: str,
item,
when: "Literal['config', 'collect', 'runtest']",
item: "Optional[Item]",
location: Optional[Tuple[str, int, str]],
) -> None:
"""(**Deprecated**) Process a warning captured by the internal pytest warnings plugin.
@ -660,7 +711,7 @@ def pytest_warning_captured(
@hookspec(historic=True)
def pytest_warning_recorded(
warning_message: "warnings.WarningMessage",
when: str,
when: "Literal['config', 'collect', 'runtest']",
nodeid: str,
location: Optional[Tuple[str, int, str]],
) -> None:
@ -714,7 +765,9 @@ def pytest_keyboard_interrupt(excinfo):
""" called for keyboard interrupt. """
def pytest_exception_interact(node, call, report):
def pytest_exception_interact(
node: "Node", call: "CallInfo[object]", report: "Union[CollectReport, TestReport]"
) -> None:
"""called when an exception was raised which can potentially be
interactively handled.
@ -723,7 +776,7 @@ def pytest_exception_interact(node, call, report):
"""
def pytest_enter_pdb(config, pdb):
def pytest_enter_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
""" called upon pdb.set_trace(), can be used by plugins to take special
action just before the python debugger enters in interactive mode.
@ -732,7 +785,7 @@ def pytest_enter_pdb(config, pdb):
"""
def pytest_leave_pdb(config, pdb):
def pytest_leave_pdb(config: "Config", pdb: "pdb.Pdb") -> None:
""" called when leaving pdb (e.g. with continue after pdb.set_trace()).
Can be used by plugins to take special action just after the python

View File

@ -14,6 +14,11 @@ import platform
import re
import sys
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import py
@ -21,10 +26,19 @@ import pytest
from _pytest import deprecated
from _pytest import nodes
from _pytest import timing
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import filename_arg
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.reports import TestReport
from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
from _pytest.warnings import _issue_warning_captured
if TYPE_CHECKING:
from typing import Type
xml_key = StoreKey["LogXML"]()
@ -54,8 +68,8 @@ del _legal_xml_re
_py_ext_re = re.compile(r"\.py$")
def bin_xml_escape(arg):
def repl(matchobj):
def bin_xml_escape(arg: str) -> py.xml.raw:
def repl(matchobj: "re.Match[str]") -> str:
i = ord(matchobj.group())
if i <= 0xFF:
return "#x%02X" % i
@ -65,7 +79,7 @@ def bin_xml_escape(arg):
return py.xml.raw(illegal_xml_re.sub(repl, py.xml.escape(arg)))
def merge_family(left, right):
def merge_family(left, right) -> None:
result = {}
for kl, vl in left.items():
for kr, vr in right.items():
@ -88,28 +102,27 @@ families["xunit2"] = families["_base"]
class _NodeReporter:
def __init__(self, nodeid, xml):
def __init__(self, nodeid: Union[str, TestReport], xml: "LogXML") -> None:
self.id = nodeid
self.xml = xml
self.add_stats = self.xml.add_stats
self.family = self.xml.family
self.duration = 0
self.properties = []
self.nodes = []
self.testcase = None
self.attrs = {}
self.properties = [] # type: List[Tuple[str, py.xml.raw]]
self.nodes = [] # type: List[py.xml.Tag]
self.attrs = {} # type: Dict[str, Union[str, py.xml.raw]]
def append(self, node):
def append(self, node: py.xml.Tag) -> None:
self.xml.add_stats(type(node).__name__)
self.nodes.append(node)
def add_property(self, name, value):
def add_property(self, name: str, value: str) -> None:
self.properties.append((str(name), bin_xml_escape(value)))
def add_attribute(self, name, value):
def add_attribute(self, name: str, value: str) -> None:
self.attrs[str(name)] = bin_xml_escape(value)
def make_properties_node(self):
def make_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any.
"""
if self.properties:
@ -121,8 +134,7 @@ class _NodeReporter:
)
return ""
def record_testreport(self, testreport):
assert not self.testcase
def record_testreport(self, testreport: TestReport) -> None:
names = mangle_test_address(testreport.nodeid)
existing_attrs = self.attrs
classnames = names[:-1]
@ -132,9 +144,9 @@ class _NodeReporter:
"classname": ".".join(classnames),
"name": bin_xml_escape(names[-1]),
"file": testreport.location[0],
}
} # type: Dict[str, Union[str, py.xml.raw]]
if testreport.location[1] is not None:
attrs["line"] = testreport.location[1]
attrs["line"] = str(testreport.location[1])
if hasattr(testreport, "url"):
attrs["url"] = testreport.url
self.attrs = attrs
@ -152,19 +164,19 @@ class _NodeReporter:
temp_attrs[key] = self.attrs[key]
self.attrs = temp_attrs
def to_xml(self):
def to_xml(self) -> py.xml.Tag:
testcase = Junit.testcase(time="%.3f" % self.duration, **self.attrs)
testcase.append(self.make_properties_node())
for node in self.nodes:
testcase.append(node)
return testcase
def _add_simple(self, kind, message, data=None):
def _add_simple(self, kind: "Type[py.xml.Tag]", message: str, data=None) -> None:
data = bin_xml_escape(data)
node = kind(data, message=message)
self.append(node)
def write_captured_output(self, report):
def write_captured_output(self, report: TestReport) -> None:
if not self.xml.log_passing_tests and report.passed:
return
@ -187,21 +199,22 @@ class _NodeReporter:
if content_all:
self._write_content(report, content_all, "system-out")
def _prepare_content(self, content, header):
def _prepare_content(self, content: str, header: str) -> str:
return "\n".join([header.center(80, "-"), content, ""])
def _write_content(self, report, content, jheader):
def _write_content(self, report: TestReport, content: str, jheader: str) -> None:
tag = getattr(Junit, jheader)
self.append(tag(bin_xml_escape(content)))
def append_pass(self, report):
def append_pass(self, report: TestReport) -> None:
self.add_stats("passed")
def append_failure(self, report):
def append_failure(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
if hasattr(report, "wasxfail"):
self._add_simple(Junit.skipped, "xfail-marked test passes unexpectedly")
else:
assert report.longrepr is not None
if getattr(report.longrepr, "reprcrash", None) is not None:
message = report.longrepr.reprcrash.message
else:
@ -211,23 +224,24 @@ class _NodeReporter:
fail.append(bin_xml_escape(report.longrepr))
self.append(fail)
def append_collect_error(self, report):
def append_collect_error(self, report: TestReport) -> None:
# msg = str(report.longrepr.reprtraceback.extraline)
assert report.longrepr is not None
self.append(
Junit.error(bin_xml_escape(report.longrepr), message="collection failure")
)
def append_collect_skipped(self, report):
def append_collect_skipped(self, report: TestReport) -> None:
self._add_simple(Junit.skipped, "collection skipped", report.longrepr)
def append_error(self, report):
def append_error(self, report: TestReport) -> None:
if report.when == "teardown":
msg = "test teardown failure"
else:
msg = "test setup failure"
self._add_simple(Junit.error, msg, report.longrepr)
def append_skipped(self, report):
def append_skipped(self, report: TestReport) -> None:
if hasattr(report, "wasxfail"):
xfailreason = report.wasxfail
if xfailreason.startswith("reason: "):
@ -238,6 +252,7 @@ class _NodeReporter:
)
)
else:
assert report.longrepr is not None
filename, lineno, skipreason = report.longrepr
if skipreason.startswith("Skipped: "):
skipreason = skipreason[9:]
@ -252,13 +267,17 @@ class _NodeReporter:
)
self.write_captured_output(report)
def finalize(self):
def finalize(self) -> None:
data = self.to_xml().unicode(indent=0)
self.__dict__.clear()
self.to_xml = lambda: py.xml.raw(data)
# Type ignored becuase mypy doesn't like overriding a method.
# Also the return value doesn't match...
self.to_xml = lambda: py.xml.raw(data) # type: ignore # noqa: F821
def _warn_incompatibility_with_xunit2(request, fixture_name):
def _warn_incompatibility_with_xunit2(
request: FixtureRequest, fixture_name: str
) -> None:
"""Emits a PytestWarning about the given fixture being incompatible with newer xunit revisions"""
from _pytest.warning_types import PytestWarning
@ -274,7 +293,7 @@ def _warn_incompatibility_with_xunit2(request, fixture_name):
@pytest.fixture
def record_property(request):
def record_property(request: FixtureRequest):
"""Add an extra properties the calling test.
User properties become part of the test report and are available to the
configured reporters, like JUnit XML.
@ -288,14 +307,14 @@ def record_property(request):
"""
_warn_incompatibility_with_xunit2(request, "record_property")
def append_property(name, value):
def append_property(name: str, value: object) -> None:
request.node.user_properties.append((name, value))
return append_property
@pytest.fixture
def record_xml_attribute(request):
def record_xml_attribute(request: FixtureRequest):
"""Add extra xml attributes to the tag for the calling test.
The fixture is callable with ``(name, value)``, with value being
automatically xml-encoded
@ -309,7 +328,7 @@ def record_xml_attribute(request):
_warn_incompatibility_with_xunit2(request, "record_xml_attribute")
# Declare noop
def add_attr_noop(name, value):
def add_attr_noop(name: str, value: str) -> None:
pass
attr_func = add_attr_noop
@ -322,7 +341,7 @@ def record_xml_attribute(request):
return attr_func
def _check_record_param_type(param, v):
def _check_record_param_type(param: str, v: str) -> None:
"""Used by record_testsuite_property to check that the given parameter name is of the proper
type"""
__tracebackhide__ = True
@ -332,7 +351,7 @@ def _check_record_param_type(param, v):
@pytest.fixture(scope="session")
def record_testsuite_property(request):
def record_testsuite_property(request: FixtureRequest):
"""
Records a new ``<property>`` tag as child of the root ``<testsuite>``. This is suitable to
writing global information regarding the entire test suite, and is compatible with ``xunit2`` JUnit family.
@ -350,7 +369,7 @@ def record_testsuite_property(request):
__tracebackhide__ = True
def record_func(name, value):
def record_func(name: str, value: str):
"""noop function in case --junitxml was not passed in the command-line"""
__tracebackhide__ = True
_check_record_param_type("name", name)
@ -361,7 +380,7 @@ def record_testsuite_property(request):
return record_func
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group.addoption(
"--junitxml",
@ -406,7 +425,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
xmlpath = config.option.xmlpath
# prevent opening xmllog on slave nodes (xdist)
if xmlpath and not hasattr(config, "slaveinput"):
@ -426,14 +445,14 @@ def pytest_configure(config):
config.pluginmanager.register(config._store[xml_key])
def pytest_unconfigure(config):
def pytest_unconfigure(config: Config) -> None:
xml = config._store.get(xml_key, None)
if xml:
del config._store[xml_key]
config.pluginmanager.unregister(xml)
def mangle_test_address(address):
def mangle_test_address(address: str) -> List[str]:
path, possible_open_bracket, params = address.partition("[")
names = path.split("::")
try:
@ -452,13 +471,13 @@ class LogXML:
def __init__(
self,
logfile,
prefix,
suite_name="pytest",
logging="no",
report_duration="total",
prefix: Optional[str],
suite_name: str = "pytest",
logging: str = "no",
report_duration: str = "total",
family="xunit1",
log_passing_tests=True,
):
log_passing_tests: bool = True,
) -> None:
logfile = os.path.expanduser(os.path.expandvars(logfile))
self.logfile = os.path.normpath(os.path.abspath(logfile))
self.prefix = prefix
@ -467,20 +486,24 @@ class LogXML:
self.log_passing_tests = log_passing_tests
self.report_duration = report_duration
self.family = family
self.stats = dict.fromkeys(["error", "passed", "failure", "skipped"], 0)
self.node_reporters = {} # nodeid -> _NodeReporter
self.node_reporters_ordered = []
self.global_properties = []
self.stats = dict.fromkeys(
["error", "passed", "failure", "skipped"], 0
) # type: Dict[str, int]
self.node_reporters = (
{}
) # type: Dict[Tuple[Union[str, TestReport], object], _NodeReporter]
self.node_reporters_ordered = [] # type: List[_NodeReporter]
self.global_properties = [] # type: List[Tuple[str, py.xml.raw]]
# List of reports that failed on call but teardown is pending.
self.open_reports = []
self.open_reports = [] # type: List[TestReport]
self.cnt_double_fail_tests = 0
# Replaces convenience family with real family
if self.family == "legacy":
self.family = "xunit1"
def finalize(self, report):
def finalize(self, report: TestReport) -> None:
nodeid = getattr(report, "nodeid", report)
# local hack to handle xdist report order
slavenode = getattr(report, "node", None)
@ -488,8 +511,8 @@ class LogXML:
if reporter is not None:
reporter.finalize()
def node_reporter(self, report):
nodeid = getattr(report, "nodeid", report)
def node_reporter(self, report: Union[TestReport, str]) -> _NodeReporter:
nodeid = getattr(report, "nodeid", report) # type: Union[str, TestReport]
# local hack to handle xdist report order
slavenode = getattr(report, "node", None)
@ -506,16 +529,16 @@ class LogXML:
return reporter
def add_stats(self, key):
def add_stats(self, key: str) -> None:
if key in self.stats:
self.stats[key] += 1
def _opentestcase(self, report):
def _opentestcase(self, report: TestReport) -> _NodeReporter:
reporter = self.node_reporter(report)
reporter.record_testreport(report)
return reporter
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: TestReport) -> None:
"""handle a setup/call/teardown report, generating the appropriate
xml tags as necessary.
@ -583,7 +606,7 @@ class LogXML:
reporter.write_captured_output(report)
for propname, propvalue in report.user_properties:
reporter.add_property(propname, propvalue)
reporter.add_property(propname, str(propvalue))
self.finalize(report)
report_wid = getattr(report, "worker_id", None)
@ -603,7 +626,7 @@ class LogXML:
if close_report:
self.open_reports.remove(close_report)
def update_testcase_duration(self, report):
def update_testcase_duration(self, report: TestReport) -> None:
"""accumulates total duration for nodeid from given report and updates
the Junit.testcase with the new total if already created.
"""
@ -611,7 +634,7 @@ class LogXML:
reporter = self.node_reporter(report)
reporter.duration += getattr(report, "duration", 0.0)
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: TestReport) -> None:
if not report.passed:
reporter = self._opentestcase(report)
if report.failed:
@ -619,15 +642,15 @@ class LogXML:
else:
reporter.append_collect_skipped(report)
def pytest_internalerror(self, excrepr):
def pytest_internalerror(self, excrepr) -> None:
reporter = self.node_reporter("internal")
reporter.attrs.update(classname="pytest", name="internal")
reporter._add_simple(Junit.error, "internal error", excrepr)
def pytest_sessionstart(self):
def pytest_sessionstart(self) -> None:
self.suite_start_time = timing.time()
def pytest_sessionfinish(self):
def pytest_sessionfinish(self) -> None:
dirname = os.path.dirname(os.path.abspath(self.logfile))
if not os.path.isdir(dirname):
os.makedirs(dirname)
@ -648,10 +671,10 @@ class LogXML:
self._get_global_properties_node(),
[x.to_xml() for x in self.node_reporters_ordered],
name=self.suite_name,
errors=self.stats["error"],
failures=self.stats["failure"],
skipped=self.stats["skipped"],
tests=numtests,
errors=str(self.stats["error"]),
failures=str(self.stats["failure"]),
skipped=str(self.stats["skipped"]),
tests=str(numtests),
time="%.3f" % suite_time_delta,
timestamp=datetime.fromtimestamp(self.suite_start_time).isoformat(),
hostname=platform.node(),
@ -659,15 +682,15 @@ class LogXML:
logfile.write(Junit.testsuites([suite_node]).unicode(indent=0))
logfile.close()
def pytest_terminal_summary(self, terminalreporter):
def pytest_terminal_summary(self, terminalreporter: TerminalReporter) -> None:
terminalreporter.write_sep("-", "generated xml file: {}".format(self.logfile))
def add_global_property(self, name, value):
def add_global_property(self, name: str, value: str) -> None:
__tracebackhide__ = True
_check_record_param_type("name", name)
self.global_properties.append((name, bin_xml_escape(value)))
def _get_global_properties_node(self):
def _get_global_properties_node(self) -> Union[py.xml.Tag, str]:
"""Return a Junit node containing custom properties, if any.
"""
if self.global_properties:

View File

@ -11,16 +11,24 @@ from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import pytest
from _pytest import nodes
from _pytest._io import TerminalWriter
from _pytest.capture import CaptureManager
from _pytest.compat import nullcontext
from _pytest.config import _strtobool
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.pathlib import Path
from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s"
@ -30,7 +38,7 @@ catch_log_handler_key = StoreKey["LogCaptureHandler"]()
catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]()
def _remove_ansi_escape_sequences(text):
def _remove_ansi_escape_sequences(text: str) -> str:
return _ANSI_ESCAPE_SEQ.sub("", text)
@ -50,7 +58,7 @@ class ColoredLevelFormatter(logging.Formatter):
} # type: Mapping[int, AbstractSet[str]]
LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)")
def __init__(self, terminalwriter, *args, **kwargs) -> None:
def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._original_fmt = self._style._fmt
self._level_to_fmt_mapping = {} # type: Dict[int, str]
@ -75,7 +83,7 @@ class ColoredLevelFormatter(logging.Formatter):
colorized_formatted_levelname, self._fmt
)
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt)
self._style._fmt = fmt
return super().format(record)
@ -88,18 +96,20 @@ class PercentStyleMultiline(logging.PercentStyle):
formats the message as if each line were logged separately.
"""
def __init__(self, fmt, auto_indent):
def __init__(self, fmt: str, auto_indent: Union[int, str, bool, None]) -> None:
super().__init__(fmt)
self._auto_indent = self._get_auto_indent(auto_indent)
@staticmethod
def _update_message(record_dict, message):
def _update_message(
record_dict: Dict[str, object], message: str
) -> Dict[str, object]:
tmp = record_dict.copy()
tmp["message"] = message
return tmp
@staticmethod
def _get_auto_indent(auto_indent_option) -> int:
def _get_auto_indent(auto_indent_option: Union[int, str, bool, None]) -> int:
"""Determines the current auto indentation setting
Specify auto indent behavior (on/off/fixed) by passing in
@ -129,7 +139,9 @@ class PercentStyleMultiline(logging.PercentStyle):
>0 (explicitly set indentation position).
"""
if type(auto_indent_option) is int:
if auto_indent_option is None:
return 0
elif type(auto_indent_option) is int:
return int(auto_indent_option)
elif type(auto_indent_option) is str:
try:
@ -147,11 +159,11 @@ class PercentStyleMultiline(logging.PercentStyle):
return 0
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
if "\n" in record.message:
if hasattr(record, "auto_indent"):
# passed in from the "extra={}" kwarg on the call to logging.log()
auto_indent = self._get_auto_indent(record.auto_indent)
auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] # noqa: F821
else:
auto_indent = self._auto_indent
@ -171,7 +183,7 @@ class PercentStyleMultiline(logging.PercentStyle):
return self._fmt % record.__dict__
def get_option_ini(config, *names):
def get_option_ini(config: Config, *names: str):
for name in names:
ret = config.getoption(name) # 'default' arg won't work as expected
if ret is None:
@ -180,7 +192,7 @@ def get_option_ini(config, *names):
return ret
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
"""Add options to control log capturing."""
group = parser.getgroup("logging")
@ -266,13 +278,16 @@ def pytest_addoption(parser):
)
_HandlerType = TypeVar("_HandlerType", bound=logging.Handler)
# Not using @contextmanager for performance reasons.
class catching_logs:
"""Context manager that prepares the whole logging machinery properly."""
__slots__ = ("handler", "level", "orig_level")
def __init__(self, handler, level=None):
def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None:
self.handler = handler
self.level = level
@ -328,7 +343,7 @@ class LogCaptureFixture:
"""Creates a new funcarg."""
self._item = item
# dict of log name -> log level
self._initial_log_levels = {} # type: Dict[str, int]
self._initial_log_levels = {} # type: Dict[Optional[str], int]
def _finalize(self) -> None:
"""Finalizes the fixture.
@ -362,17 +377,17 @@ class LogCaptureFixture:
return self._item._store[catch_log_records_key].get(when, [])
@property
def text(self):
def text(self) -> str:
"""Returns the formatted log text."""
return _remove_ansi_escape_sequences(self.handler.stream.getvalue())
@property
def records(self):
def records(self) -> List[logging.LogRecord]:
"""Returns the list of log records."""
return self.handler.records
@property
def record_tuples(self):
def record_tuples(self) -> List[Tuple[str, int, str]]:
"""Returns a list of a stripped down version of log records intended
for use in assertion comparison.
@ -383,7 +398,7 @@ class LogCaptureFixture:
return [(r.name, r.levelno, r.getMessage()) for r in self.records]
@property
def messages(self):
def messages(self) -> List[str]:
"""Returns a list of format-interpolated log messages.
Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list
@ -398,11 +413,11 @@ class LogCaptureFixture:
"""
return [r.getMessage() for r in self.records]
def clear(self):
def clear(self) -> None:
"""Reset the list of log records and the captured log text."""
self.handler.reset()
def set_level(self, level, logger=None):
def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None:
"""Sets the level for capturing of logs. The level will be restored to its previous value at the end of
the test.
@ -413,31 +428,32 @@ class LogCaptureFixture:
The levels of the loggers changed by this function will be restored to their initial values at the
end of the test.
"""
logger_name = logger
logger = logging.getLogger(logger_name)
logger_obj = logging.getLogger(logger)
# save the original log-level to restore it during teardown
self._initial_log_levels.setdefault(logger_name, logger.level)
logger.setLevel(level)
self._initial_log_levels.setdefault(logger, logger_obj.level)
logger_obj.setLevel(level)
@contextmanager
def at_level(self, level, logger=None):
def at_level(
self, level: int, logger: Optional[str] = None
) -> Generator[None, None, None]:
"""Context manager that sets the level for capturing of logs. After the end of the 'with' statement the
level is restored to its original value.
:param int level: the logger to level.
:param str logger: the logger to update the level. If not given, the root logger level is updated.
"""
logger = logging.getLogger(logger)
orig_level = logger.level
logger.setLevel(level)
logger_obj = logging.getLogger(logger)
orig_level = logger_obj.level
logger_obj.setLevel(level)
try:
yield
finally:
logger.setLevel(orig_level)
logger_obj.setLevel(orig_level)
@pytest.fixture
def caplog(request):
def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]:
"""Access and control log capturing.
Captured logs are available through the following properties/methods::
@ -478,7 +494,7 @@ def get_log_level_for_setting(config: Config, *setting_names: str) -> Optional[i
# run after terminalreporter/capturemanager are configured
@pytest.hookimpl(trylast=True)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
config.pluginmanager.register(LoggingPlugin(config), "logging-plugin")
@ -555,7 +571,7 @@ class LoggingPlugin:
return formatter
def set_log_path(self, fname):
def set_log_path(self, fname: str) -> None:
"""Public method, which can set filename parameter for
Logging.FileHandler(). Also creates parent directory if
it does not exist.
@ -563,15 +579,15 @@ class LoggingPlugin:
.. warning::
Please considered as an experimental API.
"""
fname = Path(fname)
fpath = Path(fname)
if not fname.is_absolute():
fname = Path(self._config.rootdir, fname)
if not fpath.is_absolute():
fpath = Path(self._config.rootdir, fpath)
if not fname.parent.exists():
fname.parent.mkdir(exist_ok=True, parents=True)
if not fpath.parent.exists():
fpath.parent.mkdir(exist_ok=True, parents=True)
stream = fname.open(mode="w", encoding="UTF-8")
stream = fpath.open(mode="w", encoding="UTF-8")
if sys.version_info >= (3, 7):
old_stream = self.log_file_handler.setStream(stream)
else:
@ -601,7 +617,7 @@ class LoggingPlugin:
return True
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_sessionstart(self):
def pytest_sessionstart(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionstart")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
@ -617,7 +633,7 @@ class LoggingPlugin:
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtestloop(self, session):
def pytest_runtestloop(self, session: Session) -> Generator[None, None, None]:
"""Runs all collected test items."""
if session.config.option.collectonly:
@ -654,20 +670,21 @@ class LoggingPlugin:
item.add_report_section(when, "log", log)
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_setup(self, item):
def pytest_runtest_setup(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("setup")
item._store[catch_log_records_key] = {}
empty = {} # type: Dict[str, List[logging.LogRecord]]
item._store[catch_log_records_key] = empty
yield from self._runtest_for(item, "setup")
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(self, item):
def pytest_runtest_call(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("call")
yield from self._runtest_for(item, "call")
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(self, item):
def pytest_runtest_teardown(self, item: nodes.Item) -> Generator[None, None, None]:
self.log_cli_handler.set_when("teardown")
yield from self._runtest_for(item, "teardown")
@ -675,11 +692,11 @@ class LoggingPlugin:
del item._store[catch_log_handler_key]
@pytest.hookimpl
def pytest_runtest_logfinish(self):
def pytest_runtest_logfinish(self) -> None:
self.log_cli_handler.set_when("finish")
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_sessionfinish(self):
def pytest_sessionfinish(self) -> Generator[None, None, None]:
self.log_cli_handler.set_when("sessionfinish")
with catching_logs(self.log_cli_handler, level=self.log_cli_level):
@ -687,7 +704,7 @@ class LoggingPlugin:
yield
@pytest.hookimpl
def pytest_unconfigure(self):
def pytest_unconfigure(self) -> None:
# Close the FileHandler explicitly.
# (logging.shutdown might have lost the weakref?!)
self.log_file_handler.close()
@ -712,29 +729,37 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
and won't appear in the terminal.
"""
def __init__(self, terminal_reporter, capture_manager):
# Officially stream needs to be a IO[str], but TerminalReporter
# isn't. So force it.
stream = None # type: TerminalReporter # type: ignore
def __init__(
self,
terminal_reporter: TerminalReporter,
capture_manager: Optional[CaptureManager],
) -> None:
"""
:param _pytest.terminal.TerminalReporter terminal_reporter:
:param _pytest.capture.CaptureManager capture_manager:
"""
logging.StreamHandler.__init__(self, stream=terminal_reporter)
logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type] # noqa: F821
self.capture_manager = capture_manager
self.reset()
self.set_when(None)
self._test_outcome_written = False
def reset(self):
def reset(self) -> None:
"""Reset the handler; should be called before the start of each test"""
self._first_record_emitted = False
def set_when(self, when):
def set_when(self, when: Optional[str]) -> None:
"""Prepares for the given test phase (setup/call/teardown)"""
self._when = when
self._section_name_shown = False
if when == "start":
self._test_outcome_written = False
def emit(self, record):
def emit(self, record: logging.LogRecord) -> None:
ctx_manager = (
self.capture_manager.global_and_fixture_disabled()
if self.capture_manager
@ -761,10 +786,10 @@ class _LiveLoggingStreamHandler(logging.StreamHandler):
class _LiveLoggingNullHandler(logging.NullHandler):
"""A handler used when live logging is disabled."""
def reset(self):
def reset(self) -> None:
pass
def set_when(self, when):
def set_when(self, when: str) -> None:
pass
def handleError(self, record: logging.LogRecord) -> None:

View File

@ -7,9 +7,11 @@ import sys
from typing import Callable
from typing import Dict
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
@ -18,15 +20,18 @@ import py
import _pytest._code
from _pytest import nodes
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.runner import collect_one_node
from _pytest.runner import SetupState
@ -38,7 +43,7 @@ if TYPE_CHECKING:
from _pytest.python import Package
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"directory patterns to avoid for recursion",
@ -241,7 +246,7 @@ def wrap_session(
return session.exitstatus
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main)
@ -258,11 +263,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None
def pytest_collection(session):
def pytest_collection(session: "Session") -> Sequence[nodes.Item]:
return session.perform_collect()
def pytest_runtestloop(session):
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
@ -282,7 +287,7 @@ def pytest_runtestloop(session):
return True
def _in_venv(path):
def _in_venv(path: py.path.local) -> bool:
"""Attempts to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script"""
bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin")
@ -328,7 +333,7 @@ def pytest_ignore_collect(
return None
def pytest_collection_modifyitems(items, config):
def pytest_collection_modifyitems(items: List[nodes.Item], config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes:
return
@ -385,8 +390,8 @@ class Session(nodes.FSCollector):
)
self.testsfailed = 0
self.testscollected = 0
self.shouldstop = False
self.shouldfail = False
self.shouldstop = False # type: Union[bool, str]
self.shouldfail = False # type: Union[bool, str]
self.trace = config.trace.root.get("collection")
self.startdir = config.invocation_dir
self._initialpaths = frozenset() # type: FrozenSet[py.path.local]
@ -412,10 +417,11 @@ class Session(nodes.FSCollector):
self.config.pluginmanager.register(self, name="session")
@classmethod
def from_config(cls, config):
return cls._create(config)
def from_config(cls, config: Config) -> "Session":
session = cls._create(config) # type: Session
return session
def __repr__(self):
def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__,
self.name,
@ -429,14 +435,16 @@ class Session(nodes.FSCollector):
return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True)
def pytest_collectstart(self):
def pytest_collectstart(self) -> None:
if self.shouldfail:
raise self.Failed(self.shouldfail)
if self.shouldstop:
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(
self, report: Union[TestReport, CollectReport]
) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
@ -445,13 +453,27 @@ class Session(nodes.FSCollector):
pytest_collectreport = pytest_runtest_logreport
def isinitpath(self, path):
def isinitpath(self, path: py.path.local) -> bool:
return path in self._initialpaths
def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath)
def perform_collect(self, args=None, genitems=True):
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
hook = self.config.hook
try:
items = self._perform_collect(args, genitems)
@ -464,15 +486,29 @@ class Session(nodes.FSCollector):
self.testscollected = len(items)
return items
def _perform_collect(self, args, genitems):
@overload
def _perform_collect(
self, args: Optional[Sequence[str]], genitems: "Literal[True]"
) -> List[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Union[List[Union[nodes.Item]], List[Union[nodes.Item, nodes.Collector]]]:
raise NotImplementedError()
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Union[List[Union[nodes.Item]], List[Union[nodes.Item, nodes.Collector]]]:
if args is None:
args = self.config.args
self.trace("perform_collect", self, args)
self.trace.root.indent += 1
self._notfound = []
self._notfound = [] # type: List[Tuple[str, NoMatch]]
initialpaths = [] # type: List[py.path.local]
self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = []
self.items = items = [] # type: List[nodes.Item]
for arg in args:
fspath, parts = self._parsearg(arg)
self._initial_parts.append((fspath, parts))
@ -495,7 +531,7 @@ class Session(nodes.FSCollector):
self.items.extend(self.genitems(node))
return items
def collect(self):
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
for fspath, parts in self._initial_parts:
self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1
@ -513,7 +549,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache3.clear()
self._collection_pkg_roots.clear()
def _collect(self, argpath, names):
def _collect(
self, argpath: py.path.local, names: List[str]
) -> Iterator[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package
# Start with a Session root, and delve to argpath item (dir or file)
@ -541,7 +579,7 @@ class Session(nodes.FSCollector):
if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set()
seen_dirs = set() # type: Set[py.path.local]
for path in argpath.visit(
fil=self._visit_filter, rec=self._recurse, bf=True, sort=True
):
@ -582,8 +620,9 @@ class Session(nodes.FSCollector):
# Module itself, so just use that. If this special case isn't taken, then all
# the files in the package will be yielded.
if argpath.basename == "__init__.py":
assert isinstance(m[0], nodes.Collector)
try:
yield next(m[0].collect())
yield next(iter(m[0].collect()))
except StopIteration:
# The package collects nothing with only an __init__.py
# file in it, which gets ignored by the default
@ -593,10 +632,11 @@ class Session(nodes.FSCollector):
yield from m
@staticmethod
def _visit_filter(f):
return f.check(file=1)
def _visit_filter(f: py.path.local) -> bool:
# TODO: Remove type: ignore once `py` is typed.
return f.check(file=1) # type: ignore
def _tryconvertpyarg(self, x):
def _tryconvertpyarg(self, x: str) -> str:
"""Convert a dotted module name to path."""
try:
spec = importlib.util.find_spec(x)
@ -605,14 +645,14 @@ class Session(nodes.FSCollector):
# ValueError: not a module name
except (AttributeError, ImportError, ValueError):
return x
if spec is None or spec.origin in {None, "namespace"}:
if spec is None or spec.origin is None or spec.origin == "namespace":
return x
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
else:
return spec.origin
def _parsearg(self, arg):
def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]:
""" return (fspath, names) tuple after checking the file exists. """
strpath, *parts = str(arg).split("::")
if self.config.option.pyargs:
@ -628,7 +668,9 @@ class Session(nodes.FSCollector):
fspath = fspath.realpath()
return (fspath, parts)
def matchnodes(self, matching, names):
def matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self.trace("matchnodes", matching, names)
self.trace.root.indent += 1
nodes = self._matchnodes(matching, names)
@ -639,13 +681,15 @@ class Session(nodes.FSCollector):
raise NoMatch(matching, names[:1])
return nodes
def _matchnodes(self, matching, names):
def _matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if not matching or not names:
return matching
name = names[0]
assert name
nextnames = names[1:]
resultnodes = []
resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]]
for node in matching:
if isinstance(node, nodes.Item):
if not names:
@ -676,7 +720,9 @@ class Session(nodes.FSCollector):
node.ihook.pytest_collectreport(report=rep)
return resultnodes
def genitems(self, node):
def genitems(
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)

View File

@ -1,7 +1,10 @@
""" generic mechanism for marking and selecting python functions. """
import typing
import warnings
from typing import AbstractSet
from typing import List
from typing import Optional
from typing import Union
import attr
@ -16,8 +19,10 @@ from .structures import MarkGenerator
from .structures import ParameterSet
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.deprecated import MINUS_K_COLON
from _pytest.deprecated import MINUS_K_DASH
from _pytest.store import StoreKey
@ -25,13 +30,18 @@ from _pytest.store import StoreKey
if TYPE_CHECKING:
from _pytest.nodes import Item
__all__ = ["Mark", "MarkDecorator", "MarkGenerator", "get_empty_parameterset_mark"]
old_mark_config_key = StoreKey[Optional[Config]]()
def param(*values, **kw):
def param(
*values: object,
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
id: Optional[str] = None
) -> ParameterSet:
"""Specify a parameter in `pytest.mark.parametrize`_ calls or
:ref:`parametrized fixtures <fixture-parametrize-marks>`.
@ -48,10 +58,10 @@ def param(*values, **kw):
:keyword marks: a single mark or a list of marks to be applied to this parameter set.
:keyword str id: the id to attribute to this parameter set.
"""
return ParameterSet.param(*values, **kw)
return ParameterSet.param(*values, marks=marks, id=id)
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group._addoption(
"-k",
@ -94,7 +104,7 @@ def pytest_addoption(parser):
@hookimpl(tryfirst=True)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
import _pytest.config
if config.option.markers:
@ -110,6 +120,8 @@ def pytest_cmdline_main(config):
config._ensure_unconfigure()
return 0
return None
@attr.s(slots=True)
class KeywordMatcher:
@ -135,9 +147,9 @@ class KeywordMatcher:
# Add the names of the current item and any parent items
import pytest
for item in item.listchain():
if not isinstance(item, (pytest.Instance, pytest.Session)):
mapped_names.add(item.name)
for node in item.listchain():
if not isinstance(node, (pytest.Instance, pytest.Session)):
mapped_names.add(node.name)
# Add the names added as extra keywords to current or parent items
mapped_names.update(item.listextrakeywords())
@ -162,7 +174,7 @@ class KeywordMatcher:
return False
def deselect_by_keyword(items, config):
def deselect_by_keyword(items: "List[Item]", config: Config) -> None:
keywordexpr = config.option.keyword.lstrip()
if not keywordexpr:
return
@ -218,7 +230,7 @@ class MarkMatcher:
return name in self.own_mark_names
def deselect_by_mark(items, config):
def deselect_by_mark(items: "List[Item]", config: Config) -> None:
matchexpr = config.option.markexpr
if not matchexpr:
return
@ -243,12 +255,12 @@ def deselect_by_mark(items, config):
items[:] = remaining
def pytest_collection_modifyitems(items, config):
def pytest_collection_modifyitems(items: "List[Item]", config: Config) -> None:
deselect_by_keyword(items, config)
deselect_by_mark(items, config)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
config._store[old_mark_config_key] = MARK_GEN._config
MARK_GEN._config = config
@ -261,5 +273,5 @@ def pytest_configure(config):
)
def pytest_unconfigure(config):
def pytest_unconfigure(config: Config) -> None:
MARK_GEN._config = config._store.get(old_mark_config_key, None)

View File

@ -4,10 +4,14 @@ import sys
import traceback
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from ..outcomes import fail
from ..outcomes import TEST_OUTCOME
from .structures import Mark
from _pytest.config import Config
from _pytest.nodes import Item
from _pytest.store import StoreKey
@ -28,29 +32,29 @@ def cached_eval(config: Config, expr: str, d: Dict[str, object]) -> Any:
class MarkEvaluator:
def __init__(self, item, name):
def __init__(self, item: Item, name: str) -> None:
self.item = item
self._marks = None
self._mark = None
self._marks = None # type: Optional[List[Mark]]
self._mark = None # type: Optional[Mark]
self._mark_name = name
def __bool__(self):
def __bool__(self) -> bool:
# don't cache here to prevent staleness
return bool(self._get_marks())
def wasvalid(self):
def wasvalid(self) -> bool:
return not hasattr(self, "exc")
def _get_marks(self):
def _get_marks(self) -> List[Mark]:
return list(self.item.iter_markers(name=self._mark_name))
def invalidraise(self, exc):
def invalidraise(self, exc) -> Optional[bool]:
raises = self.get("raises")
if not raises:
return
return None
return not isinstance(exc, raises)
def istrue(self):
def istrue(self) -> bool:
try:
return self._istrue()
except TEST_OUTCOME:
@ -69,25 +73,26 @@ class MarkEvaluator:
pytrace=False,
)
def _getglobals(self):
def _getglobals(self) -> Dict[str, object]:
d = {"os": os, "sys": sys, "platform": platform, "config": self.item.config}
if hasattr(self.item, "obj"):
d.update(self.item.obj.__globals__)
d.update(self.item.obj.__globals__) # type: ignore[attr-defined] # noqa: F821
return d
def _istrue(self):
def _istrue(self) -> bool:
if hasattr(self, "result"):
return self.result
result = getattr(self, "result") # type: bool
return result
self._marks = self._get_marks()
if self._marks:
self.result = False
for mark in self._marks:
self._mark = mark
if "condition" in mark.kwargs:
args = (mark.kwargs["condition"],)
else:
if "condition" not in mark.kwargs:
args = mark.args
else:
args = (mark.kwargs["condition"],)
for expr in args:
self.expr = expr

View File

@ -1,15 +1,18 @@
import collections.abc
import inspect
import typing
import warnings
from collections import namedtuple
from collections.abc import MutableMapping
from typing import Any
from typing import Callable
from typing import Iterable
from typing import List
from typing import Mapping
from typing import NamedTuple
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
@ -17,20 +20,30 @@ import attr
from .._code import getfslineno
from ..compat import ascii_escaped
from ..compat import NOTSET
from ..compat import NotSetType
from ..compat import overload
from ..compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.outcomes import fail
from _pytest.warning_types import PytestUnknownMarkWarning
if TYPE_CHECKING:
from _pytest.python import FunctionDefinition
EMPTY_PARAMETERSET_OPTION = "empty_parameter_set_mark"
def istestfunc(func):
def istestfunc(func) -> bool:
return (
hasattr(func, "__call__")
and getattr(func, "__name__", "<lambda>") != "<lambda>"
)
def get_empty_parameterset_mark(config, argnames, func):
def get_empty_parameterset_mark(
config: Config, argnames: Sequence[str], func
) -> "MarkDecorator":
from ..nodes import Collector
requested_mark = config.getini(EMPTY_PARAMETERSET_OPTION)
@ -53,16 +66,33 @@ def get_empty_parameterset_mark(config, argnames, func):
fs,
lineno,
)
return mark(reason=reason)
# Type ignored because MarkDecorator.__call__() is a bit tough to
# annotate ATM.
return mark(reason=reason) # type: ignore[no-any-return] # noqa: F723
class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
class ParameterSet(
NamedTuple(
"ParameterSet",
[
("values", Sequence[Union[object, NotSetType]]),
("marks", "typing.Collection[Union[MarkDecorator, Mark]]"),
("id", Optional[str]),
],
)
):
@classmethod
def param(cls, *values, marks=(), id=None):
def param(
cls,
*values: object,
marks: "Union[MarkDecorator, typing.Collection[Union[MarkDecorator, Mark]]]" = (),
id: Optional[str] = None
) -> "ParameterSet":
if isinstance(marks, MarkDecorator):
marks = (marks,)
else:
assert isinstance(marks, (tuple, list, set))
# TODO(py36): Change to collections.abc.Collection.
assert isinstance(marks, (collections.abc.Sequence, set))
if id is not None:
if not isinstance(id, str):
@ -73,7 +103,11 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
return cls(values, marks, id)
@classmethod
def extract_from(cls, parameterset, force_tuple=False):
def extract_from(
cls,
parameterset: Union["ParameterSet", Sequence[object], object],
force_tuple: bool = False,
) -> "ParameterSet":
"""
:param parameterset:
a legacy style parameterset that may or may not be a tuple,
@ -89,10 +123,20 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
if force_tuple:
return cls.param(parameterset)
else:
return cls(parameterset, marks=[], id=None)
# TODO: Refactor to fix this type-ignore. Currently the following
# type-checks but crashes:
#
# @pytest.mark.parametrize(('x', 'y'), [1, 2])
# def test_foo(x, y): pass
return cls(parameterset, marks=[], id=None) # type: ignore[arg-type] # noqa: F821
@staticmethod
def _parse_parametrize_args(argnames, argvalues, *args, **kwargs):
def _parse_parametrize_args(
argnames: Union[str, List[str], Tuple[str, ...]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
*args,
**kwargs
) -> Tuple[Union[List[str], Tuple[str, ...]], bool]:
if not isinstance(argnames, (tuple, list)):
argnames = [x.strip() for x in argnames.split(",") if x.strip()]
force_tuple = len(argnames) == 1
@ -101,13 +145,23 @@ class ParameterSet(namedtuple("ParameterSet", "values, marks, id")):
return argnames, force_tuple
@staticmethod
def _parse_parametrize_parameters(argvalues, force_tuple):
def _parse_parametrize_parameters(
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
force_tuple: bool,
) -> List["ParameterSet"]:
return [
ParameterSet.extract_from(x, force_tuple=force_tuple) for x in argvalues
]
@classmethod
def _for_parametrize(cls, argnames, argvalues, func, config, function_definition):
def _for_parametrize(
cls,
argnames: Union[str, List[str], Tuple[str, ...]],
argvalues: Iterable[Union["ParameterSet", Sequence[object], object]],
func,
config: Config,
function_definition: "FunctionDefinition",
) -> Tuple[Union[List[str], Tuple[str, ...]], List["ParameterSet"]]:
argnames, force_tuple = cls._parse_parametrize_args(argnames, argvalues)
parameters = cls._parse_parametrize_parameters(argvalues, force_tuple)
del argvalues
@ -189,6 +243,12 @@ class Mark:
)
# A generic parameter designating an object to which a Mark may
# be applied -- a test function (callable) or class.
# Note: a lambda is not allowed, but this can't be represented.
_Markable = TypeVar("_Markable", bound=Union[Callable[..., object], type])
@attr.s
class MarkDecorator:
"""A decorator for applying a mark on test functions and classes.
@ -260,7 +320,20 @@ class MarkDecorator:
mark = Mark(self.name, args, kwargs)
return self.__class__(self.mark.combined_with(mark))
def __call__(self, *args: object, **kwargs: object):
# Type ignored because the overloads overlap with an incompatible
# return type. Not much we can do about that. Thankfully mypy picks
# the first match so it works out even if we break the rules.
@overload
def __call__(self, arg: _Markable) -> _Markable: # type: ignore[misc] # noqa: F821
raise NotImplementedError()
@overload # noqa: F811
def __call__( # noqa: F811
self, *args: object, **kwargs: object
) -> "MarkDecorator":
raise NotImplementedError()
def __call__(self, *args: object, **kwargs: object): # noqa: F811
"""Call the MarkDecorator."""
if args and not kwargs:
func = args[0]
@ -271,7 +344,7 @@ class MarkDecorator:
return self.with_args(*args, **kwargs)
def get_unpacked_marks(obj):
def get_unpacked_marks(obj) -> List[Mark]:
"""
obtain the unpacked marks that are stored on an object
"""
@ -323,7 +396,7 @@ class MarkGenerator:
applies a 'slowtest' :class:`Mark` on ``test_function``.
"""
_config = None
_config = None # type: Optional[Config]
_markers = set() # type: Set[str]
def __getattr__(self, name: str) -> MarkDecorator:
@ -370,7 +443,7 @@ class MarkGenerator:
MARK_GEN = MarkGenerator()
class NodeKeywords(MutableMapping):
class NodeKeywords(collections.abc.MutableMapping):
def __init__(self, node):
self.node = node
self.parent = node.parent
@ -400,8 +473,8 @@ class NodeKeywords(MutableMapping):
seen.update(self.parent.keywords)
return seen
def __len__(self):
def __len__(self) -> int:
return len(self._seen())
def __repr__(self):
def __repr__(self) -> str:
return "<NodeKeywords for node {}>".format(self.node)

View File

@ -1,22 +1,26 @@
import os
import warnings
from functools import lru_cache
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
import py
import _pytest._code
from _pytest._code import getfslineno
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import ReprExceptionInfo
from _pytest._code.code import TerminalRepr
from _pytest.compat import cached_property
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ConftestImportFailure
@ -24,7 +28,6 @@ from _pytest.config import PytestPluginManager
from _pytest.deprecated import NODE_USE_FROM_PARENT
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import FixtureLookupError
from _pytest.fixtures import FixtureLookupErrorRepr
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import NodeKeywords
@ -33,8 +36,13 @@ from _pytest.pathlib import Path
from _pytest.store import Store
if TYPE_CHECKING:
from typing import Type
# Imported here due to circular import.
from _pytest.main import Session
from _pytest.warning_types import PytestWarning
from _pytest._code.code import _TracebackStyle
SEP = "/"
@ -42,7 +50,7 @@ tracebackcutdir = py.path.local(_pytest.__file__).dirpath()
@lru_cache(maxsize=None)
def _splitnode(nodeid):
def _splitnode(nodeid: str) -> Tuple[str, ...]:
"""Split a nodeid into constituent 'parts'.
Node IDs are strings, and can be things like:
@ -67,7 +75,7 @@ def _splitnode(nodeid):
return tuple(parts)
def ischildnode(baseid, nodeid):
def ischildnode(baseid: str, nodeid: str) -> bool:
"""Return True if the nodeid is a child node of the baseid.
E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp'
@ -79,6 +87,9 @@ def ischildnode(baseid, nodeid):
return node_parts[: len(base_parts)] == base_parts
_NodeType = TypeVar("_NodeType", bound="Node")
class NodeMeta(type):
def __call__(self, *k, **kw):
warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2)
@ -108,9 +119,9 @@ class Node(metaclass=NodeMeta):
def __init__(
self,
name: str,
parent: Optional["Node"] = None,
parent: "Optional[Node]" = None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
session: "Optional[Session]" = None,
fspath: Optional[py.path.local] = None,
nodeid: Optional[str] = None,
) -> None:
@ -122,7 +133,7 @@ class Node(metaclass=NodeMeta):
#: the pytest config object
if config:
self.config = config
self.config = config # type: Config
else:
if not parent:
raise TypeError("config or parent must be provided")
@ -188,10 +199,10 @@ class Node(metaclass=NodeMeta):
""" fspath sensitive hook proxy used to call pytest hooks"""
return self.session.gethookproxy(self.fspath)
def __repr__(self):
def __repr__(self) -> str:
return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None))
def warn(self, warning):
def warn(self, warning: "PytestWarning") -> None:
"""Issue a warning for this item.
Warnings will be displayed after the test session, unless explicitly suppressed
@ -216,29 +227,27 @@ class Node(metaclass=NodeMeta):
)
)
path, lineno = get_fslocation_from_item(self)
assert lineno is not None
warnings.warn_explicit(
warning,
category=None,
filename=str(path),
lineno=lineno + 1 if lineno is not None else None,
warning, category=None, filename=str(path), lineno=lineno + 1,
)
# methods for ordering nodes
@property
def nodeid(self):
def nodeid(self) -> str:
""" a ::-separated string denoting its collection tree address. """
return self._nodeid
def __hash__(self):
def __hash__(self) -> int:
return hash(self._nodeid)
def setup(self):
def setup(self) -> None:
pass
def teardown(self):
def teardown(self) -> None:
pass
def listchain(self):
def listchain(self) -> List["Node"]:
""" return list of all parent collectors up to self,
starting from root of collection tree. """
chain = []
@ -273,7 +282,7 @@ class Node(metaclass=NodeMeta):
else:
self.own_markers.insert(0, marker_.mark)
def iter_markers(self, name=None):
def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]:
"""
:param name: if given, filter the results by the name attribute
@ -281,7 +290,9 @@ class Node(metaclass=NodeMeta):
"""
return (x[1] for x in self.iter_markers_with_node(name=name))
def iter_markers_with_node(self, name=None):
def iter_markers_with_node(
self, name: Optional[str] = None
) -> Iterator[Tuple["Node", Mark]]:
"""
:param name: if given, filter the results by the name attribute
@ -293,7 +304,17 @@ class Node(metaclass=NodeMeta):
if name is None or getattr(mark, "name", None) == name:
yield node, mark
def get_closest_marker(self, name, default=None):
@overload
def get_closest_marker(self, name: str) -> Optional[Mark]:
raise NotImplementedError()
@overload # noqa: F811
def get_closest_marker(self, name: str, default: Mark) -> Mark: # noqa: F811
raise NotImplementedError()
def get_closest_marker( # noqa: F811
self, name: str, default: Optional[Mark] = None
) -> Optional[Mark]:
"""return the first marker matching the name, from closest (for example function) to farther level (for example
module level).
@ -302,17 +323,17 @@ class Node(metaclass=NodeMeta):
"""
return next(self.iter_markers(name=name), default)
def listextrakeywords(self):
def listextrakeywords(self) -> Set[str]:
""" Return a set of all extra keywords in self and any parents."""
extra_keywords = set() # type: Set[str]
for item in self.listchain():
extra_keywords.update(item.extra_keyword_matches)
return extra_keywords
def listnames(self):
def listnames(self) -> List[str]:
return [x.name for x in self.listchain()]
def addfinalizer(self, fin):
def addfinalizer(self, fin: Callable[[], object]) -> None:
""" register a function to be called when this node is finalized.
This method can only be called when this node is active
@ -320,20 +341,23 @@ class Node(metaclass=NodeMeta):
"""
self.session._setupstate.addfinalizer(fin, self)
def getparent(self, cls):
def getparent(self, cls: "Type[_NodeType]") -> Optional[_NodeType]:
""" get the next parent node (including ourself)
which is an instance of the given class"""
current = self # type: Optional[Node]
while current and not isinstance(current, cls):
current = current.parent
assert current is None or isinstance(current, cls)
return current
def _prunetraceback(self, excinfo):
pass
def _repr_failure_py(
self, excinfo: ExceptionInfo[BaseException], style=None,
) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
) -> TerminalRepr:
if isinstance(excinfo.value, ConftestImportFailure):
excinfo = ExceptionInfo(excinfo.value.excinfo)
if isinstance(excinfo.value, fail.Exception):
@ -383,8 +407,10 @@ class Node(metaclass=NodeMeta):
)
def repr_failure(
self, excinfo, style=None
) -> Union[str, ReprExceptionInfo, ExceptionChainRepr, FixtureLookupErrorRepr]:
self,
excinfo: ExceptionInfo[BaseException],
style: "Optional[_TracebackStyle]" = None,
) -> Union[str, TerminalRepr]:
"""
Return a representation of a collection or test failure.
@ -394,24 +420,26 @@ class Node(metaclass=NodeMeta):
def get_fslocation_from_item(
item: "Item",
node: "Node",
) -> Tuple[Union[str, py.path.local], Optional[int]]:
"""Tries to extract the actual location from an item, depending on available attributes:
"""Tries to extract the actual location from a node, depending on available attributes:
* "fslocation": a pair (path, lineno)
* "obj": a Python object that the item wraps.
* "location": a pair (path, lineno)
* "obj": a Python object that the node wraps.
* "fspath": just a path
:rtype: a tuple of (str|LocalPath, int) with filename and line number.
"""
try:
return item.location[:2]
except AttributeError:
pass
obj = getattr(item, "obj", None)
# See Item.location.
location = getattr(
node, "location", None
) # type: Optional[Tuple[str, Optional[int], str]]
if location is not None:
return location[:2]
obj = getattr(node, "obj", None)
if obj is not None:
return getfslineno(obj)
return getattr(item, "fspath", "unknown location"), -1
return getattr(node, "fspath", "unknown location"), -1
class Collector(Node):
@ -422,19 +450,22 @@ class Collector(Node):
class CollectError(Exception):
""" an error during collection, contains a custom message. """
def collect(self):
def collect(self) -> Iterable[Union["Item", "Collector"]]:
""" returns a list of children (items and collectors)
for this collection node.
"""
raise NotImplementedError("abstract")
def repr_failure(self, excinfo):
# TODO: This omits the style= parameter which breaks Liskov Substitution.
def repr_failure( # type: ignore[override] # noqa: F821
self, excinfo: ExceptionInfo[BaseException]
) -> Union[str, TerminalRepr]:
"""
Return a representation of a collection failure.
:param excinfo: Exception information for the failure.
"""
if excinfo.errisinstance(self.CollectError) and not self.config.getoption(
if isinstance(excinfo.value, self.CollectError) and not self.config.getoption(
"fulltrace", False
):
exc = excinfo.value
@ -476,7 +507,12 @@ class FSHookProxy:
class FSCollector(Collector):
def __init__(
self, fspath: py.path.local, parent=None, config=None, session=None, nodeid=None
self,
fspath: py.path.local,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
name = fspath.basename
if parent is not None:
@ -521,6 +557,9 @@ class FSCollector(Collector):
proxy = self.config.hook
return proxy
def gethookproxy(self, fspath: py.path.local):
raise NotImplementedError()
def _recurse(self, dirpath: py.path.local) -> bool:
if dirpath.basename == "__pycache__":
return False
@ -534,7 +573,12 @@ class FSCollector(Collector):
ihook.pytest_collect_directory(path=dirpath, parent=self)
return True
def _collectfile(self, path, handle_dupes=True):
def isinitpath(self, path: py.path.local) -> bool:
raise NotImplementedError()
def _collectfile(
self, path: py.path.local, handle_dupes: bool = True
) -> Sequence[Collector]:
assert (
path.isfile()
), "{!r} is not a file (isdir={!r}, exists={!r}, islink={!r})".format(
@ -554,7 +598,7 @@ class FSCollector(Collector):
else:
duplicate_paths.add(path)
return ihook.pytest_collect_file(path=path, parent=self)
return ihook.pytest_collect_file(path=path, parent=self) # type: ignore[no-any-return] # noqa: F723
class File(FSCollector):
@ -568,13 +612,20 @@ class Item(Node):
nextitem = None
def __init__(self, name, parent=None, config=None, session=None, nodeid=None):
def __init__(
self,
name,
parent=None,
config: Optional[Config] = None,
session: Optional["Session"] = None,
nodeid: Optional[str] = None,
) -> None:
super().__init__(name, parent, config, session, nodeid=nodeid)
self._report_sections = [] # type: List[Tuple[str, str, str]]
#: user properties is a list of tuples (name, value) that holds user
#: defined properties for this test.
self.user_properties = [] # type: List[Tuple[str, Any]]
self.user_properties = [] # type: List[Tuple[str, object]]
def runtest(self) -> None:
raise NotImplementedError("runtest must be implemented by Item subclass")

View File

@ -2,6 +2,7 @@
from _pytest import python
from _pytest import unittest
from _pytest.config import hookimpl
from _pytest.nodes import Item
@hookimpl(trylast=True)
@ -20,7 +21,7 @@ def teardown_nose(item):
call_optional(item.parent.obj, "teardown")
def is_potential_nosetest(item):
def is_potential_nosetest(item: Item) -> bool:
# extra check needed since we do not do nose style setup/teardown
# on direct unittest style classes
return isinstance(item, python.Function) and not isinstance(

View File

@ -2,15 +2,20 @@
import tempfile
from io import StringIO
from typing import IO
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import create_terminal_writer
from _pytest.config.argparsing import Parser
from _pytest.store import StoreKey
from _pytest.terminal import TerminalReporter
pastebinfile_key = StoreKey[IO[bytes]]()
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting")
group._addoption(
"--pastebin",
@ -24,7 +29,7 @@ def pytest_addoption(parser):
@pytest.hookimpl(trylast=True)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
if config.option.pastebin == "all":
tr = config.pluginmanager.getplugin("terminalreporter")
# if no terminal reporter plugin is present, nothing we can do here;
@ -44,7 +49,7 @@ def pytest_configure(config):
tr._tw.write = tee_write
def pytest_unconfigure(config):
def pytest_unconfigure(config: Config) -> None:
if pastebinfile_key in config._store:
pastebinfile = config._store[pastebinfile_key]
# get terminal contents and delete file
@ -61,11 +66,11 @@ def pytest_unconfigure(config):
tr.write_line("pastebin session-log: %s\n" % pastebinurl)
def create_new_paste(contents):
def create_new_paste(contents: Union[str, bytes]) -> str:
"""
Creates a new paste using bpaste.net service.
:contents: paste contents as utf-8 encoded bytes
:contents: paste contents string
:returns: url to the pasted contents or error message
"""
import re
@ -77,7 +82,7 @@ def create_new_paste(contents):
try:
response = (
urlopen(url, data=urlencode(params).encode("ascii")).read().decode("utf-8")
)
) # type: str
except OSError as exc_info: # urllib errors
return "bad response: %s" % exc_info
m = re.search(r'href="/raw/(\w+)"', response)
@ -87,23 +92,20 @@ def create_new_paste(contents):
return "bad response: invalid format ('" + response + "')"
def pytest_terminal_summary(terminalreporter):
import _pytest.config
def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None:
if terminalreporter.config.option.pastebin != "failed":
return
tr = terminalreporter
if "failed" in tr.stats:
if "failed" in terminalreporter.stats:
terminalreporter.write_sep("=", "Sending information to Paste Service")
for rep in terminalreporter.stats.get("failed"):
for rep in terminalreporter.stats["failed"]:
try:
msg = rep.longrepr.reprtraceback.reprentries[-1].reprfileloc
except AttributeError:
msg = tr._getfailureheadline(rep)
msg = terminalreporter._getfailureheadline(rep)
file = StringIO()
tw = _pytest.config.create_terminal_writer(terminalreporter.config, file)
tw = create_terminal_writer(terminalreporter.config, file)
rep.toterminal(tw)
s = file.getvalue()
assert len(s)
pastebinurl = create_new_paste(s)
tr.write_line("{} --> {}".format(msg, pastebinurl))
terminalreporter.write_line("{} --> {}".format(msg, pastebinurl))

View File

@ -348,7 +348,7 @@ def make_numbered_dir_with_cleanup(
raise e
def resolve_from_str(input, root):
def resolve_from_str(input: str, root):
assert not isinstance(input, Path), "would break on py2"
root = Path(root)
input = expanduser(input)

View File

@ -12,6 +12,7 @@ from fnmatch import fnmatch
from io import StringIO
from typing import Callable
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
@ -31,6 +32,7 @@ from _pytest.compat import TYPE_CHECKING
from _pytest.config import _PluggyPlugin
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureRequest
from _pytest.main import Session
from _pytest.monkeypatch import MonkeyPatch
@ -53,7 +55,7 @@ IGNORE_PAM = [ # filenames added when obtaining details about the current user
]
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
parser.addoption(
"--lsof",
action="store_true",
@ -78,7 +80,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
if config.getvalue("lsof"):
checker = LsofFdLeakChecker()
if checker.matching_platform():
@ -137,7 +139,7 @@ class LsofFdLeakChecker:
return True
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(self, item):
def pytest_runtest_protocol(self, item: Item) -> Generator[None, None, None]:
lines1 = self.get_open_files()
yield
if hasattr(sys, "pypy_version_info"):
@ -399,7 +401,7 @@ def _sys_snapshot():
@pytest.fixture
def _config_for_test():
def _config_for_test() -> Generator[Config, None, None]:
from _pytest.config import get_config
config = get_config()
@ -645,8 +647,8 @@ class Testdir:
for basename, value in items:
p = self.tmpdir.join(basename).new(ext=ext)
p.dirpath().ensure_dir()
source = Source(value)
source = "\n".join(to_text(line) for line in source.lines)
source_ = Source(value)
source = "\n".join(to_text(line) for line in source_.lines)
p.write(source.strip().encode(encoding), "wb")
if ret is None:
ret = p
@ -837,7 +839,7 @@ class Testdir:
config.hook.pytest_sessionfinish(session=session, exitstatus=ExitCode.OK)
return res
def genitems(self, colitems):
def genitems(self, colitems: List[Union[Item, Collector]]) -> List[Item]:
"""Generate all test items from a collection node.
This recurses into the collection node and returns a list of all the
@ -845,7 +847,7 @@ class Testdir:
"""
session = colitems[0].session
result = []
result = [] # type: List[Item]
for colitem in colitems:
result.extend(session.genitems(colitem))
return result
@ -938,7 +940,7 @@ class Testdir:
rec = []
class Collect:
def pytest_configure(x, config):
def pytest_configure(x, config: Config) -> None:
rec.append(self.make_hook_recorder(config.pluginmanager))
plugins.append(Collect())
@ -1167,8 +1169,10 @@ class Testdir:
popen = subprocess.Popen(cmdargs, stdout=stdout, stderr=stderr, **kw)
if stdin is Testdir.CLOSE_STDIN:
assert popen.stdin is not None
popen.stdin.close()
elif isinstance(stdin, bytes):
assert popen.stdin is not None
popen.stdin.write(stdin)
return popen

View File

@ -15,7 +15,9 @@ from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
from typing import Tuple
from typing import Union
@ -41,23 +43,34 @@ from _pytest.compat import REGEX_TYPE
from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.deprecated import FUNCARGNAMES
from _pytest.fixtures import FuncFixtureInfo
from _pytest.main import Session
from _pytest.mark import MARK_GEN
from _pytest.mark import ParameterSet
from _pytest.mark.structures import get_unpacked_marks
from _pytest.mark.structures import Mark
from _pytest.mark.structures import MarkDecorator
from _pytest.mark.structures import normalize_mark_list
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.pathlib import parts
from _pytest.reports import TerminalRepr
from _pytest.warning_types import PytestCollectionWarning
from _pytest.warning_types import PytestUnhandledCoroutineWarning
if TYPE_CHECKING:
from typing import Type
from typing_extensions import Literal
from _pytest.fixtures import _Scope
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--fixtures",
@ -112,13 +125,14 @@ def pytest_addoption(parser):
)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.showfixtures:
showfixtures(config)
return 0
if config.option.show_fixtures_per_test:
show_fixtures_per_test(config)
return 0
return None
def pytest_generate_tests(metafunc: "Metafunc") -> None:
@ -127,7 +141,7 @@ def pytest_generate_tests(metafunc: "Metafunc") -> None:
metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) # type: ignore[misc]
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"parametrize(argnames, argvalues): call a test function multiple "
@ -161,7 +175,7 @@ def async_warn_and_skip(nodeid: str) -> None:
@hookimpl(trylast=True)
def pytest_pyfunc_call(pyfuncitem: "Function"):
def pytest_pyfunc_call(pyfuncitem: "Function") -> Optional[object]:
testfunction = pyfuncitem.obj
if is_async_function(testfunction):
async_warn_and_skip(pyfuncitem.nodeid)
@ -173,16 +187,20 @@ def pytest_pyfunc_call(pyfuncitem: "Function"):
return True
def pytest_collect_file(path, parent):
def pytest_collect_file(path: py.path.local, parent) -> Optional["Module"]:
ext = path.ext
if ext == ".py":
if not parent.session.isinitpath(path):
if not path_matches_patterns(
path, parent.config.getini("python_files") + ["__init__.py"]
):
return
return None
ihook = parent.session.gethookproxy(path)
return ihook.pytest_pycollect_makemodule(path=path, parent=parent)
module = ihook.pytest_pycollect_makemodule(
path=path, parent=parent
) # type: Module
return module
return None
def path_matches_patterns(path, patterns):
@ -190,14 +208,16 @@ def path_matches_patterns(path, patterns):
return any(path.fnmatch(pattern) for pattern in patterns)
def pytest_pycollect_makemodule(path, parent):
def pytest_pycollect_makemodule(path: py.path.local, parent) -> "Module":
if path.basename == "__init__.py":
return Package.from_parent(parent, fspath=path)
return Module.from_parent(parent, fspath=path)
pkg = Package.from_parent(parent, fspath=path) # type: Package
return pkg
mod = Module.from_parent(parent, fspath=path) # type: Module
return mod
@hookimpl(hookwrapper=True)
def pytest_pycollect_makeitem(collector, name, obj):
def pytest_pycollect_makeitem(collector: "PyCollector", name: str, obj):
outcome = yield
res = outcome.get_result()
if res is not None:
@ -238,6 +258,18 @@ def pytest_pycollect_makeitem(collector, name, obj):
class PyobjMixin:
_ALLOW_MARKERS = True
# Function and attributes that the mixin needs (for type-checking only).
if TYPE_CHECKING:
name = "" # type: str
parent = None # type: Optional[nodes.Node]
own_markers = [] # type: List[Mark]
def getparent(self, cls: Type[nodes._NodeType]) -> Optional[nodes._NodeType]:
...
def listchain(self) -> List[nodes.Node]:
...
@property
def module(self):
"""Python module object this node was collected from (can be None)."""
@ -274,7 +306,10 @@ class PyobjMixin:
def _getobj(self):
"""Gets the underlying Python object. May be overwritten by subclasses."""
return getattr(self.parent.obj, self.name)
# TODO: Improve the type of `parent` such that assert/ignore aren't needed.
assert self.parent is not None
obj = self.parent.obj # type: ignore[attr-defined] # noqa: F821
return getattr(obj, self.name)
def getmodpath(self, stopatmodule=True, includemodule=False):
""" return python path relative to the containing module. """
@ -361,7 +396,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
return True
return False
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not getattr(self.obj, "__test__", True):
return []
@ -370,8 +405,8 @@ class PyCollector(PyobjMixin, nodes.Collector):
dicts = [getattr(self.obj, "__dict__", {})]
for basecls in self.obj.__class__.__mro__:
dicts.append(basecls.__dict__)
seen = set()
values = []
seen = set() # type: Set[str]
values = [] # type: List[Union[nodes.Item, nodes.Collector]]
for dic in dicts:
# Note: seems like the dict can change during iteration -
# be careful not to remove the list() without consideration.
@ -393,12 +428,21 @@ class PyCollector(PyobjMixin, nodes.Collector):
values.sort(key=sort_key)
return values
def _makeitem(self, name, obj):
def _makeitem(
self, name: str, obj
) -> Union[
None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]
]:
# assert self.ihook.fspath == self.fspath, self
return self.ihook.pytest_pycollect_makeitem(collector=self, name=name, obj=obj)
item = self.ihook.pytest_pycollect_makeitem(
collector=self, name=name, obj=obj
) # type: Union[None, nodes.Item, nodes.Collector, List[Union[nodes.Item, nodes.Collector]]]
return item
def _genfunctions(self, name, funcobj):
module = self.getparent(Module).obj
modulecol = self.getparent(Module)
assert modulecol is not None
module = modulecol.obj
clscol = self.getparent(Class)
cls = clscol and clscol.obj or None
fm = self.session._fixturemanager
@ -412,7 +456,7 @@ class PyCollector(PyobjMixin, nodes.Collector):
methods = []
if hasattr(module, "pytest_generate_tests"):
methods.append(module.pytest_generate_tests)
if hasattr(cls, "pytest_generate_tests"):
if cls is not None and hasattr(cls, "pytest_generate_tests"):
methods.append(cls().pytest_generate_tests)
self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc))
@ -447,7 +491,7 @@ class Module(nodes.File, PyCollector):
def _getobj(self):
return self._importtestmodule()
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self._inject_setup_module_fixture()
self._inject_setup_function_fixture()
self.session._fixturemanager.parsefactories(self)
@ -592,17 +636,17 @@ class Package(Module):
def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath)
def isinitpath(self, path):
def isinitpath(self, path: py.path.local) -> bool:
return path in self.session._initialpaths
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
this_path = self.fspath.dirpath()
init_module = this_path.join("__init__.py")
if init_module.check(file=1) and path_matches_patterns(
init_module, self.config.getini("python_files")
):
yield Module.from_parent(self, fspath=init_module)
pkg_prefixes = set()
pkg_prefixes = set() # type: Set[py.path.local]
for path in this_path.visit(rec=self._recurse, bf=True, sort=True):
# We will visit our own __init__.py file, in which case we skip it.
is_file = path.isfile()
@ -659,10 +703,11 @@ class Class(PyCollector):
"""
return super().from_parent(name=name, parent=parent)
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
if not safe_getattr(self.obj, "__test__", True):
return []
if hasinit(self.obj):
assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@ -672,6 +717,7 @@ class Class(PyCollector):
)
return []
elif hasnew(self.obj):
assert self.parent is not None
self.warn(
PytestCollectionWarning(
"cannot collect test class %r because it has a "
@ -743,9 +789,12 @@ class Instance(PyCollector):
# can be removed at node structure reorganization time
def _getobj(self):
return self.parent.obj()
# TODO: Improve the type of `parent` such that assert/ignore aren't needed.
assert self.parent is not None
obj = self.parent.obj # type: ignore[attr-defined] # noqa: F821
return obj()
def collect(self):
def collect(self) -> Iterable[Union[nodes.Item, nodes.Collector]]:
self.session._fixturemanager.parsefactories(self)
return super().collect()
@ -767,16 +816,17 @@ def hasnew(obj):
class CallSpec2:
def __init__(self, metafunc):
def __init__(self, metafunc: "Metafunc") -> None:
self.metafunc = metafunc
self.funcargs = {}
self._idlist = []
self.params = {}
self._arg2scopenum = {} # used for sorting parametrized resources
self.marks = []
self.indices = {}
self.funcargs = {} # type: Dict[str, object]
self._idlist = [] # type: List[str]
self.params = {} # type: Dict[str, object]
# Used for sorting parametrized resources.
self._arg2scopenum = {} # type: Dict[str, int]
self.marks = [] # type: List[Mark]
self.indices = {} # type: Dict[str, int]
def copy(self):
def copy(self) -> "CallSpec2":
cs = CallSpec2(self.metafunc)
cs.funcargs.update(self.funcargs)
cs.params.update(self.params)
@ -786,25 +836,39 @@ class CallSpec2:
cs._idlist = list(self._idlist)
return cs
def _checkargnotcontained(self, arg):
def _checkargnotcontained(self, arg: str) -> None:
if arg in self.params or arg in self.funcargs:
raise ValueError("duplicate {!r}".format(arg))
def getparam(self, name):
def getparam(self, name: str) -> object:
try:
return self.params[name]
except KeyError:
raise ValueError(name)
@property
def id(self):
def id(self) -> str:
return "-".join(map(str, self._idlist))
def setmulti2(self, valtypes, argnames, valset, id, marks, scopenum, param_index):
def setmulti2(
self,
valtypes: "Mapping[str, Literal['params', 'funcargs']]",
argnames: typing.Sequence[str],
valset: Iterable[object],
id: str,
marks: Iterable[Union[Mark, MarkDecorator]],
scopenum: int,
param_index: int,
) -> None:
for arg, val in zip(argnames, valset):
self._checkargnotcontained(arg)
valtype_for_arg = valtypes[arg]
getattr(self, valtype_for_arg)[arg] = val
if valtype_for_arg == "params":
self.params[arg] = val
elif valtype_for_arg == "funcargs":
self.funcargs[arg] = val
else: # pragma: no cover
assert False, "Unhandled valtype for arg: {}".format(valtype_for_arg)
self.indices[arg] = param_index
self._arg2scopenum[arg] = scopenum
self._idlist.append(id)
@ -864,7 +928,7 @@ class Metafunc:
Callable[[object], Optional[object]],
]
] = None,
scope: "Optional[str]" = None,
scope: "Optional[_Scope]" = None,
*,
_param_mark: Optional[Mark] = None
) -> None:
@ -1044,7 +1108,7 @@ class Metafunc:
self,
argnames: typing.Sequence[str],
indirect: Union[bool, typing.Sequence[str]],
) -> Dict[str, str]:
) -> Dict[str, "Literal['params', 'funcargs']"]:
"""Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg"
to the function, based on the ``indirect`` parameter of the parametrized() call.
@ -1056,7 +1120,9 @@ class Metafunc:
* "funcargs" if the argname should be a parameter to the parametrized test function.
"""
if isinstance(indirect, bool):
valtypes = dict.fromkeys(argnames, "params" if indirect else "funcargs")
valtypes = dict.fromkeys(
argnames, "params" if indirect else "funcargs"
) # type: Dict[str, Literal["params", "funcargs"]]
elif isinstance(indirect, Sequence):
valtypes = dict.fromkeys(argnames, "funcargs")
for arg in indirect:
@ -1308,13 +1374,13 @@ def _show_fixtures_per_test(config, session):
write_item(session_item)
def showfixtures(config):
def showfixtures(config: Config) -> Union[int, ExitCode]:
from _pytest.main import wrap_session
return wrap_session(config, _showfixtures_main)
def _showfixtures_main(config, session):
def _showfixtures_main(config: Config, session: Session) -> None:
import _pytest.config
session.perform_collect()
@ -1325,7 +1391,7 @@ def _showfixtures_main(config, session):
fm = session._fixturemanager
available = []
seen = set()
seen = set() # type: Set[Tuple[str, str]]
for argname, fixturedefs in fm._arg2fixturedefs.items():
assert fixturedefs is not None
@ -1481,7 +1547,8 @@ class Function(PyobjMixin, nodes.Item):
return getimfunc(self.obj)
def _getobj(self):
return getattr(self.parent.obj, self.originalname)
assert self.parent is not None
return getattr(self.parent.obj, self.originalname) # type: ignore[attr-defined]
@property
def _pyfuncitem(self):
@ -1525,7 +1592,10 @@ class Function(PyobjMixin, nodes.Item):
for entry in excinfo.traceback[1:-1]:
entry.set_repr_style("short")
def repr_failure(self, excinfo, outerr=None):
# TODO: Type ignored -- breaks Liskov Substitution.
def repr_failure( # type: ignore[override] # noqa: F821
self, excinfo: ExceptionInfo[BaseException], outerr: None = None
) -> Union[str, TerminalRepr]:
assert outerr is None, "XXX outerr usage is deprecated"
style = self.config.getoption("tbstyle", "auto")
if style == "auto":

View File

@ -508,7 +508,7 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
__tracebackhide__ = True
if isinstance(expected, Decimal):
cls = ApproxDecimal
cls = ApproxDecimal # type: Type[ApproxBase]
elif isinstance(expected, Number):
cls = ApproxScalar
elif isinstance(expected, Mapping):
@ -534,7 +534,7 @@ def _is_numpy_array(obj):
"""
import sys
np = sys.modules.get("numpy")
np = sys.modules.get("numpy") # type: Any
if np is not None:
return isinstance(obj, np.ndarray)
return False
@ -712,6 +712,7 @@ def raises( # noqa: F811
fail(message)
# This doesn't work with mypy for now. Use fail.Exception instead.
raises.Exception = fail.Exception # type: ignore

View File

@ -136,8 +136,9 @@ class WarningsRecorder(warnings.catch_warnings):
Adapted from `warnings.catch_warnings`.
"""
def __init__(self):
super().__init__(record=True)
def __init__(self) -> None:
# Type ignored due to the way typeshed handles warnings.catch_warnings.
super().__init__(record=True) # type: ignore[call-arg] # noqa: F821
self._entered = False
self._list = [] # type: List[warnings.WarningMessage]

View File

@ -1,9 +1,12 @@
from io import StringIO
from pprint import pprint
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
@ -21,10 +24,19 @@ from _pytest._code.code import ReprTraceback
from _pytest._code.code import TerminalRepr
from _pytest._io import TerminalWriter
from _pytest.compat import TYPE_CHECKING
from _pytest.nodes import Node
from _pytest.config import Config
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import skip
from _pytest.pathlib import Path
if TYPE_CHECKING:
from typing import NoReturn
from typing_extensions import Type
from typing_extensions import Literal
from _pytest.runner import CallInfo
def getslaveinfoline(node):
try:
@ -38,10 +50,14 @@ def getslaveinfoline(node):
return s
_R = TypeVar("_R", bound="BaseReport")
class BaseReport:
when = None # type: Optional[str]
location = None # type: Optional[Tuple[str, Optional[int], str]]
longrepr = None
# TODO: Improve this Any.
longrepr = None # type: Optional[Any]
sections = [] # type: List[Tuple[str, str]]
nodeid = None # type: str
@ -69,13 +85,13 @@ class BaseReport:
except UnicodeEncodeError:
out.line("<unprintable longrepr>")
def get_sections(self, prefix):
def get_sections(self, prefix: str) -> Iterator[Tuple[str, str]]:
for name, content in self.sections:
if name.startswith(prefix):
yield prefix, content
@property
def longreprtext(self):
def longreprtext(self) -> str:
"""
Read-only property that returns the full string representation
of ``longrepr``.
@ -90,7 +106,7 @@ class BaseReport:
return exc.strip()
@property
def caplog(self):
def caplog(self) -> str:
"""Return captured log lines, if log capturing is enabled
.. versionadded:: 3.5
@ -100,7 +116,7 @@ class BaseReport:
)
@property
def capstdout(self):
def capstdout(self) -> str:
"""Return captured text from stdout, if capturing is enabled
.. versionadded:: 3.0
@ -110,7 +126,7 @@ class BaseReport:
)
@property
def capstderr(self):
def capstderr(self) -> str:
"""Return captured text from stderr, if capturing is enabled
.. versionadded:: 3.0
@ -128,7 +144,7 @@ class BaseReport:
return self.nodeid.split("::")[0]
@property
def count_towards_summary(self):
def count_towards_summary(self) -> bool:
"""
**Experimental**
@ -143,7 +159,7 @@ class BaseReport:
return True
@property
def head_line(self):
def head_line(self) -> Optional[str]:
"""
**Experimental**
@ -163,8 +179,9 @@ class BaseReport:
if self.location is not None:
fspath, lineno, domain = self.location
return domain
return None
def _get_verbose_word(self, config):
def _get_verbose_word(self, config: Config):
_category, _short, verbose = config.hook.pytest_report_teststatus(
report=self, config=config
)
@ -182,7 +199,7 @@ class BaseReport:
return _report_to_json(self)
@classmethod
def _from_json(cls, reportdict):
def _from_json(cls: "Type[_R]", reportdict) -> _R:
"""
This was originally the serialize_report() function from xdist (ca03269).
@ -195,7 +212,9 @@ class BaseReport:
return cls(**kwargs)
def _report_unserialization_failure(type_name, report_class, reportdict):
def _report_unserialization_failure(
type_name: str, report_class: "Type[BaseReport]", reportdict
) -> "NoReturn":
url = "https://github.com/pytest-dev/pytest/issues"
stream = StringIO()
pprint("-" * 100, stream=stream)
@ -216,15 +235,15 @@ class TestReport(BaseReport):
def __init__(
self,
nodeid,
nodeid: str,
location: Tuple[str, Optional[int], str],
keywords,
outcome,
outcome: "Literal['passed', 'failed', 'skipped']",
longrepr,
when,
sections=(),
duration=0,
user_properties=None,
when: "Literal['setup', 'call', 'teardown']",
sections: Iterable[Tuple[str, str]] = (),
duration: float = 0,
user_properties: Optional[Iterable[Tuple[str, object]]] = None,
**extra
) -> None:
#: normalized collection node id
@ -263,24 +282,27 @@ class TestReport(BaseReport):
self.__dict__.update(extra)
def __repr__(self):
def __repr__(self) -> str:
return "<{} {!r} when={!r} outcome={!r}>".format(
self.__class__.__name__, self.nodeid, self.when, self.outcome
)
@classmethod
def from_item_and_call(cls, item, call) -> "TestReport":
def from_item_and_call(cls, item: Item, call: "CallInfo[None]") -> "TestReport":
"""
Factory method to create and fill a TestReport with standard item and call info.
"""
when = call.when
# Remove "collect" from the Literal type -- only for collection calls.
assert when != "collect"
duration = call.duration
keywords = {x: 1 for x in item.keywords}
excinfo = call.excinfo
sections = []
if not call.excinfo:
outcome = "passed"
longrepr = None
outcome = "passed" # type: Literal["passed", "failed", "skipped"]
# TODO: Improve this Any.
longrepr = None # type: Optional[Any]
else:
if not isinstance(excinfo, ExceptionInfo):
outcome = "failed"
@ -316,7 +338,13 @@ class CollectReport(BaseReport):
when = "collect"
def __init__(
self, nodeid: str, outcome, longrepr, result: List[Node], sections=(), **extra
self,
nodeid: str,
outcome: "Literal['passed', 'skipped', 'failed']",
longrepr,
result: Optional[List[Union[Item, Collector]]],
sections: Iterable[Tuple[str, str]] = (),
**extra
) -> None:
self.nodeid = nodeid
self.outcome = outcome
@ -329,28 +357,29 @@ class CollectReport(BaseReport):
def location(self):
return (self.fspath, None, self.fspath)
def __repr__(self):
def __repr__(self) -> str:
return "<CollectReport {!r} lenresult={} outcome={!r}>".format(
self.nodeid, len(self.result), self.outcome
)
class CollectErrorRepr(TerminalRepr):
def __init__(self, msg):
def __init__(self, msg) -> None:
self.longrepr = msg
def toterminal(self, out) -> None:
out.line(self.longrepr, red=True)
def pytest_report_to_serializable(report):
def pytest_report_to_serializable(report: BaseReport):
if isinstance(report, (TestReport, CollectReport)):
data = report._to_json()
data["$report_type"] = report.__class__.__name__
return data
return None
def pytest_report_from_serializable(data):
def pytest_report_from_serializable(data) -> Optional[BaseReport]:
if "$report_type" in data:
if data["$report_type"] == "TestReport":
return TestReport._from_json(data)
@ -359,9 +388,10 @@ def pytest_report_from_serializable(data):
assert False, "Unknown report_type unserialize data: {}".format(
data["$report_type"]
)
return None
def _report_to_json(report):
def _report_to_json(report: BaseReport):
"""
This was originally the serialize_report() function from xdist (ca03269).
@ -369,11 +399,12 @@ def _report_to_json(report):
serialization.
"""
def serialize_repr_entry(entry):
entry_data = {"type": type(entry).__name__, "data": attr.asdict(entry)}
for key, value in entry_data["data"].items():
def serialize_repr_entry(entry: Union[ReprEntry, ReprEntryNative]):
data = attr.asdict(entry)
for key, value in data.items():
if hasattr(value, "__dict__"):
entry_data["data"][key] = attr.asdict(value)
data[key] = attr.asdict(value)
entry_data = {"type": type(entry).__name__, "data": data}
return entry_data
def serialize_repr_traceback(reprtraceback: ReprTraceback):

View File

@ -5,13 +5,17 @@ import os
import py
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
from _pytest.store import StoreKey
resultlog_key = StoreKey["ResultLog"]()
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "resultlog plugin options")
group.addoption(
"--resultlog",
@ -23,7 +27,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
resultlog = config.option.resultlog
# prevent opening resultlog on slave nodes (xdist)
if resultlog and not hasattr(config, "slaveinput"):
@ -40,7 +44,7 @@ def pytest_configure(config):
_issue_warning_captured(RESULT_LOG, config.hook, stacklevel=2)
def pytest_unconfigure(config):
def pytest_unconfigure(config: Config) -> None:
resultlog = config._store.get(resultlog_key, None)
if resultlog:
resultlog.logfile.close()
@ -64,7 +68,7 @@ class ResultLog:
testpath = report.fspath
self.write_log_entry(testpath, lettercode, longrepr)
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: TestReport) -> None:
if report.when != "call" and report.passed:
return
res = self.config.hook.pytest_report_teststatus(
@ -78,12 +82,13 @@ class ResultLog:
elif report.passed:
longrepr = ""
elif report.skipped:
assert report.longrepr is not None
longrepr = str(report.longrepr[2])
else:
longrepr = str(report.longrepr)
self.log_outcome(report, code, longrepr)
def pytest_collectreport(self, report):
def pytest_collectreport(self, report: CollectReport) -> None:
if not report.passed:
if report.failed:
code = "F"
@ -91,7 +96,7 @@ class ResultLog:
else:
assert report.skipped
code = "S"
longrepr = "%s:%d: %s" % report.longrepr
longrepr = "%s:%d: %s" % report.longrepr # type: ignore
self.log_outcome(report, code, longrepr)
def pytest_internalerror(self, excrepr):

View File

@ -2,14 +2,20 @@
import bdb
import os
import sys
from typing import Any
from typing import Callable
from typing import cast
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import attr
from .reports import BaseReport
from .reports import CollectErrorRepr
from .reports import CollectReport
from .reports import TestReport
@ -17,7 +23,9 @@ from _pytest import timing
from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest.compat import TYPE_CHECKING
from _pytest.config.argparsing import Parser
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.outcomes import Exit
from _pytest.outcomes import Skipped
@ -27,11 +35,14 @@ if TYPE_CHECKING:
from typing import Type
from typing_extensions import Literal
from _pytest.main import Session
from _pytest.terminal import TerminalReporter
#
# pytest plugin hooks
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "reporting", after="general")
group.addoption(
"--durations",
@ -43,7 +54,7 @@ def pytest_addoption(parser):
)
def pytest_terminal_summary(terminalreporter):
def pytest_terminal_summary(terminalreporter: "TerminalReporter") -> None:
durations = terminalreporter.config.option.durations
verbose = terminalreporter.config.getvalue("verbose")
if durations is None:
@ -75,25 +86,27 @@ def pytest_terminal_summary(terminalreporter):
tr.write_line("{:02.2f}s {:<8} {}".format(rep.duration, rep.when, rep.nodeid))
def pytest_sessionstart(session):
def pytest_sessionstart(session: "Session") -> None:
session._setupstate = SetupState()
def pytest_sessionfinish(session):
def pytest_sessionfinish(session: "Session") -> None:
session._setupstate.teardown_all()
def pytest_runtest_protocol(item, nextitem):
def pytest_runtest_protocol(item: Item, nextitem: Optional[Item]) -> bool:
item.ihook.pytest_runtest_logstart(nodeid=item.nodeid, location=item.location)
runtestprotocol(item, nextitem=nextitem)
item.ihook.pytest_runtest_logfinish(nodeid=item.nodeid, location=item.location)
return True
def runtestprotocol(item, log=True, nextitem=None):
def runtestprotocol(
item: Item, log: bool = True, nextitem: Optional[Item] = None
) -> List[TestReport]:
hasrequest = hasattr(item, "_request")
if hasrequest and not item._request:
item._initrequest()
if hasrequest and not item._request: # type: ignore[attr-defined] # noqa: F821
item._initrequest() # type: ignore[attr-defined] # noqa: F821
rep = call_and_report(item, "setup", log)
reports = [rep]
if rep.passed:
@ -105,12 +118,12 @@ def runtestprotocol(item, log=True, nextitem=None):
# after all teardown hooks have been called
# want funcargs and request info to go away
if hasrequest:
item._request = False
item.funcargs = None
item._request = False # type: ignore[attr-defined] # noqa: F821
item.funcargs = None # type: ignore[attr-defined] # noqa: F821
return reports
def show_test_item(item):
def show_test_item(item: Item) -> None:
"""Show test function, parameters and the fixtures of the test item."""
tw = item.config.get_terminal_writer()
tw.line()
@ -122,12 +135,12 @@ def show_test_item(item):
tw.flush()
def pytest_runtest_setup(item):
def pytest_runtest_setup(item: Item) -> None:
_update_current_test_var(item, "setup")
item.session._setupstate.prepare(item)
def pytest_runtest_call(item):
def pytest_runtest_call(item: Item) -> None:
_update_current_test_var(item, "call")
try:
del sys.last_type
@ -147,13 +160,15 @@ def pytest_runtest_call(item):
raise e
def pytest_runtest_teardown(item, nextitem):
def pytest_runtest_teardown(item: Item, nextitem: Optional[Item]) -> None:
_update_current_test_var(item, "teardown")
item.session._setupstate.teardown_exact(item, nextitem)
_update_current_test_var(item, None)
def _update_current_test_var(item, when):
def _update_current_test_var(
item: Item, when: Optional["Literal['setup', 'call', 'teardown']"]
) -> None:
"""
Update :envvar:`PYTEST_CURRENT_TEST` to reflect the current item and stage.
@ -169,7 +184,7 @@ def _update_current_test_var(item, when):
os.environ.pop(var_name)
def pytest_report_teststatus(report):
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if report.when in ("setup", "teardown"):
if report.failed:
# category, shortletter, verbose-word
@ -178,6 +193,7 @@ def pytest_report_teststatus(report):
return "skipped", "s", "SKIPPED"
else:
return "", "", ""
return None
#
@ -185,11 +201,11 @@ def pytest_report_teststatus(report):
def call_and_report(
item, when: "Literal['setup', 'call', 'teardown']", log=True, **kwds
):
item: Item, when: "Literal['setup', 'call', 'teardown']", log: bool = True, **kwds
) -> TestReport:
call = call_runtest_hook(item, when, **kwds)
hook = item.ihook
report = hook.pytest_runtest_makereport(item=item, call=call)
report = hook.pytest_runtest_makereport(item=item, call=call) # type: TestReport
if log:
hook.pytest_runtest_logreport(report=report)
if check_interactive_exception(call, report):
@ -197,17 +213,19 @@ def call_and_report(
return report
def check_interactive_exception(call, report):
return call.excinfo and not (
def check_interactive_exception(call: "CallInfo", report: BaseReport) -> bool:
return call.excinfo is not None and not (
hasattr(report, "wasxfail")
or call.excinfo.errisinstance(Skipped)
or call.excinfo.errisinstance(bdb.BdbQuit)
)
def call_runtest_hook(item, when: "Literal['setup', 'call', 'teardown']", **kwds):
def call_runtest_hook(
item: Item, when: "Literal['setup', 'call', 'teardown']", **kwds
) -> "CallInfo[None]":
if when == "setup":
ihook = item.ihook.pytest_runtest_setup
ihook = item.ihook.pytest_runtest_setup # type: Callable[..., None]
elif when == "call":
ihook = item.ihook.pytest_runtest_call
elif when == "teardown":
@ -222,11 +240,14 @@ def call_runtest_hook(item, when: "Literal['setup', 'call', 'teardown']", **kwds
)
_T = TypeVar("_T")
@attr.s(repr=False)
class CallInfo:
class CallInfo(Generic[_T]):
""" Result/Exception info a function invocation.
:param result: The return value of the call, if it didn't raise. Can only be accessed
:param T result: The return value of the call, if it didn't raise. Can only be accessed
if excinfo is None.
:param Optional[ExceptionInfo] excinfo: The captured exception of the call, if it raised.
:param float start: The system time when the call started, in seconds since the epoch.
@ -235,28 +256,34 @@ class CallInfo:
:param str when: The context of invocation: "setup", "call", "teardown", ...
"""
_result = attr.ib()
excinfo = attr.ib(type=Optional[ExceptionInfo])
_result = attr.ib(type="Optional[_T]")
excinfo = attr.ib(type=Optional[ExceptionInfo[BaseException]])
start = attr.ib(type=float)
stop = attr.ib(type=float)
duration = attr.ib(type=float)
when = attr.ib(type=str)
when = attr.ib(type="Literal['collect', 'setup', 'call', 'teardown']")
@property
def result(self):
def result(self) -> _T:
if self.excinfo is not None:
raise AttributeError("{!r} has no valid result".format(self))
return self._result
# The cast is safe because an exception wasn't raised, hence
# _result has the expected function return type (which may be
# None, that's why a cast and not an assert).
return cast(_T, self._result)
@classmethod
def from_call(cls, func, when, reraise=None) -> "CallInfo":
#: context of invocation: one of "setup", "call",
#: "teardown", "memocollect"
def from_call(
cls,
func: "Callable[[], _T]",
when: "Literal['collect', 'setup', 'call', 'teardown']",
reraise: "Optional[Union[Type[BaseException], Tuple[Type[BaseException], ...]]]" = None,
) -> "CallInfo[_T]":
excinfo = None
start = timing.time()
precise_start = timing.perf_counter()
try:
result = func()
result = func() # type: Optional[_T]
except BaseException:
excinfo = ExceptionInfo.from_current()
if reraise is not None and excinfo.errisinstance(reraise):
@ -275,21 +302,22 @@ class CallInfo:
excinfo=excinfo,
)
def __repr__(self):
def __repr__(self) -> str:
if self.excinfo is None:
return "<CallInfo when={!r} result: {!r}>".format(self.when, self._result)
return "<CallInfo when={!r} excinfo={!r}>".format(self.when, self.excinfo)
def pytest_runtest_makereport(item, call):
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> TestReport:
return TestReport.from_item_and_call(item, call)
def pytest_make_collect_report(collector: Collector) -> CollectReport:
call = CallInfo.from_call(lambda: list(collector.collect()), "collect")
longrepr = None
# TODO: Better typing for longrepr.
longrepr = None # type: Optional[Any]
if not call.excinfo:
outcome = "passed"
outcome = "passed" # type: Literal["passed", "skipped", "failed"]
else:
skip_exceptions = [Skipped]
unittest = sys.modules.get("unittest")
@ -309,9 +337,8 @@ def pytest_make_collect_report(collector: Collector) -> CollectReport:
if not hasattr(errorinfo, "toterminal"):
errorinfo = CollectErrorRepr(errorinfo)
longrepr = errorinfo
rep = CollectReport(
collector.nodeid, outcome, longrepr, getattr(call, "result", None)
)
result = call.result if not call.excinfo else None
rep = CollectReport(collector.nodeid, outcome, longrepr, result)
rep.call = call # type: ignore # see collect_one_node
return rep
@ -321,9 +348,9 @@ class SetupState:
def __init__(self):
self.stack = [] # type: List[Node]
self._finalizers = {} # type: Dict[Node, List[Callable[[], None]]]
self._finalizers = {} # type: Dict[Node, List[Callable[[], object]]]
def addfinalizer(self, finalizer, colitem):
def addfinalizer(self, finalizer: Callable[[], object], colitem) -> None:
""" attach a finalizer to the given colitem. """
assert colitem and not isinstance(colitem, tuple)
assert callable(finalizer)
@ -334,7 +361,7 @@ class SetupState:
colitem = self.stack.pop()
self._teardown_with_finalization(colitem)
def _callfinalizers(self, colitem):
def _callfinalizers(self, colitem) -> None:
finalizers = self._finalizers.pop(colitem, None)
exc = None
while finalizers:
@ -349,24 +376,24 @@ class SetupState:
if exc:
raise exc
def _teardown_with_finalization(self, colitem):
def _teardown_with_finalization(self, colitem) -> None:
self._callfinalizers(colitem)
colitem.teardown()
for colitem in self._finalizers:
assert colitem in self.stack
def teardown_all(self):
def teardown_all(self) -> None:
while self.stack:
self._pop_and_teardown()
for key in list(self._finalizers):
self._teardown_with_finalization(key)
assert not self._finalizers
def teardown_exact(self, item, nextitem):
def teardown_exact(self, item, nextitem) -> None:
needed_collectors = nextitem and nextitem.listchain() or []
self._teardown_towards(needed_collectors)
def _teardown_towards(self, needed_collectors):
def _teardown_towards(self, needed_collectors) -> None:
exc = None
while self.stack:
if self.stack == needed_collectors[: len(self.stack)]:
@ -381,7 +408,7 @@ class SetupState:
if exc:
raise exc
def prepare(self, colitem):
def prepare(self, colitem) -> None:
""" setup objects along the collector chain to the test-method
and teardown previously setup objects."""
needed_collectors = colitem.listchain()
@ -390,21 +417,21 @@ class SetupState:
# check if the last collection node has raised an error
for col in self.stack:
if hasattr(col, "_prepare_exc"):
exc = col._prepare_exc
exc = col._prepare_exc # type: ignore[attr-defined] # noqa: F821
raise exc
for col in needed_collectors[len(self.stack) :]:
self.stack.append(col)
try:
col.setup()
except TEST_OUTCOME as e:
col._prepare_exc = e
col._prepare_exc = e # type: ignore[attr-defined] # noqa: F821
raise e
def collect_one_node(collector):
def collect_one_node(collector: Collector) -> CollectReport:
ihook = collector.ihook
ihook.pytest_collectstart(collector=collector)
rep = ihook.pytest_make_collect_report(collector=collector)
rep = ihook.pytest_make_collect_report(collector=collector) # type: CollectReport
call = rep.__dict__.pop("call", None)
if call and check_interactive_exception(call, rep):
ihook.pytest_exception_interact(node=collector, call=call, report=rep)

View File

@ -1,8 +1,17 @@
from typing import Generator
from typing import Optional
from typing import Union
import pytest
from _pytest._io.saferepr import saferepr
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setuponly",
@ -19,7 +28,9 @@ def pytest_addoption(parser):
@pytest.hookimpl(hookwrapper=True)
def pytest_fixture_setup(fixturedef, request):
def pytest_fixture_setup(
fixturedef: FixtureDef, request: SubRequest
) -> Generator[None, None, None]:
yield
if request.config.option.setupshow:
if hasattr(request, "param"):
@ -27,24 +38,25 @@ def pytest_fixture_setup(fixturedef, request):
# display it now and during the teardown (in .finish()).
if fixturedef.ids:
if callable(fixturedef.ids):
fixturedef.cached_param = fixturedef.ids(request.param)
param = fixturedef.ids(request.param)
else:
fixturedef.cached_param = fixturedef.ids[request.param_index]
param = fixturedef.ids[request.param_index]
else:
fixturedef.cached_param = request.param
param = request.param
fixturedef.cached_param = param # type: ignore[attr-defined] # noqa: F821
_show_fixture_action(fixturedef, "SETUP")
def pytest_fixture_post_finalizer(fixturedef) -> None:
def pytest_fixture_post_finalizer(fixturedef: FixtureDef) -> None:
if fixturedef.cached_result is not None:
config = fixturedef._fixturemanager.config
if config.option.setupshow:
_show_fixture_action(fixturedef, "TEARDOWN")
if hasattr(fixturedef, "cached_param"):
del fixturedef.cached_param
del fixturedef.cached_param # type: ignore[attr-defined] # noqa: F821
def _show_fixture_action(fixturedef, msg):
def _show_fixture_action(fixturedef: FixtureDef, msg: str) -> None:
config = fixturedef._fixturemanager.config
capman = config.pluginmanager.getplugin("capturemanager")
if capman:
@ -67,7 +79,7 @@ def _show_fixture_action(fixturedef, msg):
tw.write(" (fixtures used: {})".format(", ".join(deps)))
if hasattr(fixturedef, "cached_param"):
tw.write("[{}]".format(saferepr(fixturedef.cached_param, maxsize=42)))
tw.write("[{}]".format(saferepr(fixturedef.cached_param, maxsize=42))) # type: ignore[attr-defined]
tw.flush()
@ -76,6 +88,7 @@ def _show_fixture_action(fixturedef, msg):
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setuponly:
config.option.setupshow = True
return None

View File

@ -1,7 +1,15 @@
from typing import Optional
from typing import Union
import pytest
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureDef
from _pytest.fixtures import SubRequest
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("debugconfig")
group.addoption(
"--setupplan",
@ -13,16 +21,20 @@ def pytest_addoption(parser):
@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
def pytest_fixture_setup(
fixturedef: FixtureDef, request: SubRequest
) -> Optional[object]:
# Will return a dummy fixture if the setuponly option is provided.
if request.config.option.setupplan:
my_cache_key = fixturedef.cache_key(request)
fixturedef.cached_result = (None, my_cache_key, None)
return fixturedef.cached_result
return None
@pytest.hookimpl(tryfirst=True)
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Optional[Union[int, ExitCode]]:
if config.option.setupplan:
config.option.setuponly = True
config.option.setupshow = True
return None

View File

@ -1,9 +1,18 @@
""" support for skip/xfail functions and markers. """
from typing import Optional
from typing import Tuple
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser
from _pytest.mark.evaluate import MarkEvaluator
from _pytest.nodes import Item
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Function
from _pytest.reports import BaseReport
from _pytest.runner import CallInfo
from _pytest.store import StoreKey
@ -12,7 +21,7 @@ evalxfail_key = StoreKey[MarkEvaluator]()
unexpectedsuccess_key = StoreKey[str]()
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--runxfail",
@ -31,7 +40,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
if config.option.runxfail:
# yay a hack
import pytest
@ -42,7 +51,7 @@ def pytest_configure(config):
def nop(*args, **kwargs):
pass
nop.Exception = xfail.Exception
nop.Exception = xfail.Exception # type: ignore[attr-defined] # noqa: F821
setattr(pytest, "xfail", nop)
config.addinivalue_line(
@ -72,7 +81,7 @@ def pytest_configure(config):
@hookimpl(tryfirst=True)
def pytest_runtest_setup(item):
def pytest_runtest_setup(item: Item) -> None:
# Check if skip or skipif are specified as pytest marks
item._store[skipped_by_mark_key] = False
eval_skipif = MarkEvaluator(item, "skipif")
@ -94,7 +103,7 @@ def pytest_runtest_setup(item):
@hookimpl(hookwrapper=True)
def pytest_pyfunc_call(pyfuncitem):
def pytest_pyfunc_call(pyfuncitem: Function):
check_xfail_no_run(pyfuncitem)
outcome = yield
passed = outcome.excinfo is None
@ -102,7 +111,7 @@ def pytest_pyfunc_call(pyfuncitem):
check_strict_xfail(pyfuncitem)
def check_xfail_no_run(item):
def check_xfail_no_run(item: Item) -> None:
"""check xfail(run=False)"""
if not item.config.option.runxfail:
evalxfail = item._store[evalxfail_key]
@ -111,7 +120,7 @@ def check_xfail_no_run(item):
xfail("[NOTRUN] " + evalxfail.getexplanation())
def check_strict_xfail(pyfuncitem):
def check_strict_xfail(pyfuncitem: Function) -> None:
"""check xfail(strict=True) for the given PASSING test"""
evalxfail = pyfuncitem._store[evalxfail_key]
if evalxfail.istrue():
@ -124,7 +133,7 @@ def check_strict_xfail(pyfuncitem):
@hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
def pytest_runtest_makereport(item: Item, call: CallInfo[None]):
outcome = yield
rep = outcome.get_result()
evalxfail = item._store.get(evalxfail_key, None)
@ -139,7 +148,8 @@ def pytest_runtest_makereport(item, call):
elif item.config.option.runxfail:
pass # don't interfere
elif call.excinfo and call.excinfo.errisinstance(xfail.Exception):
elif call.excinfo and isinstance(call.excinfo.value, xfail.Exception):
assert call.excinfo.value.msg is not None
rep.wasxfail = "reason: " + call.excinfo.value.msg
rep.outcome = "skipped"
elif evalxfail and not rep.skipped and evalxfail.wasvalid() and evalxfail.istrue():
@ -169,15 +179,17 @@ def pytest_runtest_makereport(item, call):
# the location of where the skip exception was raised within pytest
_, _, reason = rep.longrepr
filename, line = item.reportinfo()[:2]
assert line is not None
rep.longrepr = str(filename), line + 1, reason
# called by terminalreporter progress reporting
def pytest_report_teststatus(report):
def pytest_report_teststatus(report: BaseReport) -> Optional[Tuple[str, str, str]]:
if hasattr(report, "wasxfail"):
if report.skipped:
return "xfailed", "x", "XFAIL"
elif report.passed:
return "xpassed", "X", "XPASS"
return None

View File

@ -1,7 +1,15 @@
from typing import List
from typing import Optional
import pytest
from _pytest import nodes
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.reports import TestReport
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("general")
group.addoption(
"--sw",
@ -19,25 +27,28 @@ def pytest_addoption(parser):
@pytest.hookimpl
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
config.pluginmanager.register(StepwisePlugin(config), "stepwiseplugin")
class StepwisePlugin:
def __init__(self, config):
def __init__(self, config: Config) -> None:
self.config = config
self.active = config.getvalue("stepwise")
self.session = None
self.session = None # type: Optional[Session]
self.report_status = ""
if self.active:
assert config.cache is not None
self.lastfailed = config.cache.get("cache/stepwise", None)
self.skip = config.getvalue("stepwise_skip")
def pytest_sessionstart(self, session):
def pytest_sessionstart(self, session: Session) -> None:
self.session = session
def pytest_collection_modifyitems(self, session, config, items):
def pytest_collection_modifyitems(
self, session: Session, config: Config, items: List[nodes.Item]
) -> None:
if not self.active:
return
if not self.lastfailed:
@ -70,7 +81,7 @@ class StepwisePlugin:
config.hook.pytest_deselected(items=already_passed)
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report: TestReport) -> None:
if not self.active:
return
@ -85,6 +96,7 @@ class StepwisePlugin:
else:
# Mark test as the last failing and interrupt the test session.
self.lastfailed = report.nodeid
assert self.session is not None
self.session.shouldstop = (
"Test failed, continuing from this test next run."
)
@ -96,11 +108,13 @@ class StepwisePlugin:
if report.nodeid == self.lastfailed:
self.lastfailed = None
def pytest_report_collectionfinish(self):
def pytest_report_collectionfinish(self) -> Optional[str]:
if self.active and self.config.getoption("verbose") >= 0 and self.report_status:
return "stepwise: %s" % self.report_status
return None
def pytest_sessionfinish(self, session):
def pytest_sessionfinish(self, session: Session) -> None:
assert self.config.cache is not None
if self.active:
self.config.cache.set("cache/stepwise", self.lastfailed)
else:

View File

@ -12,11 +12,15 @@ from functools import partial
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Set
from typing import TextIO
from typing import Tuple
from typing import Union
import attr
import pluggy
@ -29,13 +33,24 @@ from _pytest import timing
from _pytest._io import TerminalWriter
from _pytest._io.wcwidth import wcswidth
from _pytest.compat import order_preserving_dict
from _pytest.compat import TYPE_CHECKING
from _pytest.config import _PluggyPlugin
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.config.argparsing import Parser
from _pytest.deprecated import TERMINALWRITER_WRITER
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.nodes import Node
from _pytest.reports import BaseReport
from _pytest.reports import CollectReport
from _pytest.reports import TestReport
if TYPE_CHECKING:
from typing_extensions import Literal
from _pytest.main import Session
REPORT_COLLECTING_RESOLUTION = 0.5
KNOWN_TYPES = (
@ -60,7 +75,14 @@ class MoreQuietAction(argparse.Action):
used to unify verbosity handling
"""
def __init__(self, option_strings, dest, default=None, required=False, help=None):
def __init__(
self,
option_strings: Sequence[str],
dest: str,
default: object = None,
required: bool = False,
help: Optional[str] = None,
) -> None:
super().__init__(
option_strings=option_strings,
dest=dest,
@ -70,14 +92,20 @@ class MoreQuietAction(argparse.Action):
help=help,
)
def __call__(self, parser, namespace, values, option_string=None):
def __call__(
self,
parser: argparse.ArgumentParser,
namespace: argparse.Namespace,
values: Union[str, Sequence[object], None],
option_string: Optional[str] = None,
) -> None:
new_count = getattr(namespace, self.dest, 0) - 1
setattr(namespace, self.dest, new_count)
# todo Deprecate config.quiet
namespace.quiet = getattr(namespace, "quiet", 0) + 1
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("terminal reporting", "reporting", after="general")
group._addoption(
"-v",
@ -185,7 +213,7 @@ def pytest_configure(config: Config) -> None:
def getreportopt(config: Config) -> str:
reportchars = config.option.reportchars
reportchars = config.option.reportchars # type: str
old_aliases = {"F", "S"}
reportopts = ""
@ -210,14 +238,14 @@ def getreportopt(config: Config) -> str:
@pytest.hookimpl(trylast=True) # after _pytest.runner
def pytest_report_teststatus(report: TestReport) -> Tuple[str, str, str]:
def pytest_report_teststatus(report: BaseReport) -> Tuple[str, str, str]:
letter = "F"
if report.passed:
letter = "."
elif report.skipped:
letter = "s"
outcome = report.outcome
outcome = report.outcome # type: str
if report.when in ("collect", "setup", "teardown") and outcome == "failed":
outcome = "error"
letter = "E"
@ -238,10 +266,12 @@ class WarningReport:
message = attr.ib(type=str)
nodeid = attr.ib(type=Optional[str], default=None)
fslocation = attr.ib(default=None)
fslocation = attr.ib(
type=Optional[Union[Tuple[str, int], py.path.local]], default=None
)
count_towards_summary = True
def get_location(self, config):
def get_location(self, config: Config) -> Optional[str]:
"""
Returns the more user-friendly information about the location
of a warning, or None.
@ -261,13 +291,13 @@ class WarningReport:
class TerminalReporter:
def __init__(self, config: Config, file=None) -> None:
def __init__(self, config: Config, file: Optional[TextIO] = None) -> None:
import _pytest.config
self.config = config
self._numcollected = 0
self._session = None # type: Optional[Session]
self._showfspath = None
self._showfspath = None # type: Optional[bool]
self.stats = {} # type: Dict[str, List[Any]]
self._main_color = None # type: Optional[str]
@ -284,6 +314,7 @@ class TerminalReporter:
self._progress_nodeids_reported = set() # type: Set[str]
self._show_progress_info = self._determine_show_progress_info()
self._collect_report_last_write = None # type: Optional[float]
self._already_displayed_warnings = None # type: Optional[int]
@property
def writer(self) -> TerminalWriter:
@ -291,11 +322,11 @@ class TerminalReporter:
return self._tw
@writer.setter
def writer(self, value: TerminalWriter):
def writer(self, value: TerminalWriter) -> None:
warnings.warn(TERMINALWRITER_WRITER, stacklevel=2)
self._tw = value
def _determine_show_progress_info(self):
def _determine_show_progress_info(self) -> "Literal['progress', 'count', False]":
"""Return True if we should display progress information based on the current config"""
# do not show progress if we are not capturing output (#3038)
if self.config.getoption("capture", "no") == "no":
@ -303,38 +334,42 @@ class TerminalReporter:
# do not show progress if we are showing fixture setup/teardown
if self.config.getoption("setupshow", False):
return False
cfg = self.config.getini("console_output_style")
if cfg in ("progress", "count"):
return cfg
return False
cfg = self.config.getini("console_output_style") # type: str
if cfg == "progress":
return "progress"
elif cfg == "count":
return "count"
else:
return False
@property
def verbosity(self):
return self.config.option.verbose
def verbosity(self) -> int:
verbosity = self.config.option.verbose # type: int
return verbosity
@property
def showheader(self):
def showheader(self) -> bool:
return self.verbosity >= 0
@property
def showfspath(self):
def showfspath(self) -> bool:
if self._showfspath is None:
return self.verbosity >= 0
return self._showfspath
@showfspath.setter
def showfspath(self, value):
def showfspath(self, value: Optional[bool]) -> None:
self._showfspath = value
@property
def showlongtestinfo(self):
def showlongtestinfo(self) -> bool:
return self.verbosity > 0
def hasopt(self, char):
def hasopt(self, char: str) -> bool:
char = {"xfailed": "x", "skipped": "s"}.get(char, char)
return char in self.reportchars
def write_fspath_result(self, nodeid, res, **markup):
def write_fspath_result(self, nodeid: str, res, **markup: bool) -> None:
fspath = self.config.rootdir.join(nodeid.split("::")[0])
# NOTE: explicitly check for None to work around py bug, and for less
# overhead in general (https://github.com/pytest-dev/py/pull/207).
@ -347,7 +382,7 @@ class TerminalReporter:
self._tw.write(fspath + " ")
self._tw.write(res, flush=True, **markup)
def write_ensure_prefix(self, prefix, extra="", **kwargs):
def write_ensure_prefix(self, prefix, extra: str = "", **kwargs) -> None:
if self.currentfspath != prefix:
self._tw.line()
self.currentfspath = prefix
@ -356,7 +391,7 @@ class TerminalReporter:
self._tw.write(extra, **kwargs)
self.currentfspath = -2
def ensure_newline(self):
def ensure_newline(self) -> None:
if self.currentfspath:
self._tw.line()
self.currentfspath = None
@ -367,13 +402,13 @@ class TerminalReporter:
def flush(self) -> None:
self._tw.flush()
def write_line(self, line, **markup):
def write_line(self, line: Union[str, bytes], **markup: bool) -> None:
if not isinstance(line, str):
line = str(line, errors="replace")
self.ensure_newline()
self._tw.line(line, **markup)
def rewrite(self, line, **markup):
def rewrite(self, line: str, **markup: bool) -> None:
"""
Rewinds the terminal cursor to the beginning and writes the given line.
@ -391,14 +426,20 @@ class TerminalReporter:
line = str(line)
self._tw.write("\r" + line + fill, **markup)
def write_sep(self, sep, title=None, **markup):
def write_sep(
self,
sep: str,
title: Optional[str] = None,
fullwidth: Optional[int] = None,
**markup: bool
) -> None:
self.ensure_newline()
self._tw.sep(sep, title, **markup)
self._tw.sep(sep, title, fullwidth, **markup)
def section(self, title, sep="=", **kw):
def section(self, title: str, sep: str = "=", **kw: bool) -> None:
self._tw.sep(sep, title, **kw)
def line(self, msg, **kw):
def line(self, msg: str, **kw: bool) -> None:
self._tw.line(msg, **kw)
def _add_stats(self, category: str, items: List) -> None:
@ -412,7 +453,9 @@ class TerminalReporter:
self.write_line("INTERNALERROR> " + line)
return 1
def pytest_warning_recorded(self, warning_message, nodeid):
def pytest_warning_recorded(
self, warning_message: warnings.WarningMessage, nodeid: str,
) -> None:
from _pytest.warnings import warning_record_to_str
fslocation = warning_message.filename, warning_message.lineno
@ -423,7 +466,7 @@ class TerminalReporter:
)
self._add_stats("warnings", [warning_report])
def pytest_plugin_registered(self, plugin):
def pytest_plugin_registered(self, plugin: _PluggyPlugin) -> None:
if self.config.option.traceconfig:
msg = "PLUGIN registered: {}".format(plugin)
# XXX this event may happen during setup/teardown time
@ -431,10 +474,10 @@ class TerminalReporter:
# which garbles our output if we use self.write_line
self.write_line(msg)
def pytest_deselected(self, items):
def pytest_deselected(self, items) -> None:
self._add_stats("deselected", items)
def pytest_runtest_logstart(self, nodeid, location):
def pytest_runtest_logstart(self, nodeid, location) -> None:
# ensure that the path is printed before the
# 1st test of a module starts running
if self.showlongtestinfo:
@ -448,7 +491,9 @@ class TerminalReporter:
def pytest_runtest_logreport(self, report: TestReport) -> None:
self._tests_ran = True
rep = report
res = self.config.hook.pytest_report_teststatus(report=rep, config=self.config)
res = self.config.hook.pytest_report_teststatus(
report=rep, config=self.config
) # type: Tuple[str, str, str]
category, letter, word = res
if isinstance(word, tuple):
word, markup = word
@ -495,10 +540,11 @@ class TerminalReporter:
self.flush()
@property
def _is_last_item(self):
def _is_last_item(self) -> bool:
assert self._session is not None
return len(self._progress_nodeids_reported) == self._session.testscollected
def pytest_runtest_logfinish(self, nodeid):
def pytest_runtest_logfinish(self, nodeid) -> None:
assert self._session
if self.verbosity <= 0 and self._show_progress_info:
if self._show_progress_info == "count":
@ -536,7 +582,7 @@ class TerminalReporter:
)
return " [100%]"
def _write_progress_information_filling_space(self):
def _write_progress_information_filling_space(self) -> None:
color, _ = self._get_main_color()
msg = self._get_progress_information_message()
w = self._width_of_current_line
@ -544,7 +590,7 @@ class TerminalReporter:
self.write(msg.rjust(fill), flush=True, **{color: True})
@property
def _width_of_current_line(self):
def _width_of_current_line(self) -> int:
"""Return the width of current line, using the superior implementation of py-1.6 when available"""
return self._tw.width_of_current_line
@ -566,7 +612,7 @@ class TerminalReporter:
if self.isatty:
self.report_collect()
def report_collect(self, final=False):
def report_collect(self, final: bool = False) -> None:
if self.config.option.verbose < 0:
return
@ -607,7 +653,7 @@ class TerminalReporter:
self.write_line(line)
@pytest.hookimpl(trylast=True)
def pytest_sessionstart(self, session: Session) -> None:
def pytest_sessionstart(self, session: "Session") -> None:
self._session = session
self._sessionstarttime = timing.time()
if not self.showheader:
@ -634,12 +680,14 @@ class TerminalReporter:
)
self._write_report_lines_from_hooks(lines)
def _write_report_lines_from_hooks(self, lines):
def _write_report_lines_from_hooks(
self, lines: List[Union[str, List[str]]]
) -> None:
lines.reverse()
for line in collapse(lines):
self.write_line(line)
def pytest_report_header(self, config):
def pytest_report_header(self, config: Config) -> List[str]:
line = "rootdir: %s" % config.rootdir
if config.inifile:
@ -656,7 +704,7 @@ class TerminalReporter:
result.append("plugins: %s" % ", ".join(_plugin_nameversions(plugininfo)))
return result
def pytest_collection_finish(self, session):
def pytest_collection_finish(self, session: "Session") -> None:
self.report_collect(True)
lines = self.config.hook.pytest_report_collectionfinish(
@ -676,7 +724,7 @@ class TerminalReporter:
for rep in failed:
rep.toterminal(self._tw)
def _printcollecteditems(self, items):
def _printcollecteditems(self, items: Sequence[Item]) -> None:
# to print out items and their parent collectors
# we take care to leave out Instances aka ()
# because later versions are going to get rid of them anyway
@ -692,7 +740,7 @@ class TerminalReporter:
for item in items:
self._tw.line(item.nodeid)
return
stack = []
stack = [] # type: List[Node]
indent = ""
for item in items:
needed_collectors = item.listchain()[1:] # strip root node
@ -707,17 +755,16 @@ class TerminalReporter:
indent = (len(stack) - 1) * " "
self._tw.line("{}{}".format(indent, col))
if self.config.option.verbose >= 1:
try:
obj = col.obj # type: ignore
except AttributeError:
continue
doc = inspect.getdoc(obj)
obj = getattr(col, "obj", None)
doc = inspect.getdoc(obj) if obj else None
if doc:
for line in doc.splitlines():
self._tw.line("{}{}".format(indent + " ", line))
@pytest.hookimpl(hookwrapper=True)
def pytest_sessionfinish(self, session: Session, exitstatus: ExitCode):
def pytest_sessionfinish(
self, session: "Session", exitstatus: Union[int, ExitCode]
):
outcome = yield
outcome.get_result()
self._tw.line("")
@ -733,16 +780,16 @@ class TerminalReporter:
terminalreporter=self, exitstatus=exitstatus, config=self.config
)
if session.shouldfail:
self.write_sep("!", session.shouldfail, red=True)
self.write_sep("!", str(session.shouldfail), red=True)
if exitstatus == ExitCode.INTERRUPTED:
self._report_keyboardinterrupt()
del self._keyboardinterrupt_memo
elif session.shouldstop:
self.write_sep("!", session.shouldstop, red=True)
self.write_sep("!", str(session.shouldstop), red=True)
self.summary_stats()
@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(self):
def pytest_terminal_summary(self) -> Generator[None, None, None]:
self.summary_errors()
self.summary_failures()
self.summary_warnings()
@ -752,14 +799,14 @@ class TerminalReporter:
# Display any extra warnings from teardown here (if any).
self.summary_warnings()
def pytest_keyboard_interrupt(self, excinfo):
def pytest_keyboard_interrupt(self, excinfo) -> None:
self._keyboardinterrupt_memo = excinfo.getrepr(funcargs=True)
def pytest_unconfigure(self):
def pytest_unconfigure(self) -> None:
if hasattr(self, "_keyboardinterrupt_memo"):
self._report_keyboardinterrupt()
def _report_keyboardinterrupt(self):
def _report_keyboardinterrupt(self) -> None:
excrepr = self._keyboardinterrupt_memo
msg = excrepr.reprcrash.message
self.write_sep("!", msg)
@ -813,14 +860,14 @@ class TerminalReporter:
#
# summaries for sessionfinish
#
def getreports(self, name):
def getreports(self, name: str):
values = []
for x in self.stats.get(name, []):
if not hasattr(x, "_pdbshown"):
values.append(x)
return values
def summary_warnings(self):
def summary_warnings(self) -> None:
if self.hasopt("w"):
all_warnings = self.stats.get(
"warnings"
@ -828,7 +875,7 @@ class TerminalReporter:
if not all_warnings:
return
final = hasattr(self, "_already_displayed_warnings")
final = self._already_displayed_warnings is not None
if final:
warning_reports = all_warnings[self._already_displayed_warnings :]
else:
@ -843,7 +890,7 @@ class TerminalReporter:
for wr in warning_reports:
reports_grouped_by_message.setdefault(wr.message, []).append(wr)
def collapsed_location_report(reports: List[WarningReport]):
def collapsed_location_report(reports: List[WarningReport]) -> str:
locations = []
for w in reports:
location = w.get_location(self.config)
@ -877,10 +924,10 @@ class TerminalReporter:
self._tw.line()
self._tw.line("-- Docs: https://docs.pytest.org/en/latest/warnings.html")
def summary_passes(self):
def summary_passes(self) -> None:
if self.config.option.tbstyle != "no":
if self.hasopt("P"):
reports = self.getreports("passed")
reports = self.getreports("passed") # type: List[TestReport]
if not reports:
return
self.write_sep("=", "PASSES")
@ -892,9 +939,10 @@ class TerminalReporter:
self._handle_teardown_sections(rep.nodeid)
def _get_teardown_reports(self, nodeid: str) -> List[TestReport]:
reports = self.getreports("")
return [
report
for report in self.getreports("")
for report in reports
if report.when == "teardown" and report.nodeid == nodeid
]
@ -915,9 +963,9 @@ class TerminalReporter:
content = content[:-1]
self._tw.line(content)
def summary_failures(self):
def summary_failures(self) -> None:
if self.config.option.tbstyle != "no":
reports = self.getreports("failed")
reports = self.getreports("failed") # type: List[BaseReport]
if not reports:
return
self.write_sep("=", "FAILURES")
@ -932,9 +980,9 @@ class TerminalReporter:
self._outrep_summary(rep)
self._handle_teardown_sections(rep.nodeid)
def summary_errors(self):
def summary_errors(self) -> None:
if self.config.option.tbstyle != "no":
reports = self.getreports("error")
reports = self.getreports("error") # type: List[BaseReport]
if not reports:
return
self.write_sep("=", "ERRORS")
@ -947,7 +995,7 @@ class TerminalReporter:
self.write_sep("_", msg, red=True, bold=True)
self._outrep_summary(rep)
def _outrep_summary(self, rep):
def _outrep_summary(self, rep: BaseReport) -> None:
rep.toterminal(self._tw)
showcapture = self.config.option.showcapture
if showcapture == "no":
@ -960,7 +1008,7 @@ class TerminalReporter:
content = content[:-1]
self._tw.line(content)
def summary_stats(self):
def summary_stats(self) -> None:
if self.verbosity < -1:
return
@ -1030,7 +1078,7 @@ class TerminalReporter:
lines.append("{} {} {}".format(verbose_word, pos, reason))
def show_skipped(lines: List[str]) -> None:
skipped = self.stats.get("skipped", [])
skipped = self.stats.get("skipped", []) # type: List[CollectReport]
fskips = _folded_skips(self.startdir, skipped) if skipped else []
if not fskips:
return
@ -1114,12 +1162,14 @@ class TerminalReporter:
return parts, main_color
def _get_pos(config, rep):
def _get_pos(config: Config, rep: BaseReport):
nodeid = config.cwd_relative_nodeid(rep.nodeid)
return nodeid
def _get_line_with_reprcrash_message(config, rep, termwidth):
def _get_line_with_reprcrash_message(
config: Config, rep: BaseReport, termwidth: int
) -> str:
"""Get summary line for a report, trying to add reprcrash message."""
verbose_word = rep._get_verbose_word(config)
pos = _get_pos(config, rep)
@ -1132,7 +1182,8 @@ def _get_line_with_reprcrash_message(config, rep, termwidth):
return line
try:
msg = rep.longrepr.reprcrash.message
# Type ignored intentionally -- possible AttributeError expected.
msg = rep.longrepr.reprcrash.message # type: ignore[union-attr] # noqa: F821
except AttributeError:
pass
else:
@ -1155,9 +1206,12 @@ def _get_line_with_reprcrash_message(config, rep, termwidth):
return line
def _folded_skips(startdir, skipped):
d = {}
def _folded_skips(
startdir: py.path.local, skipped: Sequence[CollectReport],
) -> List[Tuple[int, str, Optional[int], str]]:
d = {} # type: Dict[Tuple[str, Optional[int], str], List[CollectReport]]
for event in skipped:
assert event.longrepr is not None
assert len(event.longrepr) == 3, (event, event.longrepr)
fspath, lineno, reason = event.longrepr
# For consistency, report all fspaths in relative form.
@ -1171,13 +1225,13 @@ def _folded_skips(startdir, skipped):
and "skip" in keywords
and "pytestmark" not in keywords
):
key = (fspath, None, reason)
key = (fspath, None, reason) # type: Tuple[str, Optional[int], str]
else:
key = (fspath, lineno, reason)
d.setdefault(key, []).append(event)
values = []
values = [] # type: List[Tuple[int, str, Optional[int], str]]
for key, events in d.items():
values.append((len(events),) + key)
values.append((len(events), *key))
return values
@ -1190,7 +1244,7 @@ _color_for_type = {
_color_for_type_default = "yellow"
def _make_plural(count, noun):
def _make_plural(count: int, noun: str) -> Tuple[int, str]:
# No need to pluralize words such as `failed` or `passed`.
if noun not in ["error", "warnings"]:
return count, noun

View File

@ -1,32 +1,62 @@
""" discovery and running of std-library "unittest" style tests. """
import sys
import traceback
import types
from typing import Any
from typing import Callable
from typing import Generator
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import _pytest._code
import pytest
from _pytest.compat import getimfunc
from _pytest.compat import is_async_function
from _pytest.compat import TYPE_CHECKING
from _pytest.config import hookimpl
from _pytest.fixtures import FixtureRequest
from _pytest.nodes import Collector
from _pytest.nodes import Item
from _pytest.outcomes import exit
from _pytest.outcomes import fail
from _pytest.outcomes import skip
from _pytest.outcomes import xfail
from _pytest.python import Class
from _pytest.python import Function
from _pytest.python import PyCollector
from _pytest.runner import CallInfo
from _pytest.skipping import skipped_by_mark_key
from _pytest.skipping import unexpectedsuccess_key
if TYPE_CHECKING:
import unittest
from typing import Type
def pytest_pycollect_makeitem(collector, name, obj):
from _pytest.fixtures import _Scope
_SysExcInfoType = Union[
Tuple[Type[BaseException], BaseException, types.TracebackType],
Tuple[None, None, None],
]
def pytest_pycollect_makeitem(
collector: PyCollector, name: str, obj
) -> Optional["UnitTestCase"]:
# has unittest been imported and is obj a subclass of its TestCase?
try:
if not issubclass(obj, sys.modules["unittest"].TestCase):
return
ut = sys.modules["unittest"]
# Type ignored because `ut` is an opaque module.
if not issubclass(obj, ut.TestCase): # type: ignore
return None
except Exception:
return
return None
# yes, so let's collect it
return UnitTestCase.from_parent(collector, name=name, obj=obj)
item = UnitTestCase.from_parent(collector, name=name, obj=obj) # type: UnitTestCase
return item
class UnitTestCase(Class):
@ -34,7 +64,7 @@ class UnitTestCase(Class):
# to declare that our children do not support funcargs
nofuncargs = True
def collect(self):
def collect(self) -> Iterable[Union[Item, Collector]]:
from unittest import TestLoader
cls = self.obj
@ -61,34 +91,36 @@ class UnitTestCase(Class):
runtest = getattr(self.obj, "runTest", None)
if runtest is not None:
ut = sys.modules.get("twisted.trial.unittest", None)
if ut is None or runtest != ut.TestCase.runTest:
# TODO: callobj consistency
# Type ignored because `ut` is an opaque module.
if ut is None or runtest != ut.TestCase.runTest: # type: ignore
yield TestCaseFunction.from_parent(self, name="runTest")
def _inject_setup_teardown_fixtures(self, cls):
def _inject_setup_teardown_fixtures(self, cls: type) -> None:
"""Injects a hidden auto-use fixture to invoke setUpClass/setup_method and corresponding
teardown functions (#517)"""
class_fixture = _make_xunit_fixture(
cls, "setUpClass", "tearDownClass", scope="class", pass_self=False
)
if class_fixture:
cls.__pytest_class_setup = class_fixture
cls.__pytest_class_setup = class_fixture # type: ignore[attr-defined] # noqa: F821
method_fixture = _make_xunit_fixture(
cls, "setup_method", "teardown_method", scope="function", pass_self=True
)
if method_fixture:
cls.__pytest_method_setup = method_fixture
cls.__pytest_method_setup = method_fixture # type: ignore[attr-defined] # noqa: F821
def _make_xunit_fixture(obj, setup_name, teardown_name, scope, pass_self):
def _make_xunit_fixture(
obj: type, setup_name: str, teardown_name: str, scope: "_Scope", pass_self: bool
):
setup = getattr(obj, setup_name, None)
teardown = getattr(obj, teardown_name, None)
if setup is None and teardown is None:
return None
@pytest.fixture(scope=scope, autouse=True)
def fixture(self, request):
def fixture(self, request: FixtureRequest) -> Generator[None, None, None]:
if _is_skipped(self):
reason = self.__unittest_skip_why__
pytest.skip(reason)
@ -109,32 +141,33 @@ def _make_xunit_fixture(obj, setup_name, teardown_name, scope, pass_self):
class TestCaseFunction(Function):
nofuncargs = True
_excinfo = None
_testcase = None
_excinfo = None # type: Optional[List[_pytest._code.ExceptionInfo]]
_testcase = None # type: Optional[unittest.TestCase]
def setup(self):
def setup(self) -> None:
# a bound method to be called during teardown() if set (see 'runtest()')
self._explicit_tearDown = None
self._testcase = self.parent.obj(self.name)
self._explicit_tearDown = None # type: Optional[Callable[[], None]]
assert self.parent is not None
self._testcase = self.parent.obj(self.name) # type: ignore[attr-defined] # noqa: F821
self._obj = getattr(self._testcase, self.name)
if hasattr(self, "_request"):
self._request._fillfixtures()
def teardown(self):
def teardown(self) -> None:
if self._explicit_tearDown is not None:
self._explicit_tearDown()
self._explicit_tearDown = None
self._testcase = None
self._obj = None
def startTest(self, testcase):
def startTest(self, testcase: "unittest.TestCase") -> None:
pass
def _addexcinfo(self, rawexcinfo):
def _addexcinfo(self, rawexcinfo: "_SysExcInfoType") -> None:
# unwrap potential exception info (see twisted trial support below)
rawexcinfo = getattr(rawexcinfo, "_rawexcinfo", rawexcinfo)
try:
excinfo = _pytest._code.ExceptionInfo(rawexcinfo)
excinfo = _pytest._code.ExceptionInfo(rawexcinfo) # type: ignore[arg-type] # noqa: F821
# invoke the attributes to trigger storing the traceback
# trial causes some issue there
excinfo.value
@ -163,7 +196,9 @@ class TestCaseFunction(Function):
excinfo = _pytest._code.ExceptionInfo.from_current()
self.__dict__.setdefault("_excinfo", []).append(excinfo)
def addError(self, testcase, rawexcinfo):
def addError(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
try:
if isinstance(rawexcinfo[1], exit.Exception):
exit(rawexcinfo[1].msg)
@ -171,29 +206,38 @@ class TestCaseFunction(Function):
pass
self._addexcinfo(rawexcinfo)
def addFailure(self, testcase, rawexcinfo):
def addFailure(
self, testcase: "unittest.TestCase", rawexcinfo: "_SysExcInfoType"
) -> None:
self._addexcinfo(rawexcinfo)
def addSkip(self, testcase, reason):
def addSkip(self, testcase: "unittest.TestCase", reason: str) -> None:
try:
skip(reason)
except skip.Exception:
self._store[skipped_by_mark_key] = True
self._addexcinfo(sys.exc_info())
def addExpectedFailure(self, testcase, rawexcinfo, reason=""):
def addExpectedFailure(
self,
testcase: "unittest.TestCase",
rawexcinfo: "_SysExcInfoType",
reason: str = "",
) -> None:
try:
xfail(str(reason))
except xfail.Exception:
self._addexcinfo(sys.exc_info())
def addUnexpectedSuccess(self, testcase, reason=""):
def addUnexpectedSuccess(
self, testcase: "unittest.TestCase", reason: str = ""
) -> None:
self._store[unexpectedsuccess_key] = reason
def addSuccess(self, testcase):
def addSuccess(self, testcase: "unittest.TestCase") -> None:
pass
def stopTest(self, testcase):
def stopTest(self, testcase: "unittest.TestCase") -> None:
pass
def _expecting_failure(self, test_method) -> bool:
@ -205,14 +249,17 @@ class TestCaseFunction(Function):
expecting_failure_class = getattr(self, "__unittest_expecting_failure__", False)
return bool(expecting_failure_class or expecting_failure_method)
def runtest(self):
def runtest(self) -> None:
from _pytest.debugging import maybe_wrap_pytest_function_for_tracing
assert self._testcase is not None
maybe_wrap_pytest_function_for_tracing(self)
# let the unittest framework handle async functions
if is_async_function(self.obj):
self._testcase(self)
# Type ignored because self acts as the TestResult, but is not actually one.
self._testcase(result=self) # type: ignore[arg-type] # noqa: F821
else:
# when --pdb is given, we want to postpone calling tearDown() otherwise
# when entering the pdb prompt, tearDown() would have probably cleaned up
@ -228,11 +275,11 @@ class TestCaseFunction(Function):
# wrap_pytest_function_for_tracing replaces self.obj by a wrapper
setattr(self._testcase, self.name, self.obj)
try:
self._testcase(result=self)
self._testcase(result=self) # type: ignore[arg-type] # noqa: F821
finally:
delattr(self._testcase, self.name)
def _prunetraceback(self, excinfo):
def _prunetraceback(self, excinfo: _pytest._code.ExceptionInfo) -> None:
Function._prunetraceback(self, excinfo)
traceback = excinfo.traceback.filter(
lambda x: not x.frame.f_globals.get("__unittest")
@ -242,7 +289,7 @@ class TestCaseFunction(Function):
@hookimpl(tryfirst=True)
def pytest_runtest_makereport(item, call):
def pytest_runtest_makereport(item: Item, call: CallInfo[None]) -> None:
if isinstance(item, TestCaseFunction):
if item._excinfo:
call.excinfo = item._excinfo.pop(0)
@ -252,10 +299,17 @@ def pytest_runtest_makereport(item, call):
pass
unittest = sys.modules.get("unittest")
if unittest and call.excinfo and call.excinfo.errisinstance(unittest.SkipTest):
if (
unittest
and call.excinfo
and call.excinfo.errisinstance(
unittest.SkipTest # type: ignore[attr-defined] # noqa: F821
)
):
excinfo = call.excinfo
# let's substitute the excinfo with a pytest.skip one
call2 = CallInfo.from_call(
lambda: pytest.skip(str(call.excinfo.value)), call.when
call2 = CallInfo[None].from_call(
lambda: pytest.skip(str(excinfo.value)), call.when
)
call.excinfo = call2.excinfo
@ -264,9 +318,9 @@ def pytest_runtest_makereport(item, call):
@hookimpl(hookwrapper=True)
def pytest_runtest_protocol(item):
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
if isinstance(item, TestCaseFunction) and "twisted.trial.unittest" in sys.modules:
ut = sys.modules["twisted.python.failure"]
ut = sys.modules["twisted.python.failure"] # type: Any
Failure__init__ = ut.Failure.__init__
check_testcase_implements_trial_reporter()
@ -293,7 +347,7 @@ def pytest_runtest_protocol(item):
yield
def check_testcase_implements_trial_reporter(done=[]):
def check_testcase_implements_trial_reporter(done: List[int] = []) -> None:
if done:
return
from zope.interface import classImplements

View File

@ -4,14 +4,20 @@ import warnings
from contextlib import contextmanager
from functools import lru_cache
from typing import Generator
from typing import Optional
from typing import Tuple
import pytest
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from _pytest.main import Session
from _pytest.nodes import Item
from _pytest.terminal import TerminalReporter
if TYPE_CHECKING:
from typing_extensions import Type
from typing import Type
from typing_extensions import Literal
@lru_cache(maxsize=50)
@ -49,7 +55,7 @@ def _parse_filter(
return (action, message, category, module, lineno)
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
group = parser.getgroup("pytest-warnings")
group.addoption(
"-W",
@ -66,7 +72,7 @@ def pytest_addoption(parser):
)
def pytest_configure(config):
def pytest_configure(config: Config) -> None:
config.addinivalue_line(
"markers",
"filterwarnings(warning): add a warning filter to the given test. "
@ -75,7 +81,12 @@ def pytest_configure(config):
@contextmanager
def catch_warnings_for_item(config, ihook, when, item):
def catch_warnings_for_item(
config: Config,
ihook,
when: "Literal['config', 'collect', 'runtest']",
item: Optional[Item],
) -> Generator[None, None, None]:
"""
Context manager that catches warnings generated in the contained execution block.
@ -129,11 +140,11 @@ def catch_warnings_for_item(config, ihook, when, item):
)
def warning_record_to_str(warning_message):
def warning_record_to_str(warning_message: warnings.WarningMessage) -> str:
"""Convert a warnings.WarningMessage to a string."""
warn_msg = warning_message.message
msg = warnings.formatwarning(
warn_msg,
str(warn_msg),
warning_message.category,
warning_message.filename,
warning_message.lineno,
@ -143,7 +154,7 @@ def warning_record_to_str(warning_message):
@pytest.hookimpl(hookwrapper=True, tryfirst=True)
def pytest_runtest_protocol(item):
def pytest_runtest_protocol(item: Item) -> Generator[None, None, None]:
with catch_warnings_for_item(
config=item.config, ihook=item.ihook, when="runtest", item=item
):
@ -160,7 +171,9 @@ def pytest_collection(session: Session) -> Generator[None, None, None]:
@pytest.hookimpl(hookwrapper=True)
def pytest_terminal_summary(terminalreporter):
def pytest_terminal_summary(
terminalreporter: TerminalReporter,
) -> Generator[None, None, None]:
config = terminalreporter.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
@ -169,7 +182,7 @@ def pytest_terminal_summary(terminalreporter):
@pytest.hookimpl(hookwrapper=True)
def pytest_sessionfinish(session):
def pytest_sessionfinish(session: Session) -> Generator[None, None, None]:
config = session.config
with catch_warnings_for_item(
config=config, ihook=config.hook, when="config", item=None
@ -177,7 +190,7 @@ def pytest_sessionfinish(session):
yield
def _issue_warning_captured(warning, hook, stacklevel):
def _issue_warning_captured(warning: Warning, hook, stacklevel: int) -> None:
"""
This function should be used instead of calling ``warnings.warn`` directly when we are in the "configure" stage:
at this point the actual options might not have been set, so we manually trigger the pytest_warning_recorded
@ -190,8 +203,6 @@ def _issue_warning_captured(warning, hook, stacklevel):
with warnings.catch_warnings(record=True) as records:
warnings.simplefilter("always", type(warning))
warnings.warn(warning, stacklevel=stacklevel)
# Mypy can't infer that record=True means records is not None; help it.
assert records is not None
frame = sys._getframe(stacklevel - 1)
location = frame.f_code.co_filename, frame.f_lineno, frame.f_code.co_name
hook.pytest_warning_captured.call_historic(

View File

@ -4,6 +4,7 @@ import os
import queue
import sys
import textwrap
from typing import Tuple
from typing import Union
import py
@ -14,6 +15,7 @@ from _pytest._code.code import ExceptionChainRepr
from _pytest._code.code import ExceptionInfo
from _pytest._code.code import FormattedExcinfo
from _pytest._io import TerminalWriter
from _pytest.compat import TYPE_CHECKING
from _pytest.pytester import LineMatcher
try:
@ -23,6 +25,9 @@ except ImportError:
else:
invalidate_import_caches = getattr(importlib, "invalidate_caches", None)
if TYPE_CHECKING:
from _pytest._code.code import _TracebackStyle
@pytest.fixture
def limited_recursion_depth():
@ -40,10 +45,11 @@ def test_excinfo_simple() -> None:
assert info.type == ValueError
def test_excinfo_from_exc_info_simple():
def test_excinfo_from_exc_info_simple() -> None:
try:
raise ValueError
except ValueError as e:
assert e.__traceback__ is not None
info = _pytest._code.ExceptionInfo.from_exc_info((type(e), e, e.__traceback__))
assert info.type == ValueError
@ -317,25 +323,25 @@ def test_excinfo_exconly():
assert msg.endswith("world")
def test_excinfo_repr_str():
excinfo = pytest.raises(ValueError, h)
assert repr(excinfo) == "<ExceptionInfo ValueError() tblen=4>"
assert str(excinfo) == "<ExceptionInfo ValueError() tblen=4>"
def test_excinfo_repr_str() -> None:
excinfo1 = pytest.raises(ValueError, h)
assert repr(excinfo1) == "<ExceptionInfo ValueError() tblen=4>"
assert str(excinfo1) == "<ExceptionInfo ValueError() tblen=4>"
class CustomException(Exception):
def __repr__(self):
return "custom_repr"
def raises():
def raises() -> None:
raise CustomException()
excinfo = pytest.raises(CustomException, raises)
assert repr(excinfo) == "<ExceptionInfo custom_repr tblen=2>"
assert str(excinfo) == "<ExceptionInfo custom_repr tblen=2>"
excinfo2 = pytest.raises(CustomException, raises)
assert repr(excinfo2) == "<ExceptionInfo custom_repr tblen=2>"
assert str(excinfo2) == "<ExceptionInfo custom_repr tblen=2>"
def test_excinfo_for_later():
e = ExceptionInfo.for_later()
def test_excinfo_for_later() -> None:
e = ExceptionInfo[BaseException].for_later()
assert "for raises" in repr(e)
assert "for raises" in str(e)
@ -463,7 +469,7 @@ class TestFormattedExcinfo:
assert lines[0] == "| def f(x):"
assert lines[1] == " pass"
def test_repr_source_excinfo(self):
def test_repr_source_excinfo(self) -> None:
""" check if indentation is right """
pr = FormattedExcinfo()
excinfo = self.excinfo_from_exec(
@ -475,6 +481,7 @@ class TestFormattedExcinfo:
)
pr = FormattedExcinfo()
source = pr._getentrysource(excinfo.traceback[-1])
assert source is not None
lines = pr.get_source(source, 1, excinfo)
assert lines == [" def f():", "> assert 0", "E AssertionError"]
@ -522,17 +529,18 @@ raise ValueError()
assert repr.reprtraceback.reprentries[0].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[0].lines[0] == "> ???"
def test_repr_local(self):
def test_repr_local(self) -> None:
p = FormattedExcinfo(showlocals=True)
loc = {"y": 5, "z": 7, "x": 3, "@x": 2, "__builtins__": {}}
reprlocals = p.repr_locals(loc)
assert reprlocals is not None
assert reprlocals.lines
assert reprlocals.lines[0] == "__builtins__ = <builtins>"
assert reprlocals.lines[1] == "x = 3"
assert reprlocals.lines[2] == "y = 5"
assert reprlocals.lines[3] == "z = 7"
def test_repr_local_with_error(self):
def test_repr_local_with_error(self) -> None:
class ObjWithErrorInRepr:
def __repr__(self):
raise NotImplementedError
@ -540,11 +548,12 @@ raise ValueError()
p = FormattedExcinfo(showlocals=True, truncate_locals=False)
loc = {"x": ObjWithErrorInRepr(), "__builtins__": {}}
reprlocals = p.repr_locals(loc)
assert reprlocals is not None
assert reprlocals.lines
assert reprlocals.lines[0] == "__builtins__ = <builtins>"
assert "[NotImplementedError() raised in repr()]" in reprlocals.lines[1]
def test_repr_local_with_exception_in_class_property(self):
def test_repr_local_with_exception_in_class_property(self) -> None:
class ExceptionWithBrokenClass(Exception):
# Type ignored because it's bypassed intentionally.
@property # type: ignore
@ -558,23 +567,26 @@ raise ValueError()
p = FormattedExcinfo(showlocals=True, truncate_locals=False)
loc = {"x": ObjWithErrorInRepr(), "__builtins__": {}}
reprlocals = p.repr_locals(loc)
assert reprlocals is not None
assert reprlocals.lines
assert reprlocals.lines[0] == "__builtins__ = <builtins>"
assert "[ExceptionWithBrokenClass() raised in repr()]" in reprlocals.lines[1]
def test_repr_local_truncated(self):
def test_repr_local_truncated(self) -> None:
loc = {"l": [i for i in range(10)]}
p = FormattedExcinfo(showlocals=True)
truncated_reprlocals = p.repr_locals(loc)
assert truncated_reprlocals is not None
assert truncated_reprlocals.lines
assert truncated_reprlocals.lines[0] == "l = [0, 1, 2, 3, 4, 5, ...]"
q = FormattedExcinfo(showlocals=True, truncate_locals=False)
full_reprlocals = q.repr_locals(loc)
assert full_reprlocals is not None
assert full_reprlocals.lines
assert full_reprlocals.lines[0] == "l = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"
def test_repr_tracebackentry_lines(self, importasmod):
def test_repr_tracebackentry_lines(self, importasmod) -> None:
mod = importasmod(
"""
def func1():
@ -602,11 +614,12 @@ raise ValueError()
assert not lines[4:]
loc = repr_entry.reprfileloc
assert loc is not None
assert loc.path == mod.__file__
assert loc.lineno == 3
# assert loc.message == "ValueError: hello"
def test_repr_tracebackentry_lines2(self, importasmod, tw_mock):
def test_repr_tracebackentry_lines2(self, importasmod, tw_mock) -> None:
mod = importasmod(
"""
def func1(m, x, y, z):
@ -618,6 +631,7 @@ raise ValueError()
entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry)
assert reprfuncargs is not None
assert reprfuncargs.args[0] == ("m", repr("m" * 90))
assert reprfuncargs.args[1] == ("x", "5")
assert reprfuncargs.args[2] == ("y", "13")
@ -625,13 +639,14 @@ raise ValueError()
p = FormattedExcinfo(funcargs=True)
repr_entry = p.repr_traceback_entry(entry)
assert repr_entry.reprfuncargs is not None
assert repr_entry.reprfuncargs.args == reprfuncargs.args
repr_entry.toterminal(tw_mock)
assert tw_mock.lines[0] == "m = " + repr("m" * 90)
assert tw_mock.lines[1] == "x = 5, y = 13"
assert tw_mock.lines[2] == "z = " + repr("z" * 120)
def test_repr_tracebackentry_lines_var_kw_args(self, importasmod, tw_mock):
def test_repr_tracebackentry_lines_var_kw_args(self, importasmod, tw_mock) -> None:
mod = importasmod(
"""
def func1(x, *y, **z):
@ -643,17 +658,19 @@ raise ValueError()
entry = excinfo.traceback[-1]
p = FormattedExcinfo(funcargs=True)
reprfuncargs = p.repr_args(entry)
assert reprfuncargs is not None
assert reprfuncargs.args[0] == ("x", repr("a"))
assert reprfuncargs.args[1] == ("y", repr(("b",)))
assert reprfuncargs.args[2] == ("z", repr({"c": "d"}))
p = FormattedExcinfo(funcargs=True)
repr_entry = p.repr_traceback_entry(entry)
assert repr_entry.reprfuncargs
assert repr_entry.reprfuncargs.args == reprfuncargs.args
repr_entry.toterminal(tw_mock)
assert tw_mock.lines[0] == "x = 'a', y = ('b',), z = {'c': 'd'}"
def test_repr_tracebackentry_short(self, importasmod):
def test_repr_tracebackentry_short(self, importasmod) -> None:
mod = importasmod(
"""
def func1():
@ -668,6 +685,7 @@ raise ValueError()
lines = reprtb.lines
basename = py.path.local(mod.__file__).basename
assert lines[0] == " func1()"
assert reprtb.reprfileloc is not None
assert basename in str(reprtb.reprfileloc.path)
assert reprtb.reprfileloc.lineno == 5
@ -677,6 +695,7 @@ raise ValueError()
lines = reprtb.lines
assert lines[0] == ' raise ValueError("hello")'
assert lines[1] == "E ValueError: hello"
assert reprtb.reprfileloc is not None
assert basename in str(reprtb.reprfileloc.path)
assert reprtb.reprfileloc.lineno == 3
@ -716,7 +735,7 @@ raise ValueError()
reprtb = p.repr_traceback(excinfo)
assert len(reprtb.reprentries) == 3
def test_traceback_short_no_source(self, importasmod, monkeypatch):
def test_traceback_short_no_source(self, importasmod, monkeypatch) -> None:
mod = importasmod(
"""
def func1():
@ -729,7 +748,7 @@ raise ValueError()
from _pytest._code.code import Code
monkeypatch.setattr(Code, "path", "bogus")
excinfo.traceback[0].frame.code.path = "bogus"
excinfo.traceback[0].frame.code.path = "bogus" # type: ignore[misc] # noqa: F821
p = FormattedExcinfo(style="short")
reprtb = p.repr_traceback_entry(excinfo.traceback[-2])
lines = reprtb.lines
@ -742,7 +761,7 @@ raise ValueError()
assert last_lines[0] == ' raise ValueError("hello")'
assert last_lines[1] == "E ValueError: hello"
def test_repr_traceback_and_excinfo(self, importasmod):
def test_repr_traceback_and_excinfo(self, importasmod) -> None:
mod = importasmod(
"""
def f(x):
@ -753,7 +772,8 @@ raise ValueError()
)
excinfo = pytest.raises(ValueError, mod.entry)
for style in ("long", "short"):
styles = ("long", "short") # type: Tuple[_TracebackStyle, ...]
for style in styles:
p = FormattedExcinfo(style=style)
reprtb = p.repr_traceback(excinfo)
assert len(reprtb.reprentries) == 2
@ -765,10 +785,11 @@ raise ValueError()
assert repr.chain[0][0]
assert len(repr.chain[0][0].reprentries) == len(reprtb.reprentries)
assert repr.reprcrash is not None
assert repr.reprcrash.path.endswith("mod.py")
assert repr.reprcrash.message == "ValueError: 0"
def test_repr_traceback_with_invalid_cwd(self, importasmod, monkeypatch):
def test_repr_traceback_with_invalid_cwd(self, importasmod, monkeypatch) -> None:
mod = importasmod(
"""
def f(x):
@ -787,7 +808,9 @@ raise ValueError()
def raiseos():
nonlocal raised
if sys._getframe().f_back.f_code.co_name == "checked_call":
upframe = sys._getframe().f_back
assert upframe is not None
if upframe.f_code.co_name == "checked_call":
# Only raise with expected calls, but not via e.g. inspect for
# py38-windows.
raised += 1
@ -831,7 +854,7 @@ raise ValueError()
assert tw_mock.lines[-1] == "content"
assert tw_mock.lines[-2] == ("-", "title")
def test_repr_excinfo_reprcrash(self, importasmod):
def test_repr_excinfo_reprcrash(self, importasmod) -> None:
mod = importasmod(
"""
def entry():
@ -840,6 +863,7 @@ raise ValueError()
)
excinfo = pytest.raises(ValueError, mod.entry)
repr = excinfo.getrepr()
assert repr.reprcrash is not None
assert repr.reprcrash.path.endswith("mod.py")
assert repr.reprcrash.lineno == 3
assert repr.reprcrash.message == "ValueError"
@ -864,7 +888,7 @@ raise ValueError()
assert reprtb.extraline == "!!! Recursion detected (same locals & position)"
assert str(reprtb)
def test_reprexcinfo_getrepr(self, importasmod):
def test_reprexcinfo_getrepr(self, importasmod) -> None:
mod = importasmod(
"""
def f(x):
@ -875,14 +899,15 @@ raise ValueError()
)
excinfo = pytest.raises(ValueError, mod.entry)
for style in ("short", "long", "no"):
styles = ("short", "long", "no") # type: Tuple[_TracebackStyle, ...]
for style in styles:
for showlocals in (True, False):
repr = excinfo.getrepr(style=style, showlocals=showlocals)
assert repr.reprtraceback.style == style
assert isinstance(repr, ExceptionChainRepr)
for repr in repr.chain:
assert repr[0].style == style
for r in repr.chain:
assert r[0].style == style
def test_reprexcinfo_unicode(self):
from _pytest._code.code import TerminalRepr

View File

@ -103,7 +103,7 @@ def test_warn_about_imminent_junit_family_default_change(testdir, junit_family):
result.stdout.fnmatch_lines([warning_msg])
def test_node_direct_ctor_warning():
def test_node_direct_ctor_warning() -> None:
class MockConfig:
pass
@ -112,8 +112,8 @@ def test_node_direct_ctor_warning():
DeprecationWarning,
match="Direct construction of .* has been deprecated, please use .*.from_parent.*",
) as w:
nodes.Node(name="test", config=ms, session=ms, nodeid="None")
assert w[0].lineno == inspect.currentframe().f_lineno - 1
nodes.Node(name="test", config=ms, session=ms, nodeid="None") # type: ignore
assert w[0].lineno == inspect.currentframe().f_lineno - 1 # type: ignore
assert w[0].filename == __file__

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
from dataclasses import field
def test_dataclasses():
def test_dataclasses() -> None:
@dataclass
class SimpleDataObject:
field_a: int = field()
field_b: int = field()
field_b: str = field()
left = SimpleDataObject(1, "b")
right = SimpleDataObject(1, "c")

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
from dataclasses import field
def test_dataclasses_with_attribute_comparison_off():
def test_dataclasses_with_attribute_comparison_off() -> None:
@dataclass
class SimpleDataObject:
field_a: int = field()
field_b: int = field(compare=False)
field_b: str = field(compare=False)
left = SimpleDataObject(1, "b")
right = SimpleDataObject(1, "c")

View File

@ -2,11 +2,11 @@ from dataclasses import dataclass
from dataclasses import field
def test_dataclasses_verbose():
def test_dataclasses_verbose() -> None:
@dataclass
class SimpleDataObject:
field_a: int = field()
field_b: int = field()
field_b: str = field()
left = SimpleDataObject(1, "b")
right = SimpleDataObject(1, "c")

View File

@ -2,18 +2,18 @@ from dataclasses import dataclass
from dataclasses import field
def test_comparing_two_different_data_classes():
def test_comparing_two_different_data_classes() -> None:
@dataclass
class SimpleDataObjectOne:
field_a: int = field()
field_b: int = field()
field_b: str = field()
@dataclass
class SimpleDataObjectTwo:
field_a: int = field()
field_b: int = field()
field_b: str = field()
left = SimpleDataObjectOne(1, "b")
right = SimpleDataObjectTwo(1, "c")
assert left != right
assert left != right # type: ignore[comparison-overlap] # noqa: F821

View File

@ -1,4 +1,6 @@
import pprint
from typing import List
from typing import Tuple
import pytest
@ -13,7 +15,7 @@ def pytest_generate_tests(metafunc):
@pytest.fixture(scope="session")
def checked_order():
order = []
order = [] # type: List[Tuple[str, str, str]]
yield order
pprint.pprint(order)

View File

@ -1,7 +1,8 @@
from typing import List
from unittest import IsolatedAsyncioTestCase # type: ignore
teardowns = []
teardowns = [] # type: List[None]
class AsyncArguments(IsolatedAsyncioTestCase):

View File

@ -1,10 +1,11 @@
"""Issue #7110"""
import asyncio
from typing import List
import asynctest
teardowns = []
teardowns = [] # type: List[None]
class Test(asynctest.TestCase):

View File

@ -25,7 +25,7 @@ def test_maxsize_error_on_instance():
assert s[0] == "(" and s[-1] == ")"
def test_exceptions():
def test_exceptions() -> None:
class BrokenRepr:
def __init__(self, ex):
self.ex = ex
@ -34,8 +34,8 @@ def test_exceptions():
raise self.ex
class BrokenReprException(Exception):
__str__ = None
__repr__ = None
__str__ = None # type: ignore[assignment] # noqa: F821
__repr__ = None # type: ignore[assignment] # noqa: F821
assert "Exception" in saferepr(BrokenRepr(Exception("broken")))
s = saferepr(BrokenReprException("really broken"))
@ -44,7 +44,7 @@ def test_exceptions():
none = None
try:
none()
none() # type: ignore[misc] # noqa: F821
except BaseException as exc:
exp_exc = repr(exc)
obj = BrokenRepr(BrokenReprException("omg even worse"))
@ -136,10 +136,10 @@ def test_big_repr():
assert len(saferepr(range(1000))) <= len("[" + SafeRepr(0).maxlist * "1000" + "]")
def test_repr_on_newstyle():
def test_repr_on_newstyle() -> None:
class Function:
def __repr__(self):
return "<%s>" % (self.name)
return "<%s>" % (self.name) # type: ignore[attr-defined] # noqa: F821
assert saferepr(Function())

View File

@ -1,10 +1,11 @@
import logging
from typing import Any
from _pytest._io import TerminalWriter
from _pytest.logging import ColoredLevelFormatter
def test_coloredlogformatter():
def test_coloredlogformatter() -> None:
logfmt = "%(filename)-25s %(lineno)4d %(levelname)-8s %(message)s"
record = logging.LogRecord(
@ -14,7 +15,7 @@ def test_coloredlogformatter():
lineno=10,
msg="Test Message",
args=(),
exc_info=False,
exc_info=None,
)
class ColorConfig:
@ -35,7 +36,7 @@ def test_coloredlogformatter():
assert output == ("dummypath 10 INFO Test Message")
def test_multiline_message():
def test_multiline_message() -> None:
from _pytest.logging import PercentStyleMultiline
logfmt = "%(filename)-25s %(lineno)4d %(levelname)-8s %(message)s"
@ -47,8 +48,8 @@ def test_multiline_message():
lineno=10,
msg="Test Message line1\nline2",
args=(),
exc_info=False,
)
exc_info=None,
) # type: Any
# this is called by logging.Formatter.format
record.message = record.getMessage()
@ -124,7 +125,7 @@ def test_multiline_message():
)
def test_colored_short_level():
def test_colored_short_level() -> None:
logfmt = "%(levelname).1s %(message)s"
record = logging.LogRecord(
@ -134,7 +135,7 @@ def test_colored_short_level():
lineno=10,
msg="Test Message",
args=(),
exc_info=False,
exc_info=None,
)
class ColorConfig:

View File

@ -1,9 +1,12 @@
import io
import os
import re
from typing import cast
import pytest
from _pytest.capture import CaptureManager
from _pytest.pytester import Testdir
from _pytest.terminal import TerminalReporter
def test_nothing_logged(testdir):
@ -808,7 +811,7 @@ def test_log_file_unicode(testdir):
@pytest.mark.parametrize("has_capture_manager", [True, False])
def test_live_logging_suspends_capture(has_capture_manager, request):
def test_live_logging_suspends_capture(has_capture_manager: bool, request) -> None:
"""Test that capture manager is suspended when we emitting messages for live logging.
This tests the implementation calls instead of behavior because it is difficult/impossible to do it using
@ -835,8 +838,10 @@ def test_live_logging_suspends_capture(has_capture_manager, request):
def section(self, *args, **kwargs):
pass
out_file = DummyTerminal()
capture_manager = MockCaptureManager() if has_capture_manager else None
out_file = cast(TerminalReporter, DummyTerminal())
capture_manager = (
cast(CaptureManager, MockCaptureManager()) if has_capture_manager else None
)
handler = _LiveLoggingStreamHandler(out_file, capture_manager)
handler.set_when("call")
@ -849,7 +854,7 @@ def test_live_logging_suspends_capture(has_capture_manager, request):
assert MockCaptureManager.calls == ["enter disabled", "exit disabled"]
else:
assert MockCaptureManager.calls == []
assert out_file.getvalue() == "\nsome message\n"
assert cast(io.StringIO, out_file).getvalue() == "\nsome message\n"
def test_collection_live_logging(testdir):

View File

@ -428,10 +428,11 @@ class TestApprox:
assert a12 != approx(a21)
assert a21 != approx(a12)
def test_doctests(self, mocked_doctest_runner):
def test_doctests(self, mocked_doctest_runner) -> None:
import doctest
parser = doctest.DocTestParser()
assert approx.__doc__ is not None
test = parser.get_doctest(
approx.__doc__, {"approx": approx}, approx.__name__, None, None
)

View File

@ -1,6 +1,8 @@
import os
import sys
import textwrap
from typing import Any
from typing import Dict
import _pytest._code
import pytest
@ -698,7 +700,7 @@ class TestFunction:
class TestSorting:
def test_check_equality(self, testdir):
def test_check_equality(self, testdir) -> None:
modcol = testdir.getmodulecol(
"""
def test_pass(): pass
@ -720,10 +722,10 @@ class TestSorting:
assert fn1 != fn3
for fn in fn1, fn2, fn3:
assert fn != 3
assert fn != 3 # type: ignore[comparison-overlap] # noqa: F821
assert fn != modcol
assert fn != [1, 2, 3]
assert [1, 2, 3] != fn
assert fn != [1, 2, 3] # type: ignore[comparison-overlap] # noqa: F821
assert [1, 2, 3] != fn # type: ignore[comparison-overlap] # noqa: F821
assert modcol != fn
def test_allow_sane_sorting_for_decorators(self, testdir):
@ -1006,7 +1008,7 @@ class TestTracebackCutting:
assert "INTERNALERROR>" not in out
result.stdout.fnmatch_lines(["*ValueError: fail me*", "* 1 error in *"])
def test_filter_traceback_generated_code(self):
def test_filter_traceback_generated_code(self) -> None:
"""test that filter_traceback() works with the fact that
_pytest._code.code.Code.path attribute might return an str object.
In this case, one of the entries on the traceback was produced by
@ -1017,17 +1019,18 @@ class TestTracebackCutting:
from _pytest.python import filter_traceback
try:
ns = {}
ns = {} # type: Dict[str, Any]
exec("def foo(): raise ValueError", ns)
ns["foo"]()
except ValueError:
_, _, tb = sys.exc_info()
tb = _pytest._code.Traceback(tb)
assert isinstance(tb[-1].path, str)
assert not filter_traceback(tb[-1])
assert tb is not None
traceback = _pytest._code.Traceback(tb)
assert isinstance(traceback[-1].path, str)
assert not filter_traceback(traceback[-1])
def test_filter_traceback_path_no_longer_valid(self, testdir):
def test_filter_traceback_path_no_longer_valid(self, testdir) -> None:
"""test that filter_traceback() works with the fact that
_pytest._code.code.Code.path attribute might return an str object.
In this case, one of the files in the traceback no longer exists.
@ -1049,10 +1052,11 @@ class TestTracebackCutting:
except ValueError:
_, _, tb = sys.exc_info()
assert tb is not None
testdir.tmpdir.join("filter_traceback_entry_as_str.py").remove()
tb = _pytest._code.Traceback(tb)
assert isinstance(tb[-1].path, str)
assert filter_traceback(tb[-1])
traceback = _pytest._code.Traceback(tb)
assert isinstance(traceback[-1].path, str)
assert filter_traceback(traceback[-1])
class TestReportInfo:

View File

@ -3799,7 +3799,7 @@ class TestScopeOrdering:
request = FixtureRequest(items[0])
assert request.fixturenames == "m1 f1".split()
def test_func_closure_with_native_fixtures(self, testdir, monkeypatch):
def test_func_closure_with_native_fixtures(self, testdir, monkeypatch) -> None:
"""Sanity check that verifies the order returned by the closures and the actual fixture execution order:
The execution order may differ because of fixture inter-dependencies.
"""
@ -3849,9 +3849,8 @@ class TestScopeOrdering:
)
testdir.runpytest()
# actual fixture execution differs: dependent fixtures must be created first ("my_tmpdir")
assert (
pytest.FIXTURE_ORDER == "s1 my_tmpdir_factory p1 m1 my_tmpdir f1 f2".split()
)
FIXTURE_ORDER = pytest.FIXTURE_ORDER # type: ignore[attr-defined] # noqa: F821
assert FIXTURE_ORDER == "s1 my_tmpdir_factory p1 m1 my_tmpdir f1 f2".split()
def test_func_closure_module(self, testdir):
testdir.makepyfile(
@ -4159,7 +4158,7 @@ def test_fixture_duplicated_arguments() -> None:
"""Raise error if there are positional and keyword arguments for the same parameter (#1682)."""
with pytest.raises(TypeError) as excinfo:
@pytest.fixture("session", scope="session")
@pytest.fixture("session", scope="session") # type: ignore[call-overload] # noqa: F821
def arg(arg):
pass
@ -4171,7 +4170,7 @@ def test_fixture_duplicated_arguments() -> None:
with pytest.raises(TypeError) as excinfo:
@pytest.fixture(
@pytest.fixture( # type: ignore[call-overload] # noqa: F821
"function",
["p1"],
True,
@ -4199,7 +4198,7 @@ def test_fixture_with_positionals() -> None:
with pytest.warns(pytest.PytestDeprecationWarning) as warnings:
@pytest.fixture("function", [0], True)
@pytest.fixture("function", [0], True) # type: ignore[call-overload] # noqa: F821
def fixture_with_positionals():
pass
@ -4213,7 +4212,7 @@ def test_fixture_with_positionals() -> None:
def test_fixture_with_too_many_positionals() -> None:
with pytest.raises(TypeError) as excinfo:
@pytest.fixture("function", [0], True, ["id"], "name", "extra")
@pytest.fixture("function", [0], True, ["id"], "name", "extra") # type: ignore[call-overload] # noqa: F821
def fixture_with_positionals():
pass

View File

@ -1,10 +1,14 @@
from typing import Any
import pytest
from _pytest import python
from _pytest import runner
class TestOEJSKITSpecials:
def test_funcarg_non_pycollectobj(self, testdir, recwarn): # rough jstests usage
def test_funcarg_non_pycollectobj(
self, testdir, recwarn
) -> None: # rough jstests usage
testdir.makeconftest(
"""
import pytest
@ -28,13 +32,14 @@ class TestOEJSKITSpecials:
)
# this hook finds funcarg factories
rep = runner.collect_one_node(collector=modcol)
clscol = rep.result[0]
# TODO: Don't treat as Any.
clscol = rep.result[0] # type: Any
clscol.obj = lambda arg1: None
clscol.funcargs = {}
pytest._fillfuncargs(clscol)
assert clscol.funcargs["arg1"] == 42
def test_autouse_fixture(self, testdir, recwarn): # rough jstests usage
def test_autouse_fixture(self, testdir, recwarn) -> None: # rough jstests usage
testdir.makeconftest(
"""
import pytest
@ -61,20 +66,21 @@ class TestOEJSKITSpecials:
)
# this hook finds funcarg factories
rep = runner.collect_one_node(modcol)
clscol = rep.result[0]
# TODO: Don't treat as Any.
clscol = rep.result[0] # type: Any
clscol.obj = lambda: None
clscol.funcargs = {}
pytest._fillfuncargs(clscol)
assert not clscol.funcargs
def test_wrapped_getfslineno():
def test_wrapped_getfslineno() -> None:
def func():
pass
def wrap(f):
func.__wrapped__ = f
func.patchings = ["qwe"]
func.__wrapped__ = f # type: ignore
func.patchings = ["qwe"] # type: ignore
return func
@wrap
@ -87,14 +93,14 @@ def test_wrapped_getfslineno():
class TestMockDecoration:
def test_wrapped_getfuncargnames(self):
def test_wrapped_getfuncargnames(self) -> None:
from _pytest.compat import getfuncargnames
def wrap(f):
def func():
pass
func.__wrapped__ = f
func.__wrapped__ = f # type: ignore
return func
@wrap
@ -322,10 +328,11 @@ class TestReRunTests:
)
def test_pytestconfig_is_session_scoped():
def test_pytestconfig_is_session_scoped() -> None:
from _pytest.fixtures import pytestconfig
assert pytestconfig._pytestfixturefunction.scope == "session"
marker = pytestconfig._pytestfixturefunction # type: ignore
assert marker.scope == "session"
class TestNoselikeTestAttribute:

View File

@ -113,7 +113,7 @@ class TestMetafunc:
fail.Exception,
match=r"parametrize\(\) call in func got an unexpected scope value 'doggy'",
):
metafunc.parametrize("x", [1], scope="doggy")
metafunc.parametrize("x", [1], scope="doggy") # type: ignore[arg-type] # noqa: F821
def test_parametrize_request_name(self, testdir: Testdir) -> None:
"""Show proper error when 'request' is used as a parameter name in parametrize (#6183)"""

View File

@ -6,9 +6,9 @@ from _pytest.outcomes import Failed
class TestRaises:
def test_check_callable(self):
def test_check_callable(self) -> None:
with pytest.raises(TypeError, match=r".* must be callable"):
pytest.raises(RuntimeError, "int('qwe')")
pytest.raises(RuntimeError, "int('qwe')") # type: ignore[call-overload] # noqa: F821
def test_raises(self):
excinfo = pytest.raises(ValueError, int, "qwe")
@ -18,19 +18,19 @@ class TestRaises:
excinfo = pytest.raises(ValueError, int, "hello")
assert "invalid literal" in str(excinfo.value)
def test_raises_callable_no_exception(self):
def test_raises_callable_no_exception(self) -> None:
class A:
def __call__(self):
pass
try:
pytest.raises(ValueError, A())
except pytest.raises.Exception:
except pytest.fail.Exception:
pass
def test_raises_falsey_type_error(self):
def test_raises_falsey_type_error(self) -> None:
with pytest.raises(TypeError):
with pytest.raises(AssertionError, match=0):
with pytest.raises(AssertionError, match=0): # type: ignore[call-overload] # noqa: F821
raise AssertionError("ohai")
def test_raises_repr_inflight(self):
@ -126,23 +126,23 @@ class TestRaises:
result = testdir.runpytest()
result.stdout.fnmatch_lines(["*2 failed*"])
def test_noclass(self):
def test_noclass(self) -> None:
with pytest.raises(TypeError):
pytest.raises("wrong", lambda: None)
pytest.raises("wrong", lambda: None) # type: ignore[call-overload] # noqa: F821
def test_invalid_arguments_to_raises(self):
def test_invalid_arguments_to_raises(self) -> None:
with pytest.raises(TypeError, match="unknown"):
with pytest.raises(TypeError, unknown="bogus"):
with pytest.raises(TypeError, unknown="bogus"): # type: ignore[call-overload] # noqa: F821
raise ValueError()
def test_tuple(self):
with pytest.raises((KeyError, ValueError)):
raise KeyError("oops")
def test_no_raise_message(self):
def test_no_raise_message(self) -> None:
try:
pytest.raises(ValueError, int, "0")
except pytest.raises.Exception as e:
except pytest.fail.Exception as e:
assert e.msg == "DID NOT RAISE {}".format(repr(ValueError))
else:
assert False, "Expected pytest.raises.Exception"
@ -150,7 +150,7 @@ class TestRaises:
try:
with pytest.raises(ValueError):
pass
except pytest.raises.Exception as e:
except pytest.fail.Exception as e:
assert e.msg == "DID NOT RAISE {}".format(repr(ValueError))
else:
assert False, "Expected pytest.raises.Exception"
@ -252,7 +252,7 @@ class TestRaises:
):
pytest.raises(ClassLooksIterableException, lambda: None)
def test_raises_with_raising_dunder_class(self):
def test_raises_with_raising_dunder_class(self) -> None:
"""Test current behavior with regard to exceptions via __class__ (#4284)."""
class CrappyClass(Exception):
@ -262,12 +262,12 @@ class TestRaises:
assert False, "via __class__"
with pytest.raises(AssertionError) as excinfo:
with pytest.raises(CrappyClass()):
with pytest.raises(CrappyClass()): # type: ignore[call-overload] # noqa: F821
pass
assert "via __class__" in excinfo.value.args[0]
def test_raises_context_manager_with_kwargs(self):
with pytest.raises(TypeError) as excinfo:
with pytest.raises(Exception, foo="bar"):
with pytest.raises(Exception, foo="bar"): # type: ignore[call-overload] # noqa: F821
pass
assert "Unexpected keyword arguments" in str(excinfo.value)

View File

@ -279,9 +279,9 @@ class TestImportHookInstallation:
]
)
def test_register_assert_rewrite_checks_types(self):
def test_register_assert_rewrite_checks_types(self) -> None:
with pytest.raises(TypeError):
pytest.register_assert_rewrite(["pytest_tests_internal_non_existing"])
pytest.register_assert_rewrite(["pytest_tests_internal_non_existing"]) # type: ignore
pytest.register_assert_rewrite(
"pytest_tests_internal_non_existing", "pytest_tests_internal_non_existing2"
)
@ -326,8 +326,10 @@ class TestAssert_reprcompare:
def test_different_types(self):
assert callequal([0, 1], "foo") is None
def test_summary(self):
summary = callequal([0, 1], [0, 2])[0]
def test_summary(self) -> None:
lines = callequal([0, 1], [0, 2])
assert lines is not None
summary = lines[0]
assert len(summary) < 65
def test_text_diff(self):
@ -337,21 +339,24 @@ class TestAssert_reprcompare:
"+ spam",
]
def test_text_skipping(self):
def test_text_skipping(self) -> None:
lines = callequal("a" * 50 + "spam", "a" * 50 + "eggs")
assert lines is not None
assert "Skipping" in lines[1]
for line in lines:
assert "a" * 50 not in line
def test_text_skipping_verbose(self):
def test_text_skipping_verbose(self) -> None:
lines = callequal("a" * 50 + "spam", "a" * 50 + "eggs", verbose=1)
assert lines is not None
assert "- " + "a" * 50 + "eggs" in lines
assert "+ " + "a" * 50 + "spam" in lines
def test_multiline_text_diff(self):
def test_multiline_text_diff(self) -> None:
left = "foo\nspam\nbar"
right = "foo\neggs\nbar"
diff = callequal(left, right)
assert diff is not None
assert "- eggs" in diff
assert "+ spam" in diff
@ -376,8 +381,9 @@ class TestAssert_reprcompare:
"+ b'spam'",
]
def test_list(self):
def test_list(self) -> None:
expl = callequal([0, 1], [0, 2])
assert expl is not None
assert len(expl) > 1
@pytest.mark.parametrize(
@ -421,21 +427,25 @@ class TestAssert_reprcompare:
),
],
)
def test_iterable_full_diff(self, left, right, expected):
def test_iterable_full_diff(self, left, right, expected) -> None:
"""Test the full diff assertion failure explanation.
When verbose is False, then just a -v notice to get the diff is rendered,
when verbose is True, then ndiff of the pprint is returned.
"""
expl = callequal(left, right, verbose=0)
assert expl is not None
assert expl[-1] == "Use -v to get the full diff"
expl = "\n".join(callequal(left, right, verbose=1))
assert expl.endswith(textwrap.dedent(expected).strip())
verbose_expl = callequal(left, right, verbose=1)
assert verbose_expl is not None
assert "\n".join(verbose_expl).endswith(textwrap.dedent(expected).strip())
def test_list_different_lengths(self):
def test_list_different_lengths(self) -> None:
expl = callequal([0, 1], [0, 1, 2])
assert expl is not None
assert len(expl) > 1
expl = callequal([0, 1, 2], [0, 1])
assert expl is not None
assert len(expl) > 1
def test_list_wrap_for_multiple_lines(self):
@ -545,27 +555,31 @@ class TestAssert_reprcompare:
" }",
]
def test_dict(self):
def test_dict(self) -> None:
expl = callequal({"a": 0}, {"a": 1})
assert expl is not None
assert len(expl) > 1
def test_dict_omitting(self):
def test_dict_omitting(self) -> None:
lines = callequal({"a": 0, "b": 1}, {"a": 1, "b": 1})
assert lines is not None
assert lines[1].startswith("Omitting 1 identical item")
assert "Common items" not in lines
for line in lines[1:]:
assert "b" not in line
def test_dict_omitting_with_verbosity_1(self):
def test_dict_omitting_with_verbosity_1(self) -> None:
""" Ensure differing items are visible for verbosity=1 (#1512) """
lines = callequal({"a": 0, "b": 1}, {"a": 1, "b": 1}, verbose=1)
assert lines is not None
assert lines[1].startswith("Omitting 1 identical item")
assert lines[2].startswith("Differing items")
assert lines[3] == "{'a': 0} != {'a': 1}"
assert "Common items" not in lines
def test_dict_omitting_with_verbosity_2(self):
def test_dict_omitting_with_verbosity_2(self) -> None:
lines = callequal({"a": 0, "b": 1}, {"a": 1, "b": 1}, verbose=2)
assert lines is not None
assert lines[1].startswith("Common items:")
assert "Omitting" not in lines[1]
assert lines[2] == "{'b': 1}"
@ -614,15 +628,17 @@ class TestAssert_reprcompare:
"+ (1, 2, 3)",
]
def test_set(self):
def test_set(self) -> None:
expl = callequal({0, 1}, {0, 2})
assert expl is not None
assert len(expl) > 1
def test_frozenzet(self):
def test_frozenzet(self) -> None:
expl = callequal(frozenset([0, 1]), {0, 2})
assert expl is not None
assert len(expl) > 1
def test_Sequence(self):
def test_Sequence(self) -> None:
# Test comparing with a Sequence subclass.
class TestSequence(collections.abc.MutableSequence):
def __init__(self, iterable):
@ -644,15 +660,18 @@ class TestAssert_reprcompare:
pass
expl = callequal(TestSequence([0, 1]), list([0, 2]))
assert expl is not None
assert len(expl) > 1
def test_list_tuples(self):
def test_list_tuples(self) -> None:
expl = callequal([], [(1, 2)])
assert expl is not None
assert len(expl) > 1
expl = callequal([(1, 2)], [])
assert expl is not None
assert len(expl) > 1
def test_repr_verbose(self):
def test_repr_verbose(self) -> None:
class Nums:
def __init__(self, nums):
self.nums = nums
@ -669,21 +688,25 @@ class TestAssert_reprcompare:
assert callequal(nums_x, nums_y) is None
expl = callequal(nums_x, nums_y, verbose=1)
assert expl is not None
assert "+" + repr(nums_x) in expl
assert "-" + repr(nums_y) in expl
expl = callequal(nums_x, nums_y, verbose=2)
assert expl is not None
assert "+" + repr(nums_x) in expl
assert "-" + repr(nums_y) in expl
def test_list_bad_repr(self):
def test_list_bad_repr(self) -> None:
class A:
def __repr__(self):
raise ValueError(42)
expl = callequal([], [A()])
assert expl is not None
assert "ValueError" in "".join(expl)
expl = callequal({}, {"1": A()}, verbose=2)
assert expl is not None
assert expl[0].startswith("{} == <[ValueError")
assert "raised in repr" in expl[0]
assert expl[1:] == [
@ -707,9 +730,10 @@ class TestAssert_reprcompare:
expl = callequal(A(), "")
assert not expl
def test_repr_no_exc(self):
expl = " ".join(callequal("foo", "bar"))
assert "raised in repr()" not in expl
def test_repr_no_exc(self) -> None:
expl = callequal("foo", "bar")
assert expl is not None
assert "raised in repr()" not in " ".join(expl)
def test_unicode(self):
assert callequal("£€", "£") == [
@ -734,11 +758,12 @@ class TestAssert_reprcompare:
def test_format_nonascii_explanation(self):
assert util.format_explanation("λ")
def test_mojibake(self):
def test_mojibake(self) -> None:
# issue 429
left = b"e"
right = b"\xc3\xa9"
expl = callequal(left, right)
assert expl is not None
for line in expl:
assert isinstance(line, str)
msg = "\n".join(expl)
@ -791,7 +816,7 @@ class TestAssert_reprcompare_dataclass:
class TestAssert_reprcompare_attrsclass:
def test_attrs(self):
def test_attrs(self) -> None:
@attr.s
class SimpleDataObject:
field_a = attr.ib()
@ -801,12 +826,13 @@ class TestAssert_reprcompare_attrsclass:
right = SimpleDataObject(1, "c")
lines = callequal(left, right)
assert lines is not None
assert lines[1].startswith("Omitting 1 identical item")
assert "Matching attributes" not in lines
for line in lines[1:]:
assert "field_a" not in line
def test_attrs_verbose(self):
def test_attrs_verbose(self) -> None:
@attr.s
class SimpleDataObject:
field_a = attr.ib()
@ -816,6 +842,7 @@ class TestAssert_reprcompare_attrsclass:
right = SimpleDataObject(1, "c")
lines = callequal(left, right, verbose=2)
assert lines is not None
assert lines[1].startswith("Matching attributes:")
assert "Omitting" not in lines[1]
assert lines[2] == "['field_a']"
@ -824,12 +851,13 @@ class TestAssert_reprcompare_attrsclass:
@attr.s
class SimpleDataObject:
field_a = attr.ib()
field_b = attr.ib(**{ATTRS_EQ_FIELD: False})
field_b = attr.ib(**{ATTRS_EQ_FIELD: False}) # type: ignore
left = SimpleDataObject(1, "b")
right = SimpleDataObject(1, "b")
lines = callequal(left, right, verbose=2)
assert lines is not None
assert lines[1].startswith("Matching attributes:")
assert "Omitting" not in lines[1]
assert lines[2] == "['field_a']"
@ -946,8 +974,8 @@ class TestTruncateExplanation:
# to calculate that results have the expected length.
LINES_IN_TRUNCATION_MSG = 2
def test_doesnt_truncate_when_input_is_empty_list(self):
expl = []
def test_doesnt_truncate_when_input_is_empty_list(self) -> None:
expl = [] # type: List[str]
result = truncate._truncate_explanation(expl, max_lines=8, max_chars=100)
assert result == expl

View File

@ -9,6 +9,13 @@ import sys
import textwrap
import zipfile
from functools import partial
from typing import Dict
from typing import List
from typing import Mapping
from typing import Optional
from typing import Set
import py
import _pytest._code
import pytest
@ -25,24 +32,26 @@ from _pytest.pathlib import Path
from _pytest.pytester import Testdir
def rewrite(src):
def rewrite(src: str) -> ast.Module:
tree = ast.parse(src)
rewrite_asserts(tree, src.encode())
return tree
def getmsg(f, extra_ns=None, must_pass=False):
def getmsg(
f, extra_ns: Optional[Mapping[str, object]] = None, *, must_pass: bool = False
) -> Optional[str]:
"""Rewrite the assertions in f, run it, and get the failure message."""
src = "\n".join(_pytest._code.Code(f).source().lines)
mod = rewrite(src)
code = compile(mod, "<test>", "exec")
ns = {}
ns = {} # type: Dict[str, object]
if extra_ns is not None:
ns.update(extra_ns)
exec(code, ns)
func = ns[f.__name__]
try:
func()
func() # type: ignore[operator] # noqa: F821
except AssertionError:
if must_pass:
pytest.fail("shouldn't have raised")
@ -53,6 +62,7 @@ def getmsg(f, extra_ns=None, must_pass=False):
else:
if not must_pass:
pytest.fail("function didn't raise at all")
return None
class TestAssertionRewrite:
@ -98,10 +108,11 @@ class TestAssertionRewrite:
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)
def test_dont_rewrite(self):
def test_dont_rewrite(self) -> None:
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s)
assert len(m.body) == 2
assert isinstance(m.body[1], ast.Assert)
assert m.body[1].msg is None
def test_dont_rewrite_plugin(self, testdir):
@ -145,28 +156,28 @@ class TestAssertionRewrite:
monkeypatch.syspath_prepend(xdir)
testdir.runpytest().assert_outcomes(passed=1)
def test_name(self, request):
def f():
def test_name(self, request) -> None:
def f1() -> None:
assert False
assert getmsg(f) == "assert False"
assert getmsg(f1) == "assert False"
def f():
def f2() -> None:
f = False
assert f
assert getmsg(f) == "assert False"
assert getmsg(f2) == "assert False"
def f():
assert a_global # noqa
def f3() -> None:
assert a_global # type: ignore[name-defined] # noqa
assert getmsg(f, {"a_global": False}) == "assert False"
assert getmsg(f3, {"a_global": False}) == "assert False"
def f():
assert sys == 42
def f4() -> None:
assert sys == 42 # type: ignore[comparison-overlap] # noqa: F821
verbose = request.config.getoption("verbose")
msg = getmsg(f, {"sys": sys})
msg = getmsg(f4, {"sys": sys})
if verbose > 0:
assert msg == (
"assert <module 'sys' (built-in)> == 42\n"
@ -176,64 +187,74 @@ class TestAssertionRewrite:
else:
assert msg == "assert sys == 42"
def f():
assert cls == 42 # noqa: F821
def f5() -> None:
assert cls == 42 # type: ignore[name-defined] # noqa: F821
class X:
pass
msg = getmsg(f, {"cls": X}).splitlines()
msg = getmsg(f5, {"cls": X})
assert msg is not None
lines = msg.splitlines()
if verbose > 1:
assert msg == ["assert {!r} == 42".format(X), " +{!r}".format(X), " -42"]
assert lines == [
"assert {!r} == 42".format(X),
" +{!r}".format(X),
" -42",
]
elif verbose > 0:
assert msg == [
assert lines == [
"assert <class 'test_...e.<locals>.X'> == 42",
" +{!r}".format(X),
" -42",
]
else:
assert msg == ["assert cls == 42"]
assert lines == ["assert cls == 42"]
def test_assertrepr_compare_same_width(self, request):
def test_assertrepr_compare_same_width(self, request) -> None:
"""Should use same width/truncation with same initial width."""
def f():
def f() -> None:
assert "1234567890" * 5 + "A" == "1234567890" * 5 + "B"
msg = getmsg(f).splitlines()[0]
msg = getmsg(f)
assert msg is not None
line = msg.splitlines()[0]
if request.config.getoption("verbose") > 1:
assert msg == (
assert line == (
"assert '12345678901234567890123456789012345678901234567890A' "
"== '12345678901234567890123456789012345678901234567890B'"
)
else:
assert msg == (
assert line == (
"assert '123456789012...901234567890A' "
"== '123456789012...901234567890B'"
)
def test_dont_rewrite_if_hasattr_fails(self, request):
def test_dont_rewrite_if_hasattr_fails(self, request) -> None:
class Y:
""" A class whos getattr fails, but not with `AttributeError` """
def __getattr__(self, attribute_name):
raise KeyError()
def __repr__(self):
def __repr__(self) -> str:
return "Y"
def __init__(self):
def __init__(self) -> None:
self.foo = 3
def f():
assert cls().foo == 2 # noqa
def f() -> None:
assert cls().foo == 2 # type: ignore[name-defined] # noqa: F821
# XXX: looks like the "where" should also be there in verbose mode?!
message = getmsg(f, {"cls": Y}).splitlines()
msg = getmsg(f, {"cls": Y})
assert msg is not None
lines = msg.splitlines()
if request.config.getoption("verbose") > 0:
assert message == ["assert 3 == 2", " +3", " -2"]
assert lines == ["assert 3 == 2", " +3", " -2"]
else:
assert message == [
assert lines == [
"assert 3 == 2",
" + where 3 = Y.foo",
" + where Y = cls()",
@ -314,145 +335,145 @@ class TestAssertionRewrite:
assert result.ret == 1
result.stdout.fnmatch_lines(["*AssertionError: b'ohai!'", "*assert False"])
def test_boolop(self):
def f():
def test_boolop(self) -> None:
def f1() -> None:
f = g = False
assert f and g
assert getmsg(f) == "assert (False)"
assert getmsg(f1) == "assert (False)"
def f():
def f2() -> None:
f = True
g = False
assert f and g
assert getmsg(f) == "assert (True and False)"
assert getmsg(f2) == "assert (True and False)"
def f():
def f3() -> None:
f = False
g = True
assert f and g
assert getmsg(f) == "assert (False)"
assert getmsg(f3) == "assert (False)"
def f():
def f4() -> None:
f = g = False
assert f or g
assert getmsg(f) == "assert (False or False)"
assert getmsg(f4) == "assert (False or False)"
def f():
def f5() -> None:
f = g = False
assert not f and not g
getmsg(f, must_pass=True)
getmsg(f5, must_pass=True)
def x():
def x() -> bool:
return False
def f():
def f6() -> None:
assert x() and x()
assert (
getmsg(f, {"x": x})
getmsg(f6, {"x": x})
== """assert (False)
+ where False = x()"""
)
def f():
def f7() -> None:
assert False or x()
assert (
getmsg(f, {"x": x})
getmsg(f7, {"x": x})
== """assert (False or False)
+ where False = x()"""
)
def f():
def f8() -> None:
assert 1 in {} and 2 in {}
assert getmsg(f) == "assert (1 in {})"
assert getmsg(f8) == "assert (1 in {})"
def f():
def f9() -> None:
x = 1
y = 2
assert x in {1: None} and y in {}
assert getmsg(f) == "assert (1 in {1: None} and 2 in {})"
assert getmsg(f9) == "assert (1 in {1: None} and 2 in {})"
def f():
def f10() -> None:
f = True
g = False
assert f or g
getmsg(f, must_pass=True)
getmsg(f10, must_pass=True)
def f():
def f11() -> None:
f = g = h = lambda: True
assert f() and g() and h()
getmsg(f, must_pass=True)
getmsg(f11, must_pass=True)
def test_short_circuit_evaluation(self):
def f():
assert True or explode # noqa
def test_short_circuit_evaluation(self) -> None:
def f1() -> None:
assert True or explode # type: ignore[name-defined] # noqa: F821
getmsg(f, must_pass=True)
getmsg(f1, must_pass=True)
def f():
def f2() -> None:
x = 1
assert x == 1 or x == 2
getmsg(f, must_pass=True)
getmsg(f2, must_pass=True)
def test_unary_op(self):
def f():
def test_unary_op(self) -> None:
def f1() -> None:
x = True
assert not x
assert getmsg(f) == "assert not True"
assert getmsg(f1) == "assert not True"
def f():
def f2() -> None:
x = 0
assert ~x + 1
assert getmsg(f) == "assert (~0 + 1)"
assert getmsg(f2) == "assert (~0 + 1)"
def f():
def f3() -> None:
x = 3
assert -x + x
assert getmsg(f) == "assert (-3 + 3)"
assert getmsg(f3) == "assert (-3 + 3)"
def f():
def f4() -> None:
x = 0
assert +x + x
assert getmsg(f) == "assert (+0 + 0)"
assert getmsg(f4) == "assert (+0 + 0)"
def test_binary_op(self):
def f():
def test_binary_op(self) -> None:
def f1() -> None:
x = 1
y = -1
assert x + y
assert getmsg(f) == "assert (1 + -1)"
assert getmsg(f1) == "assert (1 + -1)"
def f():
def f2() -> None:
assert not 5 % 4
assert getmsg(f) == "assert not (5 % 4)"
assert getmsg(f2) == "assert not (5 % 4)"
def test_boolop_percent(self):
def f():
def test_boolop_percent(self) -> None:
def f1() -> None:
assert 3 % 2 and False
assert getmsg(f) == "assert ((3 % 2) and False)"
assert getmsg(f1) == "assert ((3 % 2) and False)"
def f():
def f2() -> None:
assert False or 4 % 2
assert getmsg(f) == "assert (False or (4 % 2))"
assert getmsg(f2) == "assert (False or (4 % 2))"
def test_at_operator_issue1290(self, testdir):
testdir.makepyfile(
@ -480,133 +501,133 @@ class TestAssertionRewrite:
)
testdir.runpytest().assert_outcomes(passed=1)
def test_call(self):
def g(a=42, *args, **kwargs):
def test_call(self) -> None:
def g(a=42, *args, **kwargs) -> bool:
return False
ns = {"g": g}
def f():
def f1() -> None:
assert g()
assert (
getmsg(f, ns)
getmsg(f1, ns)
== """assert False
+ where False = g()"""
)
def f():
def f2() -> None:
assert g(1)
assert (
getmsg(f, ns)
getmsg(f2, ns)
== """assert False
+ where False = g(1)"""
)
def f():
def f3() -> None:
assert g(1, 2)
assert (
getmsg(f, ns)
getmsg(f3, ns)
== """assert False
+ where False = g(1, 2)"""
)
def f():
def f4() -> None:
assert g(1, g=42)
assert (
getmsg(f, ns)
getmsg(f4, ns)
== """assert False
+ where False = g(1, g=42)"""
)
def f():
def f5() -> None:
assert g(1, 3, g=23)
assert (
getmsg(f, ns)
getmsg(f5, ns)
== """assert False
+ where False = g(1, 3, g=23)"""
)
def f():
def f6() -> None:
seq = [1, 2, 3]
assert g(*seq)
assert (
getmsg(f, ns)
getmsg(f6, ns)
== """assert False
+ where False = g(*[1, 2, 3])"""
)
def f():
def f7() -> None:
x = "a"
assert g(**{x: 2})
assert (
getmsg(f, ns)
getmsg(f7, ns)
== """assert False
+ where False = g(**{'a': 2})"""
)
def test_attribute(self):
def test_attribute(self) -> None:
class X:
g = 3
ns = {"x": X}
def f():
assert not x.g # noqa
def f1() -> None:
assert not x.g # type: ignore[name-defined] # noqa: F821
assert (
getmsg(f, ns)
getmsg(f1, ns)
== """assert not 3
+ where 3 = x.g"""
)
def f():
x.a = False # noqa
assert x.a # noqa
def f2() -> None:
x.a = False # type: ignore[name-defined] # noqa: F821
assert x.a # type: ignore[name-defined] # noqa: F821
assert (
getmsg(f, ns)
getmsg(f2, ns)
== """assert False
+ where False = x.a"""
)
def test_comparisons(self):
def f():
def test_comparisons(self) -> None:
def f1() -> None:
a, b = range(2)
assert b < a
assert getmsg(f) == """assert 1 < 0"""
assert getmsg(f1) == """assert 1 < 0"""
def f():
def f2() -> None:
a, b, c = range(3)
assert a > b > c
assert getmsg(f) == """assert 0 > 1"""
assert getmsg(f2) == """assert 0 > 1"""
def f():
def f3() -> None:
a, b, c = range(3)
assert a < b > c
assert getmsg(f) == """assert 1 > 2"""
assert getmsg(f3) == """assert 1 > 2"""
def f():
def f4() -> None:
a, b, c = range(3)
assert a < b <= c
getmsg(f, must_pass=True)
getmsg(f4, must_pass=True)
def f():
def f5() -> None:
a, b, c = range(3)
assert a < b
assert b < c
getmsg(f, must_pass=True)
getmsg(f5, must_pass=True)
def test_len(self, request):
def f():
@ -619,29 +640,29 @@ class TestAssertionRewrite:
else:
assert msg == "assert 10 == 11\n + where 10 = len([0, 1, 2, 3, 4, 5, ...])"
def test_custom_reprcompare(self, monkeypatch):
def my_reprcompare(op, left, right):
def test_custom_reprcompare(self, monkeypatch) -> None:
def my_reprcompare1(op, left, right) -> str:
return "42"
monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
monkeypatch.setattr(util, "_reprcompare", my_reprcompare1)
def f():
def f1() -> None:
assert 42 < 3
assert getmsg(f) == "assert 42"
assert getmsg(f1) == "assert 42"
def my_reprcompare(op, left, right):
def my_reprcompare2(op, left, right) -> str:
return "{} {} {}".format(left, op, right)
monkeypatch.setattr(util, "_reprcompare", my_reprcompare)
monkeypatch.setattr(util, "_reprcompare", my_reprcompare2)
def f():
def f2() -> None:
assert 1 < 3 < 5 <= 4 < 7
assert getmsg(f) == "assert 5 <= 4"
assert getmsg(f2) == "assert 5 <= 4"
def test_assert_raising__bool__in_comparison(self):
def f():
def test_assert_raising__bool__in_comparison(self) -> None:
def f() -> None:
class A:
def __bool__(self):
raise ValueError(42)
@ -652,21 +673,25 @@ class TestAssertionRewrite:
def __repr__(self):
return "<MY42 object>"
def myany(x):
def myany(x) -> bool:
return False
assert myany(A() < 0)
assert "<MY42 object> < 0" in getmsg(f)
msg = getmsg(f)
assert msg is not None
assert "<MY42 object> < 0" in msg
def test_formatchar(self):
def f():
assert "%test" == "test"
def test_formatchar(self) -> None:
def f() -> None:
assert "%test" == "test" # type: ignore[comparison-overlap] # noqa: F821
assert getmsg(f).startswith("assert '%test' == 'test'")
msg = getmsg(f)
assert msg is not None
assert msg.startswith("assert '%test' == 'test'")
def test_custom_repr(self, request):
def f():
def test_custom_repr(self, request) -> None:
def f() -> None:
class Foo:
a = 1
@ -676,14 +701,16 @@ class TestAssertionRewrite:
f = Foo()
assert 0 == f.a
lines = util._format_lines([getmsg(f)])
msg = getmsg(f)
assert msg is not None
lines = util._format_lines([msg])
if request.config.getoption("verbose") > 0:
assert lines == ["assert 0 == 1\n +0\n -1"]
else:
assert lines == ["assert 0 == 1\n + where 1 = \\n{ \\n~ \\n}.a"]
def test_custom_repr_non_ascii(self):
def f():
def test_custom_repr_non_ascii(self) -> None:
def f() -> None:
class A:
name = "ä"
@ -694,6 +721,7 @@ class TestAssertionRewrite:
assert not a.name
msg = getmsg(f)
assert msg is not None
assert "UnicodeDecodeError" not in msg
assert "UnicodeEncodeError" not in msg
@ -895,6 +923,7 @@ def test_rewritten():
hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
)
spec = hook.find_spec("test_remember_rewritten_modules")
assert spec is not None
module = importlib.util.module_from_spec(spec)
hook.exec_module(module)
hook.mark_rewrite("test_remember_rewritten_modules")
@ -952,7 +981,8 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite")
source_path = str(tmpdir.ensure("source.py"))
pycpath = tmpdir.join("pyc").strpath
assert _write_pyc(state, [1], os.stat(source_path), pycpath)
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
if sys.platform == "win32":
from contextlib import contextmanager
@ -974,7 +1004,7 @@ class TestAssertionRewriteHookDetails:
monkeypatch.setattr("os.rename", raise_oserror)
assert not _write_pyc(state, [1], os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
def test_resources_provider_for_loader(self, testdir):
"""
@ -1006,7 +1036,7 @@ class TestAssertionRewriteHookDetails:
result = testdir.runpytest_subprocess()
result.assert_outcomes(passed=1)
def test_read_pyc(self, tmpdir):
def test_read_pyc(self, tmp_path: Path) -> None:
"""
Ensure that the `_read_pyc` can properly deal with corrupted pyc files.
In those circumstances it should just give up instead of generating
@ -1015,18 +1045,18 @@ class TestAssertionRewriteHookDetails:
import py_compile
from _pytest.assertion.rewrite import _read_pyc
source = tmpdir.join("source.py")
pyc = source + "c"
source = tmp_path / "source.py"
pyc = Path(str(source) + "c")
source.write("def test(): pass")
source.write_text("def test(): pass")
py_compile.compile(str(source), str(pyc))
contents = pyc.read(mode="rb")
contents = pyc.read_bytes()
strip_bytes = 20 # header is around 8 bytes, strip a little more
assert len(contents) > strip_bytes
pyc.write(contents[:strip_bytes], mode="wb")
pyc.write_bytes(contents[:strip_bytes])
assert _read_pyc(str(source), str(pyc)) is None # no error
assert _read_pyc(source, pyc) is None # no error
def test_reload_is_same_and_reloads(self, testdir: Testdir) -> None:
"""Reloading a (collected) module after change picks up the change."""
@ -1177,17 +1207,17 @@ def test_source_mtime_long_long(testdir, offset):
assert result.ret == 0
def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch) -> None:
"""Fix infinite recursion when writing pyc files: if an import happens to be triggered when writing the pyc
file, this would cause another call to the hook, which would trigger another pyc writing, which could
trigger another import, and so on. (#3506)"""
from _pytest.assertion import rewrite
from _pytest.assertion import rewrite as rewritemod
testdir.syspathinsert()
testdir.makepyfile(test_foo="def test_foo(): pass")
testdir.makepyfile(test_bar="def test_bar(): pass")
original_write_pyc = rewrite._write_pyc
original_write_pyc = rewritemod._write_pyc
write_pyc_called = []
@ -1198,7 +1228,7 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
assert hook.find_spec("test_bar") is None
return original_write_pyc(*args, **kwargs)
monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(rewritemod, "_write_pyc", spy_write_pyc)
monkeypatch.setattr(sys, "dont_write_bytecode", False)
hook = AssertionRewritingHook(pytestconfig)
@ -1211,14 +1241,14 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
class TestEarlyRewriteBailout:
@pytest.fixture
def hook(self, pytestconfig, monkeypatch, testdir):
def hook(self, pytestconfig, monkeypatch, testdir) -> AssertionRewritingHook:
"""Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
if PathFinder.find_spec has been called.
"""
import importlib.machinery
self.find_spec_calls = []
self.initial_paths = set()
self.find_spec_calls = [] # type: List[str]
self.initial_paths = set() # type: Set[py.path.local]
class StubSession:
_initialpaths = self.initial_paths
@ -1228,17 +1258,17 @@ class TestEarlyRewriteBailout:
def spy_find_spec(name, path):
self.find_spec_calls.append(name)
return importlib.machinery.PathFinder.find_spec(name, path)
return importlib.machinery.PathFinder.find_spec(name, path) # type: ignore
hook = AssertionRewritingHook(pytestconfig)
# use default patterns, otherwise we inherit pytest's testing config
hook.fnpats[:] = ["test_*.py", "*_test.py"]
monkeypatch.setattr(hook, "_find_spec", spy_find_spec)
hook.set_session(StubSession())
hook.set_session(StubSession()) # type: ignore[arg-type] # noqa: F821
testdir.syspathinsert()
return hook
def test_basic(self, testdir, hook):
def test_basic(self, testdir, hook: AssertionRewritingHook) -> None:
"""
Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
module will not be rewritten to optimize assertion rewriting (#3918).
@ -1271,7 +1301,9 @@ class TestEarlyRewriteBailout:
assert hook.find_spec("foobar") is not None
assert self.find_spec_calls == ["conftest", "test_foo", "foobar"]
def test_pattern_contains_subdirectories(self, testdir, hook):
def test_pattern_contains_subdirectories(
self, testdir, hook: AssertionRewritingHook
) -> None:
"""If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
because we need to match with the full path, which can only be found by calling PathFinder.find_spec
"""
@ -1514,17 +1546,17 @@ def test_get_assertion_exprs(src, expected):
assert _get_assertion_exprs(src) == expected
def test_try_makedirs(monkeypatch, tmp_path):
def test_try_makedirs(monkeypatch, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import try_makedirs
p = tmp_path / "foo"
# create
assert try_makedirs(str(p))
assert try_makedirs(p)
assert p.is_dir()
# already exist
assert try_makedirs(str(p))
assert try_makedirs(p)
# monkeypatch to simulate all error situations
def fake_mkdir(p, exist_ok=False, *, exc):
@ -1532,25 +1564,25 @@ def test_try_makedirs(monkeypatch, tmp_path):
raise exc
monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=FileNotFoundError()))
assert not try_makedirs(str(p))
assert not try_makedirs(p)
monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=NotADirectoryError()))
assert not try_makedirs(str(p))
assert not try_makedirs(p)
monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=PermissionError()))
assert not try_makedirs(str(p))
assert not try_makedirs(p)
err = OSError()
err.errno = errno.EROFS
monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=err))
assert not try_makedirs(str(p))
assert not try_makedirs(p)
# unhandled OSError should raise
err = OSError()
err.errno = errno.ECHILD
monkeypatch.setattr(os, "makedirs", partial(fake_mkdir, exc=err))
with pytest.raises(OSError) as exc_info:
try_makedirs(str(p))
try_makedirs(p)
assert exc_info.value.errno == errno.ECHILD

View File

@ -6,7 +6,9 @@ import sys
import textwrap
from io import UnsupportedOperation
from typing import BinaryIO
from typing import cast
from typing import Generator
from typing import TextIO
import pytest
from _pytest import capture
@ -1351,7 +1353,7 @@ def test_error_attribute_issue555(testdir):
not sys.platform.startswith("win") and sys.version_info[:2] >= (3, 6),
reason="only py3.6+ on windows",
)
def test_py36_windowsconsoleio_workaround_non_standard_streams():
def test_py36_windowsconsoleio_workaround_non_standard_streams() -> None:
"""
Ensure _py36_windowsconsoleio_workaround function works with objects that
do not implement the full ``io``-based stream protocol, for example execnet channels (#2666).
@ -1362,7 +1364,7 @@ def test_py36_windowsconsoleio_workaround_non_standard_streams():
def write(self, s):
pass
stream = DummyStream()
stream = cast(TextIO, DummyStream())
_py36_windowsconsoleio_workaround(stream)

View File

@ -634,13 +634,14 @@ class TestSession:
class Test_getinitialnodes:
def test_global_file(self, testdir, tmpdir):
def test_global_file(self, testdir, tmpdir) -> None:
x = tmpdir.ensure("x.py")
with tmpdir.as_cwd():
config = testdir.parseconfigure(x)
col = testdir.getnode(config, x)
assert isinstance(col, pytest.Module)
assert col.name == "x.py"
assert col.parent is not None
assert col.parent.parent is None
for col in col.listchain():
assert col.config is config

View File

@ -2,6 +2,9 @@ import os
import re
import sys
import textwrap
from typing import Dict
from typing import List
from typing import Sequence
import py.path
@ -264,9 +267,9 @@ class TestConfigCmdlineParsing:
class TestConfigAPI:
def test_config_trace(self, testdir):
def test_config_trace(self, testdir) -> None:
config = testdir.parseconfig()
values = []
values = [] # type: List[str]
config.trace.root.setwriter(values.append)
config.trace("hello")
assert len(values) == 1
@ -519,9 +522,9 @@ class TestConfigFromdictargs:
assert config.option.capture == "no"
assert config.args == args
def test_invocation_params_args(self, _sys_snapshot):
def test_invocation_params_args(self, _sys_snapshot) -> None:
"""Show that fromdictargs can handle args in their "orig" format"""
option_dict = {}
option_dict = {} # type: Dict[str, object]
args = ["-vvvv", "-s", "a", "b"]
config = Config.fromdictargs(option_dict, args)
@ -566,8 +569,8 @@ class TestConfigFromdictargs:
assert config.inicfg.get("should_not_be_set") is None
def test_options_on_small_file_do_not_blow_up(testdir):
def runfiletest(opts):
def test_options_on_small_file_do_not_blow_up(testdir) -> None:
def runfiletest(opts: Sequence[str]) -> None:
reprec = testdir.inline_run(*opts)
passed, skipped, failed = reprec.countoutcomes()
assert failed == 2
@ -580,19 +583,16 @@ def test_options_on_small_file_do_not_blow_up(testdir):
"""
)
for opts in (
[],
["-l"],
["-s"],
["--tb=no"],
["--tb=short"],
["--tb=long"],
["--fulltrace"],
["--traceconfig"],
["-v"],
["-v", "-v"],
):
runfiletest(opts + [path])
runfiletest([path])
runfiletest(["-l", path])
runfiletest(["-s", path])
runfiletest(["--tb=no", path])
runfiletest(["--tb=short", path])
runfiletest(["--tb=long", path])
runfiletest(["--fulltrace", path])
runfiletest(["--traceconfig", path])
runfiletest(["-v", path])
runfiletest(["-v", "-v", path])
def test_preparse_ordering_with_setuptools(testdir, monkeypatch):
@ -1360,7 +1360,7 @@ def test_invocation_args(testdir):
# args cannot be None
with pytest.raises(TypeError):
Config.InvocationParams(args=None, plugins=None, dir=Path())
Config.InvocationParams(args=None, plugins=None, dir=Path()) # type: ignore[arg-type] # noqa: F821
@pytest.mark.parametrize(

View File

@ -50,7 +50,7 @@ def custom_pdb_calls():
def interaction(self, *args):
called.append("interaction")
_pytest._CustomPdb = _CustomPdb
_pytest._CustomPdb = _CustomPdb # type: ignore
return called
@ -73,9 +73,9 @@ def custom_debugger_hook():
print("**CustomDebugger**")
called.append("set_trace")
_pytest._CustomDebugger = _CustomDebugger
_pytest._CustomDebugger = _CustomDebugger # type: ignore
yield called
del _pytest._CustomDebugger
del _pytest._CustomDebugger # type: ignore
class TestPDB:
@ -895,7 +895,7 @@ class TestDebuggingBreakpoints:
if sys.version_info >= (3, 7):
assert SUPPORTS_BREAKPOINT_BUILTIN is True
if sys.version_info.major == 3 and sys.version_info.minor == 5:
assert SUPPORTS_BREAKPOINT_BUILTIN is False
assert SUPPORTS_BREAKPOINT_BUILTIN is False # type: ignore[comparison-overlap]
@pytest.mark.skipif(
not SUPPORTS_BREAKPOINT_BUILTIN, reason="Requires breakpoint() builtin"

View File

@ -1,5 +1,7 @@
import inspect
import textwrap
from typing import Callable
from typing import Optional
import pytest
from _pytest.compat import MODULE_NOT_FOUND_ERROR
@ -1051,7 +1053,7 @@ class TestLiterals:
("1e3", "999"),
# The current implementation doesn't understand that numbers inside
# strings shouldn't be treated as numbers:
pytest.param("'3.1416'", "'3.14'", marks=pytest.mark.xfail),
pytest.param("'3.1416'", "'3.14'", marks=pytest.mark.xfail), # type: ignore
],
)
def test_number_non_matches(self, testdir, expression, output):
@ -1477,7 +1479,9 @@ class Broken:
@pytest.mark.parametrize( # pragma: no branch (lambdas are not called)
"stop", [None, _is_mocked, lambda f: None, lambda f: False, lambda f: True]
)
def test_warning_on_unwrap_of_broken_object(stop):
def test_warning_on_unwrap_of_broken_object(
stop: Optional[Callable[[object], object]]
) -> None:
bad_instance = Broken()
assert inspect.unwrap.__module__ == "inspect"
with _patch_unwrap_mock_aware():
@ -1486,7 +1490,7 @@ def test_warning_on_unwrap_of_broken_object(stop):
pytest.PytestWarning, match="^Got KeyError.* when unwrapping"
):
with pytest.raises(KeyError):
inspect.unwrap(bad_instance, stop=stop)
inspect.unwrap(bad_instance, stop=stop) # type: ignore[arg-type] # noqa: F821
assert inspect.unwrap.__module__ == "inspect"

View File

@ -1,16 +1,22 @@
import os
import platform
from datetime import datetime
from typing import cast
from typing import List
from typing import Tuple
from xml.dom import minidom
import py
import xmlschema
import pytest
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.junitxml import bin_xml_escape
from _pytest.junitxml import LogXML
from _pytest.pathlib import Path
from _pytest.reports import BaseReport
from _pytest.reports import TestReport
from _pytest.store import Store
@ -860,10 +866,13 @@ def test_mangle_test_address():
assert newnames == ["a.my.py.thing", "Class", "method", "[a-1-::]"]
def test_dont_configure_on_slaves(tmpdir):
gotten = []
def test_dont_configure_on_slaves(tmpdir) -> None:
gotten = [] # type: List[object]
class FakeConfig:
if TYPE_CHECKING:
slaveinput = None
def __init__(self):
self.pluginmanager = self
self.option = self
@ -877,7 +886,7 @@ def test_dont_configure_on_slaves(tmpdir):
xmlpath = str(tmpdir.join("junix.xml"))
register = gotten.append
fake_config = FakeConfig()
fake_config = cast(Config, FakeConfig())
from _pytest import junitxml
junitxml.pytest_configure(fake_config)
@ -1089,18 +1098,18 @@ def test_double_colon_split_method_issue469(testdir, run_and_parse):
node.assert_attr(name="test_func[double::colon]")
def test_unicode_issue368(testdir):
def test_unicode_issue368(testdir) -> None:
path = testdir.tmpdir.join("test.xml")
log = LogXML(str(path), None)
ustr = "ВНИ!"
class Report(BaseReport):
longrepr = ustr
sections = []
sections = [] # type: List[Tuple[str, str]]
nodeid = "something"
location = "tests/filename.py", 42, "TestClass.method"
test_report = Report()
test_report = cast(TestReport, Report())
# hopefully this is not too brittle ...
log.pytest_sessionstart()
@ -1113,7 +1122,7 @@ def test_unicode_issue368(testdir):
node_reporter.append_skipped(test_report)
test_report.longrepr = "filename", 1, "Skipped: 卡嘣嘣"
node_reporter.append_skipped(test_report)
test_report.wasxfail = ustr
test_report.wasxfail = ustr # type: ignore[attr-defined] # noqa: F821
node_reporter.append_skipped(test_report)
log.pytest_sessionfinish()
@ -1363,17 +1372,17 @@ def test_fancy_items_regression(testdir, run_and_parse):
@parametrize_families
def test_global_properties(testdir, xunit_family):
def test_global_properties(testdir, xunit_family) -> None:
path = testdir.tmpdir.join("test_global_properties.xml")
log = LogXML(str(path), None, family=xunit_family)
class Report(BaseReport):
sections = []
sections = [] # type: List[Tuple[str, str]]
nodeid = "test_node_id"
log.pytest_sessionstart()
log.add_global_property("foo", 1)
log.add_global_property("bar", 2)
log.add_global_property("foo", "1")
log.add_global_property("bar", "2")
log.pytest_sessionfinish()
dom = minidom.parse(str(path))
@ -1397,19 +1406,19 @@ def test_global_properties(testdir, xunit_family):
assert actual == expected
def test_url_property(testdir):
def test_url_property(testdir) -> None:
test_url = "http://www.github.com/pytest-dev"
path = testdir.tmpdir.join("test_url_property.xml")
log = LogXML(str(path), None)
class Report(BaseReport):
longrepr = "FooBarBaz"
sections = []
sections = [] # type: List[Tuple[str, str]]
nodeid = "something"
location = "tests/filename.py", 42, "TestClass.method"
url = test_url
test_report = Report()
test_report = cast(TestReport, Report())
log.pytest_sessionstart()
node_reporter = log._opentestcase(test_report)

View File

@ -13,14 +13,14 @@ from _pytest.nodes import Node
class TestMark:
@pytest.mark.parametrize("attr", ["mark", "param"])
@pytest.mark.parametrize("modulename", ["py.test", "pytest"])
def test_pytest_exists_in_namespace_all(self, attr, modulename):
def test_pytest_exists_in_namespace_all(self, attr: str, modulename: str) -> None:
module = sys.modules[modulename]
assert attr in module.__all__
assert attr in module.__all__ # type: ignore
def test_pytest_mark_notcallable(self):
def test_pytest_mark_notcallable(self) -> None:
mark = Mark()
with pytest.raises(TypeError):
mark()
mark() # type: ignore[operator] # noqa: F821
def test_mark_with_param(self):
def some_function(abc):
@ -30,10 +30,11 @@ class TestMark:
pass
assert pytest.mark.foo(some_function) is some_function
assert pytest.mark.foo.with_args(some_function) is not some_function
marked_with_args = pytest.mark.foo.with_args(some_function)
assert marked_with_args is not some_function # type: ignore[comparison-overlap] # noqa: F821
assert pytest.mark.foo(SomeClass) is SomeClass
assert pytest.mark.foo.with_args(SomeClass) is not SomeClass
assert pytest.mark.foo.with_args(SomeClass) is not SomeClass # type: ignore[comparison-overlap] # noqa: F821
def test_pytest_mark_name_starts_with_underscore(self):
mark = Mark()
@ -1044,9 +1045,9 @@ def test_markers_from_parametrize(testdir):
result.assert_outcomes(passed=4)
def test_pytest_param_id_requires_string():
def test_pytest_param_id_requires_string() -> None:
with pytest.raises(TypeError) as excinfo:
pytest.param(id=True)
pytest.param(id=True) # type: ignore[arg-type] # noqa: F821
(msg,) = excinfo.value.args
assert msg == "Expected id to be a string, got <class 'bool'>: True"

View File

@ -2,6 +2,8 @@ import os
import re
import sys
import textwrap
from typing import Dict
from typing import Generator
import pytest
from _pytest.compat import TYPE_CHECKING
@ -12,7 +14,7 @@ if TYPE_CHECKING:
@pytest.fixture
def mp():
def mp() -> Generator[MonkeyPatch, None, None]:
cwd = os.getcwd()
sys_path = list(sys.path)
yield MonkeyPatch()
@ -20,14 +22,14 @@ def mp():
os.chdir(cwd)
def test_setattr():
def test_setattr() -> None:
class A:
x = 1
monkeypatch = MonkeyPatch()
pytest.raises(AttributeError, monkeypatch.setattr, A, "notexists", 2)
monkeypatch.setattr(A, "y", 2, raising=False)
assert A.y == 2
assert A.y == 2 # type: ignore
monkeypatch.undo()
assert not hasattr(A, "y")
@ -49,17 +51,17 @@ class TestSetattrWithImportPath:
monkeypatch.setattr("os.path.abspath", lambda x: "hello2")
assert os.path.abspath("123") == "hello2"
def test_string_expression_class(self, monkeypatch):
def test_string_expression_class(self, monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr("_pytest.config.Config", 42)
import _pytest
assert _pytest.config.Config == 42
assert _pytest.config.Config == 42 # type: ignore
def test_unicode_string(self, monkeypatch):
def test_unicode_string(self, monkeypatch: MonkeyPatch) -> None:
monkeypatch.setattr("_pytest.config.Config", 42)
import _pytest
assert _pytest.config.Config == 42
assert _pytest.config.Config == 42 # type: ignore
monkeypatch.delattr("_pytest.config.Config")
def test_wrong_target(self, monkeypatch):
@ -73,10 +75,10 @@ class TestSetattrWithImportPath:
AttributeError, lambda: monkeypatch.setattr("os.path.qweqwe", None)
)
def test_unknown_attr_non_raising(self, monkeypatch):
def test_unknown_attr_non_raising(self, monkeypatch: MonkeyPatch) -> None:
# https://github.com/pytest-dev/pytest/issues/746
monkeypatch.setattr("os.path.qweqwe", 42, raising=False)
assert os.path.qweqwe == 42
assert os.path.qweqwe == 42 # type: ignore
def test_delattr(self, monkeypatch):
monkeypatch.delattr("os.path.abspath")
@ -123,8 +125,8 @@ def test_setitem():
assert d["x"] == 5
def test_setitem_deleted_meanwhile():
d = {}
def test_setitem_deleted_meanwhile() -> None:
d = {} # type: Dict[str, object]
monkeypatch = MonkeyPatch()
monkeypatch.setitem(d, "x", 2)
del d["x"]
@ -148,8 +150,8 @@ def test_setenv_deleted_meanwhile(before):
assert key not in os.environ
def test_delitem():
d = {"x": 1}
def test_delitem() -> None:
d = {"x": 1} # type: Dict[str, object]
monkeypatch = MonkeyPatch()
monkeypatch.delitem(d, "x")
assert "x" not in d
@ -241,7 +243,7 @@ def test_monkeypatch_plugin(testdir):
assert tuple(res) == (1, 0, 0), res
def test_syspath_prepend(mp):
def test_syspath_prepend(mp: MonkeyPatch):
old = list(sys.path)
mp.syspath_prepend("world")
mp.syspath_prepend("hello")
@ -253,7 +255,7 @@ def test_syspath_prepend(mp):
assert sys.path == old
def test_syspath_prepend_double_undo(mp):
def test_syspath_prepend_double_undo(mp: MonkeyPatch):
old_syspath = sys.path[:]
try:
mp.syspath_prepend("hello world")
@ -265,24 +267,24 @@ def test_syspath_prepend_double_undo(mp):
sys.path[:] = old_syspath
def test_chdir_with_path_local(mp, tmpdir):
def test_chdir_with_path_local(mp: MonkeyPatch, tmpdir):
mp.chdir(tmpdir)
assert os.getcwd() == tmpdir.strpath
def test_chdir_with_str(mp, tmpdir):
def test_chdir_with_str(mp: MonkeyPatch, tmpdir):
mp.chdir(tmpdir.strpath)
assert os.getcwd() == tmpdir.strpath
def test_chdir_undo(mp, tmpdir):
def test_chdir_undo(mp: MonkeyPatch, tmpdir):
cwd = os.getcwd()
mp.chdir(tmpdir)
mp.undo()
assert os.getcwd() == cwd
def test_chdir_double_undo(mp, tmpdir):
def test_chdir_double_undo(mp: MonkeyPatch, tmpdir):
mp.chdir(tmpdir.strpath)
mp.undo()
tmpdir.chdir()

View File

@ -2,6 +2,7 @@ import py
import pytest
from _pytest import nodes
from _pytest.pytester import Testdir
@pytest.mark.parametrize(
@ -17,19 +18,19 @@ from _pytest import nodes
("foo/bar", "foo/bar::TestBop", True),
),
)
def test_ischildnode(baseid, nodeid, expected):
def test_ischildnode(baseid: str, nodeid: str, expected: bool) -> None:
result = nodes.ischildnode(baseid, nodeid)
assert result is expected
def test_node_from_parent_disallowed_arguments():
def test_node_from_parent_disallowed_arguments() -> None:
with pytest.raises(TypeError, match="session is"):
nodes.Node.from_parent(None, session=None)
nodes.Node.from_parent(None, session=None) # type: ignore[arg-type] # noqa: F821
with pytest.raises(TypeError, match="config is"):
nodes.Node.from_parent(None, config=None)
nodes.Node.from_parent(None, config=None) # type: ignore[arg-type] # noqa: F821
def test_std_warn_not_pytestwarning(testdir):
def test_std_warn_not_pytestwarning(testdir: Testdir) -> None:
items = testdir.getitems(
"""
def test():
@ -40,24 +41,24 @@ def test_std_warn_not_pytestwarning(testdir):
items[0].warn(UserWarning("some warning"))
def test__check_initialpaths_for_relpath():
def test__check_initialpaths_for_relpath() -> None:
"""Ensure that it handles dirs, and does not always use dirname."""
cwd = py.path.local()
class FakeSession:
class FakeSession1:
_initialpaths = [cwd]
assert nodes._check_initialpaths_for_relpath(FakeSession, cwd) == ""
assert nodes._check_initialpaths_for_relpath(FakeSession1, cwd) == ""
sub = cwd.join("file")
class FakeSession:
class FakeSession2:
_initialpaths = [cwd]
assert nodes._check_initialpaths_for_relpath(FakeSession, sub) == "file"
assert nodes._check_initialpaths_for_relpath(FakeSession2, sub) == "file"
outside = py.path.local("/outside")
assert nodes._check_initialpaths_for_relpath(FakeSession, outside) is None
assert nodes._check_initialpaths_for_relpath(FakeSession2, outside) is None
def test_failure_with_changed_cwd(testdir):

View File

@ -1,10 +1,13 @@
from typing import List
from typing import Union
import pytest
class TestPasteCapture:
@pytest.fixture
def pastebinlist(self, monkeypatch, request):
pastebinlist = []
def pastebinlist(self, monkeypatch, request) -> List[Union[str, bytes]]:
pastebinlist = [] # type: List[Union[str, bytes]]
plugin = request.config.pluginmanager.getplugin("pastebin")
monkeypatch.setattr(plugin, "create_new_paste", pastebinlist.append)
return pastebinlist

View File

@ -1,6 +1,7 @@
import os
import sys
import types
from typing import List
import pytest
from _pytest.config import ExitCode
@ -10,7 +11,7 @@ from _pytest.main import Session
@pytest.fixture
def pytestpm():
def pytestpm() -> PytestPluginManager:
return PytestPluginManager()
@ -86,7 +87,7 @@ class TestPytestPluginInteractions:
config.pluginmanager.register(A())
assert len(values) == 2
def test_hook_tracing(self, _config_for_test):
def test_hook_tracing(self, _config_for_test) -> None:
pytestpm = _config_for_test.pluginmanager # fully initialized with plugins
saveindent = []
@ -99,7 +100,7 @@ class TestPytestPluginInteractions:
saveindent.append(pytestpm.trace.root.indent)
raise ValueError()
values = []
values = [] # type: List[str]
pytestpm.trace.root.setwriter(values.append)
undo = pytestpm.enable_tracing()
try:
@ -215,20 +216,20 @@ class TestPytestPluginManager:
assert pm.get_plugin("pytest_xyz") == mod
assert pm.is_registered(mod)
def test_consider_module(self, testdir, pytestpm):
def test_consider_module(self, testdir, pytestpm: PytestPluginManager) -> None:
testdir.syspathinsert()
testdir.makepyfile(pytest_p1="#")
testdir.makepyfile(pytest_p2="#")
mod = types.ModuleType("temp")
mod.pytest_plugins = ["pytest_p1", "pytest_p2"]
mod.__dict__["pytest_plugins"] = ["pytest_p1", "pytest_p2"]
pytestpm.consider_module(mod)
assert pytestpm.get_plugin("pytest_p1").__name__ == "pytest_p1"
assert pytestpm.get_plugin("pytest_p2").__name__ == "pytest_p2"
def test_consider_module_import_module(self, testdir, _config_for_test):
def test_consider_module_import_module(self, testdir, _config_for_test) -> None:
pytestpm = _config_for_test.pluginmanager
mod = types.ModuleType("x")
mod.pytest_plugins = "pytest_a"
mod.__dict__["pytest_plugins"] = "pytest_a"
aplugin = testdir.makepyfile(pytest_a="#")
reprec = testdir.make_hook_recorder(pytestpm)
testdir.syspathinsert(aplugin.dirpath())

View File

@ -32,7 +32,7 @@ class TestReportSerialization:
assert test_b_call.outcome == "passed"
assert test_b_call._to_json()["longrepr"] is None
def test_xdist_report_longrepr_reprcrash_130(self, testdir):
def test_xdist_report_longrepr_reprcrash_130(self, testdir) -> None:
"""Regarding issue pytest-xdist#130
This test came originally from test_remote.py in xdist (ca03269).
@ -50,6 +50,7 @@ class TestReportSerialization:
rep.longrepr.sections.append(added_section)
d = rep._to_json()
a = TestReport._from_json(d)
assert a.longrepr is not None
# Check assembled == rep
assert a.__dict__.keys() == rep.__dict__.keys()
for key in rep.__dict__.keys():
@ -67,7 +68,7 @@ class TestReportSerialization:
# Missing section attribute PR171
assert added_section in a.longrepr.sections
def test_reprentries_serialization_170(self, testdir):
def test_reprentries_serialization_170(self, testdir) -> None:
"""Regarding issue pytest-xdist#170
This test came originally from test_remote.py in xdist (ca03269).
@ -87,6 +88,7 @@ class TestReportSerialization:
rep = reports[1]
d = rep._to_json()
a = TestReport._from_json(d)
assert a.longrepr is not None
rep_entries = rep.longrepr.reprtraceback.reprentries
a_entries = a.longrepr.reprtraceback.reprentries
@ -102,7 +104,7 @@ class TestReportSerialization:
assert rep_entries[i].reprlocals.lines == a_entries[i].reprlocals.lines
assert rep_entries[i].style == a_entries[i].style
def test_reprentries_serialization_196(self, testdir):
def test_reprentries_serialization_196(self, testdir) -> None:
"""Regarding issue pytest-xdist#196
This test came originally from test_remote.py in xdist (ca03269).
@ -122,6 +124,7 @@ class TestReportSerialization:
rep = reports[1]
d = rep._to_json()
a = TestReport._from_json(d)
assert a.longrepr is not None
rep_entries = rep.longrepr.reprtraceback.reprentries
a_entries = a.longrepr.reprtraceback.reprentries
@ -157,6 +160,7 @@ class TestReportSerialization:
assert newrep.failed == rep.failed
assert newrep.skipped == rep.skipped
if newrep.skipped and not hasattr(newrep, "wasxfail"):
assert newrep.longrepr is not None
assert len(newrep.longrepr) == 3
assert newrep.outcome == rep.outcome
assert newrep.when == rep.when
@ -316,7 +320,7 @@ class TestReportSerialization:
# elsewhere and we do check the contents of the longrepr object after loading it.
loaded_report.longrepr.toterminal(tw_mock)
def test_chained_exceptions_no_reprcrash(self, testdir, tw_mock):
def test_chained_exceptions_no_reprcrash(self, testdir, tw_mock) -> None:
"""Regression test for tracebacks without a reprcrash (#5971)
This happens notably on exceptions raised by multiprocess.pool: the exception transfer
@ -367,7 +371,7 @@ class TestReportSerialization:
reports = reprec.getreports("pytest_runtest_logreport")
def check_longrepr(longrepr):
def check_longrepr(longrepr) -> None:
assert isinstance(longrepr, ExceptionChainRepr)
assert len(longrepr.chain) == 2
entry1, entry2 = longrepr.chain
@ -378,6 +382,7 @@ class TestReportSerialization:
assert "ValueError: value error" in str(tb2)
assert fileloc1 is None
assert fileloc2 is not None
assert fileloc2.message == "ValueError: value error"
# 3 reports: setup/call/teardown: get the call report
@ -394,6 +399,7 @@ class TestReportSerialization:
check_longrepr(loaded_report.longrepr)
# for same reasons as previous test, ensure we don't blow up here
assert loaded_report.longrepr is not None
loaded_report.longrepr.toterminal(tw_mock)
def test_report_prevent_ConftestImportFailure_hiding_exception(self, testdir):

View File

@ -465,27 +465,27 @@ def test_report_extra_parameters(reporttype: "Type[reports.BaseReport]") -> None
def test_callinfo() -> None:
ci = runner.CallInfo.from_call(lambda: 0, "123")
assert ci.when == "123"
ci = runner.CallInfo.from_call(lambda: 0, "collect")
assert ci.when == "collect"
assert ci.result == 0
assert "result" in repr(ci)
assert repr(ci) == "<CallInfo when='123' result: 0>"
assert str(ci) == "<CallInfo when='123' result: 0>"
assert repr(ci) == "<CallInfo when='collect' result: 0>"
assert str(ci) == "<CallInfo when='collect' result: 0>"
ci = runner.CallInfo.from_call(lambda: 0 / 0, "123")
assert ci.when == "123"
assert not hasattr(ci, "result")
assert repr(ci) == "<CallInfo when='123' excinfo={!r}>".format(ci.excinfo)
assert str(ci) == repr(ci)
assert ci.excinfo
ci2 = runner.CallInfo.from_call(lambda: 0 / 0, "collect")
assert ci2.when == "collect"
assert not hasattr(ci2, "result")
assert repr(ci2) == "<CallInfo when='collect' excinfo={!r}>".format(ci2.excinfo)
assert str(ci2) == repr(ci2)
assert ci2.excinfo
# Newlines are escaped.
def raise_assertion():
assert 0, "assert_msg"
ci = runner.CallInfo.from_call(raise_assertion, "call")
assert repr(ci) == "<CallInfo when='call' excinfo={!r}>".format(ci.excinfo)
assert "\n" not in repr(ci)
ci3 = runner.CallInfo.from_call(raise_assertion, "call")
assert repr(ci3) == "<CallInfo when='call' excinfo={!r}>".format(ci3.excinfo)
assert "\n" not in repr(ci3)
# design question: do we want general hooks in python files?
@ -884,7 +884,7 @@ def test_store_except_info_on_error() -> None:
raise IndexError("TEST")
try:
runner.pytest_runtest_call(ItemMightRaise())
runner.pytest_runtest_call(ItemMightRaise()) # type: ignore[arg-type] # noqa: F821
except IndexError:
pass
# Check that exception info is stored on sys
@ -895,7 +895,7 @@ def test_store_except_info_on_error() -> None:
# The next run should clear the exception info stored by the previous run
ItemMightRaise.raise_error = False
runner.pytest_runtest_call(ItemMightRaise())
runner.pytest_runtest_call(ItemMightRaise()) # type: ignore[arg-type] # noqa: F821
assert not hasattr(sys, "last_type")
assert not hasattr(sys, "last_value")
assert not hasattr(sys, "last_traceback")

View File

@ -2,6 +2,8 @@
test correct setup/teardowns at
module, class, and instance level
"""
from typing import List
import pytest
@ -242,12 +244,12 @@ def test_setup_funcarg_setup_when_outer_scope_fails(testdir):
@pytest.mark.parametrize("arg", ["", "arg"])
def test_setup_teardown_function_level_with_optional_argument(
testdir, monkeypatch, arg
):
testdir, monkeypatch, arg: str,
) -> None:
"""parameter to setup/teardown xunit-style functions parameter is now optional (#1728)."""
import sys
trace_setups_teardowns = []
trace_setups_teardowns = [] # type: List[str]
monkeypatch.setattr(
sys, "trace_setups_teardowns", trace_setups_teardowns, raising=False
)

View File

@ -98,7 +98,7 @@ class TestEvaluator:
expl = ev.getexplanation()
assert expl == "condition: not hasattr(os, 'murks')"
def test_marked_skip_with_not_string(self, testdir):
def test_marked_skip_with_not_string(self, testdir) -> None:
item = testdir.getitem(
"""
import pytest
@ -109,6 +109,7 @@ class TestEvaluator:
)
ev = MarkEvaluator(item, "skipif")
exc = pytest.raises(pytest.fail.Exception, ev.istrue)
assert exc.value.msg is not None
assert (
"""Failed: you need to specify reason=STRING when using booleans as conditions."""
in exc.value.msg
@ -869,7 +870,7 @@ def test_reportchars_all_error(testdir):
result.stdout.fnmatch_lines(["ERROR*test_foo*"])
def test_errors_in_xfail_skip_expressions(testdir):
def test_errors_in_xfail_skip_expressions(testdir) -> None:
testdir.makepyfile(
"""
import pytest
@ -886,7 +887,8 @@ def test_errors_in_xfail_skip_expressions(testdir):
)
result = testdir.runpytest()
markline = " ^"
if hasattr(sys, "pypy_version_info") and sys.pypy_version_info < (6,):
pypy_version_info = getattr(sys, "pypy_version_info", None)
if pypy_version_info is not None and pypy_version_info < (6,):
markline = markline[5:]
elif sys.version_info >= (3, 8) or hasattr(sys, "pypy_version_info"):
markline = markline[4:]

View File

@ -6,6 +6,7 @@ import os
import sys
import textwrap
from io import StringIO
from typing import cast
from typing import Dict
from typing import List
from typing import Tuple
@ -17,9 +18,11 @@ import _pytest.config
import _pytest.terminal
import pytest
from _pytest._io.wcwidth import wcswidth
from _pytest.config import Config
from _pytest.config import ExitCode
from _pytest.pytester import Testdir
from _pytest.reports import BaseReport
from _pytest.reports import CollectReport
from _pytest.terminal import _folded_skips
from _pytest.terminal import _get_line_with_reprcrash_message
from _pytest.terminal import _plugin_nameversions
@ -1043,17 +1046,17 @@ def test_color_yes_collection_on_non_atty(testdir, verbose):
assert "collected 10 items" in result.stdout.str()
def test_getreportopt():
def test_getreportopt() -> None:
from _pytest.terminal import _REPORTCHARS_DEFAULT
class Config:
class FakeConfig:
class Option:
reportchars = _REPORTCHARS_DEFAULT
disable_warnings = False
option = Option()
config = Config()
config = cast(Config, FakeConfig())
assert _REPORTCHARS_DEFAULT == "fE"
@ -1994,7 +1997,7 @@ class TestProgressWithTeardown:
output.stdout.re_match_lines([r"[\.E]{40} \s+ \[100%\]"])
def test_skip_reasons_folding():
def test_skip_reasons_folding() -> None:
path = "xyz"
lineno = 3
message = "justso"
@ -2003,28 +2006,28 @@ def test_skip_reasons_folding():
class X:
pass
ev1 = X()
ev1 = cast(CollectReport, X())
ev1.when = "execute"
ev1.skipped = True
ev1.longrepr = longrepr
ev2 = X()
ev2 = cast(CollectReport, X())
ev2.when = "execute"
ev2.longrepr = longrepr
ev2.skipped = True
# ev3 might be a collection report
ev3 = X()
ev3 = cast(CollectReport, X())
ev3.when = "collect"
ev3.longrepr = longrepr
ev3.skipped = True
values = _folded_skips(py.path.local(), [ev1, ev2, ev3])
assert len(values) == 1
num, fspath, lineno, reason = values[0]
num, fspath, lineno_, reason = values[0]
assert num == 3
assert fspath == path
assert lineno == lineno
assert lineno_ == lineno
assert reason == message
@ -2052,8 +2055,8 @@ def test_line_with_reprcrash(monkeypatch):
def check(msg, width, expected):
__tracebackhide__ = True
if msg:
rep.longrepr.reprcrash.message = msg
actual = _get_line_with_reprcrash_message(config, rep(), width)
rep.longrepr.reprcrash.message = msg # type: ignore
actual = _get_line_with_reprcrash_message(config, rep(), width) # type: ignore
assert actual == expected
if actual != "{} {}".format(mocked_verbose_word, mocked_pos):

View File

@ -1,6 +1,8 @@
import os
import stat
import sys
from typing import Callable
from typing import List
import attr
@ -263,10 +265,10 @@ class TestNumberedDir:
lockfile.unlink()
def test_lock_register_cleanup_removal(self, tmp_path):
def test_lock_register_cleanup_removal(self, tmp_path: Path) -> None:
lock = create_cleanup_lock(tmp_path)
registry = []
registry = [] # type: List[Callable[..., None]]
register_cleanup_lock_removal(lock, register=registry.append)
(cleanup_func,) = registry
@ -285,7 +287,7 @@ class TestNumberedDir:
assert not lock.exists()
def _do_cleanup(self, tmp_path):
def _do_cleanup(self, tmp_path: Path) -> None:
self.test_make(tmp_path)
cleanup_numbered_dir(
root=tmp_path,
@ -367,7 +369,7 @@ class TestRmRf:
assert not adir.is_dir()
def test_on_rm_rf_error(self, tmp_path):
def test_on_rm_rf_error(self, tmp_path: Path) -> None:
adir = tmp_path / "dir"
adir.mkdir()
@ -377,32 +379,32 @@ class TestRmRf:
# unknown exception
with pytest.warns(pytest.PytestWarning):
exc_info = (None, RuntimeError(), None)
on_rm_rf_error(os.unlink, str(fn), exc_info, start_path=tmp_path)
exc_info1 = (None, RuntimeError(), None)
on_rm_rf_error(os.unlink, str(fn), exc_info1, start_path=tmp_path)
assert fn.is_file()
# we ignore FileNotFoundError
exc_info = (None, FileNotFoundError(), None)
assert not on_rm_rf_error(None, str(fn), exc_info, start_path=tmp_path)
exc_info2 = (None, FileNotFoundError(), None)
assert not on_rm_rf_error(None, str(fn), exc_info2, start_path=tmp_path)
# unknown function
with pytest.warns(
pytest.PytestWarning,
match=r"^\(rm_rf\) unknown function None when removing .*foo.txt:\nNone: ",
):
exc_info = (None, PermissionError(), None)
on_rm_rf_error(None, str(fn), exc_info, start_path=tmp_path)
exc_info3 = (None, PermissionError(), None)
on_rm_rf_error(None, str(fn), exc_info3, start_path=tmp_path)
assert fn.is_file()
# ignored function
with pytest.warns(None) as warninfo:
exc_info = (None, PermissionError(), None)
on_rm_rf_error(os.open, str(fn), exc_info, start_path=tmp_path)
exc_info4 = (None, PermissionError(), None)
on_rm_rf_error(os.open, str(fn), exc_info4, start_path=tmp_path)
assert fn.is_file()
assert not [x.message for x in warninfo]
exc_info = (None, PermissionError(), None)
on_rm_rf_error(os.unlink, str(fn), exc_info, start_path=tmp_path)
exc_info5 = (None, PermissionError(), None)
on_rm_rf_error(os.unlink, str(fn), exc_info5, start_path=tmp_path)
assert not fn.is_file()

View File

@ -1,4 +1,5 @@
import gc
from typing import List
import pytest
from _pytest.config import ExitCode
@ -1158,13 +1159,13 @@ def test_trace(testdir, monkeypatch):
assert result.ret == 0
def test_pdb_teardown_called(testdir, monkeypatch):
def test_pdb_teardown_called(testdir, monkeypatch) -> None:
"""Ensure tearDown() is always called when --pdb is given in the command-line.
We delay the normal tearDown() calls when --pdb is given, so this ensures we are calling
tearDown() eventually to avoid memory leaks when using --pdb.
"""
teardowns = []
teardowns = [] # type: List[str]
monkeypatch.setattr(
pytest, "test_pdb_teardown_called_teardowns", teardowns, raising=False
)
@ -1194,11 +1195,11 @@ def test_pdb_teardown_called(testdir, monkeypatch):
@pytest.mark.parametrize("mark", ["@unittest.skip", "@pytest.mark.skip"])
def test_pdb_teardown_skipped(testdir, monkeypatch, mark):
def test_pdb_teardown_skipped(testdir, monkeypatch, mark: str) -> None:
"""
With --pdb, setUp and tearDown should not be called for skipped tests.
"""
tracked = []
tracked = [] # type: List[str]
monkeypatch.setattr(pytest, "test_pdb_teardown_skipped", tracked, raising=False)
testdir.makepyfile(

View File

@ -1,5 +1,8 @@
import os
import warnings
from typing import List
from typing import Optional
from typing import Tuple
import pytest
from _pytest.fixtures import FixtureRequest
@ -661,7 +664,9 @@ class TestStackLevel:
@pytest.fixture
def capwarn(self, testdir):
class CapturedWarnings:
captured = []
captured = (
[]
) # type: List[Tuple[warnings.WarningMessage, Optional[Tuple[str, int, str]]]]
@classmethod
def pytest_warning_recorded(cls, warning_message, when, nodeid, location):