correctly handle multiple asserts
This commit is contained in:
parent
9ac818fb5c
commit
aae89cd021
|
@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
node = nodes.popleft()
|
node = nodes.popleft()
|
||||||
for name, field in ast.iter_fields(node):
|
for name, field in ast.iter_fields(node):
|
||||||
if isinstance(field, list):
|
if isinstance(field, list):
|
||||||
|
new = []
|
||||||
for i, child in enumerate(field):
|
for i, child in enumerate(field):
|
||||||
if isinstance(child, ast.Assert):
|
if isinstance(child, ast.Assert):
|
||||||
|
# Transform assert.
|
||||||
|
new.extend(self.visit(child))
|
||||||
asserts.append((field, i, child))
|
asserts.append((field, i, child))
|
||||||
elif isinstance(child, ast.AST):
|
else:
|
||||||
nodes.append(child)
|
new.append(child)
|
||||||
|
if isinstance(child, ast.AST):
|
||||||
|
nodes.append(child)
|
||||||
|
setattr(node, name, new)
|
||||||
elif (isinstance(field, ast.AST) and
|
elif (isinstance(field, ast.AST) and
|
||||||
# Don't recurse into expressions as they can't contain
|
# Don't recurse into expressions as they can't contain
|
||||||
# asserts.
|
# asserts.
|
||||||
not isinstance(field, ast.expr)):
|
not isinstance(field, ast.expr)):
|
||||||
nodes.append(field)
|
nodes.append(field)
|
||||||
# Transform asserts.
|
|
||||||
for parent, pos, assert_ in asserts:
|
|
||||||
parent[pos:pos + 1] = self.visit(assert_)
|
|
||||||
|
|
||||||
def variable(self):
|
def variable(self):
|
||||||
"""Get a new variable."""
|
"""Get a new variable."""
|
||||||
|
|
|
@ -196,6 +196,11 @@ class TestAssertionRewrite:
|
||||||
a, b, c = range(3)
|
a, b, c = range(3)
|
||||||
assert a < b <= c
|
assert a < b <= c
|
||||||
getmsg(f, must_pass=True)
|
getmsg(f, must_pass=True)
|
||||||
|
def f():
|
||||||
|
a, b, c = range(3)
|
||||||
|
assert a < b
|
||||||
|
assert b < c
|
||||||
|
getmsg(f, must_pass=True)
|
||||||
|
|
||||||
def test_len(self):
|
def test_len(self):
|
||||||
def f():
|
def f():
|
||||||
|
|
Loading…
Reference in New Issue