Add type annotations to _pytest._code.source

At least most of it.
This commit is contained in:
Ran Benita 2019-11-03 23:05:42 +02:00
parent b2537b22d7
commit 58f2849bf6
1 changed files with 50 additions and 23 deletions

View File

@ -7,10 +7,17 @@ import tokenize
import warnings import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right from bisect import bisect_right
from types import FrameType
from typing import List from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import py import py
from _pytest.compat import overload
class Source: class Source:
""" an immutable object holding a source code fragment, """ an immutable object holding a source code fragment,
@ -19,7 +26,7 @@ class Source:
_compilecounter = 0 _compilecounter = 0
def __init__(self, *parts, **kwargs): def __init__(self, *parts, **kwargs) -> None:
self.lines = lines = [] # type: List[str] self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True) de = kwargs.get("deindent", True)
for part in parts: for part in parts:
@ -48,7 +55,15 @@ class Source:
# Ignore type because of https://github.com/python/mypy/issues/4266. # Ignore type because of https://github.com/python/mypy/issues/4266.
__hash__ = None # type: ignore __hash__ = None # type: ignore
def __getitem__(self, key): @overload
def __getitem__(self, key: int) -> str:
raise NotImplementedError()
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Source":
raise NotImplementedError()
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
if isinstance(key, int): if isinstance(key, int):
return self.lines[key] return self.lines[key]
else: else:
@ -58,10 +73,10 @@ class Source:
newsource.lines = self.lines[key.start : key.stop] newsource.lines = self.lines[key.start : key.stop]
return newsource return newsource
def __len__(self): def __len__(self) -> int:
return len(self.lines) return len(self.lines)
def strip(self): def strip(self) -> "Source":
""" return new source object with trailing """ return new source object with trailing
and leading blank lines removed. and leading blank lines removed.
""" """
@ -74,18 +89,20 @@ class Source:
source.lines[:] = self.lines[start:end] source.lines[:] = self.lines[start:end]
return source return source
def putaround(self, before="", after="", indent=" " * 4): def putaround(
self, before: str = "", after: str = "", indent: str = " " * 4
) -> "Source":
""" return a copy of the source object with """ return a copy of the source object with
'before' and 'after' wrapped around it. 'before' and 'after' wrapped around it.
""" """
before = Source(before) beforesource = Source(before)
after = Source(after) aftersource = Source(after)
newsource = Source() newsource = Source()
lines = [(indent + line) for line in self.lines] lines = [(indent + line) for line in self.lines]
newsource.lines = before.lines + lines + after.lines newsource.lines = beforesource.lines + lines + aftersource.lines
return newsource return newsource
def indent(self, indent=" " * 4): def indent(self, indent: str = " " * 4) -> "Source":
""" return a copy of the source object with """ return a copy of the source object with
all lines indented by the given indent-string. all lines indented by the given indent-string.
""" """
@ -93,14 +110,14 @@ class Source:
newsource.lines = [(indent + line) for line in self.lines] newsource.lines = [(indent + line) for line in self.lines]
return newsource return newsource
def getstatement(self, lineno): def getstatement(self, lineno: int) -> "Source":
""" return Source statement which contains the """ return Source statement which contains the
given linenumber (counted from 0). given linenumber (counted from 0).
""" """
start, end = self.getstatementrange(lineno) start, end = self.getstatementrange(lineno)
return self[start:end] return self[start:end]
def getstatementrange(self, lineno): def getstatementrange(self, lineno: int):
""" return (start, end) tuple which spans the minimal """ return (start, end) tuple which spans the minimal
statement region which containing the given lineno. statement region which containing the given lineno.
""" """
@ -109,13 +126,13 @@ class Source:
ast, start, end = getstatementrange_ast(lineno, self) ast, start, end = getstatementrange_ast(lineno, self)
return start, end return start, end
def deindent(self): def deindent(self) -> "Source":
"""return a new source object deindented.""" """return a new source object deindented."""
newsource = Source() newsource = Source()
newsource.lines[:] = deindent(self.lines) newsource.lines[:] = deindent(self.lines)
return newsource return newsource
def isparseable(self, deindent=True): def isparseable(self, deindent: bool = True) -> bool:
""" return True if source is parseable, heuristically """ return True if source is parseable, heuristically
deindenting it by default. deindenting it by default.
""" """
@ -135,11 +152,16 @@ class Source:
else: else:
return True return True
def __str__(self): def __str__(self) -> str:
return "\n".join(self.lines) return "\n".join(self.lines)
def compile( def compile(
self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None self,
filename=None,
mode="exec",
flag: int = 0,
dont_inherit: int = 0,
_genframe: Optional[FrameType] = None,
): ):
""" return compiled code object. if filename is None """ return compiled code object. if filename is None
invent an artificial filename which displays invent an artificial filename which displays
@ -183,7 +205,7 @@ class Source:
# #
def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0): def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
""" compile the given source to a raw code object, """ compile the given source to a raw code object,
and maintain an internal cache which allows later and maintain an internal cache which allows later
retrieval of the source code for the code object retrieval of the source code for the code object
@ -233,7 +255,7 @@ def getfslineno(obj):
# #
def findsource(obj): def findsource(obj) -> Tuple[Optional[Source], int]:
try: try:
sourcelines, lineno = inspect.findsource(obj) sourcelines, lineno = inspect.findsource(obj)
except Exception: except Exception:
@ -243,7 +265,7 @@ def findsource(obj):
return source, lineno return source, lineno
def getsource(obj, **kwargs): def getsource(obj, **kwargs) -> Source:
from .code import getrawcode from .code import getrawcode
obj = getrawcode(obj) obj = getrawcode(obj)
@ -255,21 +277,21 @@ def getsource(obj, **kwargs):
return Source(strsrc, **kwargs) return Source(strsrc, **kwargs)
def deindent(lines): def deindent(lines: Sequence[str]) -> List[str]:
return textwrap.dedent("\n".join(lines)).splitlines() return textwrap.dedent("\n".join(lines)).splitlines()
def get_statement_startend2(lineno, node): def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
import ast import ast
# flatten all statements and except handlers into one lineno-list # flatten all statements and except handlers into one lineno-list
# AST's line numbers start indexing at 1 # AST's line numbers start indexing at 1
values = [] values = [] # type: List[int]
for x in ast.walk(node): for x in ast.walk(node):
if isinstance(x, (ast.stmt, ast.ExceptHandler)): if isinstance(x, (ast.stmt, ast.ExceptHandler)):
values.append(x.lineno - 1) values.append(x.lineno - 1)
for name in ("finalbody", "orelse"): for name in ("finalbody", "orelse"):
val = getattr(x, name, None) val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
if val: if val:
# treat the finally/orelse part as its own statement # treat the finally/orelse part as its own statement
values.append(val[0].lineno - 1 - 1) values.append(val[0].lineno - 1 - 1)
@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node):
return start, end return start, end
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None): def getstatementrange_ast(
lineno: int,
source: Source,
assertion: bool = False,
astnode: Optional[ast.AST] = None,
) -> Tuple[ast.AST, int, int]:
if astnode is None: if astnode is None:
content = str(source) content = str(source)
# See #4260: # See #4260: