correctly handle multiple asserts

This commit is contained in:
Benjamin Peterson 2011-05-19 18:56:48 -05:00
parent 9ac818fb5c
commit aae89cd021
2 changed files with 13 additions and 5 deletions

View File

@ -99,19 +99,22 @@ class AssertionRewriter(ast.NodeVisitor):
node = nodes.popleft() node = nodes.popleft()
for name, field in ast.iter_fields(node): for name, field in ast.iter_fields(node):
if isinstance(field, list): if isinstance(field, list):
new = []
for i, child in enumerate(field): for i, child in enumerate(field):
if isinstance(child, ast.Assert): if isinstance(child, ast.Assert):
# Transform assert.
new.extend(self.visit(child))
asserts.append((field, i, child)) asserts.append((field, i, child))
elif isinstance(child, ast.AST): else:
new.append(child)
if isinstance(child, ast.AST):
nodes.append(child) nodes.append(child)
setattr(node, name, new)
elif (isinstance(field, ast.AST) and elif (isinstance(field, ast.AST) and
# Don't recurse into expressions as they can't contain # Don't recurse into expressions as they can't contain
# asserts. # asserts.
not isinstance(field, ast.expr)): not isinstance(field, ast.expr)):
nodes.append(field) nodes.append(field)
# Transform asserts.
for parent, pos, assert_ in asserts:
parent[pos:pos + 1] = self.visit(assert_)
def variable(self): def variable(self):
"""Get a new variable.""" """Get a new variable."""

View File

@ -196,6 +196,11 @@ class TestAssertionRewrite:
a, b, c = range(3) a, b, c = range(3)
assert a < b <= c assert a < b <= c
getmsg(f, must_pass=True) getmsg(f, must_pass=True)
def f():
a, b, c = range(3)
assert a < b
assert b < c
getmsg(f, must_pass=True)
def test_len(self): def test_len(self):
def f(): def f():