391 lines
12 KiB
Python
391 lines
12 KiB
Python
import ast
|
|
import inspect
|
|
import linecache
|
|
import sys
|
|
import textwrap
|
|
import tokenize
|
|
import warnings
|
|
from bisect import bisect_right
|
|
from types import CodeType
|
|
from types import FrameType
|
|
from typing import Iterator
|
|
from typing import List
|
|
from typing import Optional
|
|
from typing import Sequence
|
|
from typing import Tuple
|
|
from typing import Union
|
|
|
|
import py
|
|
|
|
from _pytest.compat import overload
|
|
from _pytest.compat import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import Literal
|
|
|
|
|
|
class Source:
|
|
""" an immutable object holding a source code fragment,
|
|
possibly deindenting it.
|
|
"""
|
|
|
|
_compilecounter = 0
|
|
|
|
def __init__(self, *parts, **kwargs) -> None:
|
|
self.lines = lines = [] # type: List[str]
|
|
de = kwargs.get("deindent", True)
|
|
for part in parts:
|
|
if not part:
|
|
partlines = [] # type: List[str]
|
|
elif isinstance(part, Source):
|
|
partlines = part.lines
|
|
elif isinstance(part, (tuple, list)):
|
|
partlines = [x.rstrip("\n") for x in part]
|
|
elif isinstance(part, str):
|
|
partlines = part.split("\n")
|
|
else:
|
|
partlines = getsource(part, deindent=de).lines
|
|
if de:
|
|
partlines = deindent(partlines)
|
|
lines.extend(partlines)
|
|
|
|
def __eq__(self, other):
|
|
try:
|
|
return self.lines == other.lines
|
|
except AttributeError:
|
|
if isinstance(other, str):
|
|
return str(self) == other
|
|
return False
|
|
|
|
# Ignore type because of https://github.com/python/mypy/issues/4266.
|
|
__hash__ = None # type: ignore
|
|
|
|
@overload
|
|
def __getitem__(self, key: int) -> str:
|
|
raise NotImplementedError()
|
|
|
|
@overload # noqa: F811
|
|
def __getitem__(self, key: slice) -> "Source": # noqa: F811
|
|
raise NotImplementedError()
|
|
|
|
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
|
|
if isinstance(key, int):
|
|
return self.lines[key]
|
|
else:
|
|
if key.step not in (None, 1):
|
|
raise IndexError("cannot slice a Source with a step")
|
|
newsource = Source()
|
|
newsource.lines = self.lines[key.start : key.stop]
|
|
return newsource
|
|
|
|
def __iter__(self) -> Iterator[str]:
|
|
return iter(self.lines)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self.lines)
|
|
|
|
def strip(self) -> "Source":
|
|
""" return new source object with trailing
|
|
and leading blank lines removed.
|
|
"""
|
|
start, end = 0, len(self)
|
|
while start < end and not self.lines[start].strip():
|
|
start += 1
|
|
while end > start and not self.lines[end - 1].strip():
|
|
end -= 1
|
|
source = Source()
|
|
source.lines[:] = self.lines[start:end]
|
|
return source
|
|
|
|
def putaround(
|
|
self, before: str = "", after: str = "", indent: str = " " * 4
|
|
) -> "Source":
|
|
""" return a copy of the source object with
|
|
'before' and 'after' wrapped around it.
|
|
"""
|
|
beforesource = Source(before)
|
|
aftersource = Source(after)
|
|
newsource = Source()
|
|
lines = [(indent + line) for line in self.lines]
|
|
newsource.lines = beforesource.lines + lines + aftersource.lines
|
|
return newsource
|
|
|
|
def indent(self, indent: str = " " * 4) -> "Source":
|
|
""" return a copy of the source object with
|
|
all lines indented by the given indent-string.
|
|
"""
|
|
newsource = Source()
|
|
newsource.lines = [(indent + line) for line in self.lines]
|
|
return newsource
|
|
|
|
def getstatement(self, lineno: int) -> "Source":
|
|
""" return Source statement which contains the
|
|
given linenumber (counted from 0).
|
|
"""
|
|
start, end = self.getstatementrange(lineno)
|
|
return self[start:end]
|
|
|
|
def getstatementrange(self, lineno: int) -> Tuple[int, int]:
|
|
""" return (start, end) tuple which spans the minimal
|
|
statement region which containing the given lineno.
|
|
"""
|
|
if not (0 <= lineno < len(self)):
|
|
raise IndexError("lineno out of range")
|
|
ast, start, end = getstatementrange_ast(lineno, self)
|
|
return start, end
|
|
|
|
def deindent(self) -> "Source":
|
|
"""return a new source object deindented."""
|
|
newsource = Source()
|
|
newsource.lines[:] = deindent(self.lines)
|
|
return newsource
|
|
|
|
def isparseable(self, deindent: bool = True) -> bool:
|
|
""" return True if source is parseable, heuristically
|
|
deindenting it by default.
|
|
"""
|
|
if deindent:
|
|
source = str(self.deindent())
|
|
else:
|
|
source = str(self)
|
|
try:
|
|
ast.parse(source)
|
|
except (SyntaxError, ValueError, TypeError):
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def __str__(self) -> str:
|
|
return "\n".join(self.lines)
|
|
|
|
@overload
|
|
def compile(
|
|
self,
|
|
filename: Optional[str] = ...,
|
|
mode: str = ...,
|
|
flag: "Literal[0]" = ...,
|
|
dont_inherit: int = ...,
|
|
_genframe: Optional[FrameType] = ...,
|
|
) -> CodeType:
|
|
raise NotImplementedError()
|
|
|
|
@overload # noqa: F811
|
|
def compile( # noqa: F811
|
|
self,
|
|
filename: Optional[str] = ...,
|
|
mode: str = ...,
|
|
flag: int = ...,
|
|
dont_inherit: int = ...,
|
|
_genframe: Optional[FrameType] = ...,
|
|
) -> Union[CodeType, ast.AST]:
|
|
raise NotImplementedError()
|
|
|
|
def compile( # noqa: F811
|
|
self,
|
|
filename: Optional[str] = None,
|
|
mode: str = "exec",
|
|
flag: int = 0,
|
|
dont_inherit: int = 0,
|
|
_genframe: Optional[FrameType] = None,
|
|
) -> Union[CodeType, ast.AST]:
|
|
""" return compiled code object. if filename is None
|
|
invent an artificial filename which displays
|
|
the source/line position of the caller frame.
|
|
"""
|
|
if not filename or py.path.local(filename).check(file=0):
|
|
if _genframe is None:
|
|
_genframe = sys._getframe(1) # the caller
|
|
fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
|
|
base = "<%d-codegen " % self._compilecounter
|
|
self.__class__._compilecounter += 1
|
|
if not filename:
|
|
filename = base + "%s:%d>" % (fn, lineno)
|
|
else:
|
|
filename = base + "%r %s:%d>" % (filename, fn, lineno)
|
|
source = "\n".join(self.lines) + "\n"
|
|
try:
|
|
co = compile(source, filename, mode, flag)
|
|
except SyntaxError as ex:
|
|
# re-represent syntax errors from parsing python strings
|
|
msglines = self.lines[: ex.lineno]
|
|
if ex.offset:
|
|
msglines.append(" " * ex.offset + "^")
|
|
msglines.append("(code was compiled probably from here: %s)" % filename)
|
|
newex = SyntaxError("\n".join(msglines))
|
|
newex.offset = ex.offset
|
|
newex.lineno = ex.lineno
|
|
newex.text = ex.text
|
|
raise newex from ex
|
|
else:
|
|
if flag & ast.PyCF_ONLY_AST:
|
|
assert isinstance(co, ast.AST)
|
|
return co
|
|
assert isinstance(co, CodeType)
|
|
lines = [(x + "\n") for x in self.lines]
|
|
# Type ignored because linecache.cache is private.
|
|
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
|
|
return co
|
|
|
|
|
|
#
|
|
# public API shortcut functions
|
|
#
|
|
|
|
|
|
@overload
|
|
def compile_(
|
|
source: Union[str, bytes, ast.mod, ast.AST],
|
|
filename: Optional[str] = ...,
|
|
mode: str = ...,
|
|
flags: "Literal[0]" = ...,
|
|
dont_inherit: int = ...,
|
|
) -> CodeType:
|
|
raise NotImplementedError()
|
|
|
|
|
|
@overload # noqa: F811
|
|
def compile_( # noqa: F811
|
|
source: Union[str, bytes, ast.mod, ast.AST],
|
|
filename: Optional[str] = ...,
|
|
mode: str = ...,
|
|
flags: int = ...,
|
|
dont_inherit: int = ...,
|
|
) -> Union[CodeType, ast.AST]:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def compile_( # noqa: F811
|
|
source: Union[str, bytes, ast.mod, ast.AST],
|
|
filename: Optional[str] = None,
|
|
mode: str = "exec",
|
|
flags: int = 0,
|
|
dont_inherit: int = 0,
|
|
) -> Union[CodeType, ast.AST]:
|
|
""" compile the given source to a raw code object,
|
|
and maintain an internal cache which allows later
|
|
retrieval of the source code for the code object
|
|
and any recursively created code objects.
|
|
"""
|
|
if isinstance(source, ast.AST):
|
|
# XXX should Source support having AST?
|
|
assert filename is not None
|
|
co = compile(source, filename, mode, flags, dont_inherit)
|
|
assert isinstance(co, (CodeType, ast.AST))
|
|
return co
|
|
_genframe = sys._getframe(1) # the caller
|
|
s = Source(source)
|
|
return s.compile(filename, mode, flags, _genframe=_genframe)
|
|
|
|
|
|
#
|
|
# helper functions
|
|
#
|
|
|
|
|
|
def findsource(obj) -> Tuple[Optional[Source], int]:
|
|
try:
|
|
sourcelines, lineno = inspect.findsource(obj)
|
|
except Exception:
|
|
return None, -1
|
|
source = Source()
|
|
source.lines = [line.rstrip() for line in sourcelines]
|
|
return source, lineno
|
|
|
|
|
|
def getrawcode(obj, trycall: bool = True):
|
|
""" return code object for given function. """
|
|
try:
|
|
return obj.__code__
|
|
except AttributeError:
|
|
obj = getattr(obj, "f_code", obj)
|
|
obj = getattr(obj, "__code__", obj)
|
|
if trycall and not hasattr(obj, "co_firstlineno"):
|
|
if hasattr(obj, "__call__") and not inspect.isclass(obj):
|
|
x = getrawcode(obj.__call__, trycall=False)
|
|
if hasattr(x, "co_firstlineno"):
|
|
return x
|
|
return obj
|
|
|
|
|
|
def getsource(obj, **kwargs) -> Source:
|
|
obj = getrawcode(obj)
|
|
try:
|
|
strsrc = inspect.getsource(obj)
|
|
except IndentationError:
|
|
strsrc = '"Buggy python version consider upgrading, cannot get source"'
|
|
assert isinstance(strsrc, str)
|
|
return Source(strsrc, **kwargs)
|
|
|
|
|
|
def deindent(lines: Sequence[str]) -> List[str]:
|
|
return textwrap.dedent("\n".join(lines)).splitlines()
|
|
|
|
|
|
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
|
|
# flatten all statements and except handlers into one lineno-list
|
|
# AST's line numbers start indexing at 1
|
|
values = [] # type: List[int]
|
|
for x in ast.walk(node):
|
|
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
|
|
values.append(x.lineno - 1)
|
|
for name in ("finalbody", "orelse"):
|
|
val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
|
|
if val:
|
|
# treat the finally/orelse part as its own statement
|
|
values.append(val[0].lineno - 1 - 1)
|
|
values.sort()
|
|
insert_index = bisect_right(values, lineno)
|
|
start = values[insert_index - 1]
|
|
if insert_index >= len(values):
|
|
end = None
|
|
else:
|
|
end = values[insert_index]
|
|
return start, end
|
|
|
|
|
|
def getstatementrange_ast(
|
|
lineno: int,
|
|
source: Source,
|
|
assertion: bool = False,
|
|
astnode: Optional[ast.AST] = None,
|
|
) -> Tuple[ast.AST, int, int]:
|
|
if astnode is None:
|
|
content = str(source)
|
|
# See #4260:
|
|
# don't produce duplicate warnings when compiling source to find ast
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
astnode = ast.parse(content, "source", "exec")
|
|
|
|
start, end = get_statement_startend2(lineno, astnode)
|
|
# we need to correct the end:
|
|
# - ast-parsing strips comments
|
|
# - there might be empty lines
|
|
# - we might have lesser indented code blocks at the end
|
|
if end is None:
|
|
end = len(source.lines)
|
|
|
|
if end > start + 1:
|
|
# make sure we don't span differently indented code blocks
|
|
# by using the BlockFinder helper used which inspect.getsource() uses itself
|
|
block_finder = inspect.BlockFinder()
|
|
# if we start with an indented line, put blockfinder to "started" mode
|
|
block_finder.started = source.lines[start][0].isspace()
|
|
it = ((x + "\n") for x in source.lines[start:end])
|
|
try:
|
|
for tok in tokenize.generate_tokens(lambda: next(it)):
|
|
block_finder.tokeneater(*tok)
|
|
except (inspect.EndOfBlock, IndentationError):
|
|
end = block_finder.last + start
|
|
except Exception:
|
|
pass
|
|
|
|
# the end might still point to a comment or empty line, correct it
|
|
while end:
|
|
line = source.lines[end - 1].lstrip()
|
|
if line.startswith("#") or not line:
|
|
end -= 1
|
|
else:
|
|
break
|
|
return astnode, start, end
|