From 0267b25c66875617ed69132445820a6f82e6e2fa Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Fri, 13 Sep 2019 23:09:08 +0300 Subject: [PATCH] Fix some check_untyped_defs mypy errors in terminal --- src/_pytest/terminal.py | 42 +++++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/_pytest/terminal.py b/src/_pytest/terminal.py index fd30d8572..35f6d324b 100644 --- a/src/_pytest/terminal.py +++ b/src/_pytest/terminal.py @@ -9,6 +9,12 @@ import platform import sys import time from functools import partial +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Set import attr import pluggy @@ -195,8 +201,8 @@ class WarningReport: file system location of the source of the warning (see ``get_location``). """ - message = attr.ib() - nodeid = attr.ib(default=None) + message = attr.ib(type=str) + nodeid = attr.ib(type=Optional[str], default=None) fslocation = attr.ib(default=None) count_towards_summary = True @@ -240,7 +246,7 @@ class TerminalReporter: self.reportchars = getreportopt(config) self.hasmarkup = self._tw.hasmarkup self.isatty = file.isatty() - self._progress_nodeids_reported = set() + self._progress_nodeids_reported = set() # type: Set[str] self._show_progress_info = self._determine_show_progress_info() self._collect_report_last_write = None @@ -619,7 +625,7 @@ class TerminalReporter: # because later versions are going to get rid of them anyway if self.config.option.verbose < 0: if self.config.option.verbose < -1: - counts = {} + counts = {} # type: Dict[str, int] for item in items: name = item.nodeid.split("::", 1)[0] counts[name] = counts.get(name, 0) + 1 @@ -750,7 +756,9 @@ class TerminalReporter: def summary_warnings(self): if self.hasopt("w"): - all_warnings = self.stats.get("warnings") + all_warnings = self.stats.get( + "warnings" + ) # type: Optional[List[WarningReport]] if not all_warnings: return @@ -763,7 +771,9 @@ class TerminalReporter: if not warning_reports: return - reports_grouped_by_message = collections.OrderedDict() + reports_grouped_by_message = ( + collections.OrderedDict() + ) # type: collections.OrderedDict[str, List[WarningReport]] for wr in warning_reports: reports_grouped_by_message.setdefault(wr.message, []).append(wr) @@ -900,11 +910,11 @@ class TerminalReporter: else: self.write_line(msg, **main_markup) - def short_test_summary(self): + def short_test_summary(self) -> None: if not self.reportchars: return - def show_simple(stat, lines): + def show_simple(stat, lines: List[str]) -> None: failed = self.stats.get(stat, []) if not failed: return @@ -914,7 +924,7 @@ class TerminalReporter: line = _get_line_with_reprcrash_message(config, rep, termwidth) lines.append(line) - def show_xfailed(lines): + def show_xfailed(lines: List[str]) -> None: xfailed = self.stats.get("xfailed", []) for rep in xfailed: verbose_word = rep._get_verbose_word(self.config) @@ -924,7 +934,7 @@ class TerminalReporter: if reason: lines.append(" " + str(reason)) - def show_xpassed(lines): + def show_xpassed(lines: List[str]) -> None: xpassed = self.stats.get("xpassed", []) for rep in xpassed: verbose_word = rep._get_verbose_word(self.config) @@ -932,7 +942,7 @@ class TerminalReporter: reason = rep.wasxfail lines.append("{} {} {}".format(verbose_word, pos, reason)) - def show_skipped(lines): + def show_skipped(lines: List[str]) -> None: skipped = self.stats.get("skipped", []) fskips = _folded_skips(skipped) if skipped else [] if not fskips: @@ -958,9 +968,9 @@ class TerminalReporter: "S": show_skipped, "p": partial(show_simple, "passed"), "E": partial(show_simple, "error"), - } + } # type: Mapping[str, Callable[[List[str]], None]] - lines = [] + lines = [] # type: List[str] for char in self.reportchars: action = REPORTCHAR_ACTIONS.get(char) if action: # skipping e.g. "P" (passed with output) here. @@ -1084,8 +1094,8 @@ def build_summary_stats_line(stats): return parts, main_color -def _plugin_nameversions(plugininfo): - values = [] +def _plugin_nameversions(plugininfo) -> List[str]: + values = [] # type: List[str] for plugin, dist in plugininfo: # gets us name and version! name = "{dist.project_name}-{dist.version}".format(dist=dist) @@ -1099,7 +1109,7 @@ def _plugin_nameversions(plugininfo): return values -def format_session_duration(seconds): +def format_session_duration(seconds: float) -> str: """Format the given seconds in a human readable manner to show in the final summary""" if seconds < 60: return "{:.2f}s".format(seconds)