From 13459c35e8d92232f8169a06dcedf24085b6acf6 Mon Sep 17 00:00:00 2001 From: Benjamin Schubert Date: Mon, 23 Oct 2023 19:31:12 +0100 Subject: [PATCH] fixup! Color the full diff that pytest shows as a diff --- src/_pytest/assertion/util.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index 2d8067294..14e2fa7ea 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -7,8 +7,10 @@ from typing import Any from typing import Callable from typing import Iterable from typing import List +from typing import Literal from typing import Mapping from typing import Optional +from typing import Protocol from typing import Sequence from unicodedata import normalize @@ -17,7 +19,6 @@ from _pytest import outcomes from _pytest._io.saferepr import _pformat_dispatch from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr_unlimited -from _pytest._io.terminalwriter import TerminalWriter from _pytest.config import Config # The _reprcompare attribute on the util module is used by the new assertion @@ -34,6 +35,11 @@ _assertion_pass: Optional[Callable[[int, str, str], None]] = None _config: Optional[Config] = None +class _HighlightFunc(Protocol): + def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str: + """Apply highlighting to the given source.""" + + def format_explanation(explanation: str) -> str: r"""Format an explanation. @@ -228,7 +234,7 @@ def assertrepr_compare( def _compare_eq_any( - left: Any, right: Any, writer: TerminalWriter, verbose: int = 0 + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0 ) -> List[str]: explanation = [] if istext(left) and istext(right): @@ -249,7 +255,7 @@ def _compare_eq_any( # field values, not the type or field names. But this branch # intentionally only handles the same-type case, which was often # used in older code bases before dataclasses/attrs were available. - explanation = _compare_eq_cls(left, right, writer, verbose) + explanation = _compare_eq_cls(left, right, highlighter, verbose) elif issequence(left) and issequence(right): explanation = _compare_eq_sequence(left, right, verbose) elif isset(left) and isset(right): @@ -258,7 +264,7 @@ def _compare_eq_any( explanation = _compare_eq_dict(left, right, verbose) if isiterable(left) and isiterable(right): - expl = _compare_eq_iterable(left, right, writer, verbose) + expl = _compare_eq_iterable(left, right, highlighter, verbose) explanation.extend(expl) return explanation @@ -327,7 +333,7 @@ def _surrounding_parens_on_own_lines(lines: List[str]) -> None: def _compare_eq_iterable( left: Iterable[Any], right: Iterable[Any], - writer: TerminalWriter, + highligher: _HighlightFunc, verbose: int = 0, ) -> List[str]: if verbose <= 0 and not running_on_ci(): @@ -353,7 +359,7 @@ def _compare_eq_iterable( # "right" is the expected base against which we compare "left", # see https://github.com/pytest-dev/pytest/issues/3333 explanation.extend( - writer._highlight( + highligher( "\n".join( line.rstrip() for line in difflib.ndiff(right_formatting, left_formatting) @@ -510,7 +516,7 @@ def _compare_eq_dict( def _compare_eq_cls( - left: Any, right: Any, writer: TerminalWriter, verbose: int + left: Any, right: Any, highlighter: _HighlightFunc, verbose: int ) -> List[str]: if not has_default_eq(left): return [] @@ -557,7 +563,9 @@ def _compare_eq_cls( ] explanation += [ indent + line - for line in _compare_eq_any(field_left, field_right, writer, verbose) + for line in _compare_eq_any( + field_left, field_right, highlighter, verbose + ) ] return explanation