rewrite: fixup end_lineno, end_col_offset of rewritten asserts

These are new additions in Python 3.8:
https://docs.python.org/3/whatsnew/3.8.html#ast
I'm not sure what's using them but we should set them anyway.
This commit is contained in:
Ran Benita 2021-10-05 09:52:09 +03:00 committed by Taiju Yamada
parent fef682a53b
commit f205b1e104
3 changed files with 32 additions and 15 deletions

View File

@ -0,0 +1 @@
The end line number and end column offset are now properly set for rewritten assert statements.

View File

@ -575,19 +575,12 @@ else:
return ast.Name(str(c), ast.Load()) return ast.Name(str(c), ast.Load())
def set_location(node, lineno, col_offset): def traverse_node(node):
"""Set node location information recursively.""" """Recursively yield node and all its children in depth-first order."""
yield node
def _fix(node, lineno, col_offset):
if "lineno" in node._attributes:
node.lineno = lineno
if "col_offset" in node._attributes:
node.col_offset = col_offset
for child in ast.iter_child_nodes(node): for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset) for descendant in traverse_node(child):
yield descendant
_fix(node, lineno, col_offset)
return node
class AssertionRewriter(ast.NodeVisitor): class AssertionRewriter(ast.NodeVisitor):
@ -868,9 +861,10 @@ class AssertionRewriter(ast.NodeVisitor):
variables = [ast.Name(name, ast.Store()) for name in self.variables] variables = [ast.Name(name, ast.Store()) for name in self.variables]
clear = ast.Assign(variables, _NameConstant(None)) clear = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear) self.statements.append(clear)
# Fix line numbers. # Fix locations (line numbers/column offsets).
for stmt in self.statements: for stmt in self.statements:
set_location(stmt, assert_.lineno, assert_.col_offset) for node in traverse_node(stmt):
ast.copy_location(node, assert_)
return self.statements return self.statements
def warn_about_none_ast(self, node, module_path, lineno): def warn_about_none_ast(self, node, module_path, lineno):

View File

@ -111,6 +111,28 @@ class TestAssertionRewrite(object):
assert imp.col_offset == 0 assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr) assert isinstance(m.body[3], ast.Expr)
def test_location_is_set(self):
s = textwrap.dedent(
"""
assert False, (
"Ouch"
)
"""
)
m = rewrite(s)
for node in m.body:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert n.lineno == 3
assert n.col_offset == 0
if sys.version_info >= (3, 8):
assert n.end_lineno == 6
assert n.end_col_offset == 3
def test_dont_rewrite(self): def test_dont_rewrite(self):
s = """'PYTEST_DONT_REWRITE'\nassert 14""" s = """'PYTEST_DONT_REWRITE'\nassert 14"""
m = rewrite(s) m = rewrite(s)