Type annotate Metafunc

This commit is contained in:
Ran Benita
2020-02-14 10:50:45 +02:00
parent 56a5dbe252
commit 5945c3fe88
3 changed files with 296 additions and 211 deletions

View File

@@ -2,14 +2,18 @@
import enum
import fnmatch
import inspect
import itertools
import os
import sys
import typing
import warnings
from collections import Counter
from collections import defaultdict
from collections.abc import Sequence
from functools import partial
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
@@ -23,6 +27,8 @@ from _pytest import nodes
from _pytest._code import filter_traceback
from _pytest._code.code import ExceptionInfo
from _pytest._code.source import getfslineno
from _pytest._io import TerminalWriter
from _pytest._io.saferepr import saferepr
from _pytest.compat import ascii_escaped
from _pytest.compat import get_default_arg_names
from _pytest.compat import get_real_func
@@ -35,8 +41,10 @@ from _pytest.compat import REGEX_TYPE
from _pytest.compat import safe_getattr
from _pytest.compat import safe_isclass
from _pytest.compat import STRING_TYPES
from _pytest.config import Config
from _pytest.config import hookimpl
from _pytest.deprecated import FUNCARGNAMES
from _pytest.fixtures import FuncFixtureInfo
from _pytest.mark import MARK_GEN
from _pytest.mark import ParameterSet
from _pytest.mark.structures import get_unpacked_marks
@@ -125,9 +133,10 @@ def pytest_cmdline_main(config):
return 0
def pytest_generate_tests(metafunc):
def pytest_generate_tests(metafunc: "Metafunc") -> None:
for marker in metafunc.definition.iter_markers(name="parametrize"):
metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker)
# TODO: Fix this type-ignore (overlapping kwargs).
metafunc.parametrize(*marker.args, **marker.kwargs, _param_mark=marker) # type: ignore[misc] # noqa: F821
def pytest_configure(config):
@@ -839,8 +848,8 @@ class Metafunc:
def __init__(
self,
definition: "FunctionDefinition",
fixtureinfo,
config,
fixtureinfo: fixtures.FuncFixtureInfo,
config: Config,
cls=None,
module=None,
) -> None:
@@ -872,14 +881,19 @@ class Metafunc:
def parametrize(
self,
argnames,
argvalues,
indirect=False,
ids=None,
scope=None,
argnames: Union[str, List[str], Tuple[str, ...]],
argvalues: Iterable[Union[ParameterSet, typing.Sequence[object], object]],
indirect: Union[bool, typing.Sequence[str]] = False,
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[object], Optional[object]],
]
] = None,
scope: "Optional[str]" = None,
*,
_param_mark: Optional[Mark] = None
):
) -> None:
""" Add new invocations to the underlying test function using the list
of argvalues for the given argnames. Parametrization is performed
during the collection phase. If you need to setup expensive resources
@@ -988,8 +1002,17 @@ class Metafunc:
self._calls = newcalls
def _resolve_arg_ids(
self, argnames: List[str], ids, parameters: List[ParameterSet], item: nodes.Item
):
self,
argnames: typing.Sequence[str],
ids: Optional[
Union[
Iterable[Union[None, str, float, int, bool]],
Callable[[object], Optional[object]],
]
],
parameters: typing.Sequence[ParameterSet],
item,
) -> List[str]:
"""Resolves the actual ids for the given argnames, based on the ``ids`` parameter given
to ``parametrize``.
@@ -1000,49 +1023,56 @@ class Metafunc:
:rtype: List[str]
:return: the list of ids for each argname given
"""
idfn = None
if callable(ids):
if ids is None:
idfn = None
ids_ = None
elif callable(ids):
idfn = ids
ids = None
if ids:
func_name = self.function.__name__
ids = self._validate_ids(ids, parameters, func_name)
ids = idmaker(argnames, parameters, idfn, ids, self.config, item=item)
return ids
ids_ = None
else:
idfn = None
ids_ = self._validate_ids(ids, parameters, self.function.__name__)
return idmaker(argnames, parameters, idfn, ids_, self.config, item=item)
def _validate_ids(self, ids, parameters, func_name):
def _validate_ids(
self,
ids: Iterable[Union[None, str, float, int, bool]],
parameters: typing.Sequence[ParameterSet],
func_name: str,
) -> List[Union[None, str]]:
try:
len(ids)
num_ids = len(ids) # type: ignore[arg-type] # noqa: F821
except TypeError:
try:
it = iter(ids)
iter(ids)
except TypeError:
raise TypeError("ids must be a callable, sequence or generator")
else:
import itertools
raise TypeError("ids must be a callable or an iterable")
num_ids = len(parameters)
new_ids = list(itertools.islice(it, len(parameters)))
else:
new_ids = list(ids)
if len(new_ids) != len(parameters):
# num_ids == 0 is a special case: https://github.com/pytest-dev/pytest/issues/1849
if num_ids != len(parameters) and num_ids != 0:
msg = "In {}: {} parameter sets specified, with different number of ids: {}"
fail(msg.format(func_name, len(parameters), len(ids)), pytrace=False)
for idx, id_value in enumerate(new_ids):
if id_value is not None:
if isinstance(id_value, (float, int, bool)):
new_ids[idx] = str(id_value)
elif not isinstance(id_value, str):
from _pytest._io.saferepr import saferepr
fail(msg.format(func_name, len(parameters), num_ids), pytrace=False)
msg = "In {}: ids must be list of string/float/int/bool, found: {} (type: {!r}) at index {}"
fail(
msg.format(func_name, saferepr(id_value), type(id_value), idx),
pytrace=False,
)
new_ids = []
for idx, id_value in enumerate(itertools.islice(ids, num_ids)):
if id_value is None or isinstance(id_value, str):
new_ids.append(id_value)
elif isinstance(id_value, (float, int, bool)):
new_ids.append(str(id_value))
else:
msg = "In {}: ids must be list of string/float/int/bool, found: {} (type: {!r}) at index {}"
fail(
msg.format(func_name, saferepr(id_value), type(id_value), idx),
pytrace=False,
)
return new_ids
def _resolve_arg_value_types(self, argnames: List[str], indirect) -> Dict[str, str]:
def _resolve_arg_value_types(
self,
argnames: typing.Sequence[str],
indirect: Union[bool, typing.Sequence[str]],
) -> Dict[str, str]:
"""Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg"
to the function, based on the ``indirect`` parameter of the parametrized() call.
@@ -1075,7 +1105,11 @@ class Metafunc:
)
return valtypes
def _validate_if_using_arg_names(self, argnames, indirect):
def _validate_if_using_arg_names(
self,
argnames: typing.Sequence[str],
indirect: Union[bool, typing.Sequence[str]],
) -> None:
"""
Check if all argnames are being used, by default values, or directly/indirectly.
@@ -1095,7 +1129,7 @@ class Metafunc:
pytrace=False,
)
else:
if isinstance(indirect, (tuple, list)):
if isinstance(indirect, Sequence):
name = "fixture" if arg in indirect else "argument"
else:
name = "fixture" if indirect else "argument"
@@ -1104,7 +1138,11 @@ class Metafunc:
pytrace=False,
)
def _validate_explicit_parameters(self, argnames, indirect):
def _validate_explicit_parameters(
self,
argnames: typing.Sequence[str],
indirect: Union[bool, typing.Sequence[str]],
) -> None:
"""
The argnames in *parametrize* should either be declared explicitly via
indirect list or in the function signature
@@ -1113,17 +1151,15 @@ class Metafunc:
:param indirect: same ``indirect`` parameter of ``parametrize()``.
:raise ValueError: if validation fails
"""
if isinstance(indirect, bool) and indirect is True:
return
parametrized_argnames = list()
funcargnames = _pytest.compat.getfuncargnames(self.function)
if isinstance(indirect, Sequence):
for arg in argnames:
if arg not in indirect:
parametrized_argnames.append(arg)
elif indirect is False:
parametrized_argnames = argnames
if isinstance(indirect, bool):
parametrized_argnames = [] if indirect else argnames
else:
parametrized_argnames = [arg for arg in argnames if arg not in indirect]
if not parametrized_argnames:
return
funcargnames = _pytest.compat.getfuncargnames(self.function)
usefixtures = fixtures.get_use_fixtures_for_node(self.definition)
for arg in parametrized_argnames:
@@ -1169,17 +1205,27 @@ def _find_parametrized_scope(argnames, arg2fixturedefs, indirect):
return "function"
def _ascii_escaped_by_config(val, config):
def _ascii_escaped_by_config(val: Union[str, bytes], config: Optional[Config]) -> str:
if config is None:
escape_option = False
else:
escape_option = config.getini(
"disable_test_id_escaping_and_forfeit_all_rights_to_community_support"
)
return val if escape_option else ascii_escaped(val)
# TODO: If escaping is turned off and the user passes bytes,
# will return a bytes. For now we ignore this but the
# code *probably* doesn't handle this case.
return val if escape_option else ascii_escaped(val) # type: ignore
def _idval(val, argname, idx, idfn, item, config):
def _idval(
val: object,
argname: str,
idx: int,
idfn: Optional[Callable[[object], Optional[object]]],
item,
config: Optional[Config],
) -> str:
if idfn:
try:
generated_id = idfn(val)
@@ -1192,7 +1238,7 @@ def _idval(val, argname, idx, idfn, item, config):
elif config:
hook_id = config.hook.pytest_make_parametrize_id(
config=config, val=val, argname=argname
)
) # type: Optional[str]
if hook_id:
return hook_id
@@ -1204,48 +1250,65 @@ def _idval(val, argname, idx, idfn, item, config):
return ascii_escaped(val.pattern)
elif isinstance(val, enum.Enum):
return str(val)
elif hasattr(val, "__name__") and isinstance(val.__name__, str):
elif isinstance(getattr(val, "__name__", None), str):
# name of a class, function, module, etc.
return val.__name__
name = getattr(val, "__name__") # type: str
return name
return str(argname) + str(idx)
def _idvalset(idx, parameterset, argnames, idfn, ids, item, config):
def _idvalset(
idx: int,
parameterset: ParameterSet,
argnames: Iterable[str],
idfn: Optional[Callable[[object], Optional[object]]],
ids: Optional[List[Union[None, str]]],
item,
config: Optional[Config],
):
if parameterset.id is not None:
return parameterset.id
if ids is None or (idx >= len(ids) or ids[idx] is None):
id = None if ids is None or idx >= len(ids) else ids[idx]
if id is None:
this_id = [
_idval(val, argname, idx, idfn, item=item, config=config)
for val, argname in zip(parameterset.values, argnames)
]
return "-".join(this_id)
else:
return _ascii_escaped_by_config(ids[idx], config)
return _ascii_escaped_by_config(id, config)
def idmaker(argnames, parametersets, idfn=None, ids=None, config=None, item=None):
ids = [
def idmaker(
argnames: Iterable[str],
parametersets: Iterable[ParameterSet],
idfn: Optional[Callable[[object], Optional[object]]] = None,
ids: Optional[List[Union[None, str]]] = None,
config: Optional[Config] = None,
item=None,
) -> List[str]:
resolved_ids = [
_idvalset(valindex, parameterset, argnames, idfn, ids, config=config, item=item)
for valindex, parameterset in enumerate(parametersets)
]
# All IDs must be unique!
unique_ids = set(ids)
if len(unique_ids) != len(ids):
unique_ids = set(resolved_ids)
if len(unique_ids) != len(resolved_ids):
# Record the number of occurrences of each test ID
test_id_counts = Counter(ids)
test_id_counts = Counter(resolved_ids)
# Map the test ID to its next suffix
test_id_suffixes = defaultdict(int)
test_id_suffixes = defaultdict(int) # type: Dict[str, int]
# Suffix non-unique IDs to make them unique
for index, test_id in enumerate(ids):
for index, test_id in enumerate(resolved_ids):
if test_id_counts[test_id] > 1:
ids[index] = "{}{}".format(test_id, test_id_suffixes[test_id])
resolved_ids[index] = "{}{}".format(test_id, test_id_suffixes[test_id])
test_id_suffixes[test_id] += 1
return ids
return resolved_ids
def show_fixtures_per_test(config):
@@ -1369,7 +1432,7 @@ def _showfixtures_main(config, session):
tw.line()
def write_docstring(tw, doc, indent=" "):
def write_docstring(tw: TerminalWriter, doc: str, indent: str = " ") -> None:
for line in doc.split("\n"):
tw.write(indent + line + "\n")
@@ -1388,13 +1451,13 @@ class Function(PyobjMixin, nodes.Item):
parent,
args=None,
config=None,
callspec=None,
callspec: Optional[CallSpec2] = None,
callobj=NOTSET,
keywords=None,
session=None,
fixtureinfo=None,
fixtureinfo: Optional[FuncFixtureInfo] = None,
originalname=None,
):
) -> None:
super().__init__(name, parent, config=config, session=session)
self._args = args
if callobj is not NOTSET:
@@ -1430,7 +1493,7 @@ class Function(PyobjMixin, nodes.Item):
fixtureinfo = self.session._fixturemanager.getfixtureinfo(
self, self.obj, self.cls, funcargs=True
)
self._fixtureinfo = fixtureinfo
self._fixtureinfo = fixtureinfo # type: FuncFixtureInfo
self.fixturenames = fixtureinfo.names_closure
self._initrequest()