Merge pull request #6205 from bluetech/type-annotations-8

Add type annotations to _pytest.compat and _pytest._code.code
This commit is contained in:
Ran Benita 2019-11-17 09:45:32 +02:00 committed by GitHub
commit fa578d7329
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 380 additions and 325 deletions

View File

@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
from io import StringIO from io import StringIO
from traceback import format_exception_only from traceback import format_exception_only
from types import CodeType from types import CodeType
from types import FrameType
from types import TracebackType from types import TracebackType
from typing import Any from typing import Any
from typing import Callable
from typing import Dict from typing import Dict
from typing import Generic from typing import Generic
from typing import Iterable
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Pattern from typing import Pattern
from typing import Sequence
from typing import Set from typing import Set
from typing import Tuple from typing import Tuple
from typing import TypeVar from typing import TypeVar
@ -27,9 +31,16 @@ import py
import _pytest import _pytest
from _pytest._io.saferepr import safeformat from _pytest._io.saferepr import safeformat
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest.compat import overload
if False: # TYPE_CHECKING if False: # TYPE_CHECKING
from typing import Type from typing import Type
from typing_extensions import Literal
from weakref import ReferenceType # noqa: F401
from _pytest._code import Source
_TracebackStyle = Literal["long", "short", "no", "native"]
class Code: class Code:
@ -38,13 +49,12 @@ class Code:
def __init__(self, rawcode) -> None: def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"): if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode) rawcode = getrawcode(rawcode)
try: if not isinstance(rawcode, CodeType):
self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
except AttributeError:
raise TypeError("not a code object: {!r}".format(rawcode)) raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode # type: CodeType self.filename = rawcode.co_filename
self.firstlineno = rawcode.co_firstlineno - 1
self.name = rawcode.co_name
self.raw = rawcode
def __eq__(self, other): def __eq__(self, other):
return self.raw == other.raw return self.raw == other.raw
@ -72,7 +82,7 @@ class Code:
return p return p
@property @property
def fullsource(self): def fullsource(self) -> Optional["Source"]:
""" return a _pytest._code.Source object for the full source file of the code """ return a _pytest._code.Source object for the full source file of the code
""" """
from _pytest._code import source from _pytest._code import source
@ -80,7 +90,7 @@ class Code:
full, _ = source.findsource(self.raw) full, _ = source.findsource(self.raw)
return full return full
def source(self): def source(self) -> "Source":
""" return a _pytest._code.Source object for the code object's source only """ return a _pytest._code.Source object for the code object's source only
""" """
# return source only for that part of code # return source only for that part of code
@ -88,7 +98,7 @@ class Code:
return _pytest._code.Source(self.raw) return _pytest._code.Source(self.raw)
def getargs(self, var=False): def getargs(self, var: bool = False) -> Tuple[str, ...]:
""" return a tuple with the argument names for the code object """ return a tuple with the argument names for the code object
if 'var' is set True also return the names of the variable and if 'var' is set True also return the names of the variable and
@ -107,7 +117,7 @@ class Frame:
"""Wrapper around a Python frame holding f_locals and f_globals """Wrapper around a Python frame holding f_locals and f_globals
in which expressions can be evaluated.""" in which expressions can be evaluated."""
def __init__(self, frame): def __init__(self, frame: FrameType) -> None:
self.lineno = frame.f_lineno - 1 self.lineno = frame.f_lineno - 1
self.f_globals = frame.f_globals self.f_globals = frame.f_globals
self.f_locals = frame.f_locals self.f_locals = frame.f_locals
@ -115,7 +125,7 @@ class Frame:
self.code = Code(frame.f_code) self.code = Code(frame.f_code)
@property @property
def statement(self): def statement(self) -> "Source":
""" statement this frame is at """ """ statement this frame is at """
import _pytest._code import _pytest._code
@ -134,7 +144,7 @@ class Frame:
f_locals.update(vars) f_locals.update(vars)
return eval(code, self.f_globals, f_locals) return eval(code, self.f_globals, f_locals)
def exec_(self, code, **vars): def exec_(self, code, **vars) -> None:
""" exec 'code' in the frame """ exec 'code' in the frame
'vars' are optional; additional local variables 'vars' are optional; additional local variables
@ -143,7 +153,7 @@ class Frame:
f_locals.update(vars) f_locals.update(vars)
exec(code, self.f_globals, f_locals) exec(code, self.f_globals, f_locals)
def repr(self, object): def repr(self, object: object) -> str:
""" return a 'safe' (non-recursive, one-line) string repr for 'object' """ return a 'safe' (non-recursive, one-line) string repr for 'object'
""" """
return saferepr(object) return saferepr(object)
@ -151,7 +161,7 @@ class Frame:
def is_true(self, object): def is_true(self, object):
return object return object
def getargs(self, var=False): def getargs(self, var: bool = False):
""" return a list of tuples (name, value) for all arguments """ return a list of tuples (name, value) for all arguments
if 'var' is set True also include the variable and keyword if 'var' is set True also include the variable and keyword
@ -169,35 +179,34 @@ class Frame:
class TracebackEntry: class TracebackEntry:
""" a single entry in a traceback """ """ a single entry in a traceback """
_repr_style = None _repr_style = None # type: Optional[Literal["short", "long"]]
exprinfo = None exprinfo = None
def __init__(self, rawentry, excinfo=None): def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
self._excinfo = excinfo self._excinfo = excinfo
self._rawentry = rawentry self._rawentry = rawentry
self.lineno = rawentry.tb_lineno - 1 self.lineno = rawentry.tb_lineno - 1
def set_repr_style(self, mode): def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
assert mode in ("short", "long") assert mode in ("short", "long")
self._repr_style = mode self._repr_style = mode
@property @property
def frame(self): def frame(self) -> Frame:
import _pytest._code return Frame(self._rawentry.tb_frame)
return _pytest._code.Frame(self._rawentry.tb_frame)
@property @property
def relline(self): def relline(self) -> int:
return self.lineno - self.frame.code.firstlineno return self.lineno - self.frame.code.firstlineno
def __repr__(self): def __repr__(self) -> str:
return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1) return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
@property @property
def statement(self): def statement(self) -> "Source":
""" _pytest._code.Source object for the current statement """ """ _pytest._code.Source object for the current statement """
source = self.frame.code.fullsource source = self.frame.code.fullsource
assert source is not None
return source.getstatement(self.lineno) return source.getstatement(self.lineno)
@property @property
@ -206,14 +215,14 @@ class TracebackEntry:
return self.frame.code.path return self.frame.code.path
@property @property
def locals(self): def locals(self) -> Dict[str, Any]:
""" locals of underlying frame """ """ locals of underlying frame """
return self.frame.f_locals return self.frame.f_locals
def getfirstlinesource(self): def getfirstlinesource(self) -> int:
return self.frame.code.firstlineno return self.frame.code.firstlineno
def getsource(self, astcache=None): def getsource(self, astcache=None) -> Optional["Source"]:
""" return failing source code. """ """ return failing source code. """
# we use the passed in astcache to not reparse asttrees # we use the passed in astcache to not reparse asttrees
# within exception info printing # within exception info printing
@ -258,7 +267,7 @@ class TracebackEntry:
return tbh(None if self._excinfo is None else self._excinfo()) return tbh(None if self._excinfo is None else self._excinfo())
return tbh return tbh
def __str__(self): def __str__(self) -> str:
try: try:
fn = str(self.path) fn = str(self.path)
except py.error.Error: except py.error.Error:
@ -273,33 +282,42 @@ class TracebackEntry:
return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line) return " File %r:%d in %s\n %s\n" % (fn, self.lineno + 1, name, line)
@property @property
def name(self): def name(self) -> str:
""" co_name of underlying code """ """ co_name of underlying code """
return self.frame.code.raw.co_name return self.frame.code.raw.co_name
class Traceback(list): class Traceback(List[TracebackEntry]):
""" Traceback objects encapsulate and offer higher level """ Traceback objects encapsulate and offer higher level
access to Traceback entries. access to Traceback entries.
""" """
Entry = TracebackEntry def __init__(
self,
def __init__(self, tb, excinfo=None): tb: Union[TracebackType, Iterable[TracebackEntry]],
excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
) -> None:
""" initialize from given python traceback object and ExceptionInfo """ """ initialize from given python traceback object and ExceptionInfo """
self._excinfo = excinfo self._excinfo = excinfo
if hasattr(tb, "tb_next"): if isinstance(tb, TracebackType):
def f(cur): def f(cur: TracebackType) -> Iterable[TracebackEntry]:
while cur is not None: cur_ = cur # type: Optional[TracebackType]
yield self.Entry(cur, excinfo=excinfo) while cur_ is not None:
cur = cur.tb_next yield TracebackEntry(cur_, excinfo=excinfo)
cur_ = cur_.tb_next
list.__init__(self, f(tb)) super().__init__(f(tb))
else: else:
list.__init__(self, tb) super().__init__(tb)
def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None): def cut(
self,
path=None,
lineno: Optional[int] = None,
firstlineno: Optional[int] = None,
excludepath=None,
) -> "Traceback":
""" return a Traceback instance wrapping part of this Traceback """ return a Traceback instance wrapping part of this Traceback
by providing any combination of path, lineno and firstlineno, the by providing any combination of path, lineno and firstlineno, the
@ -325,13 +343,25 @@ class Traceback(list):
return Traceback(x._rawentry, self._excinfo) return Traceback(x._rawentry, self._excinfo)
return self return self
def __getitem__(self, key): @overload
val = super().__getitem__(key) def __getitem__(self, key: int) -> TracebackEntry:
if isinstance(key, type(slice(0))): raise NotImplementedError()
val = self.__class__(val)
return val
def filter(self, fn=lambda x: not x.ishidden()): @overload # noqa: F811
def __getitem__(self, key: slice) -> "Traceback": # noqa: F811
raise NotImplementedError()
def __getitem__( # noqa: F811
self, key: Union[int, slice]
) -> Union[TracebackEntry, "Traceback"]:
if isinstance(key, slice):
return self.__class__(super().__getitem__(key))
else:
return super().__getitem__(key)
def filter(
self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
) -> "Traceback":
""" return a Traceback instance with certain items removed """ return a Traceback instance with certain items removed
fn is a function that gets a single argument, a TracebackEntry fn is a function that gets a single argument, a TracebackEntry
@ -343,7 +373,7 @@ class Traceback(list):
""" """
return Traceback(filter(fn, self), self._excinfo) return Traceback(filter(fn, self), self._excinfo)
def getcrashentry(self): def getcrashentry(self) -> TracebackEntry:
""" return last non-hidden traceback entry that lead """ return last non-hidden traceback entry that lead
to the exception of a traceback. to the exception of a traceback.
""" """
@ -353,7 +383,7 @@ class Traceback(list):
return entry return entry
return self[-1] return self[-1]
def recursionindex(self): def recursionindex(self) -> Optional[int]:
""" return the index of the frame/TracebackEntry where recursion """ return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred originates if appropriate, None if no recursion occurred
""" """
@ -543,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
def getrepr( def getrepr(
self, self,
showlocals: bool = False, showlocals: bool = False,
style: str = "long", style: "_TracebackStyle" = "long",
abspath: bool = False, abspath: bool = False,
tbfilter: bool = True, tbfilter: bool = True,
funcargs: bool = False, funcargs: bool = False,
@ -621,16 +651,16 @@ class FormattedExcinfo:
flow_marker = ">" flow_marker = ">"
fail_marker = "E" fail_marker = "E"
showlocals = attr.ib(default=False) showlocals = attr.ib(type=bool, default=False)
style = attr.ib(default="long") style = attr.ib(type="_TracebackStyle", default="long")
abspath = attr.ib(default=True) abspath = attr.ib(type=bool, default=True)
tbfilter = attr.ib(default=True) tbfilter = attr.ib(type=bool, default=True)
funcargs = attr.ib(default=False) funcargs = attr.ib(type=bool, default=False)
truncate_locals = attr.ib(default=True) truncate_locals = attr.ib(type=bool, default=True)
chain = attr.ib(default=True) chain = attr.ib(type=bool, default=True)
astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False) astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
def _getindent(self, source): def _getindent(self, source: "Source") -> int:
# figure out indent for given source # figure out indent for given source
try: try:
s = str(source.getstatement(len(source) - 1)) s = str(source.getstatement(len(source) - 1))
@ -645,20 +675,27 @@ class FormattedExcinfo:
return 0 return 0
return 4 + (len(s) - len(s.lstrip())) return 4 + (len(s) - len(s.lstrip()))
def _getentrysource(self, entry): def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
source = entry.getsource(self.astcache) source = entry.getsource(self.astcache)
if source is not None: if source is not None:
source = source.deindent() source = source.deindent()
return source return source
def repr_args(self, entry): def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
if self.funcargs: if self.funcargs:
args = [] args = []
for argname, argvalue in entry.frame.getargs(var=True): for argname, argvalue in entry.frame.getargs(var=True):
args.append((argname, saferepr(argvalue))) args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args) return ReprFuncArgs(args)
return None
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]: def get_source(
self,
source: "Source",
line_index: int = -1,
excinfo: Optional[ExceptionInfo] = None,
short: bool = False,
) -> List[str]:
""" return formatted and marked up source lines. """ """ return formatted and marked up source lines. """
import _pytest._code import _pytest._code
@ -682,19 +719,21 @@ class FormattedExcinfo:
lines.extend(self.get_exconly(excinfo, indent=indent, markall=True)) lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
return lines return lines
def get_exconly(self, excinfo, indent=4, markall=False): def get_exconly(
self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
) -> List[str]:
lines = [] lines = []
indent = " " * indent indentstr = " " * indent
# get the real exception information out # get the real exception information out
exlines = excinfo.exconly(tryshort=True).split("\n") exlines = excinfo.exconly(tryshort=True).split("\n")
failindent = self.fail_marker + indent[1:] failindent = self.fail_marker + indentstr[1:]
for line in exlines: for line in exlines:
lines.append(failindent + line) lines.append(failindent + line)
if not markall: if not markall:
failindent = indent failindent = indentstr
return lines return lines
def repr_locals(self, locals): def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
if self.showlocals: if self.showlocals:
lines = [] lines = []
keys = [loc for loc in locals if loc[0] != "@"] keys = [loc for loc in locals if loc[0] != "@"]
@ -719,8 +758,11 @@ class FormattedExcinfo:
# # XXX # # XXX
# pprint.pprint(value, stream=self.excinfowriter) # pprint.pprint(value, stream=self.excinfowriter)
return ReprLocals(lines) return ReprLocals(lines)
return None
def repr_traceback_entry(self, entry, excinfo=None): def repr_traceback_entry(
self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
) -> "ReprEntry":
import _pytest._code import _pytest._code
source = self._getentrysource(entry) source = self._getentrysource(entry)
@ -731,9 +773,7 @@ class FormattedExcinfo:
line_index = entry.lineno - entry.getfirstlinesource() line_index = entry.lineno - entry.getfirstlinesource()
lines = [] # type: List[str] lines = [] # type: List[str]
style = entry._repr_style style = entry._repr_style if entry._repr_style is not None else self.style
if style is None:
style = self.style
if style in ("short", "long"): if style in ("short", "long"):
short = style == "short" short = style == "short"
reprargs = self.repr_args(entry) if not short else None reprargs = self.repr_args(entry) if not short else None
@ -763,7 +803,7 @@ class FormattedExcinfo:
path = np path = np
return path return path
def repr_traceback(self, excinfo): def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
traceback = excinfo.traceback traceback = excinfo.traceback
if self.tbfilter: if self.tbfilter:
traceback = traceback.filter() traceback = traceback.filter()
@ -781,7 +821,9 @@ class FormattedExcinfo:
entries.append(reprentry) entries.append(reprentry)
return ReprTraceback(entries, extraline, style=self.style) return ReprTraceback(entries, extraline, style=self.style)
def _truncate_recursive_traceback(self, traceback): def _truncate_recursive_traceback(
self, traceback: Traceback
) -> Tuple[Traceback, Optional[str]]:
""" """
Truncate the given recursive traceback trying to find the starting point Truncate the given recursive traceback trying to find the starting point
of the recursion. of the recursion.
@ -808,7 +850,9 @@ class FormattedExcinfo:
max_frames=max_frames, max_frames=max_frames,
total=len(traceback), total=len(traceback),
) # type: Optional[str] ) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:] # Type ignored because adding two instaces of a List subtype
# currently incorrectly has type List instead of the subtype.
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
else: else:
if recursionindex is not None: if recursionindex is not None:
extraline = "!!! Recursion detected (same locals & position)" extraline = "!!! Recursion detected (same locals & position)"
@ -865,7 +909,7 @@ class FormattedExcinfo:
class TerminalRepr: class TerminalRepr:
def __str__(self): def __str__(self) -> str:
# FYI this is called from pytest-xdist's serialization of exception # FYI this is called from pytest-xdist's serialization of exception
# information. # information.
io = StringIO() io = StringIO()
@ -873,7 +917,7 @@ class TerminalRepr:
self.toterminal(tw) self.toterminal(tw)
return io.getvalue().strip() return io.getvalue().strip()
def __repr__(self): def __repr__(self) -> str:
return "<{} instance at {:0x}>".format(self.__class__, id(self)) return "<{} instance at {:0x}>".format(self.__class__, id(self))
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -884,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
def __init__(self) -> None: def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]] self.sections = [] # type: List[Tuple[str, str, str]]
def addsection(self, name, content, sep="-"): def addsection(self, name: str, content: str, sep: str = "-") -> None:
self.sections.append((name, content, sep)) self.sections.append((name, content, sep))
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -894,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
class ExceptionChainRepr(ExceptionRepr): class ExceptionChainRepr(ExceptionRepr):
def __init__(self, chain): def __init__(
self,
chain: Sequence[
Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
],
) -> None:
super().__init__() super().__init__()
self.chain = chain self.chain = chain
# reprcrash and reprtraceback of the outermost (the newest) exception # reprcrash and reprtraceback of the outermost (the newest) exception
@ -912,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
class ReprExceptionInfo(ExceptionRepr): class ReprExceptionInfo(ExceptionRepr):
def __init__(self, reprtraceback, reprcrash): def __init__(
self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
) -> None:
super().__init__() super().__init__()
self.reprtraceback = reprtraceback self.reprtraceback = reprtraceback
self.reprcrash = reprcrash self.reprcrash = reprcrash
@ -925,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
class ReprTraceback(TerminalRepr): class ReprTraceback(TerminalRepr):
entrysep = "_ " entrysep = "_ "
def __init__(self, reprentries, extraline, style): def __init__(
self,
reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
extraline: Optional[str],
style: "_TracebackStyle",
) -> None:
self.reprentries = reprentries self.reprentries = reprentries
self.extraline = extraline self.extraline = extraline
self.style = style self.style = style
@ -950,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
class ReprTracebackNative(ReprTraceback): class ReprTracebackNative(ReprTraceback):
def __init__(self, tblines): def __init__(self, tblines: Sequence[str]) -> None:
self.style = "native" self.style = "native"
self.reprentries = [ReprEntryNative(tblines)] self.reprentries = [ReprEntryNative(tblines)]
self.extraline = None self.extraline = None
class ReprEntryNative(TerminalRepr): class ReprEntryNative(TerminalRepr):
style = "native" style = "native" # type: _TracebackStyle
def __init__(self, tblines): def __init__(self, tblines: Sequence[str]) -> None:
self.lines = tblines self.lines = tblines
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -967,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
class ReprEntry(TerminalRepr): class ReprEntry(TerminalRepr):
def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style): def __init__(
self,
lines: Sequence[str],
reprfuncargs: Optional["ReprFuncArgs"],
reprlocals: Optional["ReprLocals"],
filelocrepr: Optional["ReprFileLocation"],
style: "_TracebackStyle",
) -> None:
self.lines = lines self.lines = lines
self.reprfuncargs = reprfuncargs self.reprfuncargs = reprfuncargs
self.reprlocals = reprlocals self.reprlocals = reprlocals
@ -976,6 +1039,7 @@ class ReprEntry(TerminalRepr):
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
if self.style == "short": if self.style == "short":
assert self.reprfileloc is not None
self.reprfileloc.toterminal(tw) self.reprfileloc.toterminal(tw)
for line in self.lines: for line in self.lines:
red = line.startswith("E ") red = line.startswith("E ")
@ -994,14 +1058,14 @@ class ReprEntry(TerminalRepr):
tw.line("") tw.line("")
self.reprfileloc.toterminal(tw) self.reprfileloc.toterminal(tw)
def __str__(self): def __str__(self) -> str:
return "{}\n{}\n{}".format( return "{}\n{}\n{}".format(
"\n".join(self.lines), self.reprlocals, self.reprfileloc "\n".join(self.lines), self.reprlocals, self.reprfileloc
) )
class ReprFileLocation(TerminalRepr): class ReprFileLocation(TerminalRepr):
def __init__(self, path, lineno, message): def __init__(self, path, lineno: int, message: str) -> None:
self.path = str(path) self.path = str(path)
self.lineno = lineno self.lineno = lineno
self.message = message self.message = message
@ -1018,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
class ReprLocals(TerminalRepr): class ReprLocals(TerminalRepr):
def __init__(self, lines): def __init__(self, lines: Sequence[str]) -> None:
self.lines = lines self.lines = lines
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -1027,7 +1091,7 @@ class ReprLocals(TerminalRepr):
class ReprFuncArgs(TerminalRepr): class ReprFuncArgs(TerminalRepr):
def __init__(self, args): def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
self.args = args self.args = args
def toterminal(self, tw) -> None: def toterminal(self, tw) -> None:
@ -1049,13 +1113,11 @@ class ReprFuncArgs(TerminalRepr):
tw.line("") tw.line("")
def getrawcode(obj, trycall=True): def getrawcode(obj, trycall: bool = True):
""" return code object for given function. """ """ return code object for given function. """
try: try:
return obj.__code__ return obj.__code__
except AttributeError: except AttributeError:
obj = getattr(obj, "im_func", obj)
obj = getattr(obj, "func_code", obj)
obj = getattr(obj, "f_code", obj) obj = getattr(obj, "f_code", obj)
obj = getattr(obj, "__code__", obj) obj = getattr(obj, "__code__", obj)
if trycall and not hasattr(obj, "co_firstlineno"): if trycall and not hasattr(obj, "co_firstlineno"):
@ -1079,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
_PY_DIR = py.path.local(py.__file__).dirpath() _PY_DIR = py.path.local(py.__file__).dirpath()
def filter_traceback(entry): def filter_traceback(entry: TracebackEntry) -> bool:
"""Return True if a TracebackEntry instance should be removed from tracebacks: """Return True if a TracebackEntry instance should be removed from tracebacks:
* dynamically generated code (no code to show up for it); * dynamically generated code (no code to show up for it);
* internal traceback from pytest or its internal libraries, py and pluggy. * internal traceback from pytest or its internal libraries, py and pluggy.

View File

@ -8,6 +8,7 @@ import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right from bisect import bisect_right
from types import FrameType from types import FrameType
from typing import Iterator
from typing import List from typing import List
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
@ -60,7 +61,7 @@ class Source:
raise NotImplementedError() raise NotImplementedError()
@overload # noqa: F811 @overload # noqa: F811
def __getitem__(self, key: slice) -> "Source": def __getitem__(self, key: slice) -> "Source": # noqa: F811
raise NotImplementedError() raise NotImplementedError()
def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811 def __getitem__(self, key: Union[int, slice]) -> Union[str, "Source"]: # noqa: F811
@ -73,6 +74,9 @@ class Source:
newsource.lines = self.lines[key.start : key.stop] newsource.lines = self.lines[key.start : key.stop]
return newsource return newsource
def __iter__(self) -> Iterator[str]:
return iter(self.lines)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.lines) return len(self.lines)

View File

@ -1074,13 +1074,14 @@ def try_makedirs(cache_dir) -> bool:
def get_cache_dir(file_path: Path) -> Path: def get_cache_dir(file_path: Path) -> Path:
"""Returns the cache directory to write .pyc files for the given .py file path""" """Returns the cache directory to write .pyc files for the given .py file path"""
if sys.version_info >= (3, 8) and sys.pycache_prefix: # Type ignored until added in next mypy release.
if sys.version_info >= (3, 8) and sys.pycache_prefix: # type: ignore
# given: # given:
# prefix = '/tmp/pycs' # prefix = '/tmp/pycs'
# path = '/home/user/proj/test_app.py' # path = '/home/user/proj/test_app.py'
# we want: # we want:
# '/tmp/pycs/home/user/proj' # '/tmp/pycs/home/user/proj'
return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) return Path(sys.pycache_prefix) / Path(*file_path.parts[1:-1]) # type: ignore
else: else:
# classic pycache directory # classic pycache directory
return file_path.parent / "__pycache__" return file_path.parent / "__pycache__"

View File

@ -10,11 +10,14 @@ import sys
from contextlib import contextmanager from contextlib import contextmanager
from inspect import Parameter from inspect import Parameter
from inspect import signature from inspect import signature
from typing import Any
from typing import Callable from typing import Callable
from typing import Generic from typing import Generic
from typing import Optional from typing import Optional
from typing import overload from typing import overload
from typing import Tuple
from typing import TypeVar from typing import TypeVar
from typing import Union
import attr import attr
import py import py
@ -40,12 +43,13 @@ MODULE_NOT_FOUND_ERROR = (
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from importlib import metadata as importlib_metadata # noqa: F401 # Type ignored until next mypy release.
from importlib import metadata as importlib_metadata # type: ignore
else: else:
import importlib_metadata # noqa: F401 import importlib_metadata # noqa: F401
def _format_args(func): def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func)) return str(signature(func))
@ -66,12 +70,12 @@ else:
fspath = os.fspath fspath = os.fspath
def is_generator(func): def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func) genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func) return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func): def iscoroutinefunction(func: object) -> bool:
""" """
Return True if func is a coroutine function (a function defined with async Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with def syntax, and doesn't contain yield), or a function decorated with
@ -84,7 +88,7 @@ def iscoroutinefunction(func):
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False) return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def getlocation(function, curdir=None): def getlocation(function, curdir=None) -> str:
function = get_real_func(function) function = get_real_func(function)
fn = py.path.local(inspect.getfile(function)) fn = py.path.local(inspect.getfile(function))
lineno = function.__code__.co_firstlineno lineno = function.__code__.co_firstlineno
@ -93,7 +97,7 @@ def getlocation(function, curdir=None):
return "%s:%d" % (fn, lineno + 1) return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function): def num_mock_patch_args(function) -> int:
""" return number of arguments used up by mock arguments (if any) """ """ return number of arguments used up by mock arguments (if any) """
patchings = getattr(function, "patchings", None) patchings = getattr(function, "patchings", None)
if not patchings: if not patchings:
@ -112,7 +116,13 @@ def num_mock_patch_args(function):
) )
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None): def getfuncargnames(
function: Callable[..., Any],
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None
) -> Tuple[str, ...]:
"""Returns the names of a function's mandatory arguments. """Returns the names of a function's mandatory arguments.
This should return the names of all function arguments that: This should return the names of all function arguments that:
@ -180,7 +190,7 @@ else:
from contextlib import nullcontext # noqa from contextlib import nullcontext # noqa
def get_default_arg_names(function): def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames, # Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
# to get the arguments which were excluded from its result because they had default values # to get the arguments which were excluded from its result because they had default values
return tuple( return tuple(
@ -199,18 +209,18 @@ _non_printable_ascii_translate_table.update(
) )
def _translate_non_printable(s): def _translate_non_printable(s: str) -> str:
return s.translate(_non_printable_ascii_translate_table) return s.translate(_non_printable_ascii_translate_table)
STRING_TYPES = bytes, str STRING_TYPES = bytes, str
def _bytes_to_ascii(val): def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace") return val.decode("ascii", "backslashreplace")
def ascii_escaped(val): def ascii_escaped(val: Union[bytes, str]):
"""If val is pure ascii, returns it as a str(). Otherwise, escapes """If val is pure ascii, returns it as a str(). Otherwise, escapes
bytes objects into a sequence of escaped bytes: bytes objects into a sequence of escaped bytes:
@ -307,7 +317,7 @@ def getimfunc(func):
return func return func
def safe_getattr(object, name, default): def safe_getattr(object: Any, name: str, default: Any) -> Any:
""" Like getattr but return default upon any Exception or any OutcomeException. """ Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects. Attribute access can potentially fail for 'evil' Python objects.
@ -321,7 +331,7 @@ def safe_getattr(object, name, default):
return default return default
def safe_isclass(obj): def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3.""" """Ignore any exception via isinstance on Python 3."""
try: try:
return inspect.isclass(obj) return inspect.isclass(obj)
@ -342,39 +352,26 @@ COLLECT_FAKEMODULE_ATTRIBUTES = (
) )
def _setup_collect_fakemodule(): def _setup_collect_fakemodule() -> None:
from types import ModuleType from types import ModuleType
import pytest import pytest
pytest.collect = ModuleType("pytest.collect") # Types ignored because the module is created dynamically.
pytest.collect.__all__ = [] # used for setns pytest.collect = ModuleType("pytest.collect") # type: ignore
pytest.collect.__all__ = [] # type: ignore # used for setns
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES: for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
class CaptureIO(io.TextIOWrapper): class CaptureIO(io.TextIOWrapper):
def __init__(self): def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True) super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
def getvalue(self): def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8") return self.buffer.getvalue().decode("UTF-8")
class FuncargnamesCompatAttr:
""" helper class so that Metafunc, Function and FixtureRequest
don't need to each define the "funcargnames" compatibility attribute.
"""
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
import warnings
from _pytest.deprecated import FUNCARGNAMES
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
if sys.version_info < (3, 5, 2): # pragma: no cover if sys.version_info < (3, 5, 2): # pragma: no cover
def overload(f): # noqa: F811 def overload(f): # noqa: F811
@ -407,7 +404,9 @@ else:
raise NotImplementedError() raise NotImplementedError()
@overload # noqa: F811 @overload # noqa: F811
def __get__(self, instance: _S, owner: Optional["Type[_S]"] = ...) -> _T: def __get__( # noqa: F811
self, instance: _S, owner: Optional["Type[_S]"] = ...
) -> _T:
raise NotImplementedError() raise NotImplementedError()
def __get__(self, instance, owner=None): # noqa: F811 def __get__(self, instance, owner=None): # noqa: F811

View File

@ -18,7 +18,6 @@ from _pytest._code.code import FormattedExcinfo
from _pytest._code.code import TerminalRepr from _pytest._code.code import TerminalRepr
from _pytest.compat import _format_args from _pytest.compat import _format_args
from _pytest.compat import _PytestWrapper from _pytest.compat import _PytestWrapper
from _pytest.compat import FuncargnamesCompatAttr
from _pytest.compat import get_real_func from _pytest.compat import get_real_func
from _pytest.compat import get_real_method from _pytest.compat import get_real_method
from _pytest.compat import getfslineno from _pytest.compat import getfslineno
@ -29,6 +28,7 @@ from _pytest.compat import is_generator
from _pytest.compat import NOTSET from _pytest.compat import NOTSET
from _pytest.compat import safe_getattr from _pytest.compat import safe_getattr
from _pytest.deprecated import FIXTURE_POSITIONAL_ARGUMENTS from _pytest.deprecated import FIXTURE_POSITIONAL_ARGUMENTS
from _pytest.deprecated import FUNCARGNAMES
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.outcomes import TEST_OUTCOME from _pytest.outcomes import TEST_OUTCOME
@ -336,7 +336,7 @@ class FuncFixtureInfo:
self.names_closure[:] = sorted(closure, key=self.names_closure.index) self.names_closure[:] = sorted(closure, key=self.names_closure.index)
class FixtureRequest(FuncargnamesCompatAttr): class FixtureRequest:
""" A request for a fixture from a test or fixture function. """ A request for a fixture from a test or fixture function.
A request object gives access to the requesting test context A request object gives access to the requesting test context
@ -363,6 +363,12 @@ class FixtureRequest(FuncargnamesCompatAttr):
result.extend(set(self._fixture_defs).difference(result)) result.extend(set(self._fixture_defs).difference(result))
return result return result
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
@property @property
def node(self): def node(self):
""" underlying collection node (depends on current request scope)""" """ underlying collection node (depends on current request scope)"""

View File

@ -31,6 +31,7 @@ from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES from _pytest.compat import STRING_TYPES
from _pytest.config import hookimpl from _pytest.config import hookimpl
from _pytest.deprecated import FUNCARGNAMES
from _pytest.main import FSHookProxy from _pytest.main import FSHookProxy
from _pytest.mark import MARK_GEN from _pytest.mark import MARK_GEN
from _pytest.mark.structures import get_unpacked_marks from _pytest.mark.structures import get_unpacked_marks
@ -882,7 +883,7 @@ class CallSpec2:
self.marks.extend(normalize_mark_list(marks)) self.marks.extend(normalize_mark_list(marks))
class Metafunc(fixtures.FuncargnamesCompatAttr): class Metafunc:
""" """
Metafunc objects are passed to the :func:`pytest_generate_tests <_pytest.hookspec.pytest_generate_tests>` hook. Metafunc objects are passed to the :func:`pytest_generate_tests <_pytest.hookspec.pytest_generate_tests>` hook.
They help to inspect a test function and to generate tests according to They help to inspect a test function and to generate tests according to
@ -916,6 +917,12 @@ class Metafunc(fixtures.FuncargnamesCompatAttr):
self._ids = set() self._ids = set()
self._arg2fixturedefs = fixtureinfo.name2fixturedefs self._arg2fixturedefs = fixtureinfo.name2fixturedefs
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
def parametrize(self, argnames, argvalues, indirect=False, ids=None, scope=None): def parametrize(self, argnames, argvalues, indirect=False, ids=None, scope=None):
""" Add new invocations to the underlying test function using the list """ Add new invocations to the underlying test function using the list
of argvalues for the given argnames. Parametrization is performed of argvalues for the given argnames. Parametrization is performed
@ -1333,7 +1340,7 @@ def write_docstring(tw, doc, indent=" "):
tw.write(indent + line + "\n") tw.write(indent + line + "\n")
class Function(FunctionMixin, nodes.Item, fixtures.FuncargnamesCompatAttr): class Function(FunctionMixin, nodes.Item):
""" a Function Item is responsible for setting up and executing a """ a Function Item is responsible for setting up and executing a
Python test function. Python test function.
""" """
@ -1420,6 +1427,12 @@ class Function(FunctionMixin, nodes.Item, fixtures.FuncargnamesCompatAttr):
"(compatonly) for code expecting pytest-2.2 style request objects" "(compatonly) for code expecting pytest-2.2 style request objects"
return self return self
@property
def funcargnames(self):
""" alias attribute for ``fixturenames`` for pre-2.3 compatibility"""
warnings.warn(FUNCARGNAMES, stacklevel=2)
return self.fixturenames
def runtest(self): def runtest(self):
""" execute the underlying test function. """ """ execute the underlying test function. """
self.ihook.pytest_pyfunc_call(pyfuncitem=self) self.ihook.pytest_pyfunc_call(pyfuncitem=self)

View File

@ -552,7 +552,7 @@ def raises(
@overload # noqa: F811 @overload # noqa: F811
def raises( def raises( # noqa: F811
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
func: Callable, func: Callable,
*args: Any, *args: Any,

View File

@ -60,18 +60,18 @@ def warns(
*, *,
match: "Optional[Union[str, Pattern]]" = ... match: "Optional[Union[str, Pattern]]" = ...
) -> "WarningsChecker": ) -> "WarningsChecker":
... # pragma: no cover raise NotImplementedError()
@overload # noqa: F811 @overload # noqa: F811
def warns( def warns( # noqa: F811
expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]], expected_warning: Union["Type[Warning]", Tuple["Type[Warning]", ...]],
func: Callable, func: Callable,
*args: Any, *args: Any,
match: Optional[Union[str, "Pattern"]] = ..., match: Optional[Union[str, "Pattern"]] = ...,
**kwargs: Any **kwargs: Any
) -> Union[Any]: ) -> Union[Any]:
... # pragma: no cover raise NotImplementedError()
def warns( # noqa: F811 def warns( # noqa: F811

View File

@ -1,18 +1,19 @@
import sys import sys
from types import FrameType
from unittest import mock from unittest import mock
import _pytest._code import _pytest._code
import pytest import pytest
def test_ne(): def test_ne() -> None:
code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec")) code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec"))
assert code1 == code1 assert code1 == code1
code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec")) code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec"))
assert code2 != code1 assert code2 != code1
def test_code_gives_back_name_for_not_existing_file(): def test_code_gives_back_name_for_not_existing_file() -> None:
name = "abc-123" name = "abc-123"
co_code = compile("pass\n", name, "exec") co_code = compile("pass\n", name, "exec")
assert co_code.co_filename == name assert co_code.co_filename == name
@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file():
assert code.fullsource is None assert code.fullsource is None
def test_code_with_class(): def test_code_with_class() -> None:
class A: class A:
pass pass
pytest.raises(TypeError, _pytest._code.Code, A) pytest.raises(TypeError, _pytest._code.Code, A)
def x(): def x() -> None:
raise NotImplementedError() raise NotImplementedError()
def test_code_fullsource(): def test_code_fullsource() -> None:
code = _pytest._code.Code(x) code = _pytest._code.Code(x)
full = code.fullsource full = code.fullsource
assert "test_code_fullsource()" in str(full) assert "test_code_fullsource()" in str(full)
def test_code_source(): def test_code_source() -> None:
code = _pytest._code.Code(x) code = _pytest._code.Code(x)
src = code.source() src = code.source()
expected = """def x(): expected = """def x() -> None:
raise NotImplementedError()""" raise NotImplementedError()"""
assert str(src) == expected assert str(src) == expected
def test_frame_getsourcelineno_myself(): def test_frame_getsourcelineno_myself() -> None:
def func(): def func() -> FrameType:
return sys._getframe(0) return sys._getframe(0)
f = func() f = _pytest._code.Frame(func())
f = _pytest._code.Frame(f)
source, lineno = f.code.fullsource, f.lineno source, lineno = f.code.fullsource, f.lineno
assert source is not None
assert source[lineno].startswith(" return sys._getframe(0)") assert source[lineno].startswith(" return sys._getframe(0)")
def test_getstatement_empty_fullsource(): def test_getstatement_empty_fullsource() -> None:
def func(): def func() -> FrameType:
return sys._getframe(0) return sys._getframe(0)
f = func() f = _pytest._code.Frame(func())
f = _pytest._code.Frame(f)
with mock.patch.object(f.code.__class__, "fullsource", None): with mock.patch.object(f.code.__class__, "fullsource", None):
assert f.statement == "" assert f.statement == ""
def test_code_from_func(): def test_code_from_func() -> None:
co = _pytest._code.Code(test_frame_getsourcelineno_myself) co = _pytest._code.Code(test_frame_getsourcelineno_myself)
assert co.firstlineno assert co.firstlineno
assert co.path assert co.path
def test_unicode_handling(): def test_unicode_handling() -> None:
value = "ąć".encode() value = "ąć".encode()
def f(): def f() -> None:
raise Exception(value) raise Exception(value)
excinfo = pytest.raises(Exception, f) excinfo = pytest.raises(Exception, f)
str(excinfo) str(excinfo)
def test_code_getargs(): def test_code_getargs() -> None:
def f1(x): def f1(x):
raise NotImplementedError() raise NotImplementedError()
@ -108,26 +108,26 @@ def test_code_getargs():
assert c4.getargs(var=True) == ("x", "y", "z") assert c4.getargs(var=True) == ("x", "y", "z")
def test_frame_getargs(): def test_frame_getargs() -> None:
def f1(x): def f1(x) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr1 = _pytest._code.Frame(f1("a")) fr1 = _pytest._code.Frame(f1("a"))
assert fr1.getargs(var=True) == [("x", "a")] assert fr1.getargs(var=True) == [("x", "a")]
def f2(x, *y): def f2(x, *y) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr2 = _pytest._code.Frame(f2("a", "b", "c")) fr2 = _pytest._code.Frame(f2("a", "b", "c"))
assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))] assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))]
def f3(x, **z): def f3(x, **z) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr3 = _pytest._code.Frame(f3("a", b="c")) fr3 = _pytest._code.Frame(f3("a", b="c"))
assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})] assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})]
def f4(x, *y, **z): def f4(x, *y, **z) -> FrameType:
return sys._getframe(0) return sys._getframe(0)
fr4 = _pytest._code.Frame(f4("a", "b", c="d")) fr4 = _pytest._code.Frame(f4("a", "b", c="d"))
@ -135,7 +135,7 @@ def test_frame_getargs():
class TestExceptionInfo: class TestExceptionInfo:
def test_bad_getsource(self): def test_bad_getsource(self) -> None:
try: try:
if False: if False:
pass pass
@ -145,13 +145,13 @@ class TestExceptionInfo:
exci = _pytest._code.ExceptionInfo.from_current() exci = _pytest._code.ExceptionInfo.from_current()
assert exci.getrepr() assert exci.getrepr()
def test_from_current_with_missing(self): def test_from_current_with_missing(self) -> None:
with pytest.raises(AssertionError, match="no current exception"): with pytest.raises(AssertionError, match="no current exception"):
_pytest._code.ExceptionInfo.from_current() _pytest._code.ExceptionInfo.from_current()
class TestTracebackEntry: class TestTracebackEntry:
def test_getsource(self): def test_getsource(self) -> None:
try: try:
if False: if False:
pass pass
@ -161,12 +161,13 @@ class TestTracebackEntry:
exci = _pytest._code.ExceptionInfo.from_current() exci = _pytest._code.ExceptionInfo.from_current()
entry = exci.traceback[0] entry = exci.traceback[0]
source = entry.getsource() source = entry.getsource()
assert source is not None
assert len(source) == 6 assert len(source) == 6
assert "assert False" in source[5] assert "assert False" in source[5]
class TestReprFuncArgs: class TestReprFuncArgs:
def test_not_raise_exception_with_mixed_encoding(self, tw_mock): def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
from _pytest._code.code import ReprFuncArgs from _pytest._code.code import ReprFuncArgs
args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")] args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")]

View File

@ -3,6 +3,7 @@ import os
import queue import queue
import sys import sys
import textwrap import textwrap
from typing import Union
import py import py
@ -59,9 +60,9 @@ def test_excinfo_getstatement():
except ValueError: except ValueError:
excinfo = _pytest._code.ExceptionInfo.from_current() excinfo = _pytest._code.ExceptionInfo.from_current()
linenumbers = [ linenumbers = [
_pytest._code.getrawcode(f).co_firstlineno - 1 + 4, f.__code__.co_firstlineno - 1 + 4,
_pytest._code.getrawcode(f).co_firstlineno - 1 + 1, f.__code__.co_firstlineno - 1 + 1,
_pytest._code.getrawcode(g).co_firstlineno - 1 + 1, g.__code__.co_firstlineno - 1 + 1,
] ]
values = list(excinfo.traceback) values = list(excinfo.traceback)
foundlinenumbers = [x.lineno for x in values] foundlinenumbers = [x.lineno for x in values]
@ -224,23 +225,25 @@ class TestTraceback_f_g_h:
repr = excinfo.getrepr() repr = excinfo.getrepr()
assert "RuntimeError: hello" in str(repr.reprcrash) assert "RuntimeError: hello" in str(repr.reprcrash)
def test_traceback_no_recursion_index(self): def test_traceback_no_recursion_index(self) -> None:
def do_stuff(): def do_stuff() -> None:
raise RuntimeError raise RuntimeError
def reraise_me(): def reraise_me() -> None:
import sys import sys
exc, val, tb = sys.exc_info() exc, val, tb = sys.exc_info()
assert val is not None
raise val.with_traceback(tb) raise val.with_traceback(tb)
def f(n): def f(n: int) -> None:
try: try:
do_stuff() do_stuff()
except: # noqa except: # noqa
reraise_me() reraise_me()
excinfo = pytest.raises(RuntimeError, f, 8) excinfo = pytest.raises(RuntimeError, f, 8)
assert excinfo is not None
traceback = excinfo.traceback traceback = excinfo.traceback
recindex = traceback.recursionindex() recindex = traceback.recursionindex()
assert recindex is None assert recindex is None
@ -502,65 +505,18 @@ raise ValueError()
assert repr.reprtraceback.reprentries[1].lines[0] == "> ???" assert repr.reprtraceback.reprentries[1].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[1].lines[0] == "> ???" assert repr.chain[0][0].reprentries[1].lines[0] == "> ???"
def test_repr_source_failing_fullsource(self): def test_repr_source_failing_fullsource(self, monkeypatch) -> None:
pr = FormattedExcinfo() pr = FormattedExcinfo()
class FakeCode: try:
class raw: 1 / 0
co_filename = "?" except ZeroDivisionError:
excinfo = ExceptionInfo.from_current()
path = "?" with monkeypatch.context() as m:
firstlineno = 5 m.setattr(_pytest._code.Code, "fullsource", property(lambda self: None))
repr = pr.repr_excinfo(excinfo)
def fullsource(self):
return None
fullsource = property(fullsource)
class FakeFrame:
code = FakeCode()
f_locals = {}
f_globals = {}
class FakeTracebackEntry(_pytest._code.Traceback.Entry):
def __init__(self, tb, excinfo=None):
self.lineno = 5 + 3
@property
def frame(self):
return FakeFrame()
class Traceback(_pytest._code.Traceback):
Entry = FakeTracebackEntry
class FakeExcinfo(_pytest._code.ExceptionInfo):
typename = "Foo"
value = Exception()
def __init__(self):
pass
def exconly(self, tryshort):
return "EXC"
def errisinstance(self, cls):
return False
excinfo = FakeExcinfo()
class FakeRawTB:
tb_next = None
tb = FakeRawTB()
excinfo.traceback = Traceback(tb)
fail = IOError()
repr = pr.repr_excinfo(excinfo)
assert repr.reprtraceback.reprentries[0].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[0].lines[0] == "> ???"
fail = py.error.ENOENT # noqa
repr = pr.repr_excinfo(excinfo)
assert repr.reprtraceback.reprentries[0].lines[0] == "> ???" assert repr.reprtraceback.reprentries[0].lines[0] == "> ???"
assert repr.chain[0][0].reprentries[0].lines[0] == "> ???" assert repr.chain[0][0].reprentries[0].lines[0] == "> ???"
@ -643,7 +599,6 @@ raise ValueError()
assert lines[3] == "E world" assert lines[3] == "E world"
assert not lines[4:] assert not lines[4:]
loc = repr_entry.reprlocals is not None
loc = repr_entry.reprfileloc loc = repr_entry.reprfileloc
assert loc.path == mod.__file__ assert loc.path == mod.__file__
assert loc.lineno == 3 assert loc.lineno == 3
@ -1333,9 +1288,10 @@ raise ValueError()
@pytest.mark.parametrize("style", ["short", "long"]) @pytest.mark.parametrize("style", ["short", "long"])
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"]) @pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
def test_repr_traceback_with_unicode(style, encoding): def test_repr_traceback_with_unicode(style, encoding):
msg = "" if encoding is None:
if encoding is not None: msg = "" # type: Union[str, bytes]
msg = msg.encode(encoding) else:
msg = "".encode(encoding)
try: try:
raise RuntimeError(msg) raise RuntimeError(msg)
except RuntimeError: except RuntimeError:

View File

@ -4,13 +4,16 @@
import ast import ast
import inspect import inspect
import sys import sys
from typing import Any
from typing import Dict
from typing import Optional
import _pytest._code import _pytest._code
import pytest import pytest
from _pytest._code import Source from _pytest._code import Source
def test_source_str_function(): def test_source_str_function() -> None:
x = Source("3") x = Source("3")
assert str(x) == "3" assert str(x) == "3"
@ -25,7 +28,7 @@ def test_source_str_function():
assert str(x) == "\n3" assert str(x) == "\n3"
def test_unicode(): def test_unicode() -> None:
x = Source("4") x = Source("4")
assert str(x) == "4" assert str(x) == "4"
co = _pytest._code.compile('"å"', mode="eval") co = _pytest._code.compile('"å"', mode="eval")
@ -33,12 +36,12 @@ def test_unicode():
assert isinstance(val, str) assert isinstance(val, str)
def test_source_from_function(): def test_source_from_function() -> None:
source = _pytest._code.Source(test_source_str_function) source = _pytest._code.Source(test_source_str_function)
assert str(source).startswith("def test_source_str_function():") assert str(source).startswith("def test_source_str_function() -> None:")
def test_source_from_method(): def test_source_from_method() -> None:
class TestClass: class TestClass:
def test_method(self): def test_method(self):
pass pass
@ -47,13 +50,13 @@ def test_source_from_method():
assert source.lines == ["def test_method(self):", " pass"] assert source.lines == ["def test_method(self):", " pass"]
def test_source_from_lines(): def test_source_from_lines() -> None:
lines = ["a \n", "b\n", "c"] lines = ["a \n", "b\n", "c"]
source = _pytest._code.Source(lines) source = _pytest._code.Source(lines)
assert source.lines == ["a ", "b", "c"] assert source.lines == ["a ", "b", "c"]
def test_source_from_inner_function(): def test_source_from_inner_function() -> None:
def f(): def f():
pass pass
@ -63,7 +66,7 @@ def test_source_from_inner_function():
assert str(source).startswith("def f():") assert str(source).startswith("def f():")
def test_source_putaround_simple(): def test_source_putaround_simple() -> None:
source = Source("raise ValueError") source = Source("raise ValueError")
source = source.putaround( source = source.putaround(
"try:", "try:",
@ -85,7 +88,7 @@ else:
) )
def test_source_putaround(): def test_source_putaround() -> None:
source = Source() source = Source()
source = source.putaround( source = source.putaround(
""" """
@ -96,28 +99,29 @@ def test_source_putaround():
assert str(source).strip() == "if 1:\n x=1" assert str(source).strip() == "if 1:\n x=1"
def test_source_strips(): def test_source_strips() -> None:
source = Source("") source = Source("")
assert source == Source() assert source == Source()
assert str(source) == "" assert str(source) == ""
assert source.strip() == source assert source.strip() == source
def test_source_strip_multiline(): def test_source_strip_multiline() -> None:
source = Source() source = Source()
source.lines = ["", " hello", " "] source.lines = ["", " hello", " "]
source2 = source.strip() source2 = source.strip()
assert source2.lines == [" hello"] assert source2.lines == [" hello"]
def test_syntaxerror_rerepresentation(): def test_syntaxerror_rerepresentation() -> None:
ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz") ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
assert ex is not None
assert ex.value.lineno == 1 assert ex.value.lineno == 1
assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5 assert ex.value.offset in {5, 7} # cpython: 7, pypy3.6 7.1.1: 5
assert ex.value.text.strip(), "x x" assert ex.value.text == "xyz xyz\n"
def test_isparseable(): def test_isparseable() -> None:
assert Source("hello").isparseable() assert Source("hello").isparseable()
assert Source("if 1:\n pass").isparseable() assert Source("if 1:\n pass").isparseable()
assert Source(" \nif 1:\n pass").isparseable() assert Source(" \nif 1:\n pass").isparseable()
@ -127,7 +131,7 @@ def test_isparseable():
class TestAccesses: class TestAccesses:
def setup_class(self): def setup_class(self) -> None:
self.source = Source( self.source = Source(
"""\ """\
def f(x): def f(x):
@ -137,26 +141,26 @@ class TestAccesses:
""" """
) )
def test_getrange(self): def test_getrange(self) -> None:
x = self.source[0:2] x = self.source[0:2]
assert x.isparseable() assert x.isparseable()
assert len(x.lines) == 2 assert len(x.lines) == 2
assert str(x) == "def f(x):\n pass" assert str(x) == "def f(x):\n pass"
def test_getline(self): def test_getline(self) -> None:
x = self.source[0] x = self.source[0]
assert x == "def f(x):" assert x == "def f(x):"
def test_len(self): def test_len(self) -> None:
assert len(self.source) == 4 assert len(self.source) == 4
def test_iter(self): def test_iter(self) -> None:
values = [x for x in self.source] values = [x for x in self.source]
assert len(values) == 4 assert len(values) == 4
class TestSourceParsingAndCompiling: class TestSourceParsingAndCompiling:
def setup_class(self): def setup_class(self) -> None:
self.source = Source( self.source = Source(
"""\ """\
def f(x): def f(x):
@ -166,19 +170,19 @@ class TestSourceParsingAndCompiling:
""" """
).strip() ).strip()
def test_compile(self): def test_compile(self) -> None:
co = _pytest._code.compile("x=3") co = _pytest._code.compile("x=3")
d = {} d = {} # type: Dict[str, Any]
exec(co, d) exec(co, d)
assert d["x"] == 3 assert d["x"] == 3
def test_compile_and_getsource_simple(self): def test_compile_and_getsource_simple(self) -> None:
co = _pytest._code.compile("x=3") co = _pytest._code.compile("x=3")
exec(co) exec(co)
source = _pytest._code.Source(co) source = _pytest._code.Source(co)
assert str(source) == "x=3" assert str(source) == "x=3"
def test_compile_and_getsource_through_same_function(self): def test_compile_and_getsource_through_same_function(self) -> None:
def gensource(source): def gensource(source):
return _pytest._code.compile(source) return _pytest._code.compile(source)
@ -199,7 +203,7 @@ class TestSourceParsingAndCompiling:
source2 = inspect.getsource(co2) source2 = inspect.getsource(co2)
assert "ValueError" in source2 assert "ValueError" in source2
def test_getstatement(self): def test_getstatement(self) -> None:
# print str(self.source) # print str(self.source)
ass = str(self.source[1:]) ass = str(self.source[1:])
for i in range(1, 4): for i in range(1, 4):
@ -208,7 +212,7 @@ class TestSourceParsingAndCompiling:
# x = s.deindent() # x = s.deindent()
assert str(s) == ass assert str(s) == ass
def test_getstatementrange_triple_quoted(self): def test_getstatementrange_triple_quoted(self) -> None:
# print str(self.source) # print str(self.source)
source = Source( source = Source(
"""hello(''' """hello('''
@ -219,7 +223,7 @@ class TestSourceParsingAndCompiling:
s = source.getstatement(1) s = source.getstatement(1)
assert s == str(source) assert s == str(source)
def test_getstatementrange_within_constructs(self): def test_getstatementrange_within_constructs(self) -> None:
source = Source( source = Source(
"""\ """\
try: try:
@ -241,7 +245,7 @@ class TestSourceParsingAndCompiling:
# assert source.getstatementrange(5) == (0, 7) # assert source.getstatementrange(5) == (0, 7)
assert source.getstatementrange(6) == (6, 7) assert source.getstatementrange(6) == (6, 7)
def test_getstatementrange_bug(self): def test_getstatementrange_bug(self) -> None:
source = Source( source = Source(
"""\ """\
try: try:
@ -255,7 +259,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 6 assert len(source) == 6
assert source.getstatementrange(2) == (1, 4) assert source.getstatementrange(2) == (1, 4)
def test_getstatementrange_bug2(self): def test_getstatementrange_bug2(self) -> None:
source = Source( source = Source(
"""\ """\
assert ( assert (
@ -272,7 +276,7 @@ class TestSourceParsingAndCompiling:
assert len(source) == 9 assert len(source) == 9
assert source.getstatementrange(5) == (0, 9) assert source.getstatementrange(5) == (0, 9)
def test_getstatementrange_ast_issue58(self): def test_getstatementrange_ast_issue58(self) -> None:
source = Source( source = Source(
"""\ """\
@ -286,38 +290,44 @@ class TestSourceParsingAndCompiling:
assert getstatement(2, source).lines == source.lines[2:3] assert getstatement(2, source).lines == source.lines[2:3]
assert getstatement(3, source).lines == source.lines[3:4] assert getstatement(3, source).lines == source.lines[3:4]
def test_getstatementrange_out_of_bounds_py3(self): def test_getstatementrange_out_of_bounds_py3(self) -> None:
source = Source("if xxx:\n from .collections import something") source = Source("if xxx:\n from .collections import something")
r = source.getstatementrange(1) r = source.getstatementrange(1)
assert r == (1, 2) assert r == (1, 2)
def test_getstatementrange_with_syntaxerror_issue7(self): def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
source = Source(":") source = Source(":")
pytest.raises(SyntaxError, lambda: source.getstatementrange(0)) pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
def test_compile_to_ast(self): def test_compile_to_ast(self) -> None:
source = Source("x = 4") source = Source("x = 4")
mod = source.compile(flag=ast.PyCF_ONLY_AST) mod = source.compile(flag=ast.PyCF_ONLY_AST)
assert isinstance(mod, ast.Module) assert isinstance(mod, ast.Module)
compile(mod, "<filename>", "exec") compile(mod, "<filename>", "exec")
def test_compile_and_getsource(self): def test_compile_and_getsource(self) -> None:
co = self.source.compile() co = self.source.compile()
exec(co, globals()) exec(co, globals())
f(7) f(7) # type: ignore
excinfo = pytest.raises(AssertionError, f, 6) excinfo = pytest.raises(AssertionError, f, 6) # type: ignore
assert excinfo is not None
frame = excinfo.traceback[-1].frame frame = excinfo.traceback[-1].frame
assert isinstance(frame.code.fullsource, Source)
stmt = frame.code.fullsource.getstatement(frame.lineno) stmt = frame.code.fullsource.getstatement(frame.lineno)
assert str(stmt).strip().startswith("assert") assert str(stmt).strip().startswith("assert")
@pytest.mark.parametrize("name", ["", None, "my"]) @pytest.mark.parametrize("name", ["", None, "my"])
def test_compilefuncs_and_path_sanity(self, name): def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
def check(comp, name): def check(comp, name):
co = comp(self.source, name) co = comp(self.source, name)
if not name: if not name:
expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2) # type: ignore
else: else:
expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2) expected = "codegen %r %s:%d>" % (
name,
mypath, # type: ignore
mylineno + 2 + 2, # type: ignore
) # type: ignore
fn = co.co_filename fn = co.co_filename
assert fn.endswith(expected) assert fn.endswith(expected)
@ -332,9 +342,9 @@ class TestSourceParsingAndCompiling:
pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval") pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
def test_getstartingblock_singleline(): def test_getstartingblock_singleline() -> None:
class A: class A:
def __init__(self, *args): def __init__(self, *args) -> None:
frame = sys._getframe(1) frame = sys._getframe(1)
self.source = _pytest._code.Frame(frame).statement self.source = _pytest._code.Frame(frame).statement
@ -344,22 +354,22 @@ def test_getstartingblock_singleline():
assert len(values) == 1 assert len(values) == 1
def test_getline_finally(): def test_getline_finally() -> None:
def c(): def c() -> None:
pass pass
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
teardown = None teardown = None
try: try:
c(1) c(1) # type: ignore
finally: finally:
if teardown: if teardown:
teardown() teardown()
source = excinfo.traceback[-1].statement source = excinfo.traceback[-1].statement
assert str(source).strip() == "c(1)" assert str(source).strip() == "c(1) # type: ignore"
def test_getfuncsource_dynamic(): def test_getfuncsource_dynamic() -> None:
source = """ source = """
def f(): def f():
raise ValueError raise ValueError
@ -368,11 +378,13 @@ def test_getfuncsource_dynamic():
""" """
co = _pytest._code.compile(source) co = _pytest._code.compile(source)
exec(co, globals()) exec(co, globals())
assert str(_pytest._code.Source(f)).strip() == "def f():\n raise ValueError" f_source = _pytest._code.Source(f) # type: ignore
assert str(_pytest._code.Source(g)).strip() == "def g(): pass" g_source = _pytest._code.Source(g) # type: ignore
assert str(f_source).strip() == "def f():\n raise ValueError"
assert str(g_source).strip() == "def g(): pass"
def test_getfuncsource_with_multine_string(): def test_getfuncsource_with_multine_string() -> None:
def f(): def f():
c = """while True: c = """while True:
pass pass
@ -387,7 +399,7 @@ def test_getfuncsource_with_multine_string():
assert str(_pytest._code.Source(f)) == expected.rstrip() assert str(_pytest._code.Source(f)) == expected.rstrip()
def test_deindent(): def test_deindent() -> None:
from _pytest._code.source import deindent as deindent from _pytest._code.source import deindent as deindent
assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"] assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"]
@ -401,7 +413,7 @@ def test_deindent():
assert lines == ["def f():", " def g():", " pass"] assert lines == ["def f():", " def g():", " pass"]
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot): def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None:
# this test fails because the implicit inspect.getsource(A) below # this test fails because the implicit inspect.getsource(A) below
# does not return the "x = 1" last line. # does not return the "x = 1" last line.
source = _pytest._code.Source( source = _pytest._code.Source(
@ -423,7 +435,7 @@ if True:
pass pass
def test_getsource_fallback(): def test_getsource_fallback() -> None:
from _pytest._code.source import getsource from _pytest._code.source import getsource
expected = """def x(): expected = """def x():
@ -432,7 +444,7 @@ def test_getsource_fallback():
assert src == expected assert src == expected
def test_idem_compile_and_getsource(): def test_idem_compile_and_getsource() -> None:
from _pytest._code.source import getsource from _pytest._code.source import getsource
expected = "def x(): pass" expected = "def x(): pass"
@ -441,15 +453,16 @@ def test_idem_compile_and_getsource():
assert src == expected assert src == expected
def test_findsource_fallback(): def test_findsource_fallback() -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
src, lineno = findsource(x) src, lineno = findsource(x)
assert src is not None
assert "test_findsource_simple" in str(src) assert "test_findsource_simple" in str(src)
assert src[lineno] == " def x():" assert src[lineno] == " def x():"
def test_findsource(): def test_findsource() -> None:
from _pytest._code.source import findsource from _pytest._code.source import findsource
co = _pytest._code.compile( co = _pytest._code.compile(
@ -460,25 +473,27 @@ def test_findsource():
) )
src, lineno = findsource(co) src, lineno = findsource(co)
assert src is not None
assert "if 1:" in str(src) assert "if 1:" in str(src)
d = {} d = {} # type: Dict[str, Any]
eval(co, d) eval(co, d)
src, lineno = findsource(d["x"]) src, lineno = findsource(d["x"])
assert src is not None
assert "if 1:" in str(src) assert "if 1:" in str(src)
assert src[lineno] == " def x():" assert src[lineno] == " def x():"
def test_getfslineno(): def test_getfslineno() -> None:
from _pytest._code import getfslineno from _pytest._code import getfslineno
def f(x): def f(x) -> None:
pass pass
fspath, lineno = getfslineno(f) fspath, lineno = getfslineno(f)
assert fspath.basename == "test_source.py" assert fspath.basename == "test_source.py"
assert lineno == _pytest._code.getrawcode(f).co_firstlineno - 1 # see findsource assert lineno == f.__code__.co_firstlineno - 1 # see findsource
class A: class A:
pass pass
@ -498,40 +513,40 @@ def test_getfslineno():
assert getfslineno(B)[1] == -1 assert getfslineno(B)[1] == -1
def test_code_of_object_instance_with_call(): def test_code_of_object_instance_with_call() -> None:
class A: class A:
pass pass
pytest.raises(TypeError, lambda: _pytest._code.Source(A())) pytest.raises(TypeError, lambda: _pytest._code.Source(A()))
class WithCall: class WithCall:
def __call__(self): def __call__(self) -> None:
pass pass
code = _pytest._code.Code(WithCall()) code = _pytest._code.Code(WithCall())
assert "pass" in str(code.source()) assert "pass" in str(code.source())
class Hello: class Hello:
def __call__(self): def __call__(self) -> None:
pass pass
pytest.raises(TypeError, lambda: _pytest._code.Code(Hello)) pytest.raises(TypeError, lambda: _pytest._code.Code(Hello))
def getstatement(lineno, source): def getstatement(lineno: int, source) -> Source:
from _pytest._code.source import getstatementrange_ast from _pytest._code.source import getstatementrange_ast
source = _pytest._code.Source(source, deindent=False) src = _pytest._code.Source(source, deindent=False)
ast, start, end = getstatementrange_ast(lineno, source) ast, start, end = getstatementrange_ast(lineno, src)
return source[start:end] return src[start:end]
def test_oneline(): def test_oneline() -> None:
source = getstatement(0, "raise ValueError") source = getstatement(0, "raise ValueError")
assert str(source) == "raise ValueError" assert str(source) == "raise ValueError"
def test_comment_and_no_newline_at_end(): def test_comment_and_no_newline_at_end() -> None:
from _pytest._code.source import getstatementrange_ast from _pytest._code.source import getstatementrange_ast
source = Source( source = Source(
@ -545,12 +560,12 @@ def test_comment_and_no_newline_at_end():
assert end == 2 assert end == 2
def test_oneline_and_comment(): def test_oneline_and_comment() -> None:
source = getstatement(0, "raise ValueError\n#hello") source = getstatement(0, "raise ValueError\n#hello")
assert str(source) == "raise ValueError" assert str(source) == "raise ValueError"
def test_comments(): def test_comments() -> None:
source = '''def test(): source = '''def test():
"comment 1" "comment 1"
x = 1 x = 1
@ -576,7 +591,7 @@ comment 4
assert str(getstatement(line, source)) == '"""\ncomment 4\n"""' assert str(getstatement(line, source)) == '"""\ncomment 4\n"""'
def test_comment_in_statement(): def test_comment_in_statement() -> None:
source = """test(foo=1, source = """test(foo=1,
# comment 1 # comment 1
bar=2) bar=2)
@ -588,17 +603,17 @@ def test_comment_in_statement():
) )
def test_single_line_else(): def test_single_line_else() -> None:
source = getstatement(1, "if False: 2\nelse: 3") source = getstatement(1, "if False: 2\nelse: 3")
assert str(source) == "else: 3" assert str(source) == "else: 3"
def test_single_line_finally(): def test_single_line_finally() -> None:
source = getstatement(1, "try: 1\nfinally: 3") source = getstatement(1, "try: 1\nfinally: 3")
assert str(source) == "finally: 3" assert str(source) == "finally: 3"
def test_issue55(): def test_issue55() -> None:
source = ( source = (
"def round_trip(dinp):\n assert 1 == dinp\n" "def round_trip(dinp):\n assert 1 == dinp\n"
'def test_rt():\n round_trip("""\n""")\n' 'def test_rt():\n round_trip("""\n""")\n'
@ -607,7 +622,7 @@ def test_issue55():
assert str(s) == ' round_trip("""\n""")' assert str(s) == ' round_trip("""\n""")'
def test_multiline(): def test_multiline() -> None:
source = getstatement( source = getstatement(
0, 0,
"""\ """\
@ -621,7 +636,7 @@ x = 3
class TestTry: class TestTry:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
try: try:
raise ValueError raise ValueError
@ -631,25 +646,25 @@ else:
raise KeyError() raise KeyError()
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " raise ValueError" assert str(source) == " raise ValueError"
def test_except_line(self): def test_except_line(self) -> None:
source = getstatement(2, self.source) source = getstatement(2, self.source)
assert str(source) == "except Something:" assert str(source) == "except Something:"
def test_except_body(self): def test_except_body(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)" assert str(source) == " raise IndexError(1)"
def test_else(self): def test_else(self) -> None:
source = getstatement(5, self.source) source = getstatement(5, self.source)
assert str(source) == " raise KeyError()" assert str(source) == " raise KeyError()"
class TestTryFinally: class TestTryFinally:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
try: try:
raise ValueError raise ValueError
@ -657,17 +672,17 @@ finally:
raise IndexError(1) raise IndexError(1)
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " raise ValueError" assert str(source) == " raise ValueError"
def test_finally(self): def test_finally(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " raise IndexError(1)" assert str(source) == " raise IndexError(1)"
class TestIf: class TestIf:
def setup_class(self): def setup_class(self) -> None:
self.source = """\ self.source = """\
if 1: if 1:
y = 3 y = 3
@ -677,24 +692,24 @@ else:
y = 7 y = 7
""" """
def test_body(self): def test_body(self) -> None:
source = getstatement(1, self.source) source = getstatement(1, self.source)
assert str(source) == " y = 3" assert str(source) == " y = 3"
def test_elif_clause(self): def test_elif_clause(self) -> None:
source = getstatement(2, self.source) source = getstatement(2, self.source)
assert str(source) == "elif False:" assert str(source) == "elif False:"
def test_elif(self): def test_elif(self) -> None:
source = getstatement(3, self.source) source = getstatement(3, self.source)
assert str(source) == " y = 5" assert str(source) == " y = 5"
def test_else(self): def test_else(self) -> None:
source = getstatement(5, self.source) source = getstatement(5, self.source)
assert str(source) == " y = 7" assert str(source) == " y = 7"
def test_semicolon(): def test_semicolon() -> None:
s = """\ s = """\
hello ; pytest.skip() hello ; pytest.skip()
""" """
@ -702,7 +717,7 @@ hello ; pytest.skip()
assert str(source) == s.strip() assert str(source) == s.strip()
def test_def_online(): def test_def_online() -> None:
s = """\ s = """\
def func(): raise ValueError(42) def func(): raise ValueError(42)
@ -713,7 +728,7 @@ def something():
assert str(source) == "def func(): raise ValueError(42)" assert str(source) == "def func(): raise ValueError(42)"
def XXX_test_expression_multiline(): def XXX_test_expression_multiline() -> None:
source = """\ source = """\
something something
''' '''
@ -722,7 +737,7 @@ something
assert str(result) == "'''\n'''" assert str(result) == "'''\n'''"
def test_getstartingblock_multiline(): def test_getstartingblock_multiline() -> None:
class A: class A:
def __init__(self, *args): def __init__(self, *args):
frame = sys._getframe(1) frame = sys._getframe(1)

View File

@ -92,8 +92,6 @@ class TestCaptureManager:
@pytest.mark.parametrize("method", ["fd", "sys"]) @pytest.mark.parametrize("method", ["fd", "sys"])
def test_capturing_unicode(testdir, method): def test_capturing_unicode(testdir, method):
if hasattr(sys, "pypy_version_info") and sys.pypy_version_info < (2, 2):
pytest.xfail("does not work on pypy < 2.2")
obj = "'b\u00f6y'" obj = "'b\u00f6y'"
testdir.makepyfile( testdir.makepyfile(
"""\ """\