diff --git a/src/_pytest/config/__init__.py b/src/_pytest/config/__init__.py index 3516b333e..ed3334e5f 100644 --- a/src/_pytest/config/__init__.py +++ b/src/_pytest/config/__init__.py @@ -48,6 +48,13 @@ if TYPE_CHECKING: from typing import Type +_PluggyPlugin = object +"""A type to represent plugin objects. +Plugins can be any namespace, so we can't narrow it down much, but we use an +alias to make the intent clear. +Ideally this type would be provided by pluggy itself.""" + + hookimpl = HookimplMarker("pytest") hookspec = HookspecMarker("pytest") diff --git a/src/_pytest/pytester.py b/src/_pytest/pytester.py index 8e7fa5e09..cfe1b9a6c 100644 --- a/src/_pytest/pytester.py +++ b/src/_pytest/pytester.py @@ -29,12 +29,17 @@ from _pytest._io.saferepr import saferepr from _pytest.capture import MultiCapture from _pytest.capture import SysCapture from _pytest.compat import TYPE_CHECKING +from _pytest.config import _PluggyPlugin from _pytest.fixtures import FixtureRequest from _pytest.main import ExitCode from _pytest.main import Session from _pytest.monkeypatch import MonkeyPatch +from _pytest.nodes import Collector +from _pytest.nodes import Item from _pytest.pathlib import Path +from _pytest.python import Module from _pytest.reports import TestReport +from _pytest.tmpdir import TempdirFactory if TYPE_CHECKING: from typing import Type @@ -534,13 +539,15 @@ class Testdir: class TimeoutExpired(Exception): pass - def __init__(self, request, tmpdir_factory): + def __init__(self, request: FixtureRequest, tmpdir_factory: TempdirFactory) -> None: self.request = request - self._mod_collections = WeakKeyDictionary() + self._mod_collections = ( + WeakKeyDictionary() + ) # type: WeakKeyDictionary[Module, List[Union[Item, Collector]]] name = request.function.__name__ self.tmpdir = tmpdir_factory.mktemp(name, numbered=True) self.test_tmproot = tmpdir_factory.mktemp("tmp-" + name, numbered=True) - self.plugins = [] + self.plugins = [] # type: List[Union[str, _PluggyPlugin]] self._cwd_snapshot = CwdSnapshot() self._sys_path_snapshot = SysPathsSnapshot() self._sys_modules_snapshot = self.__take_sys_modules_snapshot() @@ -1064,7 +1071,9 @@ class Testdir: self.config = config = self.parseconfigure(path, *configargs) return self.getnode(config, path) - def collect_by_name(self, modcol, name): + def collect_by_name( + self, modcol: Module, name: str + ) -> Optional[Union[Item, Collector]]: """Return the collection node for name from the module collection. This will search a module collection node for a collection node @@ -1073,13 +1082,13 @@ class Testdir: :param modcol: a module collection node; see :py:meth:`getmodulecol` :param name: the name of the node to return - """ if modcol not in self._mod_collections: self._mod_collections[modcol] = list(modcol.collect()) for colitem in self._mod_collections[modcol]: if colitem.name == name: return colitem + return None def popen( self, diff --git a/testing/test_collection.py b/testing/test_collection.py index 62de0b953..5073f675e 100644 --- a/testing/test_collection.py +++ b/testing/test_collection.py @@ -9,6 +9,7 @@ import pytest from _pytest.main import _in_venv from _pytest.main import ExitCode from _pytest.main import Session +from _pytest.pytester import Testdir class TestCollector: @@ -18,7 +19,7 @@ class TestCollector: assert not issubclass(Collector, Item) assert not issubclass(Item, Collector) - def test_check_equality(self, testdir): + def test_check_equality(self, testdir: Testdir) -> None: modcol = testdir.getmodulecol( """ def test_pass(): pass @@ -40,12 +41,15 @@ class TestCollector: assert fn1 != fn3 for fn in fn1, fn2, fn3: - assert fn != 3 + assert isinstance(fn, pytest.Function) + assert fn != 3 # type: ignore[comparison-overlap] # noqa: F821 assert fn != modcol - assert fn != [1, 2, 3] - assert [1, 2, 3] != fn + assert fn != [1, 2, 3] # type: ignore[comparison-overlap] # noqa: F821 + assert [1, 2, 3] != fn # type: ignore[comparison-overlap] # noqa: F821 assert modcol != fn + assert testdir.collect_by_name(modcol, "doesnotexist") is None + def test_getparent(self, testdir): modcol = testdir.getmodulecol( """