Add type annotations to _pytest._code.code
This commit is contained in:
		
							parent
							
								
									562d4811d5
								
							
						
					
					
						commit
						eaa34a9df0
					
				| 
						 | 
				
			
			@ -7,13 +7,17 @@ from inspect import CO_VARKEYWORDS
 | 
			
		|||
from io import StringIO
 | 
			
		||||
from traceback import format_exception_only
 | 
			
		||||
from types import CodeType
 | 
			
		||||
from types import FrameType
 | 
			
		||||
from types import TracebackType
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Callable
 | 
			
		||||
from typing import Dict
 | 
			
		||||
from typing import Generic
 | 
			
		||||
from typing import Iterable
 | 
			
		||||
from typing import List
 | 
			
		||||
from typing import Optional
 | 
			
		||||
from typing import Pattern
 | 
			
		||||
from typing import Sequence
 | 
			
		||||
from typing import Set
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from typing import TypeVar
 | 
			
		||||
| 
						 | 
				
			
			@ -27,9 +31,16 @@ import py
 | 
			
		|||
import _pytest
 | 
			
		||||
from _pytest._io.saferepr import safeformat
 | 
			
		||||
from _pytest._io.saferepr import saferepr
 | 
			
		||||
from _pytest.compat import overload
 | 
			
		||||
 | 
			
		||||
if False:  # TYPE_CHECKING
 | 
			
		||||
    from typing import Type
 | 
			
		||||
    from typing_extensions import Literal
 | 
			
		||||
    from weakref import ReferenceType  # noqa: F401
 | 
			
		||||
 | 
			
		||||
    from _pytest._code import Source
 | 
			
		||||
 | 
			
		||||
    _TracebackStyle = Literal["long", "short", "no", "native"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Code:
 | 
			
		||||
| 
						 | 
				
			
			@ -38,13 +49,12 @@ class Code:
 | 
			
		|||
    def __init__(self, rawcode) -> None:
 | 
			
		||||
        if not hasattr(rawcode, "co_filename"):
 | 
			
		||||
            rawcode = getrawcode(rawcode)
 | 
			
		||||
        try:
 | 
			
		||||
        if not isinstance(rawcode, CodeType):
 | 
			
		||||
            raise TypeError("not a code object: {!r}".format(rawcode))
 | 
			
		||||
        self.filename = rawcode.co_filename
 | 
			
		||||
        self.firstlineno = rawcode.co_firstlineno - 1
 | 
			
		||||
        self.name = rawcode.co_name
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            raise TypeError("not a code object: {!r}".format(rawcode))
 | 
			
		||||
        self.raw = rawcode  # type: CodeType
 | 
			
		||||
        self.raw = rawcode
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, other):
 | 
			
		||||
        return self.raw == other.raw
 | 
			
		||||
| 
						 | 
				
			
			@ -72,7 +82,7 @@ class Code:
 | 
			
		|||
        return p
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def fullsource(self):
 | 
			
		||||
    def fullsource(self) -> Optional["Source"]:
 | 
			
		||||
        """ return a _pytest._code.Source object for the full source file of the code
 | 
			
		||||
        """
 | 
			
		||||
        from _pytest._code import source
 | 
			
		||||
| 
						 | 
				
			
			@ -80,7 +90,7 @@ class Code:
 | 
			
		|||
        full, _ = source.findsource(self.raw)
 | 
			
		||||
        return full
 | 
			
		||||
 | 
			
		||||
    def source(self):
 | 
			
		||||
    def source(self) -> "Source":
 | 
			
		||||
        """ return a _pytest._code.Source object for the code object's source only
 | 
			
		||||
        """
 | 
			
		||||
        # return source only for that part of code
 | 
			
		||||
| 
						 | 
				
			
			@ -88,7 +98,7 @@ class Code:
 | 
			
		|||
 | 
			
		||||
        return _pytest._code.Source(self.raw)
 | 
			
		||||
 | 
			
		||||
    def getargs(self, var=False):
 | 
			
		||||
    def getargs(self, var: bool = False) -> Tuple[str, ...]:
 | 
			
		||||
        """ return a tuple with the argument names for the code object
 | 
			
		||||
 | 
			
		||||
            if 'var' is set True also return the names of the variable and
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +117,7 @@ class Frame:
 | 
			
		|||
    """Wrapper around a Python frame holding f_locals and f_globals
 | 
			
		||||
    in which expressions can be evaluated."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, frame):
 | 
			
		||||
    def __init__(self, frame: FrameType) -> None:
 | 
			
		||||
        self.lineno = frame.f_lineno - 1
 | 
			
		||||
        self.f_globals = frame.f_globals
 | 
			
		||||
        self.f_locals = frame.f_locals
 | 
			
		||||
| 
						 | 
				
			
			@ -115,7 +125,7 @@ class Frame:
 | 
			
		|||
        self.code = Code(frame.f_code)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def statement(self):
 | 
			
		||||
    def statement(self) -> "Source":
 | 
			
		||||
        """ statement this frame is at """
 | 
			
		||||
        import _pytest._code
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -134,7 +144,7 @@ class Frame:
 | 
			
		|||
        f_locals.update(vars)
 | 
			
		||||
        return eval(code, self.f_globals, f_locals)
 | 
			
		||||
 | 
			
		||||
    def exec_(self, code, **vars):
 | 
			
		||||
    def exec_(self, code, **vars) -> None:
 | 
			
		||||
        """ exec 'code' in the frame
 | 
			
		||||
 | 
			
		||||
            'vars' are optional; additional local variables
 | 
			
		||||
| 
						 | 
				
			
			@ -143,7 +153,7 @@ class Frame:
 | 
			
		|||
        f_locals.update(vars)
 | 
			
		||||
        exec(code, self.f_globals, f_locals)
 | 
			
		||||
 | 
			
		||||
    def repr(self, object):
 | 
			
		||||
    def repr(self, object: object) -> str:
 | 
			
		||||
        """ return a 'safe' (non-recursive, one-line) string repr for 'object'
 | 
			
		||||
        """
 | 
			
		||||
        return saferepr(object)
 | 
			
		||||
| 
						 | 
				
			
			@ -151,7 +161,7 @@ class Frame:
 | 
			
		|||
    def is_true(self, object):
 | 
			
		||||
        return object
 | 
			
		||||
 | 
			
		||||
    def getargs(self, var=False):
 | 
			
		||||
    def getargs(self, var: bool = False):
 | 
			
		||||
        """ return a list of tuples (name, value) for all arguments
 | 
			
		||||
 | 
			
		||||
            if 'var' is set True also include the variable and keyword
 | 
			
		||||
| 
						 | 
				
			
			@ -169,35 +179,34 @@ class Frame:
 | 
			
		|||
class TracebackEntry:
 | 
			
		||||
    """ a single entry in a traceback """
 | 
			
		||||
 | 
			
		||||
    _repr_style = None
 | 
			
		||||
    _repr_style = None  # type: Optional[Literal["short", "long"]]
 | 
			
		||||
    exprinfo = None
 | 
			
		||||
 | 
			
		||||
    def __init__(self, rawentry, excinfo=None):
 | 
			
		||||
    def __init__(self, rawentry: TracebackType, excinfo=None) -> None:
 | 
			
		||||
        self._excinfo = excinfo
 | 
			
		||||
        self._rawentry = rawentry
 | 
			
		||||
        self.lineno = rawentry.tb_lineno - 1
 | 
			
		||||
 | 
			
		||||
    def set_repr_style(self, mode):
 | 
			
		||||
    def set_repr_style(self, mode: "Literal['short', 'long']") -> None:
 | 
			
		||||
        assert mode in ("short", "long")
 | 
			
		||||
        self._repr_style = mode
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def frame(self):
 | 
			
		||||
        import _pytest._code
 | 
			
		||||
 | 
			
		||||
        return _pytest._code.Frame(self._rawentry.tb_frame)
 | 
			
		||||
    def frame(self) -> Frame:
 | 
			
		||||
        return Frame(self._rawentry.tb_frame)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def relline(self):
 | 
			
		||||
    def relline(self) -> int:
 | 
			
		||||
        return self.lineno - self.frame.code.firstlineno
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return "<TracebackEntry %s:%d>" % (self.frame.code.path, self.lineno + 1)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def statement(self):
 | 
			
		||||
    def statement(self) -> "Source":
 | 
			
		||||
        """ _pytest._code.Source object for the current statement """
 | 
			
		||||
        source = self.frame.code.fullsource
 | 
			
		||||
        assert source is not None
 | 
			
		||||
        return source.getstatement(self.lineno)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
| 
						 | 
				
			
			@ -206,14 +215,14 @@ class TracebackEntry:
 | 
			
		|||
        return self.frame.code.path
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def locals(self):
 | 
			
		||||
    def locals(self) -> Dict[str, Any]:
 | 
			
		||||
        """ locals of underlying frame """
 | 
			
		||||
        return self.frame.f_locals
 | 
			
		||||
 | 
			
		||||
    def getfirstlinesource(self):
 | 
			
		||||
    def getfirstlinesource(self) -> int:
 | 
			
		||||
        return self.frame.code.firstlineno
 | 
			
		||||
 | 
			
		||||
    def getsource(self, astcache=None):
 | 
			
		||||
    def getsource(self, astcache=None) -> Optional["Source"]:
 | 
			
		||||
        """ return failing source code. """
 | 
			
		||||
        # we use the passed in astcache to not reparse asttrees
 | 
			
		||||
        # within exception info printing
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +267,7 @@ class TracebackEntry:
 | 
			
		|||
            return tbh(None if self._excinfo is None else self._excinfo())
 | 
			
		||||
        return tbh
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        try:
 | 
			
		||||
            fn = str(self.path)
 | 
			
		||||
        except py.error.Error:
 | 
			
		||||
| 
						 | 
				
			
			@ -273,31 +282,42 @@ class TracebackEntry:
 | 
			
		|||
        return "  File %r:%d in %s\n  %s\n" % (fn, self.lineno + 1, name, line)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def name(self):
 | 
			
		||||
    def name(self) -> str:
 | 
			
		||||
        """ co_name of underlying code """
 | 
			
		||||
        return self.frame.code.raw.co_name
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Traceback(list):
 | 
			
		||||
class Traceback(List[TracebackEntry]):
 | 
			
		||||
    """ Traceback objects encapsulate and offer higher level
 | 
			
		||||
        access to Traceback entries.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tb, excinfo=None):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        tb: Union[TracebackType, Iterable[TracebackEntry]],
 | 
			
		||||
        excinfo: Optional["ReferenceType[ExceptionInfo]"] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """ initialize from given python traceback object and ExceptionInfo """
 | 
			
		||||
        self._excinfo = excinfo
 | 
			
		||||
        if hasattr(tb, "tb_next"):
 | 
			
		||||
        if isinstance(tb, TracebackType):
 | 
			
		||||
 | 
			
		||||
            def f(cur):
 | 
			
		||||
                while cur is not None:
 | 
			
		||||
                    yield TracebackEntry(cur, excinfo=excinfo)
 | 
			
		||||
                    cur = cur.tb_next
 | 
			
		||||
            def f(cur: TracebackType) -> Iterable[TracebackEntry]:
 | 
			
		||||
                cur_ = cur  # type: Optional[TracebackType]
 | 
			
		||||
                while cur_ is not None:
 | 
			
		||||
                    yield TracebackEntry(cur_, excinfo=excinfo)
 | 
			
		||||
                    cur_ = cur_.tb_next
 | 
			
		||||
 | 
			
		||||
            list.__init__(self, f(tb))
 | 
			
		||||
            super().__init__(f(tb))
 | 
			
		||||
        else:
 | 
			
		||||
            list.__init__(self, tb)
 | 
			
		||||
            super().__init__(tb)
 | 
			
		||||
 | 
			
		||||
    def cut(self, path=None, lineno=None, firstlineno=None, excludepath=None):
 | 
			
		||||
    def cut(
 | 
			
		||||
        self,
 | 
			
		||||
        path=None,
 | 
			
		||||
        lineno: Optional[int] = None,
 | 
			
		||||
        firstlineno: Optional[int] = None,
 | 
			
		||||
        excludepath=None,
 | 
			
		||||
    ) -> "Traceback":
 | 
			
		||||
        """ return a Traceback instance wrapping part of this Traceback
 | 
			
		||||
 | 
			
		||||
            by providing any combination of path, lineno and firstlineno, the
 | 
			
		||||
| 
						 | 
				
			
			@ -323,13 +343,25 @@ class Traceback(list):
 | 
			
		|||
                return Traceback(x._rawentry, self._excinfo)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, key):
 | 
			
		||||
        val = super().__getitem__(key)
 | 
			
		||||
        if isinstance(key, type(slice(0))):
 | 
			
		||||
            val = self.__class__(val)
 | 
			
		||||
        return val
 | 
			
		||||
    @overload
 | 
			
		||||
    def __getitem__(self, key: int) -> TracebackEntry:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def filter(self, fn=lambda x: not x.ishidden()):
 | 
			
		||||
    @overload  # noqa: F811
 | 
			
		||||
    def __getitem__(self, key: slice) -> "Traceback":  # noqa: F811
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def __getitem__(  # noqa: F811
 | 
			
		||||
        self, key: Union[int, slice]
 | 
			
		||||
    ) -> Union[TracebackEntry, "Traceback"]:
 | 
			
		||||
        if isinstance(key, slice):
 | 
			
		||||
            return self.__class__(super().__getitem__(key))
 | 
			
		||||
        else:
 | 
			
		||||
            return super().__getitem__(key)
 | 
			
		||||
 | 
			
		||||
    def filter(
 | 
			
		||||
        self, fn: Callable[[TracebackEntry], bool] = lambda x: not x.ishidden()
 | 
			
		||||
    ) -> "Traceback":
 | 
			
		||||
        """ return a Traceback instance with certain items removed
 | 
			
		||||
 | 
			
		||||
            fn is a function that gets a single argument, a TracebackEntry
 | 
			
		||||
| 
						 | 
				
			
			@ -341,7 +373,7 @@ class Traceback(list):
 | 
			
		|||
        """
 | 
			
		||||
        return Traceback(filter(fn, self), self._excinfo)
 | 
			
		||||
 | 
			
		||||
    def getcrashentry(self):
 | 
			
		||||
    def getcrashentry(self) -> TracebackEntry:
 | 
			
		||||
        """ return last non-hidden traceback entry that lead
 | 
			
		||||
        to the exception of a traceback.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -351,7 +383,7 @@ class Traceback(list):
 | 
			
		|||
                return entry
 | 
			
		||||
        return self[-1]
 | 
			
		||||
 | 
			
		||||
    def recursionindex(self):
 | 
			
		||||
    def recursionindex(self) -> Optional[int]:
 | 
			
		||||
        """ return the index of the frame/TracebackEntry where recursion
 | 
			
		||||
            originates if appropriate, None if no recursion occurred
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -541,7 +573,7 @@ class ExceptionInfo(Generic[_E]):
 | 
			
		|||
    def getrepr(
 | 
			
		||||
        self,
 | 
			
		||||
        showlocals: bool = False,
 | 
			
		||||
        style: str = "long",
 | 
			
		||||
        style: "_TracebackStyle" = "long",
 | 
			
		||||
        abspath: bool = False,
 | 
			
		||||
        tbfilter: bool = True,
 | 
			
		||||
        funcargs: bool = False,
 | 
			
		||||
| 
						 | 
				
			
			@ -619,16 +651,16 @@ class FormattedExcinfo:
 | 
			
		|||
    flow_marker = ">"
 | 
			
		||||
    fail_marker = "E"
 | 
			
		||||
 | 
			
		||||
    showlocals = attr.ib(default=False)
 | 
			
		||||
    style = attr.ib(default="long")
 | 
			
		||||
    abspath = attr.ib(default=True)
 | 
			
		||||
    tbfilter = attr.ib(default=True)
 | 
			
		||||
    funcargs = attr.ib(default=False)
 | 
			
		||||
    truncate_locals = attr.ib(default=True)
 | 
			
		||||
    chain = attr.ib(default=True)
 | 
			
		||||
    showlocals = attr.ib(type=bool, default=False)
 | 
			
		||||
    style = attr.ib(type="_TracebackStyle", default="long")
 | 
			
		||||
    abspath = attr.ib(type=bool, default=True)
 | 
			
		||||
    tbfilter = attr.ib(type=bool, default=True)
 | 
			
		||||
    funcargs = attr.ib(type=bool, default=False)
 | 
			
		||||
    truncate_locals = attr.ib(type=bool, default=True)
 | 
			
		||||
    chain = attr.ib(type=bool, default=True)
 | 
			
		||||
    astcache = attr.ib(default=attr.Factory(dict), init=False, repr=False)
 | 
			
		||||
 | 
			
		||||
    def _getindent(self, source):
 | 
			
		||||
    def _getindent(self, source: "Source") -> int:
 | 
			
		||||
        # figure out indent for given source
 | 
			
		||||
        try:
 | 
			
		||||
            s = str(source.getstatement(len(source) - 1))
 | 
			
		||||
| 
						 | 
				
			
			@ -643,20 +675,27 @@ class FormattedExcinfo:
 | 
			
		|||
                return 0
 | 
			
		||||
        return 4 + (len(s) - len(s.lstrip()))
 | 
			
		||||
 | 
			
		||||
    def _getentrysource(self, entry):
 | 
			
		||||
    def _getentrysource(self, entry: TracebackEntry) -> Optional["Source"]:
 | 
			
		||||
        source = entry.getsource(self.astcache)
 | 
			
		||||
        if source is not None:
 | 
			
		||||
            source = source.deindent()
 | 
			
		||||
        return source
 | 
			
		||||
 | 
			
		||||
    def repr_args(self, entry):
 | 
			
		||||
    def repr_args(self, entry: TracebackEntry) -> Optional["ReprFuncArgs"]:
 | 
			
		||||
        if self.funcargs:
 | 
			
		||||
            args = []
 | 
			
		||||
            for argname, argvalue in entry.frame.getargs(var=True):
 | 
			
		||||
                args.append((argname, saferepr(argvalue)))
 | 
			
		||||
            return ReprFuncArgs(args)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
 | 
			
		||||
    def get_source(
 | 
			
		||||
        self,
 | 
			
		||||
        source: "Source",
 | 
			
		||||
        line_index: int = -1,
 | 
			
		||||
        excinfo: Optional[ExceptionInfo] = None,
 | 
			
		||||
        short: bool = False,
 | 
			
		||||
    ) -> List[str]:
 | 
			
		||||
        """ return formatted and marked up source lines. """
 | 
			
		||||
        import _pytest._code
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -680,19 +719,21 @@ class FormattedExcinfo:
 | 
			
		|||
            lines.extend(self.get_exconly(excinfo, indent=indent, markall=True))
 | 
			
		||||
        return lines
 | 
			
		||||
 | 
			
		||||
    def get_exconly(self, excinfo, indent=4, markall=False):
 | 
			
		||||
    def get_exconly(
 | 
			
		||||
        self, excinfo: ExceptionInfo, indent: int = 4, markall: bool = False
 | 
			
		||||
    ) -> List[str]:
 | 
			
		||||
        lines = []
 | 
			
		||||
        indent = " " * indent
 | 
			
		||||
        indentstr = " " * indent
 | 
			
		||||
        # get the real exception information out
 | 
			
		||||
        exlines = excinfo.exconly(tryshort=True).split("\n")
 | 
			
		||||
        failindent = self.fail_marker + indent[1:]
 | 
			
		||||
        failindent = self.fail_marker + indentstr[1:]
 | 
			
		||||
        for line in exlines:
 | 
			
		||||
            lines.append(failindent + line)
 | 
			
		||||
            if not markall:
 | 
			
		||||
                failindent = indent
 | 
			
		||||
                failindent = indentstr
 | 
			
		||||
        return lines
 | 
			
		||||
 | 
			
		||||
    def repr_locals(self, locals):
 | 
			
		||||
    def repr_locals(self, locals: Dict[str, object]) -> Optional["ReprLocals"]:
 | 
			
		||||
        if self.showlocals:
 | 
			
		||||
            lines = []
 | 
			
		||||
            keys = [loc for loc in locals if loc[0] != "@"]
 | 
			
		||||
| 
						 | 
				
			
			@ -717,8 +758,11 @@ class FormattedExcinfo:
 | 
			
		|||
                    #    # XXX
 | 
			
		||||
                    #    pprint.pprint(value, stream=self.excinfowriter)
 | 
			
		||||
            return ReprLocals(lines)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def repr_traceback_entry(self, entry, excinfo=None):
 | 
			
		||||
    def repr_traceback_entry(
 | 
			
		||||
        self, entry: TracebackEntry, excinfo: Optional[ExceptionInfo] = None
 | 
			
		||||
    ) -> "ReprEntry":
 | 
			
		||||
        import _pytest._code
 | 
			
		||||
 | 
			
		||||
        source = self._getentrysource(entry)
 | 
			
		||||
| 
						 | 
				
			
			@ -729,9 +773,7 @@ class FormattedExcinfo:
 | 
			
		|||
            line_index = entry.lineno - entry.getfirstlinesource()
 | 
			
		||||
 | 
			
		||||
        lines = []  # type: List[str]
 | 
			
		||||
        style = entry._repr_style
 | 
			
		||||
        if style is None:
 | 
			
		||||
            style = self.style
 | 
			
		||||
        style = entry._repr_style if entry._repr_style is not None else self.style
 | 
			
		||||
        if style in ("short", "long"):
 | 
			
		||||
            short = style == "short"
 | 
			
		||||
            reprargs = self.repr_args(entry) if not short else None
 | 
			
		||||
| 
						 | 
				
			
			@ -761,7 +803,7 @@ class FormattedExcinfo:
 | 
			
		|||
                path = np
 | 
			
		||||
        return path
 | 
			
		||||
 | 
			
		||||
    def repr_traceback(self, excinfo):
 | 
			
		||||
    def repr_traceback(self, excinfo: ExceptionInfo) -> "ReprTraceback":
 | 
			
		||||
        traceback = excinfo.traceback
 | 
			
		||||
        if self.tbfilter:
 | 
			
		||||
            traceback = traceback.filter()
 | 
			
		||||
| 
						 | 
				
			
			@ -779,7 +821,9 @@ class FormattedExcinfo:
 | 
			
		|||
            entries.append(reprentry)
 | 
			
		||||
        return ReprTraceback(entries, extraline, style=self.style)
 | 
			
		||||
 | 
			
		||||
    def _truncate_recursive_traceback(self, traceback):
 | 
			
		||||
    def _truncate_recursive_traceback(
 | 
			
		||||
        self, traceback: Traceback
 | 
			
		||||
    ) -> Tuple[Traceback, Optional[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Truncate the given recursive traceback trying to find the starting point
 | 
			
		||||
        of the recursion.
 | 
			
		||||
| 
						 | 
				
			
			@ -806,7 +850,9 @@ class FormattedExcinfo:
 | 
			
		|||
                max_frames=max_frames,
 | 
			
		||||
                total=len(traceback),
 | 
			
		||||
            )  # type: Optional[str]
 | 
			
		||||
            traceback = traceback[:max_frames] + traceback[-max_frames:]
 | 
			
		||||
            # Type ignored because adding two instaces of a List subtype
 | 
			
		||||
            # currently incorrectly has type List instead of the subtype.
 | 
			
		||||
            traceback = traceback[:max_frames] + traceback[-max_frames:]  # type: ignore
 | 
			
		||||
        else:
 | 
			
		||||
            if recursionindex is not None:
 | 
			
		||||
                extraline = "!!! Recursion detected (same locals & position)"
 | 
			
		||||
| 
						 | 
				
			
			@ -863,7 +909,7 @@ class FormattedExcinfo:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class TerminalRepr:
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        # FYI this is called from pytest-xdist's serialization of exception
 | 
			
		||||
        # information.
 | 
			
		||||
        io = StringIO()
 | 
			
		||||
| 
						 | 
				
			
			@ -871,7 +917,7 @@ class TerminalRepr:
 | 
			
		|||
        self.toterminal(tw)
 | 
			
		||||
        return io.getvalue().strip()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return "<{} instance at {:0x}>".format(self.__class__, id(self))
 | 
			
		||||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -882,7 +928,7 @@ class ExceptionRepr(TerminalRepr):
 | 
			
		|||
    def __init__(self) -> None:
 | 
			
		||||
        self.sections = []  # type: List[Tuple[str, str, str]]
 | 
			
		||||
 | 
			
		||||
    def addsection(self, name, content, sep="-"):
 | 
			
		||||
    def addsection(self, name: str, content: str, sep: str = "-") -> None:
 | 
			
		||||
        self.sections.append((name, content, sep))
 | 
			
		||||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -892,7 +938,12 @@ class ExceptionRepr(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ExceptionChainRepr(ExceptionRepr):
 | 
			
		||||
    def __init__(self, chain):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        chain: Sequence[
 | 
			
		||||
            Tuple["ReprTraceback", Optional["ReprFileLocation"], Optional[str]]
 | 
			
		||||
        ],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.chain = chain
 | 
			
		||||
        # reprcrash and reprtraceback of the outermost (the newest) exception
 | 
			
		||||
| 
						 | 
				
			
			@ -910,7 +961,9 @@ class ExceptionChainRepr(ExceptionRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ReprExceptionInfo(ExceptionRepr):
 | 
			
		||||
    def __init__(self, reprtraceback, reprcrash):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, reprtraceback: "ReprTraceback", reprcrash: "ReprFileLocation"
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.reprtraceback = reprtraceback
 | 
			
		||||
        self.reprcrash = reprcrash
 | 
			
		||||
| 
						 | 
				
			
			@ -923,7 +976,12 @@ class ReprExceptionInfo(ExceptionRepr):
 | 
			
		|||
class ReprTraceback(TerminalRepr):
 | 
			
		||||
    entrysep = "_ "
 | 
			
		||||
 | 
			
		||||
    def __init__(self, reprentries, extraline, style):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        reprentries: Sequence[Union["ReprEntry", "ReprEntryNative"]],
 | 
			
		||||
        extraline: Optional[str],
 | 
			
		||||
        style: "_TracebackStyle",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.reprentries = reprentries
 | 
			
		||||
        self.extraline = extraline
 | 
			
		||||
        self.style = style
 | 
			
		||||
| 
						 | 
				
			
			@ -948,16 +1006,16 @@ class ReprTraceback(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ReprTracebackNative(ReprTraceback):
 | 
			
		||||
    def __init__(self, tblines):
 | 
			
		||||
    def __init__(self, tblines: Sequence[str]) -> None:
 | 
			
		||||
        self.style = "native"
 | 
			
		||||
        self.reprentries = [ReprEntryNative(tblines)]
 | 
			
		||||
        self.extraline = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReprEntryNative(TerminalRepr):
 | 
			
		||||
    style = "native"
 | 
			
		||||
    style = "native"  # type: _TracebackStyle
 | 
			
		||||
 | 
			
		||||
    def __init__(self, tblines):
 | 
			
		||||
    def __init__(self, tblines: Sequence[str]) -> None:
 | 
			
		||||
        self.lines = tblines
 | 
			
		||||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -965,7 +1023,14 @@ class ReprEntryNative(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ReprEntry(TerminalRepr):
 | 
			
		||||
    def __init__(self, lines, reprfuncargs, reprlocals, filelocrepr, style):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        lines: Sequence[str],
 | 
			
		||||
        reprfuncargs: Optional["ReprFuncArgs"],
 | 
			
		||||
        reprlocals: Optional["ReprLocals"],
 | 
			
		||||
        filelocrepr: Optional["ReprFileLocation"],
 | 
			
		||||
        style: "_TracebackStyle",
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.lines = lines
 | 
			
		||||
        self.reprfuncargs = reprfuncargs
 | 
			
		||||
        self.reprlocals = reprlocals
 | 
			
		||||
| 
						 | 
				
			
			@ -974,6 +1039,7 @@ class ReprEntry(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
        if self.style == "short":
 | 
			
		||||
            assert self.reprfileloc is not None
 | 
			
		||||
            self.reprfileloc.toterminal(tw)
 | 
			
		||||
            for line in self.lines:
 | 
			
		||||
                red = line.startswith("E   ")
 | 
			
		||||
| 
						 | 
				
			
			@ -992,14 +1058,14 @@ class ReprEntry(TerminalRepr):
 | 
			
		|||
                tw.line("")
 | 
			
		||||
            self.reprfileloc.toterminal(tw)
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        return "{}\n{}\n{}".format(
 | 
			
		||||
            "\n".join(self.lines), self.reprlocals, self.reprfileloc
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ReprFileLocation(TerminalRepr):
 | 
			
		||||
    def __init__(self, path, lineno, message):
 | 
			
		||||
    def __init__(self, path, lineno: int, message: str) -> None:
 | 
			
		||||
        self.path = str(path)
 | 
			
		||||
        self.lineno = lineno
 | 
			
		||||
        self.message = message
 | 
			
		||||
| 
						 | 
				
			
			@ -1016,7 +1082,7 @@ class ReprFileLocation(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ReprLocals(TerminalRepr):
 | 
			
		||||
    def __init__(self, lines):
 | 
			
		||||
    def __init__(self, lines: Sequence[str]) -> None:
 | 
			
		||||
        self.lines = lines
 | 
			
		||||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -1025,7 +1091,7 @@ class ReprLocals(TerminalRepr):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class ReprFuncArgs(TerminalRepr):
 | 
			
		||||
    def __init__(self, args):
 | 
			
		||||
    def __init__(self, args: Sequence[Tuple[str, object]]) -> None:
 | 
			
		||||
        self.args = args
 | 
			
		||||
 | 
			
		||||
    def toterminal(self, tw) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -1047,7 +1113,7 @@ class ReprFuncArgs(TerminalRepr):
 | 
			
		|||
            tw.line("")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def getrawcode(obj, trycall=True):
 | 
			
		||||
def getrawcode(obj, trycall: bool = True):
 | 
			
		||||
    """ return code object for given function. """
 | 
			
		||||
    try:
 | 
			
		||||
        return obj.__code__
 | 
			
		||||
| 
						 | 
				
			
			@ -1075,7 +1141,7 @@ _PYTEST_DIR = py.path.local(_pytest.__file__).dirpath()
 | 
			
		|||
_PY_DIR = py.path.local(py.__file__).dirpath()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def filter_traceback(entry):
 | 
			
		||||
def filter_traceback(entry: TracebackEntry) -> bool:
 | 
			
		||||
    """Return True if a TracebackEntry instance should be removed from tracebacks:
 | 
			
		||||
    * dynamically generated code (no code to show up for it);
 | 
			
		||||
    * internal traceback from pytest or its internal libraries, py and pluggy.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,18 +1,19 @@
 | 
			
		|||
import sys
 | 
			
		||||
from types import FrameType
 | 
			
		||||
from unittest import mock
 | 
			
		||||
 | 
			
		||||
import _pytest._code
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_ne():
 | 
			
		||||
def test_ne() -> None:
 | 
			
		||||
    code1 = _pytest._code.Code(compile('foo = "bar"', "", "exec"))
 | 
			
		||||
    assert code1 == code1
 | 
			
		||||
    code2 = _pytest._code.Code(compile('foo = "baz"', "", "exec"))
 | 
			
		||||
    assert code2 != code1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_gives_back_name_for_not_existing_file():
 | 
			
		||||
def test_code_gives_back_name_for_not_existing_file() -> None:
 | 
			
		||||
    name = "abc-123"
 | 
			
		||||
    co_code = compile("pass\n", name, "exec")
 | 
			
		||||
    assert co_code.co_filename == name
 | 
			
		||||
| 
						 | 
				
			
			@ -21,68 +22,67 @@ def test_code_gives_back_name_for_not_existing_file():
 | 
			
		|||
    assert code.fullsource is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_with_class():
 | 
			
		||||
def test_code_with_class() -> None:
 | 
			
		||||
    class A:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    pytest.raises(TypeError, _pytest._code.Code, A)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def x():
 | 
			
		||||
def x() -> None:
 | 
			
		||||
    raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_fullsource():
 | 
			
		||||
def test_code_fullsource() -> None:
 | 
			
		||||
    code = _pytest._code.Code(x)
 | 
			
		||||
    full = code.fullsource
 | 
			
		||||
    assert "test_code_fullsource()" in str(full)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_source():
 | 
			
		||||
def test_code_source() -> None:
 | 
			
		||||
    code = _pytest._code.Code(x)
 | 
			
		||||
    src = code.source()
 | 
			
		||||
    expected = """def x():
 | 
			
		||||
    expected = """def x() -> None:
 | 
			
		||||
    raise NotImplementedError()"""
 | 
			
		||||
    assert str(src) == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_frame_getsourcelineno_myself():
 | 
			
		||||
    def func():
 | 
			
		||||
def test_frame_getsourcelineno_myself() -> None:
 | 
			
		||||
    def func() -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    f = func()
 | 
			
		||||
    f = _pytest._code.Frame(f)
 | 
			
		||||
    f = _pytest._code.Frame(func())
 | 
			
		||||
    source, lineno = f.code.fullsource, f.lineno
 | 
			
		||||
    assert source is not None
 | 
			
		||||
    assert source[lineno].startswith("        return sys._getframe(0)")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getstatement_empty_fullsource():
 | 
			
		||||
    def func():
 | 
			
		||||
def test_getstatement_empty_fullsource() -> None:
 | 
			
		||||
    def func() -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    f = func()
 | 
			
		||||
    f = _pytest._code.Frame(f)
 | 
			
		||||
    f = _pytest._code.Frame(func())
 | 
			
		||||
    with mock.patch.object(f.code.__class__, "fullsource", None):
 | 
			
		||||
        assert f.statement == ""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_from_func():
 | 
			
		||||
def test_code_from_func() -> None:
 | 
			
		||||
    co = _pytest._code.Code(test_frame_getsourcelineno_myself)
 | 
			
		||||
    assert co.firstlineno
 | 
			
		||||
    assert co.path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unicode_handling():
 | 
			
		||||
def test_unicode_handling() -> None:
 | 
			
		||||
    value = "ąć".encode()
 | 
			
		||||
 | 
			
		||||
    def f():
 | 
			
		||||
    def f() -> None:
 | 
			
		||||
        raise Exception(value)
 | 
			
		||||
 | 
			
		||||
    excinfo = pytest.raises(Exception, f)
 | 
			
		||||
    str(excinfo)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_getargs():
 | 
			
		||||
def test_code_getargs() -> None:
 | 
			
		||||
    def f1(x):
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -108,26 +108,26 @@ def test_code_getargs():
 | 
			
		|||
    assert c4.getargs(var=True) == ("x", "y", "z")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_frame_getargs():
 | 
			
		||||
    def f1(x):
 | 
			
		||||
def test_frame_getargs() -> None:
 | 
			
		||||
    def f1(x) -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    fr1 = _pytest._code.Frame(f1("a"))
 | 
			
		||||
    assert fr1.getargs(var=True) == [("x", "a")]
 | 
			
		||||
 | 
			
		||||
    def f2(x, *y):
 | 
			
		||||
    def f2(x, *y) -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    fr2 = _pytest._code.Frame(f2("a", "b", "c"))
 | 
			
		||||
    assert fr2.getargs(var=True) == [("x", "a"), ("y", ("b", "c"))]
 | 
			
		||||
 | 
			
		||||
    def f3(x, **z):
 | 
			
		||||
    def f3(x, **z) -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    fr3 = _pytest._code.Frame(f3("a", b="c"))
 | 
			
		||||
    assert fr3.getargs(var=True) == [("x", "a"), ("z", {"b": "c"})]
 | 
			
		||||
 | 
			
		||||
    def f4(x, *y, **z):
 | 
			
		||||
    def f4(x, *y, **z) -> FrameType:
 | 
			
		||||
        return sys._getframe(0)
 | 
			
		||||
 | 
			
		||||
    fr4 = _pytest._code.Frame(f4("a", "b", c="d"))
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +135,7 @@ def test_frame_getargs():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class TestExceptionInfo:
 | 
			
		||||
    def test_bad_getsource(self):
 | 
			
		||||
    def test_bad_getsource(self) -> None:
 | 
			
		||||
        try:
 | 
			
		||||
            if False:
 | 
			
		||||
                pass
 | 
			
		||||
| 
						 | 
				
			
			@ -145,13 +145,13 @@ class TestExceptionInfo:
 | 
			
		|||
            exci = _pytest._code.ExceptionInfo.from_current()
 | 
			
		||||
        assert exci.getrepr()
 | 
			
		||||
 | 
			
		||||
    def test_from_current_with_missing(self):
 | 
			
		||||
    def test_from_current_with_missing(self) -> None:
 | 
			
		||||
        with pytest.raises(AssertionError, match="no current exception"):
 | 
			
		||||
            _pytest._code.ExceptionInfo.from_current()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTracebackEntry:
 | 
			
		||||
    def test_getsource(self):
 | 
			
		||||
    def test_getsource(self) -> None:
 | 
			
		||||
        try:
 | 
			
		||||
            if False:
 | 
			
		||||
                pass
 | 
			
		||||
| 
						 | 
				
			
			@ -161,12 +161,13 @@ class TestTracebackEntry:
 | 
			
		|||
            exci = _pytest._code.ExceptionInfo.from_current()
 | 
			
		||||
        entry = exci.traceback[0]
 | 
			
		||||
        source = entry.getsource()
 | 
			
		||||
        assert source is not None
 | 
			
		||||
        assert len(source) == 6
 | 
			
		||||
        assert "assert False" in source[5]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestReprFuncArgs:
 | 
			
		||||
    def test_not_raise_exception_with_mixed_encoding(self, tw_mock):
 | 
			
		||||
    def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
 | 
			
		||||
        from _pytest._code.code import ReprFuncArgs
 | 
			
		||||
 | 
			
		||||
        args = [("unicode_string", "São Paulo"), ("utf8_string", b"S\xc3\xa3o Paulo")]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ import os
 | 
			
		|||
import queue
 | 
			
		||||
import sys
 | 
			
		||||
import textwrap
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
import py
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -224,23 +225,25 @@ class TestTraceback_f_g_h:
 | 
			
		|||
        repr = excinfo.getrepr()
 | 
			
		||||
        assert "RuntimeError: hello" in str(repr.reprcrash)
 | 
			
		||||
 | 
			
		||||
    def test_traceback_no_recursion_index(self):
 | 
			
		||||
        def do_stuff():
 | 
			
		||||
    def test_traceback_no_recursion_index(self) -> None:
 | 
			
		||||
        def do_stuff() -> None:
 | 
			
		||||
            raise RuntimeError
 | 
			
		||||
 | 
			
		||||
        def reraise_me():
 | 
			
		||||
        def reraise_me() -> None:
 | 
			
		||||
            import sys
 | 
			
		||||
 | 
			
		||||
            exc, val, tb = sys.exc_info()
 | 
			
		||||
            assert val is not None
 | 
			
		||||
            raise val.with_traceback(tb)
 | 
			
		||||
 | 
			
		||||
        def f(n):
 | 
			
		||||
        def f(n: int) -> None:
 | 
			
		||||
            try:
 | 
			
		||||
                do_stuff()
 | 
			
		||||
            except:  # noqa
 | 
			
		||||
                reraise_me()
 | 
			
		||||
 | 
			
		||||
        excinfo = pytest.raises(RuntimeError, f, 8)
 | 
			
		||||
        assert excinfo is not None
 | 
			
		||||
        traceback = excinfo.traceback
 | 
			
		||||
        recindex = traceback.recursionindex()
 | 
			
		||||
        assert recindex is None
 | 
			
		||||
| 
						 | 
				
			
			@ -596,7 +599,6 @@ raise ValueError()
 | 
			
		|||
        assert lines[3] == "E       world"
 | 
			
		||||
        assert not lines[4:]
 | 
			
		||||
 | 
			
		||||
        loc = repr_entry.reprlocals is not None
 | 
			
		||||
        loc = repr_entry.reprfileloc
 | 
			
		||||
        assert loc.path == mod.__file__
 | 
			
		||||
        assert loc.lineno == 3
 | 
			
		||||
| 
						 | 
				
			
			@ -1286,9 +1288,10 @@ raise ValueError()
 | 
			
		|||
@pytest.mark.parametrize("style", ["short", "long"])
 | 
			
		||||
@pytest.mark.parametrize("encoding", [None, "utf8", "utf16"])
 | 
			
		||||
def test_repr_traceback_with_unicode(style, encoding):
 | 
			
		||||
    msg = "☹"
 | 
			
		||||
    if encoding is not None:
 | 
			
		||||
        msg = msg.encode(encoding)
 | 
			
		||||
    if encoding is None:
 | 
			
		||||
        msg = "☹"  # type: Union[str, bytes]
 | 
			
		||||
    else:
 | 
			
		||||
        msg = "☹".encode(encoding)
 | 
			
		||||
    try:
 | 
			
		||||
        raise RuntimeError(msg)
 | 
			
		||||
    except RuntimeError:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,13 +4,16 @@
 | 
			
		|||
import ast
 | 
			
		||||
import inspect
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Any
 | 
			
		||||
from typing import Dict
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import _pytest._code
 | 
			
		||||
import pytest
 | 
			
		||||
from _pytest._code import Source
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_str_function():
 | 
			
		||||
def test_source_str_function() -> None:
 | 
			
		||||
    x = Source("3")
 | 
			
		||||
    assert str(x) == "3"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -25,7 +28,7 @@ def test_source_str_function():
 | 
			
		|||
    assert str(x) == "\n3"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unicode():
 | 
			
		||||
def test_unicode() -> None:
 | 
			
		||||
    x = Source("4")
 | 
			
		||||
    assert str(x) == "4"
 | 
			
		||||
    co = _pytest._code.compile('"å"', mode="eval")
 | 
			
		||||
| 
						 | 
				
			
			@ -33,12 +36,12 @@ def test_unicode():
 | 
			
		|||
    assert isinstance(val, str)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_from_function():
 | 
			
		||||
def test_source_from_function() -> None:
 | 
			
		||||
    source = _pytest._code.Source(test_source_str_function)
 | 
			
		||||
    assert str(source).startswith("def test_source_str_function():")
 | 
			
		||||
    assert str(source).startswith("def test_source_str_function() -> None:")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_from_method():
 | 
			
		||||
def test_source_from_method() -> None:
 | 
			
		||||
    class TestClass:
 | 
			
		||||
        def test_method(self):
 | 
			
		||||
            pass
 | 
			
		||||
| 
						 | 
				
			
			@ -47,13 +50,13 @@ def test_source_from_method():
 | 
			
		|||
    assert source.lines == ["def test_method(self):", "    pass"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_from_lines():
 | 
			
		||||
def test_source_from_lines() -> None:
 | 
			
		||||
    lines = ["a \n", "b\n", "c"]
 | 
			
		||||
    source = _pytest._code.Source(lines)
 | 
			
		||||
    assert source.lines == ["a ", "b", "c"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_from_inner_function():
 | 
			
		||||
def test_source_from_inner_function() -> None:
 | 
			
		||||
    def f():
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +66,7 @@ def test_source_from_inner_function():
 | 
			
		|||
    assert str(source).startswith("def f():")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_putaround_simple():
 | 
			
		||||
def test_source_putaround_simple() -> None:
 | 
			
		||||
    source = Source("raise ValueError")
 | 
			
		||||
    source = source.putaround(
 | 
			
		||||
        "try:",
 | 
			
		||||
| 
						 | 
				
			
			@ -85,7 +88,7 @@ else:
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_putaround():
 | 
			
		||||
def test_source_putaround() -> None:
 | 
			
		||||
    source = Source()
 | 
			
		||||
    source = source.putaround(
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -96,28 +99,29 @@ def test_source_putaround():
 | 
			
		|||
    assert str(source).strip() == "if 1:\n    x=1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_strips():
 | 
			
		||||
def test_source_strips() -> None:
 | 
			
		||||
    source = Source("")
 | 
			
		||||
    assert source == Source()
 | 
			
		||||
    assert str(source) == ""
 | 
			
		||||
    assert source.strip() == source
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_strip_multiline():
 | 
			
		||||
def test_source_strip_multiline() -> None:
 | 
			
		||||
    source = Source()
 | 
			
		||||
    source.lines = ["", " hello", "  "]
 | 
			
		||||
    source2 = source.strip()
 | 
			
		||||
    assert source2.lines == [" hello"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_syntaxerror_rerepresentation():
 | 
			
		||||
def test_syntaxerror_rerepresentation() -> None:
 | 
			
		||||
    ex = pytest.raises(SyntaxError, _pytest._code.compile, "xyz xyz")
 | 
			
		||||
    assert ex is not None
 | 
			
		||||
    assert ex.value.lineno == 1
 | 
			
		||||
    assert ex.value.offset in {5, 7}  # cpython: 7, pypy3.6 7.1.1: 5
 | 
			
		||||
    assert ex.value.text.strip(), "x x"
 | 
			
		||||
    assert ex.value.text == "xyz xyz\n"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_isparseable():
 | 
			
		||||
def test_isparseable() -> None:
 | 
			
		||||
    assert Source("hello").isparseable()
 | 
			
		||||
    assert Source("if 1:\n  pass").isparseable()
 | 
			
		||||
    assert Source(" \nif 1:\n  pass").isparseable()
 | 
			
		||||
| 
						 | 
				
			
			@ -127,7 +131,7 @@ def test_isparseable():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class TestAccesses:
 | 
			
		||||
    def setup_class(self):
 | 
			
		||||
    def setup_class(self) -> None:
 | 
			
		||||
        self.source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
            def f(x):
 | 
			
		||||
| 
						 | 
				
			
			@ -137,26 +141,26 @@ class TestAccesses:
 | 
			
		|||
        """
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_getrange(self):
 | 
			
		||||
    def test_getrange(self) -> None:
 | 
			
		||||
        x = self.source[0:2]
 | 
			
		||||
        assert x.isparseable()
 | 
			
		||||
        assert len(x.lines) == 2
 | 
			
		||||
        assert str(x) == "def f(x):\n    pass"
 | 
			
		||||
 | 
			
		||||
    def test_getline(self):
 | 
			
		||||
    def test_getline(self) -> None:
 | 
			
		||||
        x = self.source[0]
 | 
			
		||||
        assert x == "def f(x):"
 | 
			
		||||
 | 
			
		||||
    def test_len(self):
 | 
			
		||||
    def test_len(self) -> None:
 | 
			
		||||
        assert len(self.source) == 4
 | 
			
		||||
 | 
			
		||||
    def test_iter(self):
 | 
			
		||||
    def test_iter(self) -> None:
 | 
			
		||||
        values = [x for x in self.source]
 | 
			
		||||
        assert len(values) == 4
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSourceParsingAndCompiling:
 | 
			
		||||
    def setup_class(self):
 | 
			
		||||
    def setup_class(self) -> None:
 | 
			
		||||
        self.source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
            def f(x):
 | 
			
		||||
| 
						 | 
				
			
			@ -166,19 +170,19 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        """
 | 
			
		||||
        ).strip()
 | 
			
		||||
 | 
			
		||||
    def test_compile(self):
 | 
			
		||||
    def test_compile(self) -> None:
 | 
			
		||||
        co = _pytest._code.compile("x=3")
 | 
			
		||||
        d = {}
 | 
			
		||||
        d = {}  # type: Dict[str, Any]
 | 
			
		||||
        exec(co, d)
 | 
			
		||||
        assert d["x"] == 3
 | 
			
		||||
 | 
			
		||||
    def test_compile_and_getsource_simple(self):
 | 
			
		||||
    def test_compile_and_getsource_simple(self) -> None:
 | 
			
		||||
        co = _pytest._code.compile("x=3")
 | 
			
		||||
        exec(co)
 | 
			
		||||
        source = _pytest._code.Source(co)
 | 
			
		||||
        assert str(source) == "x=3"
 | 
			
		||||
 | 
			
		||||
    def test_compile_and_getsource_through_same_function(self):
 | 
			
		||||
    def test_compile_and_getsource_through_same_function(self) -> None:
 | 
			
		||||
        def gensource(source):
 | 
			
		||||
            return _pytest._code.compile(source)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -199,7 +203,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        source2 = inspect.getsource(co2)
 | 
			
		||||
        assert "ValueError" in source2
 | 
			
		||||
 | 
			
		||||
    def test_getstatement(self):
 | 
			
		||||
    def test_getstatement(self) -> None:
 | 
			
		||||
        # print str(self.source)
 | 
			
		||||
        ass = str(self.source[1:])
 | 
			
		||||
        for i in range(1, 4):
 | 
			
		||||
| 
						 | 
				
			
			@ -208,7 +212,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
            # x = s.deindent()
 | 
			
		||||
            assert str(s) == ass
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_triple_quoted(self):
 | 
			
		||||
    def test_getstatementrange_triple_quoted(self) -> None:
 | 
			
		||||
        # print str(self.source)
 | 
			
		||||
        source = Source(
 | 
			
		||||
            """hello('''
 | 
			
		||||
| 
						 | 
				
			
			@ -219,7 +223,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        s = source.getstatement(1)
 | 
			
		||||
        assert s == str(source)
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_within_constructs(self):
 | 
			
		||||
    def test_getstatementrange_within_constructs(self) -> None:
 | 
			
		||||
        source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -241,7 +245,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        # assert source.getstatementrange(5) == (0, 7)
 | 
			
		||||
        assert source.getstatementrange(6) == (6, 7)
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_bug(self):
 | 
			
		||||
    def test_getstatementrange_bug(self) -> None:
 | 
			
		||||
        source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
            try:
 | 
			
		||||
| 
						 | 
				
			
			@ -255,7 +259,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        assert len(source) == 6
 | 
			
		||||
        assert source.getstatementrange(2) == (1, 4)
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_bug2(self):
 | 
			
		||||
    def test_getstatementrange_bug2(self) -> None:
 | 
			
		||||
        source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
            assert (
 | 
			
		||||
| 
						 | 
				
			
			@ -272,7 +276,7 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        assert len(source) == 9
 | 
			
		||||
        assert source.getstatementrange(5) == (0, 9)
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_ast_issue58(self):
 | 
			
		||||
    def test_getstatementrange_ast_issue58(self) -> None:
 | 
			
		||||
        source = Source(
 | 
			
		||||
            """\
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -286,38 +290,44 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        assert getstatement(2, source).lines == source.lines[2:3]
 | 
			
		||||
        assert getstatement(3, source).lines == source.lines[3:4]
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_out_of_bounds_py3(self):
 | 
			
		||||
    def test_getstatementrange_out_of_bounds_py3(self) -> None:
 | 
			
		||||
        source = Source("if xxx:\n   from .collections import something")
 | 
			
		||||
        r = source.getstatementrange(1)
 | 
			
		||||
        assert r == (1, 2)
 | 
			
		||||
 | 
			
		||||
    def test_getstatementrange_with_syntaxerror_issue7(self):
 | 
			
		||||
    def test_getstatementrange_with_syntaxerror_issue7(self) -> None:
 | 
			
		||||
        source = Source(":")
 | 
			
		||||
        pytest.raises(SyntaxError, lambda: source.getstatementrange(0))
 | 
			
		||||
 | 
			
		||||
    def test_compile_to_ast(self):
 | 
			
		||||
    def test_compile_to_ast(self) -> None:
 | 
			
		||||
        source = Source("x = 4")
 | 
			
		||||
        mod = source.compile(flag=ast.PyCF_ONLY_AST)
 | 
			
		||||
        assert isinstance(mod, ast.Module)
 | 
			
		||||
        compile(mod, "<filename>", "exec")
 | 
			
		||||
 | 
			
		||||
    def test_compile_and_getsource(self):
 | 
			
		||||
    def test_compile_and_getsource(self) -> None:
 | 
			
		||||
        co = self.source.compile()
 | 
			
		||||
        exec(co, globals())
 | 
			
		||||
        f(7)
 | 
			
		||||
        excinfo = pytest.raises(AssertionError, f, 6)
 | 
			
		||||
        f(7)  # type: ignore
 | 
			
		||||
        excinfo = pytest.raises(AssertionError, f, 6)  # type: ignore
 | 
			
		||||
        assert excinfo is not None
 | 
			
		||||
        frame = excinfo.traceback[-1].frame
 | 
			
		||||
        assert isinstance(frame.code.fullsource, Source)
 | 
			
		||||
        stmt = frame.code.fullsource.getstatement(frame.lineno)
 | 
			
		||||
        assert str(stmt).strip().startswith("assert")
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize("name", ["", None, "my"])
 | 
			
		||||
    def test_compilefuncs_and_path_sanity(self, name):
 | 
			
		||||
    def test_compilefuncs_and_path_sanity(self, name: Optional[str]) -> None:
 | 
			
		||||
        def check(comp, name):
 | 
			
		||||
            co = comp(self.source, name)
 | 
			
		||||
            if not name:
 | 
			
		||||
                expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2)
 | 
			
		||||
                expected = "codegen %s:%d>" % (mypath, mylineno + 2 + 2)  # type: ignore
 | 
			
		||||
            else:
 | 
			
		||||
                expected = "codegen %r %s:%d>" % (name, mypath, mylineno + 2 + 2)
 | 
			
		||||
                expected = "codegen %r %s:%d>" % (
 | 
			
		||||
                    name,
 | 
			
		||||
                    mypath,  # type: ignore
 | 
			
		||||
                    mylineno + 2 + 2,  # type: ignore
 | 
			
		||||
                )  # type: ignore
 | 
			
		||||
            fn = co.co_filename
 | 
			
		||||
            assert fn.endswith(expected)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -332,9 +342,9 @@ class TestSourceParsingAndCompiling:
 | 
			
		|||
        pytest.raises(SyntaxError, _pytest._code.compile, "lambda a,a: 0", mode="eval")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getstartingblock_singleline():
 | 
			
		||||
def test_getstartingblock_singleline() -> None:
 | 
			
		||||
    class A:
 | 
			
		||||
        def __init__(self, *args):
 | 
			
		||||
        def __init__(self, *args) -> None:
 | 
			
		||||
            frame = sys._getframe(1)
 | 
			
		||||
            self.source = _pytest._code.Frame(frame).statement
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -344,22 +354,22 @@ def test_getstartingblock_singleline():
 | 
			
		|||
    assert len(values) == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getline_finally():
 | 
			
		||||
    def c():
 | 
			
		||||
def test_getline_finally() -> None:
 | 
			
		||||
    def c() -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    with pytest.raises(TypeError) as excinfo:
 | 
			
		||||
        teardown = None
 | 
			
		||||
        try:
 | 
			
		||||
            c(1)
 | 
			
		||||
            c(1)  # type: ignore
 | 
			
		||||
        finally:
 | 
			
		||||
            if teardown:
 | 
			
		||||
                teardown()
 | 
			
		||||
    source = excinfo.traceback[-1].statement
 | 
			
		||||
    assert str(source).strip() == "c(1)"
 | 
			
		||||
    assert str(source).strip() == "c(1)  # type: ignore"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getfuncsource_dynamic():
 | 
			
		||||
def test_getfuncsource_dynamic() -> None:
 | 
			
		||||
    source = """
 | 
			
		||||
        def f():
 | 
			
		||||
            raise ValueError
 | 
			
		||||
| 
						 | 
				
			
			@ -368,11 +378,13 @@ def test_getfuncsource_dynamic():
 | 
			
		|||
    """
 | 
			
		||||
    co = _pytest._code.compile(source)
 | 
			
		||||
    exec(co, globals())
 | 
			
		||||
    assert str(_pytest._code.Source(f)).strip() == "def f():\n    raise ValueError"
 | 
			
		||||
    assert str(_pytest._code.Source(g)).strip() == "def g(): pass"
 | 
			
		||||
    f_source = _pytest._code.Source(f)  # type: ignore
 | 
			
		||||
    g_source = _pytest._code.Source(g)  # type: ignore
 | 
			
		||||
    assert str(f_source).strip() == "def f():\n    raise ValueError"
 | 
			
		||||
    assert str(g_source).strip() == "def g(): pass"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getfuncsource_with_multine_string():
 | 
			
		||||
def test_getfuncsource_with_multine_string() -> None:
 | 
			
		||||
    def f():
 | 
			
		||||
        c = """while True:
 | 
			
		||||
    pass
 | 
			
		||||
| 
						 | 
				
			
			@ -387,7 +399,7 @@ def test_getfuncsource_with_multine_string():
 | 
			
		|||
    assert str(_pytest._code.Source(f)) == expected.rstrip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_deindent():
 | 
			
		||||
def test_deindent() -> None:
 | 
			
		||||
    from _pytest._code.source import deindent as deindent
 | 
			
		||||
 | 
			
		||||
    assert deindent(["\tfoo", "\tbar"]) == ["foo", "bar"]
 | 
			
		||||
| 
						 | 
				
			
			@ -401,7 +413,7 @@ def test_deindent():
 | 
			
		|||
    assert lines == ["def f():", "    def g():", "        pass"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot):
 | 
			
		||||
def test_source_of_class_at_eof_without_newline(tmpdir, _sys_snapshot) -> None:
 | 
			
		||||
    # this test fails because the implicit inspect.getsource(A) below
 | 
			
		||||
    # does not return the "x = 1" last line.
 | 
			
		||||
    source = _pytest._code.Source(
 | 
			
		||||
| 
						 | 
				
			
			@ -423,7 +435,7 @@ if True:
 | 
			
		|||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getsource_fallback():
 | 
			
		||||
def test_getsource_fallback() -> None:
 | 
			
		||||
    from _pytest._code.source import getsource
 | 
			
		||||
 | 
			
		||||
    expected = """def x():
 | 
			
		||||
| 
						 | 
				
			
			@ -432,7 +444,7 @@ def test_getsource_fallback():
 | 
			
		|||
    assert src == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_idem_compile_and_getsource():
 | 
			
		||||
def test_idem_compile_and_getsource() -> None:
 | 
			
		||||
    from _pytest._code.source import getsource
 | 
			
		||||
 | 
			
		||||
    expected = "def x(): pass"
 | 
			
		||||
| 
						 | 
				
			
			@ -441,15 +453,16 @@ def test_idem_compile_and_getsource():
 | 
			
		|||
    assert src == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_findsource_fallback():
 | 
			
		||||
def test_findsource_fallback() -> None:
 | 
			
		||||
    from _pytest._code.source import findsource
 | 
			
		||||
 | 
			
		||||
    src, lineno = findsource(x)
 | 
			
		||||
    assert src is not None
 | 
			
		||||
    assert "test_findsource_simple" in str(src)
 | 
			
		||||
    assert src[lineno] == "    def x():"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_findsource():
 | 
			
		||||
def test_findsource() -> None:
 | 
			
		||||
    from _pytest._code.source import findsource
 | 
			
		||||
 | 
			
		||||
    co = _pytest._code.compile(
 | 
			
		||||
| 
						 | 
				
			
			@ -460,19 +473,21 @@ def test_findsource():
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
    src, lineno = findsource(co)
 | 
			
		||||
    assert src is not None
 | 
			
		||||
    assert "if 1:" in str(src)
 | 
			
		||||
 | 
			
		||||
    d = {}
 | 
			
		||||
    d = {}  # type: Dict[str, Any]
 | 
			
		||||
    eval(co, d)
 | 
			
		||||
    src, lineno = findsource(d["x"])
 | 
			
		||||
    assert src is not None
 | 
			
		||||
    assert "if 1:" in str(src)
 | 
			
		||||
    assert src[lineno] == "    def x():"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getfslineno():
 | 
			
		||||
def test_getfslineno() -> None:
 | 
			
		||||
    from _pytest._code import getfslineno
 | 
			
		||||
 | 
			
		||||
    def f(x):
 | 
			
		||||
    def f(x) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    fspath, lineno = getfslineno(f)
 | 
			
		||||
| 
						 | 
				
			
			@ -498,40 +513,40 @@ def test_getfslineno():
 | 
			
		|||
    assert getfslineno(B)[1] == -1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_code_of_object_instance_with_call():
 | 
			
		||||
def test_code_of_object_instance_with_call() -> None:
 | 
			
		||||
    class A:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    pytest.raises(TypeError, lambda: _pytest._code.Source(A()))
 | 
			
		||||
 | 
			
		||||
    class WithCall:
 | 
			
		||||
        def __call__(self):
 | 
			
		||||
        def __call__(self) -> None:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    code = _pytest._code.Code(WithCall())
 | 
			
		||||
    assert "pass" in str(code.source())
 | 
			
		||||
 | 
			
		||||
    class Hello:
 | 
			
		||||
        def __call__(self):
 | 
			
		||||
        def __call__(self) -> None:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
    pytest.raises(TypeError, lambda: _pytest._code.Code(Hello))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def getstatement(lineno, source):
 | 
			
		||||
def getstatement(lineno: int, source) -> Source:
 | 
			
		||||
    from _pytest._code.source import getstatementrange_ast
 | 
			
		||||
 | 
			
		||||
    source = _pytest._code.Source(source, deindent=False)
 | 
			
		||||
    ast, start, end = getstatementrange_ast(lineno, source)
 | 
			
		||||
    return source[start:end]
 | 
			
		||||
    src = _pytest._code.Source(source, deindent=False)
 | 
			
		||||
    ast, start, end = getstatementrange_ast(lineno, src)
 | 
			
		||||
    return src[start:end]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_oneline():
 | 
			
		||||
def test_oneline() -> None:
 | 
			
		||||
    source = getstatement(0, "raise ValueError")
 | 
			
		||||
    assert str(source) == "raise ValueError"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_comment_and_no_newline_at_end():
 | 
			
		||||
def test_comment_and_no_newline_at_end() -> None:
 | 
			
		||||
    from _pytest._code.source import getstatementrange_ast
 | 
			
		||||
 | 
			
		||||
    source = Source(
 | 
			
		||||
| 
						 | 
				
			
			@ -545,12 +560,12 @@ def test_comment_and_no_newline_at_end():
 | 
			
		|||
    assert end == 2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_oneline_and_comment():
 | 
			
		||||
def test_oneline_and_comment() -> None:
 | 
			
		||||
    source = getstatement(0, "raise ValueError\n#hello")
 | 
			
		||||
    assert str(source) == "raise ValueError"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_comments():
 | 
			
		||||
def test_comments() -> None:
 | 
			
		||||
    source = '''def test():
 | 
			
		||||
    "comment 1"
 | 
			
		||||
    x = 1
 | 
			
		||||
| 
						 | 
				
			
			@ -576,7 +591,7 @@ comment 4
 | 
			
		|||
        assert str(getstatement(line, source)) == '"""\ncomment 4\n"""'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_comment_in_statement():
 | 
			
		||||
def test_comment_in_statement() -> None:
 | 
			
		||||
    source = """test(foo=1,
 | 
			
		||||
    # comment 1
 | 
			
		||||
    bar=2)
 | 
			
		||||
| 
						 | 
				
			
			@ -588,17 +603,17 @@ def test_comment_in_statement():
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_single_line_else():
 | 
			
		||||
def test_single_line_else() -> None:
 | 
			
		||||
    source = getstatement(1, "if False: 2\nelse: 3")
 | 
			
		||||
    assert str(source) == "else: 3"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_single_line_finally():
 | 
			
		||||
def test_single_line_finally() -> None:
 | 
			
		||||
    source = getstatement(1, "try: 1\nfinally: 3")
 | 
			
		||||
    assert str(source) == "finally: 3"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_issue55():
 | 
			
		||||
def test_issue55() -> None:
 | 
			
		||||
    source = (
 | 
			
		||||
        "def round_trip(dinp):\n  assert 1 == dinp\n"
 | 
			
		||||
        'def test_rt():\n  round_trip("""\n""")\n'
 | 
			
		||||
| 
						 | 
				
			
			@ -607,7 +622,7 @@ def test_issue55():
 | 
			
		|||
    assert str(s) == '  round_trip("""\n""")'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_multiline():
 | 
			
		||||
def test_multiline() -> None:
 | 
			
		||||
    source = getstatement(
 | 
			
		||||
        0,
 | 
			
		||||
        """\
 | 
			
		||||
| 
						 | 
				
			
			@ -621,7 +636,7 @@ x = 3
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class TestTry:
 | 
			
		||||
    def setup_class(self):
 | 
			
		||||
    def setup_class(self) -> None:
 | 
			
		||||
        self.source = """\
 | 
			
		||||
try:
 | 
			
		||||
    raise ValueError
 | 
			
		||||
| 
						 | 
				
			
			@ -631,25 +646,25 @@ else:
 | 
			
		|||
    raise KeyError()
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
    def test_body(self):
 | 
			
		||||
    def test_body(self) -> None:
 | 
			
		||||
        source = getstatement(1, self.source)
 | 
			
		||||
        assert str(source) == "    raise ValueError"
 | 
			
		||||
 | 
			
		||||
    def test_except_line(self):
 | 
			
		||||
    def test_except_line(self) -> None:
 | 
			
		||||
        source = getstatement(2, self.source)
 | 
			
		||||
        assert str(source) == "except Something:"
 | 
			
		||||
 | 
			
		||||
    def test_except_body(self):
 | 
			
		||||
    def test_except_body(self) -> None:
 | 
			
		||||
        source = getstatement(3, self.source)
 | 
			
		||||
        assert str(source) == "    raise IndexError(1)"
 | 
			
		||||
 | 
			
		||||
    def test_else(self):
 | 
			
		||||
    def test_else(self) -> None:
 | 
			
		||||
        source = getstatement(5, self.source)
 | 
			
		||||
        assert str(source) == "    raise KeyError()"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestTryFinally:
 | 
			
		||||
    def setup_class(self):
 | 
			
		||||
    def setup_class(self) -> None:
 | 
			
		||||
        self.source = """\
 | 
			
		||||
try:
 | 
			
		||||
    raise ValueError
 | 
			
		||||
| 
						 | 
				
			
			@ -657,17 +672,17 @@ finally:
 | 
			
		|||
    raise IndexError(1)
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
    def test_body(self):
 | 
			
		||||
    def test_body(self) -> None:
 | 
			
		||||
        source = getstatement(1, self.source)
 | 
			
		||||
        assert str(source) == "    raise ValueError"
 | 
			
		||||
 | 
			
		||||
    def test_finally(self):
 | 
			
		||||
    def test_finally(self) -> None:
 | 
			
		||||
        source = getstatement(3, self.source)
 | 
			
		||||
        assert str(source) == "    raise IndexError(1)"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestIf:
 | 
			
		||||
    def setup_class(self):
 | 
			
		||||
    def setup_class(self) -> None:
 | 
			
		||||
        self.source = """\
 | 
			
		||||
if 1:
 | 
			
		||||
    y = 3
 | 
			
		||||
| 
						 | 
				
			
			@ -677,24 +692,24 @@ else:
 | 
			
		|||
    y = 7
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
    def test_body(self):
 | 
			
		||||
    def test_body(self) -> None:
 | 
			
		||||
        source = getstatement(1, self.source)
 | 
			
		||||
        assert str(source) == "    y = 3"
 | 
			
		||||
 | 
			
		||||
    def test_elif_clause(self):
 | 
			
		||||
    def test_elif_clause(self) -> None:
 | 
			
		||||
        source = getstatement(2, self.source)
 | 
			
		||||
        assert str(source) == "elif False:"
 | 
			
		||||
 | 
			
		||||
    def test_elif(self):
 | 
			
		||||
    def test_elif(self) -> None:
 | 
			
		||||
        source = getstatement(3, self.source)
 | 
			
		||||
        assert str(source) == "    y = 5"
 | 
			
		||||
 | 
			
		||||
    def test_else(self):
 | 
			
		||||
    def test_else(self) -> None:
 | 
			
		||||
        source = getstatement(5, self.source)
 | 
			
		||||
        assert str(source) == "    y = 7"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_semicolon():
 | 
			
		||||
def test_semicolon() -> None:
 | 
			
		||||
    s = """\
 | 
			
		||||
hello ; pytest.skip()
 | 
			
		||||
"""
 | 
			
		||||
| 
						 | 
				
			
			@ -702,7 +717,7 @@ hello ; pytest.skip()
 | 
			
		|||
    assert str(source) == s.strip()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_def_online():
 | 
			
		||||
def test_def_online() -> None:
 | 
			
		||||
    s = """\
 | 
			
		||||
def func(): raise ValueError(42)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -713,7 +728,7 @@ def something():
 | 
			
		|||
    assert str(source) == "def func(): raise ValueError(42)"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def XXX_test_expression_multiline():
 | 
			
		||||
def XXX_test_expression_multiline() -> None:
 | 
			
		||||
    source = """\
 | 
			
		||||
something
 | 
			
		||||
'''
 | 
			
		||||
| 
						 | 
				
			
			@ -722,7 +737,7 @@ something
 | 
			
		|||
    assert str(result) == "'''\n'''"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_getstartingblock_multiline():
 | 
			
		||||
def test_getstartingblock_multiline() -> None:
 | 
			
		||||
    class A:
 | 
			
		||||
        def __init__(self, *args):
 | 
			
		||||
            frame = sys._getframe(1)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue