Remove astor and reproduce the original assertion expression
This commit is contained in:
@@ -1,16 +1,18 @@
|
||||
"""Rewrite assertion AST to produce nice error messages"""
|
||||
import ast
|
||||
import errno
|
||||
import functools
|
||||
import importlib.machinery
|
||||
import importlib.util
|
||||
import io
|
||||
import itertools
|
||||
import marshal
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import tokenize
|
||||
import types
|
||||
|
||||
import astor
|
||||
import atomicwrites
|
||||
|
||||
from _pytest._io.saferepr import saferepr
|
||||
@@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
|
||||
with open(fn, "rb") as f:
|
||||
source = f.read()
|
||||
tree = ast.parse(source, filename=fn)
|
||||
rewrite_asserts(tree, fn, config)
|
||||
rewrite_asserts(tree, source, fn, config)
|
||||
co = compile(tree, fn, "exec", dont_inherit=True)
|
||||
return stat, co
|
||||
|
||||
@@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
|
||||
return co
|
||||
|
||||
|
||||
def rewrite_asserts(mod, module_path=None, config=None):
|
||||
def rewrite_asserts(mod, source, module_path=None, config=None):
|
||||
"""Rewrite the assert statements in mod."""
|
||||
AssertionRewriter(module_path, config).run(mod)
|
||||
AssertionRewriter(module_path, config, source).run(mod)
|
||||
|
||||
|
||||
def _saferepr(obj):
|
||||
@@ -457,6 +459,59 @@ def set_location(node, lineno, col_offset):
|
||||
return node
|
||||
|
||||
|
||||
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
|
||||
"""Returns a mapping from {lineno: "assertion test expression"}"""
|
||||
ret = {}
|
||||
|
||||
depth = 0
|
||||
lines = []
|
||||
assert_lineno = None
|
||||
seen_lines = set()
|
||||
|
||||
def _write_and_reset() -> None:
|
||||
nonlocal depth, lines, assert_lineno, seen_lines
|
||||
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
|
||||
depth = 0
|
||||
lines = []
|
||||
assert_lineno = None
|
||||
seen_lines = set()
|
||||
|
||||
tokens = tokenize.tokenize(io.BytesIO(src).readline)
|
||||
for tp, src, (lineno, offset), _, line in tokens:
|
||||
if tp == tokenize.NAME and src == "assert":
|
||||
assert_lineno = lineno
|
||||
elif assert_lineno is not None:
|
||||
# keep track of depth for the assert-message `,` lookup
|
||||
if tp == tokenize.OP and src in "([{":
|
||||
depth += 1
|
||||
elif tp == tokenize.OP and src in ")]}":
|
||||
depth -= 1
|
||||
|
||||
if not lines:
|
||||
lines.append(line[offset:])
|
||||
seen_lines.add(lineno)
|
||||
# a non-nested comma separates the expression from the message
|
||||
elif depth == 0 and tp == tokenize.OP and src == ",":
|
||||
# one line assert with message
|
||||
if lineno in seen_lines and len(lines) == 1:
|
||||
offset_in_trimmed = offset + len(lines[-1]) - len(line)
|
||||
lines[-1] = lines[-1][:offset_in_trimmed]
|
||||
# multi-line assert with message
|
||||
elif lineno in seen_lines:
|
||||
lines[-1] = lines[-1][:offset]
|
||||
# multi line assert with escapd newline before message
|
||||
else:
|
||||
lines.append(line[:offset])
|
||||
_write_and_reset()
|
||||
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
|
||||
_write_and_reset()
|
||||
elif lines and lineno not in seen_lines:
|
||||
lines.append(line)
|
||||
seen_lines.add(lineno)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class AssertionRewriter(ast.NodeVisitor):
|
||||
"""Assertion rewriting implementation.
|
||||
|
||||
@@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, module_path, config):
|
||||
def __init__(self, module_path, config, source):
|
||||
super().__init__()
|
||||
self.module_path = module_path
|
||||
self.config = config
|
||||
@@ -521,6 +576,11 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
)
|
||||
else:
|
||||
self.enable_assertion_pass_hook = False
|
||||
self.source = source
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _assert_expr_to_lineno(self):
|
||||
return _get_assertion_exprs(self.source)
|
||||
|
||||
def run(self, mod):
|
||||
"""Find all assert statements in *mod* and rewrite them."""
|
||||
@@ -738,7 +798,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||
|
||||
# Passed
|
||||
fmt_pass = self.helper("_format_explanation", msg)
|
||||
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
|
||||
orig = self._assert_expr_to_lineno()[assert_.lineno]
|
||||
hook_call_pass = ast.Expr(
|
||||
self.helper(
|
||||
"_call_assertion_pass",
|
||||
|
||||
Reference in New Issue
Block a user