Add type annotations to _pytest.compat
This commit is contained in:
parent
a649f157de
commit
562d4811d5
|
@ -10,11 +10,14 @@ import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from inspect import Parameter
|
from inspect import Parameter
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
from typing import Any
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from typing import Generic
|
from typing import Generic
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import overload
|
from typing import overload
|
||||||
|
from typing import Tuple
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import py
|
import py
|
||||||
|
@ -46,7 +49,7 @@ else:
|
||||||
import importlib_metadata # noqa: F401
|
import importlib_metadata # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
def _format_args(func):
|
def _format_args(func: Callable[..., Any]) -> str:
|
||||||
return str(signature(func))
|
return str(signature(func))
|
||||||
|
|
||||||
|
|
||||||
|
@ -67,12 +70,12 @@ else:
|
||||||
fspath = os.fspath
|
fspath = os.fspath
|
||||||
|
|
||||||
|
|
||||||
def is_generator(func):
|
def is_generator(func: object) -> bool:
|
||||||
genfunc = inspect.isgeneratorfunction(func)
|
genfunc = inspect.isgeneratorfunction(func)
|
||||||
return genfunc and not iscoroutinefunction(func)
|
return genfunc and not iscoroutinefunction(func)
|
||||||
|
|
||||||
|
|
||||||
def iscoroutinefunction(func):
|
def iscoroutinefunction(func: object) -> bool:
|
||||||
"""
|
"""
|
||||||
Return True if func is a coroutine function (a function defined with async
|
Return True if func is a coroutine function (a function defined with async
|
||||||
def syntax, and doesn't contain yield), or a function decorated with
|
def syntax, and doesn't contain yield), or a function decorated with
|
||||||
|
@ -85,7 +88,7 @@ def iscoroutinefunction(func):
|
||||||
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
|
return inspect.iscoroutinefunction(func) or getattr(func, "_is_coroutine", False)
|
||||||
|
|
||||||
|
|
||||||
def getlocation(function, curdir=None):
|
def getlocation(function, curdir=None) -> str:
|
||||||
function = get_real_func(function)
|
function = get_real_func(function)
|
||||||
fn = py.path.local(inspect.getfile(function))
|
fn = py.path.local(inspect.getfile(function))
|
||||||
lineno = function.__code__.co_firstlineno
|
lineno = function.__code__.co_firstlineno
|
||||||
|
@ -94,7 +97,7 @@ def getlocation(function, curdir=None):
|
||||||
return "%s:%d" % (fn, lineno + 1)
|
return "%s:%d" % (fn, lineno + 1)
|
||||||
|
|
||||||
|
|
||||||
def num_mock_patch_args(function):
|
def num_mock_patch_args(function) -> int:
|
||||||
""" return number of arguments used up by mock arguments (if any) """
|
""" return number of arguments used up by mock arguments (if any) """
|
||||||
patchings = getattr(function, "patchings", None)
|
patchings = getattr(function, "patchings", None)
|
||||||
if not patchings:
|
if not patchings:
|
||||||
|
@ -113,7 +116,13 @@ def num_mock_patch_args(function):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def getfuncargnames(function, *, name: str = "", is_method=False, cls=None):
|
def getfuncargnames(
|
||||||
|
function: Callable[..., Any],
|
||||||
|
*,
|
||||||
|
name: str = "",
|
||||||
|
is_method: bool = False,
|
||||||
|
cls: Optional[type] = None
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
"""Returns the names of a function's mandatory arguments.
|
"""Returns the names of a function's mandatory arguments.
|
||||||
|
|
||||||
This should return the names of all function arguments that:
|
This should return the names of all function arguments that:
|
||||||
|
@ -181,7 +190,7 @@ else:
|
||||||
from contextlib import nullcontext # noqa
|
from contextlib import nullcontext # noqa
|
||||||
|
|
||||||
|
|
||||||
def get_default_arg_names(function):
|
def get_default_arg_names(function: Callable[..., Any]) -> Tuple[str, ...]:
|
||||||
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
|
# Note: this code intentionally mirrors the code at the beginning of getfuncargnames,
|
||||||
# to get the arguments which were excluded from its result because they had default values
|
# to get the arguments which were excluded from its result because they had default values
|
||||||
return tuple(
|
return tuple(
|
||||||
|
@ -200,18 +209,18 @@ _non_printable_ascii_translate_table.update(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _translate_non_printable(s):
|
def _translate_non_printable(s: str) -> str:
|
||||||
return s.translate(_non_printable_ascii_translate_table)
|
return s.translate(_non_printable_ascii_translate_table)
|
||||||
|
|
||||||
|
|
||||||
STRING_TYPES = bytes, str
|
STRING_TYPES = bytes, str
|
||||||
|
|
||||||
|
|
||||||
def _bytes_to_ascii(val):
|
def _bytes_to_ascii(val: bytes) -> str:
|
||||||
return val.decode("ascii", "backslashreplace")
|
return val.decode("ascii", "backslashreplace")
|
||||||
|
|
||||||
|
|
||||||
def ascii_escaped(val):
|
def ascii_escaped(val: Union[bytes, str]):
|
||||||
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
|
"""If val is pure ascii, returns it as a str(). Otherwise, escapes
|
||||||
bytes objects into a sequence of escaped bytes:
|
bytes objects into a sequence of escaped bytes:
|
||||||
|
|
||||||
|
@ -308,7 +317,7 @@ def getimfunc(func):
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
def safe_getattr(object, name, default):
|
def safe_getattr(object: Any, name: str, default: Any) -> Any:
|
||||||
""" Like getattr but return default upon any Exception or any OutcomeException.
|
""" Like getattr but return default upon any Exception or any OutcomeException.
|
||||||
|
|
||||||
Attribute access can potentially fail for 'evil' Python objects.
|
Attribute access can potentially fail for 'evil' Python objects.
|
||||||
|
@ -322,7 +331,7 @@ def safe_getattr(object, name, default):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
def safe_isclass(obj):
|
def safe_isclass(obj: object) -> bool:
|
||||||
"""Ignore any exception via isinstance on Python 3."""
|
"""Ignore any exception via isinstance on Python 3."""
|
||||||
try:
|
try:
|
||||||
return inspect.isclass(obj)
|
return inspect.isclass(obj)
|
||||||
|
@ -343,21 +352,23 @@ COLLECT_FAKEMODULE_ATTRIBUTES = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _setup_collect_fakemodule():
|
def _setup_collect_fakemodule() -> None:
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
pytest.collect = ModuleType("pytest.collect")
|
# Types ignored because the module is created dynamically.
|
||||||
pytest.collect.__all__ = [] # used for setns
|
pytest.collect = ModuleType("pytest.collect") # type: ignore
|
||||||
|
pytest.collect.__all__ = [] # type: ignore # used for setns
|
||||||
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
|
for attr_name in COLLECT_FAKEMODULE_ATTRIBUTES:
|
||||||
setattr(pytest.collect, attr_name, getattr(pytest, attr_name))
|
setattr(pytest.collect, attr_name, getattr(pytest, attr_name)) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class CaptureIO(io.TextIOWrapper):
|
class CaptureIO(io.TextIOWrapper):
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
|
super().__init__(io.BytesIO(), encoding="UTF-8", newline="", write_through=True)
|
||||||
|
|
||||||
def getvalue(self):
|
def getvalue(self) -> str:
|
||||||
|
assert isinstance(self.buffer, io.BytesIO)
|
||||||
return self.buffer.getvalue().decode("UTF-8")
|
return self.buffer.getvalue().decode("UTF-8")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue