diff --git a/src/_pytest/logging.py b/src/_pytest/logging.py index 92046ed51..ce3a18f03 100644 --- a/src/_pytest/logging.py +++ b/src/_pytest/logging.py @@ -11,18 +11,24 @@ from typing import Generator from typing import List from typing import Mapping from typing import Optional +from typing import Tuple +from typing import TypeVar from typing import Union import pytest from _pytest import nodes +from _pytest._io import TerminalWriter +from _pytest.capture import CaptureManager from _pytest.compat import nullcontext from _pytest.config import _strtobool from _pytest.config import Config from _pytest.config import create_terminal_writer from _pytest.config.argparsing import Parser +from _pytest.fixtures import FixtureRequest from _pytest.main import Session from _pytest.pathlib import Path from _pytest.store import StoreKey +from _pytest.terminal import TerminalReporter DEFAULT_LOG_FORMAT = "%(levelname)-8s %(name)s:%(filename)s:%(lineno)d %(message)s" @@ -32,7 +38,7 @@ catch_log_handler_key = StoreKey["LogCaptureHandler"]() catch_log_records_key = StoreKey[Dict[str, List[logging.LogRecord]]]() -def _remove_ansi_escape_sequences(text): +def _remove_ansi_escape_sequences(text: str) -> str: return _ANSI_ESCAPE_SEQ.sub("", text) @@ -52,7 +58,7 @@ class ColoredLevelFormatter(logging.Formatter): } # type: Mapping[int, AbstractSet[str]] LEVELNAME_FMT_REGEX = re.compile(r"%\(levelname\)([+-.]?\d*s)") - def __init__(self, terminalwriter, *args, **kwargs) -> None: + def __init__(self, terminalwriter: TerminalWriter, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._original_fmt = self._style._fmt self._level_to_fmt_mapping = {} # type: Dict[int, str] @@ -77,7 +83,7 @@ class ColoredLevelFormatter(logging.Formatter): colorized_formatted_levelname, self._fmt ) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: fmt = self._level_to_fmt_mapping.get(record.levelno, self._original_fmt) self._style._fmt = fmt return super().format(record) @@ -90,18 +96,20 @@ class PercentStyleMultiline(logging.PercentStyle): formats the message as if each line were logged separately. """ - def __init__(self, fmt, auto_indent): + def __init__(self, fmt: str, auto_indent: Union[int, str, bool]) -> None: super().__init__(fmt) self._auto_indent = self._get_auto_indent(auto_indent) @staticmethod - def _update_message(record_dict, message): + def _update_message( + record_dict: Dict[str, object], message: str + ) -> Dict[str, object]: tmp = record_dict.copy() tmp["message"] = message return tmp @staticmethod - def _get_auto_indent(auto_indent_option) -> int: + def _get_auto_indent(auto_indent_option: Union[int, str, bool]) -> int: """Determines the current auto indentation setting Specify auto indent behavior (on/off/fixed) by passing in @@ -149,11 +157,11 @@ class PercentStyleMultiline(logging.PercentStyle): return 0 - def format(self, record): + def format(self, record: logging.LogRecord) -> str: if "\n" in record.message: if hasattr(record, "auto_indent"): # passed in from the "extra={}" kwarg on the call to logging.log() - auto_indent = self._get_auto_indent(record.auto_indent) + auto_indent = self._get_auto_indent(record.auto_indent) # type: ignore[attr-defined] # noqa: F821 else: auto_indent = self._auto_indent @@ -173,7 +181,7 @@ class PercentStyleMultiline(logging.PercentStyle): return self._fmt % record.__dict__ -def get_option_ini(config, *names): +def get_option_ini(config: Config, *names: str): for name in names: ret = config.getoption(name) # 'default' arg won't work as expected if ret is None: @@ -268,13 +276,16 @@ def pytest_addoption(parser: Parser) -> None: ) +_HandlerType = TypeVar("_HandlerType", bound=logging.Handler) + + # Not using @contextmanager for performance reasons. class catching_logs: """Context manager that prepares the whole logging machinery properly.""" __slots__ = ("handler", "level", "orig_level") - def __init__(self, handler, level=None): + def __init__(self, handler: _HandlerType, level: Optional[int] = None) -> None: self.handler = handler self.level = level @@ -330,7 +341,7 @@ class LogCaptureFixture: """Creates a new funcarg.""" self._item = item # dict of log name -> log level - self._initial_log_levels = {} # type: Dict[str, int] + self._initial_log_levels = {} # type: Dict[Optional[str], int] def _finalize(self) -> None: """Finalizes the fixture. @@ -364,17 +375,17 @@ class LogCaptureFixture: return self._item._store[catch_log_records_key].get(when, []) @property - def text(self): + def text(self) -> str: """Returns the formatted log text.""" return _remove_ansi_escape_sequences(self.handler.stream.getvalue()) @property - def records(self): + def records(self) -> List[logging.LogRecord]: """Returns the list of log records.""" return self.handler.records @property - def record_tuples(self): + def record_tuples(self) -> List[Tuple[str, int, str]]: """Returns a list of a stripped down version of log records intended for use in assertion comparison. @@ -385,7 +396,7 @@ class LogCaptureFixture: return [(r.name, r.levelno, r.getMessage()) for r in self.records] @property - def messages(self): + def messages(self) -> List[str]: """Returns a list of format-interpolated log messages. Unlike 'records', which contains the format string and parameters for interpolation, log messages in this list @@ -400,11 +411,11 @@ class LogCaptureFixture: """ return [r.getMessage() for r in self.records] - def clear(self): + def clear(self) -> None: """Reset the list of log records and the captured log text.""" self.handler.reset() - def set_level(self, level, logger=None): + def set_level(self, level: Union[int, str], logger: Optional[str] = None) -> None: """Sets the level for capturing of logs. The level will be restored to its previous value at the end of the test. @@ -415,31 +426,32 @@ class LogCaptureFixture: The levels of the loggers changed by this function will be restored to their initial values at the end of the test. """ - logger_name = logger - logger = logging.getLogger(logger_name) + logger_obj = logging.getLogger(logger) # save the original log-level to restore it during teardown - self._initial_log_levels.setdefault(logger_name, logger.level) - logger.setLevel(level) + self._initial_log_levels.setdefault(logger, logger_obj.level) + logger_obj.setLevel(level) @contextmanager - def at_level(self, level, logger=None): + def at_level( + self, level: int, logger: Optional[str] = None + ) -> Generator[None, None, None]: """Context manager that sets the level for capturing of logs. After the end of the 'with' statement the level is restored to its original value. :param int level: the logger to level. :param str logger: the logger to update the level. If not given, the root logger level is updated. """ - logger = logging.getLogger(logger) - orig_level = logger.level - logger.setLevel(level) + logger_obj = logging.getLogger(logger) + orig_level = logger_obj.level + logger_obj.setLevel(level) try: yield finally: - logger.setLevel(orig_level) + logger_obj.setLevel(orig_level) @pytest.fixture -def caplog(request): +def caplog(request: FixtureRequest) -> Generator[LogCaptureFixture, None, None]: """Access and control log capturing. Captured logs are available through the following properties/methods:: @@ -557,7 +569,7 @@ class LoggingPlugin: return formatter - def set_log_path(self, fname): + def set_log_path(self, fname: str) -> None: """Public method, which can set filename parameter for Logging.FileHandler(). Also creates parent directory if it does not exist. @@ -565,15 +577,15 @@ class LoggingPlugin: .. warning:: Please considered as an experimental API. """ - fname = Path(fname) + fpath = Path(fname) - if not fname.is_absolute(): - fname = Path(self._config.rootdir, fname) + if not fpath.is_absolute(): + fpath = Path(self._config.rootdir, fpath) - if not fname.parent.exists(): - fname.parent.mkdir(exist_ok=True, parents=True) + if not fpath.parent.exists(): + fpath.parent.mkdir(exist_ok=True, parents=True) - stream = fname.open(mode="w", encoding="UTF-8") + stream = fpath.open(mode="w", encoding="UTF-8") if sys.version_info >= (3, 7): old_stream = self.log_file_handler.setStream(stream) else: @@ -715,29 +727,35 @@ class _LiveLoggingStreamHandler(logging.StreamHandler): and won't appear in the terminal. """ - def __init__(self, terminal_reporter, capture_manager): + # Officially stream needs to be a IO[str], but TerminalReporter + # isn't. So force it. + stream = None # type: TerminalReporter # type: ignore + + def __init__( + self, terminal_reporter: TerminalReporter, capture_manager: CaptureManager + ) -> None: """ :param _pytest.terminal.TerminalReporter terminal_reporter: :param _pytest.capture.CaptureManager capture_manager: """ - logging.StreamHandler.__init__(self, stream=terminal_reporter) + logging.StreamHandler.__init__(self, stream=terminal_reporter) # type: ignore[arg-type] # noqa: F821 self.capture_manager = capture_manager self.reset() self.set_when(None) self._test_outcome_written = False - def reset(self): + def reset(self) -> None: """Reset the handler; should be called before the start of each test""" self._first_record_emitted = False - def set_when(self, when): + def set_when(self, when: Optional[str]) -> None: """Prepares for the given test phase (setup/call/teardown)""" self._when = when self._section_name_shown = False if when == "start": self._test_outcome_written = False - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: ctx_manager = ( self.capture_manager.global_and_fixture_disabled() if self.capture_manager @@ -764,10 +782,10 @@ class _LiveLoggingStreamHandler(logging.StreamHandler): class _LiveLoggingNullHandler(logging.NullHandler): """A handler used when live logging is disabled.""" - def reset(self): + def reset(self) -> None: pass - def set_when(self, when): + def set_when(self, when: str) -> None: pass def handleError(self, record: logging.LogRecord) -> None: