Add a few missing type annotations in _pytest._code
These are more "dirty" than the previous batch (that's why they were left out). The trouble is that `compile` can return either a code object or an AST depending on a flag, so we need to add an overload to make the common case Union free. But it's still worthwhile.
This commit is contained in:
		
							parent
							
								
									3e6f0f34ff
								
							
						
					
					
						commit
						0c247be769
					
				| 
						 | 
					@ -67,7 +67,7 @@ class Code:
 | 
				
			||||||
        return not self == other
 | 
					        return not self == other
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def path(self):
 | 
					    def path(self) -> Union[py.path.local, str]:
 | 
				
			||||||
        """ return a path object pointing to source code (note that it
 | 
					        """ return a path object pointing to source code (note that it
 | 
				
			||||||
        might not point to an actually existing file). """
 | 
					        might not point to an actually existing file). """
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -335,7 +335,7 @@ class Traceback(List[TracebackEntry]):
 | 
				
			||||||
                (path is None or codepath == path)
 | 
					                (path is None or codepath == path)
 | 
				
			||||||
                and (
 | 
					                and (
 | 
				
			||||||
                    excludepath is None
 | 
					                    excludepath is None
 | 
				
			||||||
                    or not hasattr(codepath, "relto")
 | 
					                    or not isinstance(codepath, py.path.local)
 | 
				
			||||||
                    or not codepath.relto(excludepath)
 | 
					                    or not codepath.relto(excludepath)
 | 
				
			||||||
                )
 | 
					                )
 | 
				
			||||||
                and (lineno is None or x.lineno == lineno)
 | 
					                and (lineno is None or x.lineno == lineno)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,7 @@ import textwrap
 | 
				
			||||||
import tokenize
 | 
					import tokenize
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from bisect import bisect_right
 | 
					from bisect import bisect_right
 | 
				
			||||||
 | 
					from types import CodeType
 | 
				
			||||||
from types import FrameType
 | 
					from types import FrameType
 | 
				
			||||||
from typing import Iterator
 | 
					from typing import Iterator
 | 
				
			||||||
from typing import List
 | 
					from typing import List
 | 
				
			||||||
| 
						 | 
					@ -17,6 +18,10 @@ from typing import Union
 | 
				
			||||||
import py
 | 
					import py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from _pytest.compat import overload
 | 
					from _pytest.compat import overload
 | 
				
			||||||
 | 
					from _pytest.compat import TYPE_CHECKING
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if TYPE_CHECKING:
 | 
				
			||||||
 | 
					    from typing_extensions import Literal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class Source:
 | 
					class Source:
 | 
				
			||||||
| 
						 | 
					@ -120,7 +125,7 @@ class Source:
 | 
				
			||||||
        start, end = self.getstatementrange(lineno)
 | 
					        start, end = self.getstatementrange(lineno)
 | 
				
			||||||
        return self[start:end]
 | 
					        return self[start:end]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def getstatementrange(self, lineno: int):
 | 
					    def getstatementrange(self, lineno: int) -> Tuple[int, int]:
 | 
				
			||||||
        """ return (start, end) tuple which spans the minimal
 | 
					        """ return (start, end) tuple which spans the minimal
 | 
				
			||||||
            statement region which containing the given lineno.
 | 
					            statement region which containing the given lineno.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -158,14 +163,36 @@ class Source:
 | 
				
			||||||
    def __str__(self) -> str:
 | 
					    def __str__(self) -> str:
 | 
				
			||||||
        return "\n".join(self.lines)
 | 
					        return "\n".join(self.lines)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @overload
 | 
				
			||||||
    def compile(
 | 
					    def compile(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        filename=None,
 | 
					        filename: Optional[str] = ...,
 | 
				
			||||||
        mode="exec",
 | 
					        mode: str = ...,
 | 
				
			||||||
 | 
					        flag: "Literal[0]" = ...,
 | 
				
			||||||
 | 
					        dont_inherit: int = ...,
 | 
				
			||||||
 | 
					        _genframe: Optional[FrameType] = ...,
 | 
				
			||||||
 | 
					    ) -> CodeType:
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @overload  # noqa: F811
 | 
				
			||||||
 | 
					    def compile(  # noqa: F811
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        filename: Optional[str] = ...,
 | 
				
			||||||
 | 
					        mode: str = ...,
 | 
				
			||||||
 | 
					        flag: int = ...,
 | 
				
			||||||
 | 
					        dont_inherit: int = ...,
 | 
				
			||||||
 | 
					        _genframe: Optional[FrameType] = ...,
 | 
				
			||||||
 | 
					    ) -> Union[CodeType, ast.AST]:
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def compile(  # noqa: F811
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        filename: Optional[str] = None,
 | 
				
			||||||
 | 
					        mode: str = "exec",
 | 
				
			||||||
        flag: int = 0,
 | 
					        flag: int = 0,
 | 
				
			||||||
        dont_inherit: int = 0,
 | 
					        dont_inherit: int = 0,
 | 
				
			||||||
        _genframe: Optional[FrameType] = None,
 | 
					        _genframe: Optional[FrameType] = None,
 | 
				
			||||||
    ):
 | 
					    ) -> Union[CodeType, ast.AST]:
 | 
				
			||||||
        """ return compiled code object. if filename is None
 | 
					        """ return compiled code object. if filename is None
 | 
				
			||||||
            invent an artificial filename which displays
 | 
					            invent an artificial filename which displays
 | 
				
			||||||
            the source/line position of the caller frame.
 | 
					            the source/line position of the caller frame.
 | 
				
			||||||
| 
						 | 
					@ -196,7 +223,9 @@ class Source:
 | 
				
			||||||
            raise newex
 | 
					            raise newex
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            if flag & ast.PyCF_ONLY_AST:
 | 
					            if flag & ast.PyCF_ONLY_AST:
 | 
				
			||||||
 | 
					                assert isinstance(co, ast.AST)
 | 
				
			||||||
                return co
 | 
					                return co
 | 
				
			||||||
 | 
					            assert isinstance(co, CodeType)
 | 
				
			||||||
            lines = [(x + "\n") for x in self.lines]
 | 
					            lines = [(x + "\n") for x in self.lines]
 | 
				
			||||||
            # Type ignored because linecache.cache is private.
 | 
					            # Type ignored because linecache.cache is private.
 | 
				
			||||||
            linecache.cache[filename] = (1, None, lines, filename)  # type: ignore
 | 
					            linecache.cache[filename] = (1, None, lines, filename)  # type: ignore
 | 
				
			||||||
| 
						 | 
					@ -208,7 +237,35 @@ class Source:
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: int = 0):
 | 
					@overload
 | 
				
			||||||
 | 
					def compile_(
 | 
				
			||||||
 | 
					    source: Union[str, bytes, ast.mod, ast.AST],
 | 
				
			||||||
 | 
					    filename: Optional[str] = ...,
 | 
				
			||||||
 | 
					    mode: str = ...,
 | 
				
			||||||
 | 
					    flags: "Literal[0]" = ...,
 | 
				
			||||||
 | 
					    dont_inherit: int = ...,
 | 
				
			||||||
 | 
					) -> CodeType:
 | 
				
			||||||
 | 
					    raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@overload  # noqa: F811
 | 
				
			||||||
 | 
					def compile_(  # noqa: F811
 | 
				
			||||||
 | 
					    source: Union[str, bytes, ast.mod, ast.AST],
 | 
				
			||||||
 | 
					    filename: Optional[str] = ...,
 | 
				
			||||||
 | 
					    mode: str = ...,
 | 
				
			||||||
 | 
					    flags: int = ...,
 | 
				
			||||||
 | 
					    dont_inherit: int = ...,
 | 
				
			||||||
 | 
					) -> Union[CodeType, ast.AST]:
 | 
				
			||||||
 | 
					    raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def compile_(  # noqa: F811
 | 
				
			||||||
 | 
					    source: Union[str, bytes, ast.mod, ast.AST],
 | 
				
			||||||
 | 
					    filename: Optional[str] = None,
 | 
				
			||||||
 | 
					    mode: str = "exec",
 | 
				
			||||||
 | 
					    flags: int = 0,
 | 
				
			||||||
 | 
					    dont_inherit: int = 0,
 | 
				
			||||||
 | 
					) -> Union[CodeType, ast.AST]:
 | 
				
			||||||
    """ compile the given source to a raw code object,
 | 
					    """ compile the given source to a raw code object,
 | 
				
			||||||
        and maintain an internal cache which allows later
 | 
					        and maintain an internal cache which allows later
 | 
				
			||||||
        retrieval of the source code for the code object
 | 
					        retrieval of the source code for the code object
 | 
				
			||||||
| 
						 | 
					@ -216,14 +273,16 @@ def compile_(source, filename=None, mode="exec", flags: int = 0, dont_inherit: i
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    if isinstance(source, ast.AST):
 | 
					    if isinstance(source, ast.AST):
 | 
				
			||||||
        # XXX should Source support having AST?
 | 
					        # XXX should Source support having AST?
 | 
				
			||||||
        return compile(source, filename, mode, flags, dont_inherit)
 | 
					        assert filename is not None
 | 
				
			||||||
 | 
					        co = compile(source, filename, mode, flags, dont_inherit)
 | 
				
			||||||
 | 
					        assert isinstance(co, (CodeType, ast.AST))
 | 
				
			||||||
 | 
					        return co
 | 
				
			||||||
    _genframe = sys._getframe(1)  # the caller
 | 
					    _genframe = sys._getframe(1)  # the caller
 | 
				
			||||||
    s = Source(source)
 | 
					    s = Source(source)
 | 
				
			||||||
    co = s.compile(filename, mode, flags, _genframe=_genframe)
 | 
					    return s.compile(filename, mode, flags, _genframe=_genframe)
 | 
				
			||||||
    return co
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getfslineno(obj):
 | 
					def getfslineno(obj) -> Tuple[Union[str, py.path.local], int]:
 | 
				
			||||||
    """ Return source location (path, lineno) for the given object.
 | 
					    """ Return source location (path, lineno) for the given object.
 | 
				
			||||||
    If the source cannot be determined return ("", -1).
 | 
					    If the source cannot be determined return ("", -1).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -4,10 +4,13 @@
 | 
				
			||||||
import ast
 | 
					import ast
 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					from types import CodeType
 | 
				
			||||||
from typing import Any
 | 
					from typing import Any
 | 
				
			||||||
from typing import Dict
 | 
					from typing import Dict
 | 
				
			||||||
from typing import Optional
 | 
					from typing import Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import _pytest._code
 | 
					import _pytest._code
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from _pytest._code import Source
 | 
					from _pytest._code import Source
 | 
				
			||||||
| 
						 | 
					@ -147,6 +150,10 @@ class TestAccesses:
 | 
				
			||||||
        assert len(x.lines) == 2
 | 
					        assert len(x.lines) == 2
 | 
				
			||||||
        assert str(x) == "def f(x):\n    pass"
 | 
					        assert str(x) == "def f(x):\n    pass"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_getrange_step_not_supported(self) -> None:
 | 
				
			||||||
 | 
					        with pytest.raises(IndexError, match=r"step"):
 | 
				
			||||||
 | 
					            self.source[::2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_getline(self) -> None:
 | 
					    def test_getline(self) -> None:
 | 
				
			||||||
        x = self.source[0]
 | 
					        x = self.source[0]
 | 
				
			||||||
        assert x == "def f(x):"
 | 
					        assert x == "def f(x):"
 | 
				
			||||||
| 
						 | 
					@ -449,6 +456,14 @@ def test_idem_compile_and_getsource() -> None:
 | 
				
			||||||
    assert src == expected
 | 
					    assert src == expected
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_compile_ast() -> None:
 | 
				
			||||||
 | 
					    # We don't necessarily want to support this.
 | 
				
			||||||
 | 
					    # This test was added just for coverage.
 | 
				
			||||||
 | 
					    stmt = ast.parse("def x(): pass")
 | 
				
			||||||
 | 
					    co = _pytest._code.compile(stmt, filename="foo.py")
 | 
				
			||||||
 | 
					    assert isinstance(co, CodeType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_findsource_fallback() -> None:
 | 
					def test_findsource_fallback() -> None:
 | 
				
			||||||
    from _pytest._code.source import findsource
 | 
					    from _pytest._code.source import findsource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -488,6 +503,7 @@ def test_getfslineno() -> None:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fspath, lineno = getfslineno(f)
 | 
					    fspath, lineno = getfslineno(f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert isinstance(fspath, py.path.local)
 | 
				
			||||||
    assert fspath.basename == "test_source.py"
 | 
					    assert fspath.basename == "test_source.py"
 | 
				
			||||||
    assert lineno == f.__code__.co_firstlineno - 1  # see findsource
 | 
					    assert lineno == f.__code__.co_firstlineno - 1  # see findsource
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue