Remove astor and reproduce the original assertion expression
This commit is contained in:
		
							parent
							
								
									3c9b46f781
								
							
						
					
					
						commit
						7ee244476a
					
				| 
						 | 
					@ -1 +0,0 @@
 | 
				
			||||||
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.
 | 
					 | 
				
			||||||
							
								
								
									
										1
									
								
								setup.py
								
								
								
								
							
							
						
						
									
										1
									
								
								setup.py
								
								
								
								
							| 
						 | 
					@ -13,7 +13,6 @@ INSTALL_REQUIRES = [
 | 
				
			||||||
    "pluggy>=0.12,<1.0",
 | 
					    "pluggy>=0.12,<1.0",
 | 
				
			||||||
    "importlib-metadata>=0.12",
 | 
					    "importlib-metadata>=0.12",
 | 
				
			||||||
    "wcwidth",
 | 
					    "wcwidth",
 | 
				
			||||||
    "astor",
 | 
					 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,16 +1,18 @@
 | 
				
			||||||
"""Rewrite assertion AST to produce nice error messages"""
 | 
					"""Rewrite assertion AST to produce nice error messages"""
 | 
				
			||||||
import ast
 | 
					import ast
 | 
				
			||||||
import errno
 | 
					import errno
 | 
				
			||||||
 | 
					import functools
 | 
				
			||||||
import importlib.machinery
 | 
					import importlib.machinery
 | 
				
			||||||
import importlib.util
 | 
					import importlib.util
 | 
				
			||||||
 | 
					import io
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import marshal
 | 
					import marshal
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import struct
 | 
					import struct
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
 | 
					import tokenize
 | 
				
			||||||
import types
 | 
					import types
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import astor
 | 
					 | 
				
			||||||
import atomicwrites
 | 
					import atomicwrites
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from _pytest._io.saferepr import saferepr
 | 
					from _pytest._io.saferepr import saferepr
 | 
				
			||||||
| 
						 | 
					@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
 | 
				
			||||||
    with open(fn, "rb") as f:
 | 
					    with open(fn, "rb") as f:
 | 
				
			||||||
        source = f.read()
 | 
					        source = f.read()
 | 
				
			||||||
    tree = ast.parse(source, filename=fn)
 | 
					    tree = ast.parse(source, filename=fn)
 | 
				
			||||||
    rewrite_asserts(tree, fn, config)
 | 
					    rewrite_asserts(tree, source, fn, config)
 | 
				
			||||||
    co = compile(tree, fn, "exec", dont_inherit=True)
 | 
					    co = compile(tree, fn, "exec", dont_inherit=True)
 | 
				
			||||||
    return stat, co
 | 
					    return stat, co
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
 | 
				
			||||||
        return co
 | 
					        return co
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rewrite_asserts(mod, module_path=None, config=None):
 | 
					def rewrite_asserts(mod, source, module_path=None, config=None):
 | 
				
			||||||
    """Rewrite the assert statements in mod."""
 | 
					    """Rewrite the assert statements in mod."""
 | 
				
			||||||
    AssertionRewriter(module_path, config).run(mod)
 | 
					    AssertionRewriter(module_path, config, source).run(mod)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _saferepr(obj):
 | 
					def _saferepr(obj):
 | 
				
			||||||
| 
						 | 
					@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
 | 
				
			||||||
    return node
 | 
					    return node
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_assertion_exprs(src: bytes):  # -> Dict[int, str]
 | 
				
			||||||
 | 
					    """Returns a mapping from {lineno: "assertion test expression"}"""
 | 
				
			||||||
 | 
					    ret = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    depth = 0
 | 
				
			||||||
 | 
					    lines = []
 | 
				
			||||||
 | 
					    assert_lineno = None
 | 
				
			||||||
 | 
					    seen_lines = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _write_and_reset() -> None:
 | 
				
			||||||
 | 
					        nonlocal depth, lines, assert_lineno, seen_lines
 | 
				
			||||||
 | 
					        ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
 | 
				
			||||||
 | 
					        depth = 0
 | 
				
			||||||
 | 
					        lines = []
 | 
				
			||||||
 | 
					        assert_lineno = None
 | 
				
			||||||
 | 
					        seen_lines = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    tokens = tokenize.tokenize(io.BytesIO(src).readline)
 | 
				
			||||||
 | 
					    for tp, src, (lineno, offset), _, line in tokens:
 | 
				
			||||||
 | 
					        if tp == tokenize.NAME and src == "assert":
 | 
				
			||||||
 | 
					            assert_lineno = lineno
 | 
				
			||||||
 | 
					        elif assert_lineno is not None:
 | 
				
			||||||
 | 
					            # keep track of depth for the assert-message `,` lookup
 | 
				
			||||||
 | 
					            if tp == tokenize.OP and src in "([{":
 | 
				
			||||||
 | 
					                depth += 1
 | 
				
			||||||
 | 
					            elif tp == tokenize.OP and src in ")]}":
 | 
				
			||||||
 | 
					                depth -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not lines:
 | 
				
			||||||
 | 
					                lines.append(line[offset:])
 | 
				
			||||||
 | 
					                seen_lines.add(lineno)
 | 
				
			||||||
 | 
					            # a non-nested comma separates the expression from the message
 | 
				
			||||||
 | 
					            elif depth == 0 and tp == tokenize.OP and src == ",":
 | 
				
			||||||
 | 
					                # one line assert with message
 | 
				
			||||||
 | 
					                if lineno in seen_lines and len(lines) == 1:
 | 
				
			||||||
 | 
					                    offset_in_trimmed = offset + len(lines[-1]) - len(line)
 | 
				
			||||||
 | 
					                    lines[-1] = lines[-1][:offset_in_trimmed]
 | 
				
			||||||
 | 
					                # multi-line assert with message
 | 
				
			||||||
 | 
					                elif lineno in seen_lines:
 | 
				
			||||||
 | 
					                    lines[-1] = lines[-1][:offset]
 | 
				
			||||||
 | 
					                # multi line assert with escapd newline before message
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    lines.append(line[:offset])
 | 
				
			||||||
 | 
					                _write_and_reset()
 | 
				
			||||||
 | 
					            elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
 | 
				
			||||||
 | 
					                _write_and_reset()
 | 
				
			||||||
 | 
					            elif lines and lineno not in seen_lines:
 | 
				
			||||||
 | 
					                lines.append(line)
 | 
				
			||||||
 | 
					                seen_lines.add(lineno)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return ret
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AssertionRewriter(ast.NodeVisitor):
 | 
					class AssertionRewriter(ast.NodeVisitor):
 | 
				
			||||||
    """Assertion rewriting implementation.
 | 
					    """Assertion rewriting implementation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, module_path, config):
 | 
					    def __init__(self, module_path, config, source):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.module_path = module_path
 | 
					        self.module_path = module_path
 | 
				
			||||||
        self.config = config
 | 
					        self.config = config
 | 
				
			||||||
| 
						 | 
					@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.enable_assertion_pass_hook = False
 | 
					            self.enable_assertion_pass_hook = False
 | 
				
			||||||
 | 
					        self.source = source
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @functools.lru_cache(maxsize=1)
 | 
				
			||||||
 | 
					    def _assert_expr_to_lineno(self):
 | 
				
			||||||
 | 
					        return _get_assertion_exprs(self.source)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def run(self, mod):
 | 
					    def run(self, mod):
 | 
				
			||||||
        """Find all assert statements in *mod* and rewrite them."""
 | 
					        """Find all assert statements in *mod* and rewrite them."""
 | 
				
			||||||
| 
						 | 
					@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Passed
 | 
					            # Passed
 | 
				
			||||||
            fmt_pass = self.helper("_format_explanation", msg)
 | 
					            fmt_pass = self.helper("_format_explanation", msg)
 | 
				
			||||||
            orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
 | 
					            orig = self._assert_expr_to_lineno()[assert_.lineno]
 | 
				
			||||||
            hook_call_pass = ast.Expr(
 | 
					            hook_call_pass = ast.Expr(
 | 
				
			||||||
                self.helper(
 | 
					                self.helper(
 | 
				
			||||||
                    "_call_assertion_pass",
 | 
					                    "_call_assertion_pass",
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,7 @@ import py
 | 
				
			||||||
import _pytest._code
 | 
					import _pytest._code
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from _pytest.assertion import util
 | 
					from _pytest.assertion import util
 | 
				
			||||||
 | 
					from _pytest.assertion.rewrite import _get_assertion_exprs
 | 
				
			||||||
from _pytest.assertion.rewrite import AssertionRewritingHook
 | 
					from _pytest.assertion.rewrite import AssertionRewritingHook
 | 
				
			||||||
from _pytest.assertion.rewrite import PYTEST_TAG
 | 
					from _pytest.assertion.rewrite import PYTEST_TAG
 | 
				
			||||||
from _pytest.assertion.rewrite import rewrite_asserts
 | 
					from _pytest.assertion.rewrite import rewrite_asserts
 | 
				
			||||||
| 
						 | 
					@ -31,7 +32,7 @@ def teardown_module(mod):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def rewrite(src):
 | 
					def rewrite(src):
 | 
				
			||||||
    tree = ast.parse(src)
 | 
					    tree = ast.parse(src)
 | 
				
			||||||
    rewrite_asserts(tree)
 | 
					    rewrite_asserts(tree, src.encode())
 | 
				
			||||||
    return tree
 | 
					    return tree
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1292,7 +1293,7 @@ class TestEarlyRewriteBailout:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        p = testdir.makepyfile(
 | 
					        p = testdir.makepyfile(
 | 
				
			||||||
            **{
 | 
					            **{
 | 
				
			||||||
                "tests/file.py": """
 | 
					                "tests/file.py": """\
 | 
				
			||||||
                    def test_simple_failure():
 | 
					                    def test_simple_failure():
 | 
				
			||||||
                        assert 1 + 1 == 3
 | 
					                        assert 1 + 1 == 3
 | 
				
			||||||
                """
 | 
					                """
 | 
				
			||||||
| 
						 | 
					@ -1315,7 +1316,7 @@ class TestEarlyRewriteBailout:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        testdir.makepyfile(
 | 
					        testdir.makepyfile(
 | 
				
			||||||
            **{
 | 
					            **{
 | 
				
			||||||
                "test_setup_nonexisting_cwd.py": """
 | 
					                "test_setup_nonexisting_cwd.py": """\
 | 
				
			||||||
                    import os
 | 
					                    import os
 | 
				
			||||||
                    import shutil
 | 
					                    import shutil
 | 
				
			||||||
                    import tempfile
 | 
					                    import tempfile
 | 
				
			||||||
| 
						 | 
					@ -1324,7 +1325,7 @@ class TestEarlyRewriteBailout:
 | 
				
			||||||
                    os.chdir(d)
 | 
					                    os.chdir(d)
 | 
				
			||||||
                    shutil.rmtree(d)
 | 
					                    shutil.rmtree(d)
 | 
				
			||||||
                """,
 | 
					                """,
 | 
				
			||||||
                "test_test.py": """
 | 
					                "test_test.py": """\
 | 
				
			||||||
                    def test():
 | 
					                    def test():
 | 
				
			||||||
                        pass
 | 
					                        pass
 | 
				
			||||||
                """,
 | 
					                """,
 | 
				
			||||||
| 
						 | 
					@ -1339,23 +1340,22 @@ class TestAssertionPass:
 | 
				
			||||||
        config = testdir.parseconfig()
 | 
					        config = testdir.parseconfig()
 | 
				
			||||||
        assert config.getini("enable_assertion_pass_hook") is False
 | 
					        assert config.getini("enable_assertion_pass_hook") is False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_hook_call(self, testdir):
 | 
					    @pytest.fixture
 | 
				
			||||||
 | 
					    def flag_on(self, testdir):
 | 
				
			||||||
 | 
					        testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @pytest.fixture
 | 
				
			||||||
 | 
					    def hook_on(self, testdir):
 | 
				
			||||||
        testdir.makeconftest(
 | 
					        testdir.makeconftest(
 | 
				
			||||||
            """
 | 
					            """\
 | 
				
			||||||
            def pytest_assertion_pass(item, lineno, orig, expl):
 | 
					            def pytest_assertion_pass(item, lineno, orig, expl):
 | 
				
			||||||
                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
 | 
					                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        testdir.makeini(
 | 
					    def test_hook_call(self, testdir, flag_on, hook_on):
 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
        [pytest]
 | 
					 | 
				
			||||||
        enable_assertion_pass_hook = True
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					        testdir.makepyfile(
 | 
				
			||||||
            """
 | 
					            """\
 | 
				
			||||||
            def test_simple():
 | 
					            def test_simple():
 | 
				
			||||||
                a=1
 | 
					                a=1
 | 
				
			||||||
                b=2
 | 
					                b=2
 | 
				
			||||||
| 
						 | 
					@ -1374,7 +1374,18 @@ class TestAssertionPass:
 | 
				
			||||||
            "*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
 | 
					            "*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
 | 
					    def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
 | 
				
			||||||
 | 
					        testdir.makepyfile(
 | 
				
			||||||
 | 
					            """\
 | 
				
			||||||
 | 
					            def f(): return 1
 | 
				
			||||||
 | 
					            def test():
 | 
				
			||||||
 | 
					                assert f()
 | 
				
			||||||
 | 
					            """
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        result = testdir.runpytest()
 | 
				
			||||||
 | 
					        result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
 | 
				
			||||||
        """Assertion pass should not be called (and hence formatting should
 | 
					        """Assertion pass should not be called (and hence formatting should
 | 
				
			||||||
        not occur) if there is no hook declared for pytest_assertion_pass"""
 | 
					        not occur) if there is no hook declared for pytest_assertion_pass"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1385,15 +1396,8 @@ class TestAssertionPass:
 | 
				
			||||||
            _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
 | 
					            _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        testdir.makeini(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
        [pytest]
 | 
					 | 
				
			||||||
        enable_assertion_pass_hook = True
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					        testdir.makepyfile(
 | 
				
			||||||
            """
 | 
					            """\
 | 
				
			||||||
            def test_simple():
 | 
					            def test_simple():
 | 
				
			||||||
                a=1
 | 
					                a=1
 | 
				
			||||||
                b=2
 | 
					                b=2
 | 
				
			||||||
| 
						 | 
					@ -1418,21 +1422,14 @@ class TestAssertionPass:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        testdir.makeconftest(
 | 
					        testdir.makeconftest(
 | 
				
			||||||
            """
 | 
					            """\
 | 
				
			||||||
            def pytest_assertion_pass(item, lineno, orig, expl):
 | 
					            def pytest_assertion_pass(item, lineno, orig, expl):
 | 
				
			||||||
                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
 | 
					                raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        testdir.makeini(
 | 
					 | 
				
			||||||
            """
 | 
					 | 
				
			||||||
        [pytest]
 | 
					 | 
				
			||||||
        enable_assertion_pass_hook = False
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        testdir.makepyfile(
 | 
					        testdir.makepyfile(
 | 
				
			||||||
            """
 | 
					            """\
 | 
				
			||||||
            def test_simple():
 | 
					            def test_simple():
 | 
				
			||||||
                a=1
 | 
					                a=1
 | 
				
			||||||
                b=2
 | 
					                b=2
 | 
				
			||||||
| 
						 | 
					@ -1444,3 +1441,90 @@ class TestAssertionPass:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        result = testdir.runpytest()
 | 
					        result = testdir.runpytest()
 | 
				
			||||||
        result.assert_outcomes(passed=1)
 | 
					        result.assert_outcomes(passed=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ("src", "expected"),
 | 
				
			||||||
 | 
					    (
 | 
				
			||||||
 | 
					        # fmt: off
 | 
				
			||||||
 | 
					        pytest.param(b"", {}, id="trivial"),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x(): assert 1\n",
 | 
				
			||||||
 | 
					            {1: "1"},
 | 
				
			||||||
 | 
					            id="assert statement not on own line",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert 1\n"
 | 
				
			||||||
 | 
					            b"    assert 1+2\n",
 | 
				
			||||||
 | 
					            {2: "1", 3: "1+2"},
 | 
				
			||||||
 | 
					            id="multiple assertions",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            # changes in encoding cause the byte offsets to be different
 | 
				
			||||||
 | 
					            "# -*- coding: latin1\n"
 | 
				
			||||||
 | 
					            "def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
 | 
				
			||||||
 | 
					            {2: "1"},
 | 
				
			||||||
 | 
					            id="latin1 encoded on first line\n",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            # using the default utf-8 encoding
 | 
				
			||||||
 | 
					            "def ÀÀÀÀÀ(): assert 1\n".encode(),
 | 
				
			||||||
 | 
					            {1: "1"},
 | 
				
			||||||
 | 
					            id="utf-8 encoded on first line",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert (\n"
 | 
				
			||||||
 | 
					            b"        1 + 2  # comment\n"
 | 
				
			||||||
 | 
					            b"    )\n",
 | 
				
			||||||
 | 
					            {2: "(\n        1 + 2  # comment\n    )"},
 | 
				
			||||||
 | 
					            id="multi-line assertion",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert y == [\n"
 | 
				
			||||||
 | 
					            b"        1, 2, 3\n"
 | 
				
			||||||
 | 
					            b"    ]\n",
 | 
				
			||||||
 | 
					            {2: "y == [\n        1, 2, 3\n    ]"},
 | 
				
			||||||
 | 
					            id="multi line assert with list continuation",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert 1 + \\\n"
 | 
				
			||||||
 | 
					            b"        2\n",
 | 
				
			||||||
 | 
					            {2: "1 + \\\n        2"},
 | 
				
			||||||
 | 
					            id="backslash continuation",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert x, y\n",
 | 
				
			||||||
 | 
					            {2: "x"},
 | 
				
			||||||
 | 
					            id="assertion with message",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert (\n"
 | 
				
			||||||
 | 
					            b"        f(1, 2, 3)\n"
 | 
				
			||||||
 | 
					            b"    ),  'f did not work!'\n",
 | 
				
			||||||
 | 
					            {2: "(\n        f(1, 2, 3)\n    )"},
 | 
				
			||||||
 | 
					            id="assertion with message, test spanning multiple lines",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x():\n"
 | 
				
			||||||
 | 
					            b"    assert \\\n"
 | 
				
			||||||
 | 
					            b"        x\\\n"
 | 
				
			||||||
 | 
					            b"        , 'failure message'\n",
 | 
				
			||||||
 | 
					            {2: "x"},
 | 
				
			||||||
 | 
					            id="escaped newlines plus message",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        pytest.param(
 | 
				
			||||||
 | 
					            b"def x(): assert 5",
 | 
				
			||||||
 | 
					            {1: "5"},
 | 
				
			||||||
 | 
					            id="no newline at end of file",
 | 
				
			||||||
 | 
					        ),
 | 
				
			||||||
 | 
					        # fmt: on
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_get_assertion_exprs(src, expected):
 | 
				
			||||||
 | 
					    assert _get_assertion_exprs(src) == expected
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue