Add type annotations to _pytest._code.code

This commit is contained in:
Ran Benita
2019-11-15 23:02:55 +02:00
parent 562d4811d5
commit eaa34a9df0
4 changed files with 299 additions and 214 deletions

View File

@@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
from io import StringIO
from traceback import format_exception_only
from types import CodeType
from types import FrameType
from types import TracebackType
from typing import Any
from typing import Callable
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import List
from typing import Optional
from typing import Pattern
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import TypeVar
@@ -27,9 +31,16 @@ import py
import _pytest
from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr
from _pytest.compat import overload
if False: # TYPE_CHECKING
from typing import Type
from typing_extensions import Literal
from weakref import ReferenceType # noqa: F401
from _pytest._code import Source
_TracebackStyle = Literal["long", "short", "no", "native"]
class Code:
@@ -38,13 +49,12 @@ class Code:
def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode)
try:
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
except AttributeError:
if not isinstance(rawcode, CodeType):
raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode # type: CodeType
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
self.raw = rawcode
def __eq__(self, other):
return self.raw == other.raw
@@ -72,7 +82,7 @@ class Code:
return p
@property
def fullsource(self):
def fullsource(self) -> Optional["Source"]:
""" return a _pytest._code.Source object for the full source file of the code
"""
from _pytest._code import source
@@ -80,7 +90,7 @@ class Code:
full, _ = source.findsource(self.raw)
return full
def source(self):
def source(self) -> "Source":
""" return a _pytest._code.Source object for the code object's source only
"""
# return source only for that part of code
@@ -88,7 +98,7 @@ class Code:
return _pytest._code.Source(self.raw)
def getargs(self, var=False):
def getargs(self, var: bool = False) -> Tuple[str, ...]:
""" return a tuple with the argument names for the code object
if 'var' is set True also return the names of the variable and
@@ -107,7 +117,7 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated."""
def __init__(self, frame):
def __init__(self, frame: FrameType) -> None:
self.lineno = frame.f_lineno - 1
self.f_globals = frame.f_globals
self.f_locals = frame.f_locals
@@ -115,7 +125,7 @@ class Frame:
self.code = Code(frame.f_code)
@property
def statement(self):
def statement(self) -> "Source":
""" statement this frame is at """
import _pytest._code
@@ -134,7 +144,7 @@ class Frame:
f_locals.update(vars)
return eval(code, self.f_globals, f_locals)
def exec_(self, code, **vars):
def exec_(self, code, **vars) -> None:
""" exec 'code' in the frame
'vars' are optional; additional local variables
@@ -143,7 +153,7 @@ class Frame:
f_locals.update(vars)
exec(code, self.f_globals, f_locals)
def repr(self, object):
def repr(self, object: object) -> str:
""" return a 'safe' (non-recursive, one-line) string repr for 'object'
"""
return saferepr(object)
@@ -151,7 +161,7 @@ class Frame:
def is_true(self, object):
return object
def getargs(self, var=False):
def getargs(self, var: bool = False):
""" return a list of tuples (name, value) for all arguments
if 'var' is set True also include the variable and keyword
@@ -169,35 +179,34 @@ class Frame:
class TracebackEntry:
""" a single entry in a traceback """
_repr_style = None
_repr_style = None # type: Optional[Literal["short", "long"]]
exprinfo = None
def __init__(self, rawentry, excinfo=None):
def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
self._excinfo = excinfo
self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1
def set_repr_style(self, mode):
def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long")
self._repr_style = mode
@property
def frame(self):
import _pytest._code
return _pytest._code.Frame(self._rawentry.tb_frame)
def frame(self) -> Frame:
return Frame(self._rawentry.tb_frame)
@property
def relline(self):
def relline(self) -> int:
return self.lineno - self.frame.code.firstlineno
def __repr__(self):
def __repr__(self) -> str:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property
def statement(self):
def statement(self) -> "Source":
""" _pytest._code.Source object for the current statement """
source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno)
@property
@@ -206,14 +215,14 @@ class TracebackEntry:
return self.frame.code.path
@property
def locals(self):
def locals(self) -> Dict[str, Any]:
""" locals of underlying frame """
return self.frame.f_locals
def getfirstlinesource(self):
def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno
def getsource(self, astcache=None):
def getsource(self, astcache=None) -> Optional["Source"]:
""" return failing source code. """
# we use the passed in astcache to not reparse asttrees
# within exception info printing
@@ -258,7 +267,7 @@ class TracebackEntry:
return tbh(None if self._excinfo is None else self._excinfo())
return tbh
def __str__(self):
def __str__(self) -> str:
try:
fn = str(self.path)
except py.error.Error:
@@ -273,31 +282,42 @@ class TracebackEntry:
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
@property
def name(self):
def name(self) -> str:
""" co_name of underlying code """
return self.frame.code.raw.co_name
class Traceback(list):
class Traceback(List[TracebackEntry]):
""" Traceback objects encapsulate and offer higher level
access to Traceback entries.
"""
def __init__(self, tb, excinfo=None):
def __init__(
self,
tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
) -> None:
""" initialize from given python traceback object and ExceptionInfo """
self._excinfo = excinfo
if hasattr(tb, "tb_next"):
if isinstance(tb, TracebackType):
def f(cur):
while cur is not None:
yield TracebackEntry(cur, excinfo=excinfo)
cur = cur.tb_next
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
cur_ = cur # type: Optional[TracebackType]
while cur_ is not None:
yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
list.__init__(self, f(tb))
super().__init__(f(tb))
else:
list.__init__(self, tb)
super().__init__(tb)
def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None):
def cut(
self,
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
excludepath=None,
) -> "Traceback":
""" return a Traceback instance wrapping part of this Traceback
by providing any combination of path, lineno and firstlineno, the
@@ -323,13 +343,25 @@ class Traceback(list):
return Traceback(x._rawentry, self._excinfo)
return self
def __getitem__(self, key):
val = super().__getitem__(key)
if isinstance(key, type(slice(0))):
val = self.__class__(val)
return val
@overload
def __getitem__(self, key: int) -> TracebackEntry:
raise NotImplementedError()
def filter(self, fn=lambda x: not x.ishidden()):
@overload # noqa: F811
def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
raise NotImplementedError()
def __getitem__( # noqa: F811
self, key: Union[int, slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
""" return a Traceback instance with certain items removed
fn is a function that gets a single argument, a TracebackEntry
@@ -341,7 +373,7 @@ class Traceback(list):
"""
return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self):
def getcrashentry(self) -> TracebackEntry:
""" return last non-hidden traceback entry that lead
to the exception of a traceback.
"""
@@ -351,7 +383,7 @@ class Traceback(list):
return entry
return self[-1]
def recursionindex(self):
def recursionindex(self) -> Optional[int]:
""" return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred
"""
@@ -541,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
def getrepr(
self,
showlocals: bool = False,
style: str = "long",
style: "_TracebackStyle" = "long",
abspath: bool = False,
tbfilter: bool = True,
funcargs: bool = False,
@@ -619,16 +651,16 @@ class FormattedExcinfo:
flow_marker = ">"
fail_marker = "E"
showlocals = attr.ib(default=False)
style = attr.ib(default="long")
abspath = attr.ib(default=True)
tbfilter = attr.ib(default=True)
funcargs = attr.ib(default=False)
truncate_locals = attr.ib(default=True)
chain = attr.ib(default=True)
showlocals = attr.ib(type=bool, default=False)
style = attr.ib(type="_TracebackStyle", default="long")
abspath = attr.ib(type=bool, default=True)
tbfilter = attr.ib(type=bool, default=True)
funcargs = attr.ib(type=bool, default=False)
truncate_locals = attr.ib(type=bool, default=True)
chain = attr.ib(type=bool, default=True)
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source):
def _getindent(self, source: "Source") -> int:
# figure out indent for given source
try:
s = str(source.getstatement(len(source) - 1))
@@ -643,20 +675,27 @@ class FormattedExcinfo:
return 0
return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry):
def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
source = entry.getsource(self.astcache)
if source is not None:
source = source.deindent()
return source
def repr_args(self, entry):
def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
if self.funcargs:
args = []
for argname, argvalue in entry.frame.getargs(var=True):
args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args)
return None
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
def get_source(
self,
source: "Source",
line_index: int = -1,
excinfo: Optional[ExceptionInfo] = None,
short: bool = False,
) -> List[str]:
""" return formatted and marked up source lines. """
import _pytest._code
@@ -680,19 +719,21 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines
def get_exconly(self, excinfo, indent=4, markall=False):
def get_exconly(
self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
) -> List[str]:
lines = []
indent = " " * indent
indentstr = " " * indent
# get the real exception information out
exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indent[1:]
failindent = self.fail_marker + indentstr[1:]
for line in exlines:
lines.append(failindent + line)
if not markall:
failindent = indent
failindent = indentstr
return lines
def repr_locals(self, locals):
def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
if self.showlocals:
lines = []
keys = [loc for loc in locals if loc[0] != "@"]
@@ -717,8 +758,11 @@ class FormattedExcinfo:
# # XXX
# pprint.pprint(value, stream=self.excinfowriter)
return ReprLocals(lines)
return None
def repr_traceback_entry(self, entry, excinfo=None):
def repr_traceback_entry(
self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
) -> "ReprEntry":
import _pytest._code
source = self._getentrysource(entry)
@@ -729,9 +773,7 @@ class FormattedExcinfo:
line_index = entry.lineno - entry.getfirstlinesource()
lines = [] # type: List[str]
style = entry._repr_style
if style is None:
style = self.style
style = entry._repr_style if entry._repr_style is not None else self.style
if style in ("short", "long"):
short = style == "short"
reprargs = self.repr_args(entry) if not short else None
@@ -761,7 +803,7 @@ class FormattedExcinfo:
path = np
return path
def repr_traceback(self, excinfo):
def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
traceback = excinfo.traceback
if self.tbfilter:
traceback = traceback.filter()
@@ -779,7 +821,9 @@ class FormattedExcinfo:
entries.append(reprentry)
return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(self, traceback):
def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
"""
Truncate the given recursive traceback trying to find the starting point
of the recursion.
@@ -806,7 +850,9 @@ class FormattedExcinfo:
max_frames=max_frames,
total=len(traceback),
) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:]
# Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
else:
if recursionindex is not None:
extraline = "!!! Recursion detected (same locals & position)"
@@ -863,7 +909,7 @@ class FormattedExcinfo:
class TerminalRepr:
def __str__(self):
def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception
# information.
io = StringIO()
@@ -871,7 +917,7 @@ class TerminalRepr:
self.toterminal(tw)
return io.getvalue().strip()
def __repr__(self):
def __repr__(self) -> str:
return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None:
@@ -882,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"):
def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep))
def toterminal(self, tw) -> None:
@@ -892,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
class ExceptionChainRepr(ExceptionRepr):
def __init__(self, chain):
def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
super().__init__()
self.chain = chain
# reprcrash and reprtraceback of the outermost (the newest) exception
@@ -910,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
class ReprExceptionInfo(ExceptionRepr):
def __init__(self, reprtraceback, reprcrash):
def __init__(
self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
) -> None:
super().__init__()
self.reprtraceback = reprtraceback
self.reprcrash = reprcrash
@@ -923,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr):
entrysep = "_ "
def __init__(self, reprentries, extraline, style):
def __init__(
self,
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
extraline: Optional[str],
style: "_TracebackStyle",
) -> None:
self.reprentries = reprentries
self.extraline = extraline
self.style = style
@@ -948,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native"
self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None
class ReprEntryNative(TerminalRepr):
style = "native"
style = "native" # type: _TracebackStyle
def __init__(self, tblines):
def __init__(self, tblines: Sequence[str]) -> None:
self.lines = tblines
def toterminal(self, tw) -> None:
@@ -965,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr):
def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style):
def __init__(
self,
lines: Sequence[str],
reprfuncargs: Optional["ReprFuncArgs"],
reprlocals: Optional["ReprLocals"],
filelocrepr: Optional["ReprFileLocation"],
style: "_TracebackStyle",
) -> None:
self.lines = lines
self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals
@@ -974,6 +1039,7 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw) -> None:
if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw)
for line in self.lines:
red = line.startswith("E ")
@@ -992,14 +1058,14 @@ class ReprEntry(TerminalRepr):
tw.line("")
self.reprfileloc.toterminal(tw)
def __str__(self):
def __str__(self) -> str:
return "{}\n{}\n{}".format(
"\n".join(self.lines), self.reprlocals, self.reprfileloc
)
class ReprFileLocation(TerminalRepr):
def __init__(self, path, lineno, message):
def __init__(self, path, lineno: int, message: str) -> None:
self.path = str(path)
self.lineno = lineno
self.message = message
@@ -1016,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
class ReprLocals(TerminalRepr):
def __init__(self, lines):
def __init__(self, lines: Sequence[str]) -> None:
self.lines = lines
def toterminal(self, tw) -> None:
@@ -1025,7 +1091,7 @@ class ReprLocals(TerminalRepr):
class ReprFuncArgs(TerminalRepr):
def __init__(self, args):
def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
self.args = args
def toterminal(self, tw) -> None:
@@ -1047,7 +1113,7 @@ class ReprFuncArgs(TerminalRepr):
tw.line("")
def getrawcode(obj, trycall=True):
def getrawcode(obj, trycall: bool = True):
""" return code object for given function. """
try:
return obj.__code__
@@ -1075,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
_PY_DIR = py.path.local(py.__file__).dirpath()
def filter_traceback(entry):
def filter_traceback(entry: TracebackEntry) -> bool:
"""Return True if a TracebackEntry instance should be removed from tracebacks:
* dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy.