diff --git a/src/_pytest/_code/source.py b/src/_pytest/_code/source.py index db78bbd0d..1e9dd5031 100644 --- a/src/_pytest/_code/source.py +++ b/src/_pytest/_code/source.py @@ -7,10 +7,17 @@ import tokenize import warnings from ast import PyCF_ONLY_AST as _AST_FLAG from bisect import bisect_right +from types import FrameType from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union import py +from _pytest.compat import overload + class Source: """ an immutable object holding a source code fragment, @@ -19,7 +26,7 @@ class Source: _compilecounter = 0 - def __init__(self, *parts, **kwargs): + def __init__(self, *parts, **kwargs) -> None: self.lines = lines = [] # type: List[str] de = kwargs.get("deindent", True) for part in parts: @@ -48,7 +55,15 @@ class Source: # Ignore type because of https://github.com/python/mypy/issues/4266. __hash__ = None # type: ignore - def __getitem__(self, key): + @overload + def __getitem__(self, key: int) -> str: + raise NotImplementedError() + + @overload # noqa: F811 + def __getitem__(self, key: slice) -> "Source": + raise NotImplementedError() + + def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 if isinstance(key, int): return self.lines[key] else: @@ -58,10 +73,10 @@ class Source: newsource.lines = self.lines[key.start : key.stop] return newsource - def __len__(self): + def __len__(self) -> int: return len(self.lines) - def strip(self): + def strip(self) -> "Source": """ return new source object with trailing and leading blank lines removed. """ @@ -74,18 +89,20 @@ class Source: source.lines[:] = self.lines[start:end] return source - def putaround(self, before="", after="", indent=" " * 4): + def putaround( + self, before: str = "", after: str = "", indent: str = " " * 4 + ) -> "Source": """ return a copy of the source object with 'before' and 'after' wrapped around it. """ - before = Source(before) - after = Source(after) + beforesource = Source(before) + aftersource = Source(after) newsource = Source() lines = [(indent + line) for line in self.lines] - newsource.lines = before.lines + lines + after.lines + newsource.lines = beforesource.lines + lines + aftersource.lines return newsource - def indent(self, indent=" " * 4): + def indent(self, indent: str = " " * 4) -> "Source": """ return a copy of the source object with all lines indented by the given indent-string. """ @@ -93,14 +110,14 @@ class Source: newsource.lines = [(indent + line) for line in self.lines] return newsource - def getstatement(self, lineno): + def getstatement(self, lineno: int) -> "Source": """ return Source statement which contains the given linenumber (counted from 0). """ start, end = self.getstatementrange(lineno) return self[start:end] - def getstatementrange(self, lineno): + def getstatementrange(self, lineno: int): """ return (start, end) tuple which spans the minimal statement region which containing the given lineno. """ @@ -109,13 +126,13 @@ class Source: ast, start, end = getstatementrange_ast(lineno, self) return start, end - def deindent(self): + def deindent(self) -> "Source": """return a new source object deindented.""" newsource = Source() newsource.lines[:] = deindent(self.lines) return newsource - def isparseable(self, deindent=True): + def isparseable(self, deindent: bool = True) -> bool: """ return True if source is parseable, heuristically deindenting it by default. """ @@ -135,11 +152,16 @@ class Source: else: return True - def __str__(self): + def __str__(self) -> str: return "\n".join(self.lines) def compile( - self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None + self, + filename=None, + mode="exec", + flag: int = 0, + dont_inherit: int = 0, + _genframe: Optional[FrameType] = None, ): """ return compiled code object. if filename is None invent an artificial filename which displays @@ -183,7 +205,7 @@ class Source: # -def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0): +def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0): """ compile the given source to a raw code object, and maintain an internal cache which allows later retrieval of the source code for the code object @@ -233,7 +255,7 @@ def getfslineno(obj): # -def findsource(obj): +def findsource(obj) -> Tuple[Optional[Source], int]: try: sourcelines, lineno = inspect.findsource(obj) except Exception: @@ -243,7 +265,7 @@ def findsource(obj): return source, lineno -def getsource(obj, **kwargs): +def getsource(obj, **kwargs) -> Source: from .code import getrawcode obj = getrawcode(obj) @@ -255,21 +277,21 @@ def getsource(obj, **kwargs): return Source(strsrc, **kwargs) -def deindent(lines): +def deindent(lines: Sequence[str]) -> List[str]: return textwrap.dedent("\n".join(lines)).splitlines() -def get_statement_startend2(lineno, node): +def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]: import ast # flatten all statements and except handlers into one lineno-list # AST's line numbers start indexing at 1 - values = [] + values = [] # type: List[int] for x in ast.walk(node): if isinstance(x, (ast.stmt, ast.ExceptHandler)): values.append(x.lineno - 1) for name in ("finalbody", "orelse"): - val = getattr(x, name, None) + val = getattr(x, name, None) # type: Optional[List[ast.stmt]] if val: # treat the finally/orelse part as its own statement values.append(val[0].lineno - 1 - 1) @@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node): return start, end -def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): +def getstatementrange_ast( + lineno: int, + source: Source, + assertion: bool = False, + astnode: Optional[ast.AST] = None, +) -> Tuple[ast.AST, int, int]: if astnode is None: content = str(source) # See #4260: