From f0eb82f7d40281c28b94239e85b95918d1d7aeb9 Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Sat, 1 Aug 2020 10:49:51 +0300 Subject: [PATCH] pytester: improve type annotations --- src/_pytest/pytester.py | 111 +++++++++++++++++++++++++++++++--------- 1 file changed, 86 insertions(+), 25 deletions(-) diff --git a/src/_pytest/pytester.py b/src/_pytest/pytester.py index e0e7b1fbc..83c525fd8 100644 --- a/src/_pytest/pytester.py +++ b/src/_pytest/pytester.py @@ -28,6 +28,7 @@ import pytest from _pytest import timing from _pytest._code import Source from _pytest.capture import _get_multicapture +from _pytest.compat import overload from _pytest.compat import TYPE_CHECKING from _pytest.config import _PluggyPlugin from _pytest.config import Config @@ -42,11 +43,13 @@ from _pytest.nodes import Item from _pytest.pathlib import make_numbered_dir from _pytest.pathlib import Path from _pytest.python import Module +from _pytest.reports import CollectReport from _pytest.reports import TestReport from _pytest.tmpdir import TempdirFactory if TYPE_CHECKING: from typing import Type + from typing_extensions import Literal import pexpect @@ -180,24 +183,24 @@ class PytestArg: return hookrecorder -def get_public_names(values): +def get_public_names(values: Iterable[str]) -> List[str]: """Only return names from iterator values without a leading underscore.""" return [x for x in values if x[0] != "_"] class ParsedCall: - def __init__(self, name, kwargs): + def __init__(self, name: str, kwargs) -> None: self.__dict__.update(kwargs) self._name = name - def __repr__(self): + def __repr__(self) -> str: d = self.__dict__.copy() del d["_name"] return "".format(self._name, d) if TYPE_CHECKING: # The class has undetermined attributes, this tells mypy about it. - def __getattr__(self, key): + def __getattr__(self, key: str): raise NotImplementedError() @@ -211,6 +214,7 @@ class HookRecorder: def __init__(self, pluginmanager: PytestPluginManager) -> None: self._pluginmanager = pluginmanager self.calls = [] # type: List[ParsedCall] + self.ret = None # type: Optional[Union[int, ExitCode]] def before(hook_name: str, hook_impls, kwargs) -> None: self.calls.append(ParsedCall(hook_name, kwargs)) @@ -228,7 +232,7 @@ class HookRecorder: names = names.split() return [call for call in self.calls if call._name in names] - def assert_contains(self, entries) -> None: + def assert_contains(self, entries: Sequence[Tuple[str, str]]) -> None: __tracebackhide__ = True i = 0 entries = list(entries) @@ -266,22 +270,46 @@ class HookRecorder: # functionality for test reports + @overload def getreports( + self, names: "Literal['pytest_collectreport']", + ) -> Sequence[CollectReport]: + raise NotImplementedError() + + @overload # noqa: F811 + def getreports( # noqa: F811 + self, names: "Literal['pytest_runtest_logreport']", + ) -> Sequence[TestReport]: + raise NotImplementedError() + + @overload # noqa: F811 + def getreports( # noqa: F811 self, - names: Union[ - str, Iterable[str] - ] = "pytest_runtest_logreport pytest_collectreport", - ) -> List[TestReport]: + names: Union[str, Iterable[str]] = ( + "pytest_collectreport", + "pytest_runtest_logreport", + ), + ) -> Sequence[Union[CollectReport, TestReport]]: + raise NotImplementedError() + + def getreports( # noqa: F811 + self, + names: Union[str, Iterable[str]] = ( + "pytest_collectreport", + "pytest_runtest_logreport", + ), + ) -> Sequence[Union[CollectReport, TestReport]]: return [x.report for x in self.getcalls(names)] def matchreport( self, inamepart: str = "", - names: Union[ - str, Iterable[str] - ] = "pytest_runtest_logreport pytest_collectreport", - when=None, - ): + names: Union[str, Iterable[str]] = ( + "pytest_runtest_logreport", + "pytest_collectreport", + ), + when: Optional[str] = None, + ) -> Union[CollectReport, TestReport]: """Return a testreport whose dotted import path matches.""" values = [] for rep in self.getreports(names=names): @@ -305,26 +333,56 @@ class HookRecorder: ) return values[0] + @overload def getfailures( + self, names: "Literal['pytest_collectreport']", + ) -> Sequence[CollectReport]: + raise NotImplementedError() + + @overload # noqa: F811 + def getfailures( # noqa: F811 + self, names: "Literal['pytest_runtest_logreport']", + ) -> Sequence[TestReport]: + raise NotImplementedError() + + @overload # noqa: F811 + def getfailures( # noqa: F811 self, - names: Union[ - str, Iterable[str] - ] = "pytest_runtest_logreport pytest_collectreport", - ) -> List[TestReport]: + names: Union[str, Iterable[str]] = ( + "pytest_collectreport", + "pytest_runtest_logreport", + ), + ) -> Sequence[Union[CollectReport, TestReport]]: + raise NotImplementedError() + + def getfailures( # noqa: F811 + self, + names: Union[str, Iterable[str]] = ( + "pytest_collectreport", + "pytest_runtest_logreport", + ), + ) -> Sequence[Union[CollectReport, TestReport]]: return [rep for rep in self.getreports(names) if rep.failed] - def getfailedcollections(self) -> List[TestReport]: + def getfailedcollections(self) -> Sequence[CollectReport]: return self.getfailures("pytest_collectreport") def listoutcomes( self, - ) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]: + ) -> Tuple[ + Sequence[TestReport], + Sequence[Union[CollectReport, TestReport]], + Sequence[Union[CollectReport, TestReport]], + ]: passed = [] skipped = [] failed = [] - for rep in self.getreports("pytest_collectreport pytest_runtest_logreport"): + for rep in self.getreports( + ("pytest_collectreport", "pytest_runtest_logreport") + ): if rep.passed: if rep.when == "call": + assert isinstance(rep, TestReport) passed.append(rep) elif rep.skipped: skipped.append(rep) @@ -879,7 +937,7 @@ class Testdir: runner = testclassinstance.getrunner() return runner(item) - def inline_runsource(self, source, *cmdlineargs): + def inline_runsource(self, source, *cmdlineargs) -> HookRecorder: """Run a test module in process using ``pytest.main()``. This run writes "source" into a temporary file and runs @@ -896,7 +954,7 @@ class Testdir: values = list(cmdlineargs) + [p] return self.inline_run(*values) - def inline_genitems(self, *args): + def inline_genitems(self, *args) -> Tuple[List[Item], HookRecorder]: """Run ``pytest.main(['--collectonly'])`` in-process. Runs the :py:func:`pytest.main` function to run all of pytest inside @@ -907,7 +965,9 @@ class Testdir: items = [x.item for x in rec.getcalls("pytest_itemcollected")] return items, rec - def inline_run(self, *args, plugins=(), no_reraise_ctrlc: bool = False): + def inline_run( + self, *args, plugins=(), no_reraise_ctrlc: bool = False + ) -> HookRecorder: """Run ``pytest.main()`` in-process, returning a HookRecorder. Runs the :py:func:`pytest.main` function to run all of pytest inside @@ -962,7 +1022,7 @@ class Testdir: class reprec: # type: ignore pass - reprec.ret = ret # type: ignore[attr-defined] + reprec.ret = ret # Typically we reraise keyboard interrupts from the child run # because it's our user requesting interruption of the testing. @@ -1010,6 +1070,7 @@ class Testdir: sys.stdout.write(out) sys.stderr.write(err) + assert reprec.ret is not None res = RunResult( reprec.ret, out.splitlines(), err.splitlines(), timing.time() - now )