diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index db78bbd0d..1e9dd5031 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -7,10 +7,17 @@ import tokenize import warnings from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right +from types import FrameType from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union import py +from _pytest.compat import overload + class Source: """ an immutable object holding a source code fragment, @@ -19,7 +26,7 @@ class Source: _compilecounter = 0 - def __init__(self, *parts, **kwargs): + def __init__(self, *parts, **kwargs) -> None: self.lines = lines = [] # type: List[str] de = kwargs.get("deindent", True) for part in parts: @@ -48,7 +55,15 @@ class Source: # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore - def __getitem__(self, key): + @overload + def __getitem__(self, key: int) -> str: + raise NotImplementedError() + + @overload # noqa: F811 + def __getitem__(self, key: slice) -> "Source": + raise NotImplementedError() + + def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 if isinstance(key, int): return self.lines[key] else: @@ -58,10 +73,10 @@ class Source: newsource.lines = self.lines[key.start : key.stop] return newsource - def __len__(self): + def __len__(self) -> int: return len(self.lines) - def strip(self): + def strip(self) -> "Source": """ return new source object with trailing and leading blank lines removed. """ @@ -74,18 +89,20 @@ class Source: source.lines[:] = self.lines[start:end] return source - def putaround(self, before="", after="", indent=" " * 4): + def putaround( + self, before: str = "", after: str = "", indent: str = " " * 4 + ) -> "Source": """ return a copy of the source object with 'before' and 'after' wrapped around it. """ - before = Source(before) - after = Source(after) + beforesource = Source(before) + aftersource = Source(after) newsource = Source() lines = [(indent + line) for line in self.lines] - newsource.lines = before.lines + lines + after.lines + newsource.lines = beforesource.lines + lines + aftersource.lines return newsource - def indent(self, indent=" " * 4): + def indent(self, indent: str = " " * 4) -> "Source": """ return a copy of the source object with all lines indented by the given indent-string. """ @@ -93,14 +110,14 @@ class Source: newsource.lines = [(indent + line) for line in self.lines] return newsource - def getstatement(self, lineno): + def getstatement(self, lineno: int) -> "Source": """ return Source statement which contains the given linenumber (counted from 0). """ start, end = self.getstatementrange(lineno) return self[start:end] - def getstatementrange(self, lineno): + def getstatementrange(self, lineno: int): """ return (start, end) tuple which spans the minimal statement region which containing the given lineno. """ @@ -109,13 +126,13 @@ class Source: ast, start, end = getstatementrange_ast(lineno, self) return start, end - def deindent(self): + def deindent(self) -> "Source": """return a new source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) return newsource - def isparseable(self, deindent=True): + def isparseable(self, deindent: bool = True) -> bool: """ return True if source is parseable, heuristically deindenting it by default. """ @@ -135,11 +152,16 @@ class Source: else: return True - def __str__(self): + def __str__(self) -> str: return "\n".join(self.lines) def compile( - self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None + self, + filename=None, + mode="exec", + flag: int = 0, + dont_inherit: int = 0, + _genframe: Optional[FrameType] = None, ): """ return compiled code object. if filename is None invent an artificial filename which displays @@ -183,7 +205,7 @@ class Source: # -def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0): +def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): """ compile the given source to a raw code object, and maintain an internal cache which allows later retrieval of the source code for the code object @@ -233,7 +255,7 @@ def getfslineno(obj): # -def findsource(obj): +def findsource(obj) -> Tuple[Optional[Source], int]: try: sourcelines, lineno = inspect.findsource(obj) except Exception: @@ -243,7 +265,7 @@ def findsource(obj): return source, lineno -def getsource(obj, **kwargs): +def getsource(obj, **kwargs) -> Source: from .code import getrawcode obj = getrawcode(obj) @@ -255,21 +277,21 @@ def getsource(obj, **kwargs): return Source(strsrc, **kwargs) -def deindent(lines): +def deindent(lines: Sequence[str]) -> List[str]: return textwrap.dedent("\n".join(lines)).splitlines() -def get_statement_startend2(lineno, node): +def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: import ast # flatten all statements and except handlers into one lineno-list # AST's line numbers start indexing at 1 - values = [] + values = [] # type: List[int] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): values.append(x.lineno - 1) for name in ("finalbody", "orelse"): - val = getattr(x, name, None) + val = getattr(x, name, None) # type: Optional[List[ast.stmt]] if val: # treat the finally/orelse part as its own statement values.append(val[0].lineno - 1 - 1) @@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node): return start, end -def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): +def getstatementrange_ast( + lineno: int, + source: Source, + assertion: bool = False, + astnode: Optional[ast.AST] = None, +) -> Tuple[ast.AST, int, int]: if astnode is None: content = str(source) # See #4260: diff --git a/src/_pytest/pytester.py b/src/_pytest/pytester.py index 14db8409e..6b45e077b 100644 --- a/src/_pytest/pytester.py +++ b/src/_pytest/pytester.py @@ -1,4 +1,5 @@ """(disabled by default) support for testing pytest and pytest plugins.""" +import collections.abc import gc import importlib import os @@ -8,9 +9,15 @@ import subprocess import sys import time import traceback -from collections.abc import Sequence from fnmatch import fnmatch from io import StringIO +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple from typing import Union from weakref import WeakKeyDictionary @@ -21,10 +28,16 @@ from _pytest._code import Source from _pytest._io.saferepr import saferepr from _pytest.capture import MultiCapture from _pytest.capture import SysCapture +from _pytest.fixtures import FixtureRequest from _pytest.main import ExitCode from _pytest.main import Session from _pytest.monkeypatch import MonkeyPatch from _pytest.pathlib import Path +from _pytest.reports import TestReport + +if False: # TYPE_CHECKING + from typing import Type + IGNORE_PAM = [ # filenames added when obtaining details about the current user "/var/lib/sss/mc/passwd" @@ -142,7 +155,7 @@ class LsofFdLeakChecker: @pytest.fixture -def _pytest(request): +def _pytest(request: FixtureRequest) -> "PytestArg": """Return a helper which offers a gethookrecorder(hook) method which returns a HookRecorder instance which helps to make assertions about called hooks. @@ -152,10 +165,10 @@ def _pytest(request): class PytestArg: - def __init__(self, request): + def __init__(self, request: FixtureRequest) -> None: self.request = request - def gethookrecorder(self, hook): + def gethookrecorder(self, hook) -> "HookRecorder": hookrecorder = HookRecorder(hook._pm) self.request.addfinalizer(hookrecorder.finish_recording) return hookrecorder @@ -176,6 +189,11 @@ class ParsedCall: del d["_name"] return "".format(self._name, d) + if False: # TYPE_CHECKING + # The class has undetermined attributes, this tells mypy about it. + def __getattr__(self, key): + raise NotImplementedError() + class HookRecorder: """Record all hooks called in a plugin manager. @@ -185,27 +203,27 @@ class HookRecorder: """ - def __init__(self, pluginmanager): + def __init__(self, pluginmanager) -> None: self._pluginmanager = pluginmanager - self.calls = [] + self.calls = [] # type: List[ParsedCall] - def before(hook_name, hook_impls, kwargs): + def before(hook_name: str, hook_impls, kwargs) -> None: self.calls.append(ParsedCall(hook_name, kwargs)) - def after(outcome, hook_name, hook_impls, kwargs): + def after(outcome, hook_name: str, hook_impls, kwargs) -> None: pass self._undo_wrapping = pluginmanager.add_hookcall_monitoring(before, after) - def finish_recording(self): + def finish_recording(self) -> None: self._undo_wrapping() - def getcalls(self, names): + def getcalls(self, names: Union[str, Iterable[str]]) -> List[ParsedCall]: if isinstance(names, str): names = names.split() return [call for call in self.calls if call._name in names] - def assert_contains(self, entries): + def assert_contains(self, entries) -> None: __tracebackhide__ = True i = 0 entries = list(entries) @@ -226,7 +244,7 @@ class HookRecorder: else: pytest.fail("could not find {!r} check {!r}".format(name, check)) - def popcall(self, name): + def popcall(self, name: str) -> ParsedCall: __tracebackhide__ = True for i, call in enumerate(self.calls): if call._name == name: @@ -236,20 +254,27 @@ class HookRecorder: lines.extend([" %s" % x for x in self.calls]) pytest.fail("\n".join(lines)) - def getcall(self, name): + def getcall(self, name: str) -> ParsedCall: values = self.getcalls(name) assert len(values) == 1, (name, values) return values[0] # functionality for test reports - def getreports(self, names="pytest_runtest_logreport pytest_collectreport"): + def getreports( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [x.report for x in self.getcalls(names)] def matchreport( self, - inamepart="", - names="pytest_runtest_logreport pytest_collectreport", + inamepart: str = "", + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", when=None, ): """return a testreport whose dotted import path matches""" @@ -275,13 +300,20 @@ class HookRecorder: ) return values[0] - def getfailures(self, names="pytest_runtest_logreport pytest_collectreport"): + def getfailures( + self, + names: Union[ + str, Iterable[str] + ] = "pytest_runtest_logreport pytest_collectreport", + ) -> List[TestReport]: return [rep for rep in self.getreports(names) if rep.failed] - def getfailedcollections(self): + def getfailedcollections(self) -> List[TestReport]: return self.getfailures("pytest_collectreport") - def listoutcomes(self): + def listoutcomes( + self + ) -> Tuple[List[TestReport], List[TestReport], List[TestReport]]: passed = [] skipped = [] failed = [] @@ -296,31 +328,31 @@ class HookRecorder: failed.append(rep) return passed, skipped, failed - def countoutcomes(self): + def countoutcomes(self) -> List[int]: return [len(x) for x in self.listoutcomes()] - def assertoutcome(self, passed=0, skipped=0, failed=0): + def assertoutcome(self, passed: int = 0, skipped: int = 0, failed: int = 0) -> None: realpassed, realskipped, realfailed = self.listoutcomes() assert passed == len(realpassed) assert skipped == len(realskipped) assert failed == len(realfailed) - def clear(self): + def clear(self) -> None: self.calls[:] = [] @pytest.fixture -def linecomp(request): +def linecomp(request: FixtureRequest) -> "LineComp": return LineComp() @pytest.fixture(name="LineMatcher") -def LineMatcher_fixture(request): +def LineMatcher_fixture(request: FixtureRequest) -> "Type[LineMatcher]": return LineMatcher @pytest.fixture -def testdir(request, tmpdir_factory): +def testdir(request: FixtureRequest, tmpdir_factory) -> "Testdir": return Testdir(request, tmpdir_factory) @@ -363,7 +395,13 @@ class RunResult: :ivar duration: duration in seconds """ - def __init__(self, ret: Union[int, ExitCode], outlines, errlines, duration) -> None: + def __init__( + self, + ret: Union[int, ExitCode], + outlines: Sequence[str], + errlines: Sequence[str], + duration: float, + ) -> None: try: self.ret = pytest.ExitCode(ret) # type: Union[int, ExitCode] except ValueError: @@ -374,13 +412,13 @@ class RunResult: self.stderr = LineMatcher(errlines) self.duration = duration - def __repr__(self): + def __repr__(self) -> str: return ( "" % (self.ret, len(self.stdout.lines), len(self.stderr.lines), self.duration) ) - def parseoutcomes(self): + def parseoutcomes(self) -> Dict[str, int]: """Return a dictionary of outcomestring->num from parsing the terminal output that the test process produced. @@ -393,8 +431,14 @@ class RunResult: raise ValueError("Pytest terminal summary report not found") def assert_outcomes( - self, passed=0, skipped=0, failed=0, error=0, xpassed=0, xfailed=0 - ): + self, + passed: int = 0, + skipped: int = 0, + failed: int = 0, + error: int = 0, + xpassed: int = 0, + xfailed: int = 0, + ) -> None: """Assert that the specified outcomes appear with the respective numbers (0 means it didn't occur) in the text output from a test run. @@ -420,19 +464,19 @@ class RunResult: class CwdSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = os.getcwd() - def restore(self): + def restore(self) -> None: os.chdir(self.__saved) class SysModulesSnapshot: - def __init__(self, preserve=None): + def __init__(self, preserve: Optional[Callable[[str], bool]] = None): self.__preserve = preserve self.__saved = dict(sys.modules) - def restore(self): + def restore(self) -> None: if self.__preserve: self.__saved.update( (k, m) for k, m in sys.modules.items() if self.__preserve(k) @@ -442,10 +486,10 @@ class SysModulesSnapshot: class SysPathsSnapshot: - def __init__(self): + def __init__(self) -> None: self.__saved = list(sys.path), list(sys.meta_path) - def restore(self): + def restore(self) -> None: sys.path[:], sys.meta_path[:] = self.__saved @@ -1357,7 +1401,7 @@ class LineMatcher: :param str match_nickname: the nickname for the match function that will be logged to stdout when a match occurs """ - assert isinstance(lines2, Sequence) + assert isinstance(lines2, collections.abc.Sequence) lines2 = self._getlines(lines2) lines1 = self.lines[:] nextline = None diff --git a/src/_pytest/warning_types.py b/src/_pytest/warning_types.py index 80353ccbc..22cb17dba 100644 --- a/src/_pytest/warning_types.py +++ b/src/_pytest/warning_types.py @@ -1,6 +1,14 @@ +from typing import Any +from typing import Generic +from typing import TypeVar + import attr +if False: # TYPE_CHECKING + from typing import Type # noqa: F401 (used in type string) + + class PytestWarning(UserWarning): """ Bases: :class:`UserWarning`. @@ -72,7 +80,7 @@ class PytestExperimentalApiWarning(PytestWarning, FutureWarning): __module__ = "pytest" @classmethod - def simple(cls, apiname): + def simple(cls, apiname: str) -> "PytestExperimentalApiWarning": return cls( "{apiname} is an experimental api that may change over time".format( apiname=apiname @@ -103,17 +111,20 @@ class PytestUnknownMarkWarning(PytestWarning): __module__ = "pytest" +_W = TypeVar("_W", bound=PytestWarning) + + @attr.s -class UnformattedWarning: +class UnformattedWarning(Generic[_W]): """Used to hold warnings that need to format their message at runtime, as opposed to a direct message. Using this class avoids to keep all the warning types and messages in this module, avoiding misuse. """ - category = attr.ib() - template = attr.ib() + category = attr.ib(type="Type[_W]") + template = attr.ib(type=str) - def format(self, **kwargs): + def format(self, **kwargs: Any) -> _W: """Returns an instance of the warning category, formatted with given kwargs""" return self.category(self.template.format(**kwargs))