diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 6091186ca..27133fa02 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -17,13 +17,8 @@ def rewrite_asserts(mod): _saferepr = py.io.saferepr from _pytest.assertion.util import format_explanation as _format_explanation -def _format_boolop(operands, explanations, is_or): - show_explanations = [] - for operand, expl in zip(operands, explanations): - show_explanations.append(expl) - if operand == is_or: - break - return "(" + (is_or and " or " or " and ").join(show_explanations) + ")" +def _format_boolop(explanations, is_or): + return "(" + (is_or and " or " or " and ").join(explanations) + ")" def _call_reprcompare(ops, results, expls, each_obj): for i, res, expl in zip(range(len(ops)), results, expls): @@ -143,7 +138,7 @@ class AssertionRewriter(ast.NodeVisitor): """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. name = "@py_assert" + str(next(self.variable_counter)) - self.variables.add(name) + self.variables[self.cond_chain].add(name) return name def assign(self, expr): @@ -172,6 +167,13 @@ class AssertionRewriter(ast.NodeVisitor): self.explanation_specifiers[specifier] = expr return "%(" + specifier + ")s" + def enter_cond(self, cond, body): + self.statements.append(ast.If(cond, body, [])) + self.cond_chain += cond, + + def leave_cond(self, n=1): + self.cond_chain = self.cond_chain[:-n] + def push_format_context(self): self.explanation_specifiers = {} self.stack.append(self.explanation_specifiers) @@ -198,7 +200,8 @@ class AssertionRewriter(ast.NodeVisitor): # There's already a message. Don't mess with it. return [assert_] self.statements = [] - self.variables = set() + self.cond_chain = () + self.variables = collections.defaultdict(set) self.variable_counter = itertools.count() self.stack = [] self.on_failure = [] @@ -220,11 +223,22 @@ class AssertionRewriter(ast.NodeVisitor): else: raise_ = ast.Raise(exc, None, None) body.append(raise_) - # Delete temporary variables. - names = [ast.Name(name, ast.Del()) for name in self.variables] - if names: - delete = ast.Delete(names) - self.statements.append(delete) + # Delete temporary variables. This requires a bit cleverness about the + # order, so we don't delete variables that are themselves conditions for + # later variables. + for chain in sorted(self.variables, key=len, reverse=True): + if chain: + where = [] + if len(chain) > 1: + cond = ast.Boolop(ast.And(), chain) + else: + cond = chain[0] + self.statements.append(ast.If(cond, where, [])) + else: + where = self.statements + v = self.variables[chain] + names = [ast.Name(name, ast.Del()) for name in v] + where.append(ast.Delete(names)) # Fix line numbers. for stmt in self.statements: set_location(stmt, assert_.lineno, assert_.col_offset) @@ -240,21 +254,32 @@ class AssertionRewriter(ast.NodeVisitor): return name, self.explanation_param(expr) def visit_BoolOp(self, boolop): - operands = [] - explanations = [] + res_var = self.variable() + expl_list = self.assign(ast.List([], ast.Load())) + app = ast.Attribute(expl_list, "append", ast.Load()) + is_or = isinstance(boolop.op, ast.Or) + body = save = self.statements + levels = len(boolop.values) - 1 self.push_format_context() - for operand in boolop.values: - res, explanation = self.visit(operand) - operands.append(res) - explanations.append(explanation) - expls = ast.Tuple([ast.Str(expl) for expl in explanations], ast.Load()) - is_or = ast.Num(isinstance(boolop.op, ast.Or)) - expl_template = self.helper("format_boolop", - ast.Tuple(operands, ast.Load()), expls, - is_or) + # Process each operand, short-circuting if needed. + for i, v in enumerate(boolop.values): + res, expl = self.visit(v) + body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) + call = ast.Call(app, [ast.Str(expl)], [], None, None) + body.append(ast.Expr(call)) + if i < levels: + inner = [] + cond = res + if is_or: + cond = ast.UnaryOp(ast.Not(), cond) + self.enter_cond(cond, inner) + self.statements = body = inner + # Leave all conditions. + self.leave_cond(levels) + self.statements = save + expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or)) expl = self.pop_format_context(expl_template) - res = self.assign(ast.BoolOp(boolop.op, operands)) - return res, self.explanation_param(expl) + return ast.Name(res_var, ast.Load()), self.explanation_param(expl) def visit_UnaryOp(self, unary): pattern = unary_map[unary.op.__class__] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index ca38ff2bf..62a256db5 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -128,6 +128,10 @@ class TestAssertionRewrite: f = g = False assert f or g assert getmsg(f) == "assert (False or False)" + def f(): + f = g = False + assert not f and not g + getmsg(f, must_pass=True) def f(): f = True g = False @@ -135,10 +139,13 @@ class TestAssertionRewrite: getmsg(f, must_pass=True) def test_short_circut_evaluation(self): - pytest.xfail("complicated fix; I'm not sure if it's important") def f(): assert True or explode getmsg(f, must_pass=True) + def f(): + x = 1 + assert x == 1 or x == 2 + getmsg(f, must_pass=True) def test_unary_op(self): def f():