Add type annotations to _pytest.assertion.util

This commit is contained in:
Ran Benita 2019-11-03 16:57:14 +02:00
parent 18d181fa77
commit 7d3ce374d2
1 changed files with 42 additions and 23 deletions

View File

@ -1,9 +1,15 @@
"""Utilities for assertion debugging""" """Utilities for assertion debugging"""
import collections.abc
import pprint import pprint
from collections.abc import Sequence from typing import AbstractSet
from typing import Any
from typing import Callable from typing import Callable
from typing import Iterable
from typing import List from typing import List
from typing import Mapping
from typing import Optional from typing import Optional
from typing import Sequence
from typing import Tuple
import _pytest._code import _pytest._code
from _pytest import outcomes from _pytest import outcomes
@ -22,7 +28,7 @@ _reprcompare = None # type: Optional[Callable[[str, object, object], Optional[s
_assertion_pass = None # type: Optional[Callable[[int, str, str], None]] _assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
def format_explanation(explanation): def format_explanation(explanation: str) -> str:
"""This formats an explanation """This formats an explanation
Normally all embedded newlines are escaped, however there are Normally all embedded newlines are escaped, however there are
@ -38,7 +44,7 @@ def format_explanation(explanation):
return "\n".join(result) return "\n".join(result)
def _split_explanation(explanation): def _split_explanation(explanation: str) -> List[str]:
"""Return a list of individual lines in the explanation """Return a list of individual lines in the explanation
This will return a list of lines split on '\n{', '\n}' and '\n~'. This will return a list of lines split on '\n{', '\n}' and '\n~'.
@ -55,7 +61,7 @@ def _split_explanation(explanation):
return lines return lines
def _format_lines(lines): def _format_lines(lines: Sequence[str]) -> List[str]:
"""Format the individual lines """Format the individual lines
This will replace the '{', '}' and '~' characters of our mini This will replace the '{', '}' and '~' characters of our mini
@ -64,7 +70,7 @@ def _format_lines(lines):
Return a list of formatted lines. Return a list of formatted lines.
""" """
result = lines[:1] result = list(lines[:1])
stack = [0] stack = [0]
stackcnt = [0] stackcnt = [0]
for line in lines[1:]: for line in lines[1:]:
@ -90,31 +96,31 @@ def _format_lines(lines):
return result return result
def issequence(x): def issequence(x: Any) -> bool:
return isinstance(x, Sequence) and not isinstance(x, str) return isinstance(x, collections.abc.Sequence) and not isinstance(x, str)
def istext(x): def istext(x: Any) -> bool:
return isinstance(x, str) return isinstance(x, str)
def isdict(x): def isdict(x: Any) -> bool:
return isinstance(x, dict) return isinstance(x, dict)
def isset(x): def isset(x: Any) -> bool:
return isinstance(x, (set, frozenset)) return isinstance(x, (set, frozenset))
def isdatacls(obj): def isdatacls(obj: Any) -> bool:
return getattr(obj, "__dataclass_fields__", None) is not None return getattr(obj, "__dataclass_fields__", None) is not None
def isattrs(obj): def isattrs(obj: Any) -> bool:
return getattr(obj, "__attrs_attrs__", None) is not None return getattr(obj, "__attrs_attrs__", None) is not None
def isiterable(obj): def isiterable(obj: Any) -> bool:
try: try:
iter(obj) iter(obj)
return not istext(obj) return not istext(obj)
@ -122,7 +128,7 @@ def isiterable(obj):
return False return False
def assertrepr_compare(config, op, left, right): def assertrepr_compare(config, op: str, left: Any, right: Any) -> Optional[List[str]]:
"""Return specialised explanations for some operators/operands""" """Return specialised explanations for some operators/operands"""
verbose = config.getoption("verbose") verbose = config.getoption("verbose")
if verbose > 1: if verbose > 1:
@ -180,7 +186,7 @@ def assertrepr_compare(config, op, left, right):
return [summary] + explanation return [summary] + explanation
def _diff_text(left, right, verbose=0): def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
"""Return the explanation for the diff between text. """Return the explanation for the diff between text.
Unless --verbose is used this will skip leading and trailing Unless --verbose is used this will skip leading and trailing
@ -226,7 +232,7 @@ def _diff_text(left, right, verbose=0):
return explanation return explanation
def _compare_eq_verbose(left, right): def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
keepends = True keepends = True
left_lines = repr(left).splitlines(keepends) left_lines = repr(left).splitlines(keepends)
right_lines = repr(right).splitlines(keepends) right_lines = repr(right).splitlines(keepends)
@ -238,7 +244,7 @@ def _compare_eq_verbose(left, right):
return explanation return explanation
def _surrounding_parens_on_own_lines(lines): # type: (List) -> None def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
"""Move opening/closing parenthesis/bracket to own lines.""" """Move opening/closing parenthesis/bracket to own lines."""
opening = lines[0][:1] opening = lines[0][:1]
if opening in ["(", "[", "{"]: if opening in ["(", "[", "{"]:
@ -250,7 +256,9 @@ def _surrounding_parens_on_own_lines(lines): # type: (List) -> None
lines[:] = lines + [closing] lines[:] = lines + [closing]
def _compare_eq_iterable(left, right, verbose=0): def _compare_eq_iterable(
left: Iterable[Any], right: Iterable[Any], verbose: int = 0
) -> List[str]:
if not verbose: if not verbose:
return ["Use -v to get the full diff"] return ["Use -v to get the full diff"]
# dynamic import to speedup pytest # dynamic import to speedup pytest
@ -283,7 +291,9 @@ def _compare_eq_iterable(left, right, verbose=0):
return explanation return explanation
def _compare_eq_sequence(left, right, verbose=0): def _compare_eq_sequence(
left: Sequence[Any], right: Sequence[Any], verbose: int = 0
) -> List[str]:
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes) comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
explanation = [] # type: List[str] explanation = [] # type: List[str]
len_left = len(left) len_left = len(left)
@ -337,7 +347,9 @@ def _compare_eq_sequence(left, right, verbose=0):
return explanation return explanation
def _compare_eq_set(left, right, verbose=0): def _compare_eq_set(
left: AbstractSet[Any], right: AbstractSet[Any], verbose: int = 0
) -> List[str]:
explanation = [] explanation = []
diff_left = left - right diff_left = left - right
diff_right = right - left diff_right = right - left
@ -352,7 +364,9 @@ def _compare_eq_set(left, right, verbose=0):
return explanation return explanation
def _compare_eq_dict(left, right, verbose=0): def _compare_eq_dict(
left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
) -> List[str]:
explanation = [] # type: List[str] explanation = [] # type: List[str]
set_left = set(left) set_left = set(left)
set_right = set(right) set_right = set(right)
@ -391,7 +405,12 @@ def _compare_eq_dict(left, right, verbose=0):
return explanation return explanation
def _compare_eq_cls(left, right, verbose, type_fns): def _compare_eq_cls(
left: Any,
right: Any,
verbose: int,
type_fns: Tuple[Callable[[Any], bool], Callable[[Any], bool]],
) -> List[str]:
isdatacls, isattrs = type_fns isdatacls, isattrs = type_fns
if isdatacls(left): if isdatacls(left):
all_fields = left.__dataclass_fields__ all_fields = left.__dataclass_fields__
@ -425,7 +444,7 @@ def _compare_eq_cls(left, right, verbose, type_fns):
return explanation return explanation
def _notin_text(term, text, verbose=0): def _notin_text(term: str, text: str, verbose: int = 0) -> List[str]:
index = text.find(term) index = text.find(term)
head = text[:index] head = text[:index]
tail = text[index + len(term) :] tail = text[index + len(term) :]