Add type annotations to _pytest.compat

This commit is contained in:
Ran Benita 2019-11-15 16:26:46 +02:00
parent a649f157de
commit 562d4811d5
1 changed files with 29 additions and 18 deletions

View File

@ -10,11 +10,14 @@ import sys
from contextlib import contextmanager from contextlib import contextmanager
from inspect import Parameter from inspect import Parameter
from inspect import signature from inspect import signature
from typing import Any
from typing import Callable from typing import Callable
from typing import Generic from typing import Generic
from typing import Optional from typing import Optional
from typing import overload from typing import overload
from typing import Tuple
from typing import TypeVar from typing import TypeVar
from typing import Union
import attr import attr
import py import py
@ -46,7 +49,7 @@ else:
import importlib_metadata # noqa: F401 import importlib_metadata # noqa: F401
def _format_args(func): def _format_args(func: Callable[..., Any]) -> str:
return str(signature(func)) return str(signature(func))
@ -67,12 +70,12 @@ else:
fspath = os.fspath fspath = os.fspath
def is_generator(func): def is_generator(func: object) -> bool:
genfunc = inspect.isgeneratorfunction(func) genfunc = inspect.isgeneratorfunction(func)
return genfunc and not iscoroutinefunction(func) return genfunc and not iscoroutinefunction(func)
def iscoroutinefunction(func): def iscoroutinefunction(func: object) -> bool:
""" """
Return True if func is a coroutine function (a function defined with async Return True if func is a coroutine function (a function defined with async
def syntax, and doesn't contain yield), or a function decorated with def syntax, and doesn't contain yield), or a function decorated with
@ -85,7 +88,7 @@ def iscoroutinefunction(func):
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False) return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
def getlocation(function, curdir=None): def getlocation(function, curdir=None) -> str:
function = get_real_func(function) function = get_real_func(function)
fn = py.path.local(inspect.getfile(function)) fn = py.path.local(inspect.getfile(function))
lineno = function.__code__.co_firstlineno lineno = function.__code__.co_firstlineno
@ -94,7 +97,7 @@ def getlocation(function, curdir=None):
return "%s:%d" % (fn, lineno + 1) return "%s:%d" % (fn, lineno + 1)
def num_mock_patch_args(function): def num_mock_patch_args(function) -> int:
""" return number of arguments used up by mock arguments (if any) """ """ return number of arguments used up by mock arguments (if any) """
patchings = getattr(function, "patchings", None) patchings = getattr(function, "patchings", None)
if not patchings: if not patchings:
@ -113,7 +116,13 @@ def num_mock_patch_args(function):
) )
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None): def getfuncargnames(
function: Callable[..., Any],
*,
name: str = "",
is_method: bool = False,
cls: Optional[type] = None
) -> Tuple[str, ...]:
"""Returns the names of a function's mandatory arguments. """Returns the names of a function's mandatory arguments.
This should return the names of all function arguments that: This should return the names of all function arguments that:
@ -181,7 +190,7 @@ else:
from contextlib import nullcontext # noqa from contextlib import nullcontext # noqa
def get_default_arg_names(function): def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames, # Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
# to get the arguments which were excluded from its result because they had default values # to get the arguments which were excluded from its result because they had default values
return tuple( return tuple(
@ -200,18 +209,18 @@ _non_printable_ascii_translate_table.update(
) )
def _translate_non_printable(s): def _translate_non_printable(s: str) -> str:
return s.translate(_non_printable_ascii_translate_table) return s.translate(_non_printable_ascii_translate_table)
STRING_TYPES = bytes, str STRING_TYPES = bytes, str
def _bytes_to_ascii(val): def _bytes_to_ascii(val: bytes) -> str:
return val.decode("ascii", "backslashreplace") return val.decode("ascii", "backslashreplace")
def ascii_escaped(val): def ascii_escaped(val: Union[bytes, str]):
"""If val is pure ascii, returns it as a str(). Otherwise, escapes """If val is pure ascii, returns it as a str(). Otherwise, escapes
bytes objects into a sequence of escaped bytes: bytes objects into a sequence of escaped bytes:
@ -308,7 +317,7 @@ def getimfunc(func):
return func return func
def safe_getattr(object, name, default): def safe_getattr(object: Any, name: str, default: Any) -> Any:
""" Like getattr but return default upon any Exception or any OutcomeException. """ Like getattr but return default upon any Exception or any OutcomeException.
Attribute access can potentially fail for 'evil' Python objects. Attribute access can potentially fail for 'evil' Python objects.
@ -322,7 +331,7 @@ def safe_getattr(object, name, default):
return default return default
def safe_isclass(obj): def safe_isclass(obj: object) -> bool:
"""Ignore any exception via isinstance on Python 3.""" """Ignore any exception via isinstance on Python 3."""
try: try:
return inspect.isclass(obj) return inspect.isclass(obj)
@ -343,21 +352,23 @@ COLLECT_FAKEMODULE_ATTRIBUTES = (
) )
def _setup_collect_fakemodule(): def _setup_collect_fakemodule() -> None:
from types import ModuleType from types import ModuleType
import pytest import pytest
pytest.collect = ModuleType("pytest.collect") # Types ignored because the module is created dynamically.
pytest.collect.__all__ = [] # used for setns pytest.collect = ModuleType("pytest.collect") # type: ignore
pytest.collect.__all__ = [] # type: ignore # used for setns
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES: for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
class CaptureIO(io.TextIOWrapper): class CaptureIO(io.TextIOWrapper):
def __init__(self): def __init__(self) -> None:
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True) super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
def getvalue(self): def getvalue(self) -> str:
assert isinstance(self.buffer, io.BytesIO)
return self.buffer.getvalue().decode("UTF-8") return self.buffer.getvalue().decode("UTF-8")