Type annotate main.py and some parts related to collection

This commit is contained in:
Ran Benita
2020-05-01 14:40:15 +03:00
parent f8de424241
commit be00e12d47
9 changed files with 175 additions and 75 deletions

View File

@@ -7,9 +7,11 @@ import sys
from typing import Callable
from typing import Dict
from typing import FrozenSet
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Union
@@ -18,12 +20,14 @@ import py
import _pytest._code
from _pytest import nodes
from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from _pytest.config import Config
from _pytest.config import directory_arg
from _pytest.config import ExitCode
from _pytest.config import hookimpl
from _pytest.config import UsageError
from _pytest.config.argparsing import Parser
from _pytest.fixtures import FixtureManager
from _pytest.outcomes import exit
from _pytest.reports import CollectReport
@@ -38,7 +42,7 @@ if TYPE_CHECKING:
from _pytest.python import Package
def pytest_addoption(parser):
def pytest_addoption(parser: Parser) -> None:
parser.addini(
"norecursedirs",
"directory patterns to avoid for recursion",
@@ -241,7 +245,7 @@ def wrap_session(
return session.exitstatus
def pytest_cmdline_main(config):
def pytest_cmdline_main(config: Config) -> Union[int, ExitCode]:
return wrap_session(config, _main)
@@ -258,11 +262,11 @@ def _main(config: Config, session: "Session") -> Optional[Union[int, ExitCode]]:
return None
def pytest_collection(session):
def pytest_collection(session: "Session") -> Sequence[nodes.Item]:
return session.perform_collect()
def pytest_runtestloop(session):
def pytest_runtestloop(session: "Session") -> bool:
if session.testsfailed and not session.config.option.continue_on_collection_errors:
raise session.Interrupted(
"%d error%s during collection"
@@ -282,7 +286,7 @@ def pytest_runtestloop(session):
return True
def _in_venv(path):
def _in_venv(path: py.path.local) -> bool:
"""Attempts to detect if ``path`` is the root of a Virtual Environment by
checking for the existence of the appropriate activate script"""
bindir = path.join("Scripts" if sys.platform.startswith("win") else "bin")
@@ -328,7 +332,7 @@ def pytest_ignore_collect(
return None
def pytest_collection_modifyitems(items, config):
def pytest_collection_modifyitems(items, config: Config) -> None:
deselect_prefixes = tuple(config.getoption("deselect") or [])
if not deselect_prefixes:
return
@@ -385,8 +389,8 @@ class Session(nodes.FSCollector):
)
self.testsfailed = 0
self.testscollected = 0
self.shouldstop = False
self.shouldfail = False
self.shouldstop = False # type: Union[bool, str]
self.shouldfail = False # type: Union[bool, str]
self.trace = config.trace.root.get("collection")
self.startdir = config.invocation_dir
self._initialpaths = frozenset() # type: FrozenSet[py.path.local]
@@ -412,10 +416,11 @@ class Session(nodes.FSCollector):
self.config.pluginmanager.register(self, name="session")
@classmethod
def from_config(cls, config):
return cls._create(config)
def from_config(cls, config: Config) -> "Session":
session = cls._create(config) # type: Session
return session
def __repr__(self):
def __repr__(self) -> str:
return "<%s %s exitstatus=%r testsfailed=%d testscollected=%d>" % (
self.__class__.__name__,
self.name,
@@ -429,14 +434,14 @@ class Session(nodes.FSCollector):
return self._bestrelpathcache[node_path]
@hookimpl(tryfirst=True)
def pytest_collectstart(self):
def pytest_collectstart(self) -> None:
if self.shouldfail:
raise self.Failed(self.shouldfail)
if self.shouldstop:
raise self.Interrupted(self.shouldstop)
@hookimpl(tryfirst=True)
def pytest_runtest_logreport(self, report):
def pytest_runtest_logreport(self, report) -> None:
if report.failed and not hasattr(report, "wasxfail"):
self.testsfailed += 1
maxfail = self.config.getvalue("maxfail")
@@ -445,13 +450,27 @@ class Session(nodes.FSCollector):
pytest_collectreport = pytest_runtest_logreport
def isinitpath(self, path):
def isinitpath(self, path: py.path.local) -> bool:
return path in self._initialpaths
def gethookproxy(self, fspath: py.path.local):
return super()._gethookproxy(fspath)
def perform_collect(self, args=None, genitems=True):
@overload
def perform_collect(
self, args: Optional[Sequence[str]] = ..., genitems: "Literal[True]" = ...
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = ..., genitems: bool = ...
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def perform_collect( # noqa: F811
self, args: Optional[Sequence[str]] = None, genitems: bool = True
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
hook = self.config.hook
try:
items = self._perform_collect(args, genitems)
@@ -464,15 +483,29 @@ class Session(nodes.FSCollector):
self.testscollected = len(items)
return items
def _perform_collect(self, args, genitems):
@overload
def _perform_collect(
self, args: Optional[Sequence[str]], genitems: "Literal[True]"
) -> Sequence[nodes.Item]:
raise NotImplementedError()
@overload # noqa: F811
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
raise NotImplementedError()
def _perform_collect( # noqa: F811
self, args: Optional[Sequence[str]], genitems: bool
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if args is None:
args = self.config.args
self.trace("perform_collect", self, args)
self.trace.root.indent += 1
self._notfound = []
self._notfound = [] # type: List[Tuple[str, NoMatch]]
initialpaths = [] # type: List[py.path.local]
self._initial_parts = [] # type: List[Tuple[py.path.local, List[str]]]
self.items = items = []
self.items = items = [] # type: List[nodes.Item]
for arg in args:
fspath, parts = self._parsearg(arg)
self._initial_parts.append((fspath, parts))
@@ -495,7 +528,7 @@ class Session(nodes.FSCollector):
self.items.extend(self.genitems(node))
return items
def collect(self):
def collect(self) -> Iterator[Union[nodes.Item, nodes.Collector]]:
for fspath, parts in self._initial_parts:
self.trace("processing argument", (fspath, parts))
self.trace.root.indent += 1
@@ -513,7 +546,9 @@ class Session(nodes.FSCollector):
self._collection_node_cache3.clear()
self._collection_pkg_roots.clear()
def _collect(self, argpath, names):
def _collect(
self, argpath: py.path.local, names: List[str]
) -> Iterator[Union[nodes.Item, nodes.Collector]]:
from _pytest.python import Package
# Start with a Session root, and delve to argpath item (dir or file)
@@ -541,7 +576,7 @@ class Session(nodes.FSCollector):
if argpath.check(dir=1):
assert not names, "invalid arg {!r}".format((argpath, names))
seen_dirs = set()
seen_dirs = set() # type: Set[py.path.local]
for path in argpath.visit(
fil=self._visit_filter, rec=self._recurse, bf=True, sort=True
):
@@ -582,8 +617,9 @@ class Session(nodes.FSCollector):
# Module itself, so just use that. If this special case isn't taken, then all
# the files in the package will be yielded.
if argpath.basename == "__init__.py":
assert isinstance(m[0], nodes.Collector)
try:
yield next(m[0].collect())
yield next(iter(m[0].collect()))
except StopIteration:
# The package collects nothing with only an __init__.py
# file in it, which gets ignored by the default
@@ -593,10 +629,11 @@ class Session(nodes.FSCollector):
yield from m
@staticmethod
def _visit_filter(f):
return f.check(file=1)
def _visit_filter(f: py.path.local) -> bool:
# TODO: Remove type: ignore once `py` is typed.
return f.check(file=1) # type: ignore
def _tryconvertpyarg(self, x):
def _tryconvertpyarg(self, x: str) -> str:
"""Convert a dotted module name to path."""
try:
spec = importlib.util.find_spec(x)
@@ -605,14 +642,14 @@ class Session(nodes.FSCollector):
# ValueError: not a module name
except (AttributeError, ImportError, ValueError):
return x
if spec is None or spec.origin in {None, "namespace"}:
if spec is None or spec.origin is None or spec.origin == "namespace":
return x
elif spec.submodule_search_locations:
return os.path.dirname(spec.origin)
else:
return spec.origin
def _parsearg(self, arg):
def _parsearg(self, arg: str) -> Tuple[py.path.local, List[str]]:
""" return (fspath, names) tuple after checking the file exists. """
strpath, *parts = str(arg).split("::")
if self.config.option.pyargs:
@@ -628,7 +665,9 @@ class Session(nodes.FSCollector):
fspath = fspath.realpath()
return (fspath, parts)
def matchnodes(self, matching, names):
def matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
self.trace("matchnodes", matching, names)
self.trace.root.indent += 1
nodes = self._matchnodes(matching, names)
@@ -639,13 +678,15 @@ class Session(nodes.FSCollector):
raise NoMatch(matching, names[:1])
return nodes
def _matchnodes(self, matching, names):
def _matchnodes(
self, matching: Sequence[Union[nodes.Item, nodes.Collector]], names: List[str],
) -> Sequence[Union[nodes.Item, nodes.Collector]]:
if not matching or not names:
return matching
name = names[0]
assert name
nextnames = names[1:]
resultnodes = []
resultnodes = [] # type: List[Union[nodes.Item, nodes.Collector]]
for node in matching:
if isinstance(node, nodes.Item):
if not names:
@@ -676,7 +717,9 @@ class Session(nodes.FSCollector):
node.ihook.pytest_collectreport(report=rep)
return resultnodes
def genitems(self, node):
def genitems(
self, node: Union[nodes.Item, nodes.Collector]
) -> Iterator[nodes.Item]:
self.trace("genitems", node)
if isinstance(node, nodes.Item):
node.ihook.pytest_itemcollected(item=node)