diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 2f011bba8..61f2ebf25 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -657,6 +657,7 @@ class AssertionRewriter(ast.NodeVisitor): else: self.enable_assertion_pass_hook = False self.source = source + self.overwrite: Dict[str, str] = {} def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -943,8 +944,8 @@ class AssertionRewriter(ast.NodeVisitor): return self.statements def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: - # Display the repr of the target name if it's a local variable or - # _should_repr_global_name() thinks it's acceptable. + # This method handles the 'walrus operator' repr of the target + # name if it's a local variable or _should_repr_global_name() thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id # type: ignore[attr-defined] inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) @@ -973,12 +974,32 @@ class AssertionRewriter(ast.NodeVisitor): levels = len(boolop.values) - 1 self.push_format_context() # Process each operand, short-circuiting if needed. + pytest_temp = None for i, v in enumerate(boolop.values): if i: fail_inner: List[ast.stmt] = [] # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner + if isinstance(v, ast.Compare): + if isinstance(v.left, ast.NamedExpr) and ( + v.left.target.id + in [ + ast_expr.id + for ast_expr in boolop.values[:i] + if hasattr(ast_expr, "id") + ] + or v.left.target.id == pytest_temp + ): + pytest_temp = f"pytest_{v.left.target.id}_temp" + self.overwrite[v.left.target.id] = pytest_temp + v.left.target.id = pytest_temp + + elif isinstance(v.left, ast.Name) and ( + pytest_temp is not None + and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp") + ): + v.left.id = pytest_temp self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1054,6 +1075,8 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() + if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite: + comp.left.id = self.overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1065,6 +1088,13 @@ class AssertionRewriter(ast.NodeVisitor): syms = [] results = [left_res] for i, op, next_operand in it: + if ( + isinstance(next_operand, ast.NamedExpr) + and isinstance(left_res, ast.Name) + and next_operand.target.id == left_res.id + ): + next_operand.target.id = f"pytest_{left_res.id}_temp" + self.overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 624947d74..fcede242f 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1374,8 +1374,7 @@ class TestIssue10743: ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"]) + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) def test_assertion_walrus_operator_boolean_composite( self, pytester: Pytester @@ -1385,7 +1384,7 @@ class TestIssue10743: def test_walrus_operator_change_boolean_value(): a = True assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None) - + assert a is None """ ) result = pytester.runpytest() @@ -1403,8 +1402,7 @@ class TestIssue10743: ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*assert not (False)"]) + result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1418,8 +1416,7 @@ class TestIssue10743: ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*assert not (None)"]) + result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) @pytest.mark.skipif(