diff --git a/src/_pytest/mark/__init__.py b/src/_pytest/mark/__init__.py index 16e821aee..7bbea54d2 100644 --- a/src/_pytest/mark/__init__.py +++ b/src/_pytest/mark/__init__.py @@ -147,9 +147,9 @@ class KeywordMatcher: # Add the names of the current item and any parent items import pytest - for item in item.listchain(): - if not isinstance(item, (pytest.Instance, pytest.Session)): - mapped_names.add(item.name) + for node in item.listchain(): + if not isinstance(node, (pytest.Instance, pytest.Session)): + mapped_names.add(node.name) # Add the names added as extra keywords to current or parent items mapped_names.update(item.listextrakeywords()) diff --git a/src/_pytest/nodes.py b/src/_pytest/nodes.py index 010dce925..4fdf1df74 100644 --- a/src/_pytest/nodes.py +++ b/src/_pytest/nodes.py @@ -5,11 +5,13 @@ from typing import Any from typing import Callable from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Sequence from typing import Set from typing import Tuple +from typing import TypeVar from typing import Union import py @@ -20,6 +22,7 @@ from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionInfo from _pytest._code.code import ReprExceptionInfo from _pytest.compat import cached_property +from _pytest.compat import overload from _pytest.compat import TYPE_CHECKING from _pytest.config import Config from _pytest.config import ConftestImportFailure @@ -36,6 +39,8 @@ from _pytest.pathlib import Path from _pytest.store import Store if TYPE_CHECKING: + from typing import Type + # Imported here due to circular import. from _pytest.main import Session @@ -45,7 +50,7 @@ tracebackcutdir = py.path.local(_pytest.__file__).dirpath() @lru_cache(maxsize=None) -def _splitnode(nodeid): +def _splitnode(nodeid: str) -> Tuple[str, ...]: """Split a nodeid into constituent 'parts'. Node IDs are strings, and can be things like: @@ -70,7 +75,7 @@ def _splitnode(nodeid): return tuple(parts) -def ischildnode(baseid, nodeid): +def ischildnode(baseid: str, nodeid: str) -> bool: """Return True if the nodeid is a child node of the baseid. E.g. 'foo/bar::Baz' is a child of 'foo', 'foo/bar' and 'foo/bar::Baz', but not of 'foo/blorp' @@ -82,6 +87,9 @@ def ischildnode(baseid, nodeid): return node_parts[: len(base_parts)] == base_parts +_NodeType = TypeVar("_NodeType", bound="Node") + + class NodeMeta(type): def __call__(self, *k, **kw): warnings.warn(NODE_USE_FROM_PARENT.format(name=self.__name__), stacklevel=2) @@ -191,7 +199,7 @@ class Node(metaclass=NodeMeta): """ fspath sensitive hook proxy used to call pytest hooks""" return self.session.gethookproxy(self.fspath) - def __repr__(self): + def __repr__(self) -> str: return "<{} {}>".format(self.__class__.__name__, getattr(self, "name", None)) def warn(self, warning): @@ -232,16 +240,16 @@ class Node(metaclass=NodeMeta): """ a ::-separated string denoting its collection tree address. """ return self._nodeid - def __hash__(self): + def __hash__(self) -> int: return hash(self._nodeid) - def setup(self): + def setup(self) -> None: pass - def teardown(self): + def teardown(self) -> None: pass - def listchain(self): + def listchain(self) -> List["Node"]: """ return list of all parent collectors up to self, starting from root of collection tree. """ chain = [] @@ -276,7 +284,7 @@ class Node(metaclass=NodeMeta): else: self.own_markers.insert(0, marker_.mark) - def iter_markers(self, name=None): + def iter_markers(self, name: Optional[str] = None) -> Iterator[Mark]: """ :param name: if given, filter the results by the name attribute @@ -284,7 +292,9 @@ class Node(metaclass=NodeMeta): """ return (x[1] for x in self.iter_markers_with_node(name=name)) - def iter_markers_with_node(self, name=None): + def iter_markers_with_node( + self, name: Optional[str] = None + ) -> Iterator[Tuple["Node", Mark]]: """ :param name: if given, filter the results by the name attribute @@ -296,7 +306,17 @@ class Node(metaclass=NodeMeta): if name is None or getattr(mark, "name", None) == name: yield node, mark - def get_closest_marker(self, name, default=None): + @overload + def get_closest_marker(self, name: str) -> Optional[Mark]: + raise NotImplementedError() + + @overload # noqa: F811 + def get_closest_marker(self, name: str, default: Mark) -> Mark: # noqa: F811 + raise NotImplementedError() + + def get_closest_marker( # noqa: F811 + self, name: str, default: Optional[Mark] = None + ) -> Optional[Mark]: """return the first marker matching the name, from closest (for example function) to farther level (for example module level). @@ -305,14 +325,14 @@ class Node(metaclass=NodeMeta): """ return next(self.iter_markers(name=name), default) - def listextrakeywords(self): + def listextrakeywords(self) -> Set[str]: """ Return a set of all extra keywords in self and any parents.""" extra_keywords = set() # type: Set[str] for item in self.listchain(): extra_keywords.update(item.extra_keyword_matches) return extra_keywords - def listnames(self): + def listnames(self) -> List[str]: return [x.name for x in self.listchain()] def addfinalizer(self, fin: Callable[[], object]) -> None: @@ -323,12 +343,13 @@ class Node(metaclass=NodeMeta): """ self.session._setupstate.addfinalizer(fin, self) - def getparent(self, cls): + def getparent(self, cls: "Type[_NodeType]") -> Optional[_NodeType]: """ get the next parent node (including ourself) which is an instance of the given class""" current = self # type: Optional[Node] while current and not isinstance(current, cls): current = current.parent + assert current is None or isinstance(current, cls) return current def _prunetraceback(self, excinfo): @@ -479,7 +500,12 @@ class FSHookProxy: class FSCollector(Collector): def __init__( - self, fspath: py.path.local, parent=None, config=None, session=None, nodeid=None + self, + fspath: py.path.local, + parent=None, + config: Optional[Config] = None, + session: Optional["Session"] = None, + nodeid: Optional[str] = None, ) -> None: name = fspath.basename if parent is not None: @@ -579,7 +605,14 @@ class Item(Node): nextitem = None - def __init__(self, name, parent=None, config=None, session=None, nodeid=None): + def __init__( + self, + name, + parent=None, + config: Optional[Config] = None, + session: Optional["Session"] = None, + nodeid: Optional[str] = None, + ) -> None: super().__init__(name, parent, config, session, nodeid=nodeid) self._report_sections = [] # type: List[Tuple[str, str, str]] diff --git a/src/_pytest/python.py b/src/_pytest/python.py index 9b8dcf608..55ed2b164 100644 --- a/src/_pytest/python.py +++ b/src/_pytest/python.py @@ -423,7 +423,9 @@ class PyCollector(PyobjMixin, nodes.Collector): return item def _genfunctions(self, name, funcobj): - module = self.getparent(Module).obj + modulecol = self.getparent(Module) + assert modulecol is not None + module = modulecol.obj clscol = self.getparent(Class) cls = clscol and clscol.obj or None fm = self.session._fixturemanager @@ -437,7 +439,7 @@ class PyCollector(PyobjMixin, nodes.Collector): methods = [] if hasattr(module, "pytest_generate_tests"): methods.append(module.pytest_generate_tests) - if hasattr(cls, "pytest_generate_tests"): + if cls is not None and hasattr(cls, "pytest_generate_tests"): methods.append(cls().pytest_generate_tests) self.ihook.pytest_generate_tests.call_extra(methods, dict(metafunc=metafunc))