Add type annotations to _pytest._code.source
At least most of it.
This commit is contained in:
parent
b2537b22d7
commit
58f2849bf6
|
@ -7,10 +7,17 @@ import tokenize
|
||||||
import warnings
|
import warnings
|
||||||
from ast import PyCF_ONLY_AST as _AST_FLAG
|
from ast import PyCF_ONLY_AST as _AST_FLAG
|
||||||
from bisect import bisect_right
|
from bisect import bisect_right
|
||||||
|
from types import FrameType
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Sequence
|
||||||
|
from typing import Tuple
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import py
|
import py
|
||||||
|
|
||||||
|
from _pytest.compat import overload
|
||||||
|
|
||||||
|
|
||||||
class Source:
|
class Source:
|
||||||
""" an immutable object holding a source code fragment,
|
""" an immutable object holding a source code fragment,
|
||||||
|
@ -19,7 +26,7 @@ class Source:
|
||||||
|
|
||||||
_compilecounter = 0
|
_compilecounter = 0
|
||||||
|
|
||||||
def __init__(self, *parts, **kwargs):
|
def __init__(self, *parts, **kwargs) -> None:
|
||||||
self.lines = lines = [] # type: List[str]
|
self.lines = lines = [] # type: List[str]
|
||||||
de = kwargs.get("deindent", True)
|
de = kwargs.get("deindent", True)
|
||||||
for part in parts:
|
for part in parts:
|
||||||
|
@ -48,7 +55,15 @@ class Source:
|
||||||
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
||||||
__hash__ = None # type: ignore
|
__hash__ = None # type: ignore
|
||||||
|
|
||||||
def __getitem__(self, key):
|
@overload
|
||||||
|
def __getitem__(self, key: int) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@overload # noqa: F811
|
||||||
|
def __getitem__(self, key: slice) -> "Source":
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
|
||||||
if isinstance(key, int):
|
if isinstance(key, int):
|
||||||
return self.lines[key]
|
return self.lines[key]
|
||||||
else:
|
else:
|
||||||
|
@ -58,10 +73,10 @@ class Source:
|
||||||
newsource.lines = self.lines[key.start : key.stop]
|
newsource.lines = self.lines[key.start : key.stop]
|
||||||
return newsource
|
return newsource
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.lines)
|
return len(self.lines)
|
||||||
|
|
||||||
def strip(self):
|
def strip(self) -> "Source":
|
||||||
""" return new source object with trailing
|
""" return new source object with trailing
|
||||||
and leading blank lines removed.
|
and leading blank lines removed.
|
||||||
"""
|
"""
|
||||||
|
@ -74,18 +89,20 @@ class Source:
|
||||||
source.lines[:] = self.lines[start:end]
|
source.lines[:] = self.lines[start:end]
|
||||||
return source
|
return source
|
||||||
|
|
||||||
def putaround(self, before="", after="", indent=" " * 4):
|
def putaround(
|
||||||
|
self, before: str = "", after: str = "", indent: str = " " * 4
|
||||||
|
) -> "Source":
|
||||||
""" return a copy of the source object with
|
""" return a copy of the source object with
|
||||||
'before' and 'after' wrapped around it.
|
'before' and 'after' wrapped around it.
|
||||||
"""
|
"""
|
||||||
before = Source(before)
|
beforesource = Source(before)
|
||||||
after = Source(after)
|
aftersource = Source(after)
|
||||||
newsource = Source()
|
newsource = Source()
|
||||||
lines = [(indent + line) for line in self.lines]
|
lines = [(indent + line) for line in self.lines]
|
||||||
newsource.lines = before.lines + lines + after.lines
|
newsource.lines = beforesource.lines + lines + aftersource.lines
|
||||||
return newsource
|
return newsource
|
||||||
|
|
||||||
def indent(self, indent=" " * 4):
|
def indent(self, indent: str = " " * 4) -> "Source":
|
||||||
""" return a copy of the source object with
|
""" return a copy of the source object with
|
||||||
all lines indented by the given indent-string.
|
all lines indented by the given indent-string.
|
||||||
"""
|
"""
|
||||||
|
@ -93,14 +110,14 @@ class Source:
|
||||||
newsource.lines = [(indent + line) for line in self.lines]
|
newsource.lines = [(indent + line) for line in self.lines]
|
||||||
return newsource
|
return newsource
|
||||||
|
|
||||||
def getstatement(self, lineno):
|
def getstatement(self, lineno: int) -> "Source":
|
||||||
""" return Source statement which contains the
|
""" return Source statement which contains the
|
||||||
given linenumber (counted from 0).
|
given linenumber (counted from 0).
|
||||||
"""
|
"""
|
||||||
start, end = self.getstatementrange(lineno)
|
start, end = self.getstatementrange(lineno)
|
||||||
return self[start:end]
|
return self[start:end]
|
||||||
|
|
||||||
def getstatementrange(self, lineno):
|
def getstatementrange(self, lineno: int):
|
||||||
""" return (start, end) tuple which spans the minimal
|
""" return (start, end) tuple which spans the minimal
|
||||||
statement region which containing the given lineno.
|
statement region which containing the given lineno.
|
||||||
"""
|
"""
|
||||||
|
@ -109,13 +126,13 @@ class Source:
|
||||||
ast, start, end = getstatementrange_ast(lineno, self)
|
ast, start, end = getstatementrange_ast(lineno, self)
|
||||||
return start, end
|
return start, end
|
||||||
|
|
||||||
def deindent(self):
|
def deindent(self) -> "Source":
|
||||||
"""return a new source object deindented."""
|
"""return a new source object deindented."""
|
||||||
newsource = Source()
|
newsource = Source()
|
||||||
newsource.lines[:] = deindent(self.lines)
|
newsource.lines[:] = deindent(self.lines)
|
||||||
return newsource
|
return newsource
|
||||||
|
|
||||||
def isparseable(self, deindent=True):
|
def isparseable(self, deindent: bool = True) -> bool:
|
||||||
""" return True if source is parseable, heuristically
|
""" return True if source is parseable, heuristically
|
||||||
deindenting it by default.
|
deindenting it by default.
|
||||||
"""
|
"""
|
||||||
|
@ -135,11 +152,16 @@ class Source:
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
return "\n".join(self.lines)
|
return "\n".join(self.lines)
|
||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None
|
self,
|
||||||
|
filename=None,
|
||||||
|
mode="exec",
|
||||||
|
flag: int = 0,
|
||||||
|
dont_inherit: int = 0,
|
||||||
|
_genframe: Optional[FrameType] = None,
|
||||||
):
|
):
|
||||||
""" return compiled code object. if filename is None
|
""" return compiled code object. if filename is None
|
||||||
invent an artificial filename which displays
|
invent an artificial filename which displays
|
||||||
|
@ -183,7 +205,7 @@ class Source:
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0):
|
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
|
||||||
""" compile the given source to a raw code object,
|
""" compile the given source to a raw code object,
|
||||||
and maintain an internal cache which allows later
|
and maintain an internal cache which allows later
|
||||||
retrieval of the source code for the code object
|
retrieval of the source code for the code object
|
||||||
|
@ -233,7 +255,7 @@ def getfslineno(obj):
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def findsource(obj):
|
def findsource(obj) -> Tuple[Optional[Source], int]:
|
||||||
try:
|
try:
|
||||||
sourcelines, lineno = inspect.findsource(obj)
|
sourcelines, lineno = inspect.findsource(obj)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
@ -243,7 +265,7 @@ def findsource(obj):
|
||||||
return source, lineno
|
return source, lineno
|
||||||
|
|
||||||
|
|
||||||
def getsource(obj, **kwargs):
|
def getsource(obj, **kwargs) -> Source:
|
||||||
from .code import getrawcode
|
from .code import getrawcode
|
||||||
|
|
||||||
obj = getrawcode(obj)
|
obj = getrawcode(obj)
|
||||||
|
@ -255,21 +277,21 @@ def getsource(obj, **kwargs):
|
||||||
return Source(strsrc, **kwargs)
|
return Source(strsrc, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def deindent(lines):
|
def deindent(lines: Sequence[str]) -> List[str]:
|
||||||
return textwrap.dedent("\n".join(lines)).splitlines()
|
return textwrap.dedent("\n".join(lines)).splitlines()
|
||||||
|
|
||||||
|
|
||||||
def get_statement_startend2(lineno, node):
|
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
|
||||||
import ast
|
import ast
|
||||||
|
|
||||||
# flatten all statements and except handlers into one lineno-list
|
# flatten all statements and except handlers into one lineno-list
|
||||||
# AST's line numbers start indexing at 1
|
# AST's line numbers start indexing at 1
|
||||||
values = []
|
values = [] # type: List[int]
|
||||||
for x in ast.walk(node):
|
for x in ast.walk(node):
|
||||||
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
||||||
values.append(x.lineno - 1)
|
values.append(x.lineno - 1)
|
||||||
for name in ("finalbody", "orelse"):
|
for name in ("finalbody", "orelse"):
|
||||||
val = getattr(x, name, None)
|
val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
|
||||||
if val:
|
if val:
|
||||||
# treat the finally/orelse part as its own statement
|
# treat the finally/orelse part as its own statement
|
||||||
values.append(val[0].lineno - 1 - 1)
|
values.append(val[0].lineno - 1 - 1)
|
||||||
|
@ -283,7 +305,12 @@ def get_statement_startend2(lineno, node):
|
||||||
return start, end
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
|
def getstatementrange_ast(
|
||||||
|
lineno: int,
|
||||||
|
source: Source,
|
||||||
|
assertion: bool = False,
|
||||||
|
astnode: Optional[ast.AST] = None,
|
||||||
|
) -> Tuple[ast.AST, int, int]:
|
||||||
if astnode is None:
|
if astnode is None:
|
||||||
content = str(source)
|
content = str(source)
|
||||||
# See #4260:
|
# See #4260:
|
||||||
|
|
Loading…
Reference in New Issue