fixup! Color the full diff that pytest shows as a diff

This commit is contained in:
Benjamin Schubert 2023-10-23 19:31:12 +01:00
parent bc11201dff
commit 13459c35e8
1 changed files with 16 additions and 8 deletions

View File

@ -7,8 +7,10 @@ from typing import Any
from typing import Callable from typing import Callable
from typing import Iterable from typing import Iterable
from typing import List from typing import List
from typing import Literal
from typing import Mapping from typing import Mapping
from typing import Optional from typing import Optional
from typing import Protocol
from typing import Sequence from typing import Sequence
from unicodedata import normalize from unicodedata import normalize
@ -17,7 +19,6 @@ from _pytest import outcomes
from _pytest._io.saferepr import _pformat_dispatch from _pytest._io.saferepr import _pformat_dispatch
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest._io.saferepr import saferepr_unlimited from _pytest._io.saferepr import saferepr_unlimited
from _pytest._io.terminalwriter import TerminalWriter
from _pytest.config import Config from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion # The _reprcompare attribute on the util module is used by the new assertion
@ -34,6 +35,11 @@ _assertion_pass: Optional[Callable[[int, str, str], None]] = None
_config: Optional[Config] = None _config: Optional[Config] = None
class _HighlightFunc(Protocol):
def __call__(self, source: str, lexer: Literal["diff", "python"] = "python") -> str:
"""Apply highlighting to the given source."""
def format_explanation(explanation: str) -> str: def format_explanation(explanation: str) -> str:
r"""Format an explanation. r"""Format an explanation.
@ -228,7 +234,7 @@ def assertrepr_compare(
def _compare_eq_any( def _compare_eq_any(
left: Any, right: Any, writer: TerminalWriter, verbose: int = 0 left: Any, right: Any, highlighter: _HighlightFunc, verbose: int = 0
) -> List[str]: ) -> List[str]:
explanation = [] explanation = []
if istext(left) and istext(right): if istext(left) and istext(right):
@ -249,7 +255,7 @@ def _compare_eq_any(
# field values, not the type or field names. But this branch # field values, not the type or field names. But this branch
# intentionally only handles the same-type case, which was often # intentionally only handles the same-type case, which was often
# used in older code bases before dataclasses/attrs were available. # used in older code bases before dataclasses/attrs were available.
explanation = _compare_eq_cls(left, right, writer, verbose) explanation = _compare_eq_cls(left, right, highlighter, verbose)
elif issequence(left) and issequence(right): elif issequence(left) and issequence(right):
explanation = _compare_eq_sequence(left, right, verbose) explanation = _compare_eq_sequence(left, right, verbose)
elif isset(left) and isset(right): elif isset(left) and isset(right):
@ -258,7 +264,7 @@ def _compare_eq_any(
explanation = _compare_eq_dict(left, right, verbose) explanation = _compare_eq_dict(left, right, verbose)
if isiterable(left) and isiterable(right): if isiterable(left) and isiterable(right):
expl = _compare_eq_iterable(left, right, writer, verbose) expl = _compare_eq_iterable(left, right, highlighter, verbose)
explanation.extend(expl) explanation.extend(expl)
return explanation return explanation
@ -327,7 +333,7 @@ def _surrounding_parens_on_own_lines(lines: List[str]) -> None:
def _compare_eq_iterable( def _compare_eq_iterable(
left: Iterable[Any], left: Iterable[Any],
right: Iterable[Any], right: Iterable[Any],
writer: TerminalWriter, highligher: _HighlightFunc,
verbose: int = 0, verbose: int = 0,
) -> List[str]: ) -> List[str]:
if verbose <= 0 and not running_on_ci(): if verbose <= 0 and not running_on_ci():
@ -353,7 +359,7 @@ def _compare_eq_iterable(
# "right" is the expected base against which we compare "left", # "right" is the expected base against which we compare "left",
# see https://github.com/pytest-dev/pytest/issues/3333 # see https://github.com/pytest-dev/pytest/issues/3333
explanation.extend( explanation.extend(
writer._highlight( highligher(
"\n".join( "\n".join(
line.rstrip() line.rstrip()
for line in difflib.ndiff(right_formatting, left_formatting) for line in difflib.ndiff(right_formatting, left_formatting)
@ -510,7 +516,7 @@ def _compare_eq_dict(
def _compare_eq_cls( def _compare_eq_cls(
left: Any, right: Any, writer: TerminalWriter, verbose: int left: Any, right: Any, highlighter: _HighlightFunc, verbose: int
) -> List[str]: ) -> List[str]:
if not has_default_eq(left): if not has_default_eq(left):
return [] return []
@ -557,7 +563,9 @@ def _compare_eq_cls(
] ]
explanation += [ explanation += [
indent + line indent + line
for line in _compare_eq_any(field_left, field_right, writer, verbose) for line in _compare_eq_any(
field_left, field_right, highlighter, verbose
)
] ]
return explanation return explanation