From 52f818d2bec133cfdd77261e8415aeeed53811a3 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 1 Mar 2023 01:06:43 +0100 Subject: [PATCH] add test on variables that have been overwritten are cleared after each test --- src/_pytest/assertion/rewrite.py | 14 +++++++------- testing/test_assertrewrite.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 61f2ebf25..ded658ba2 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -657,7 +657,7 @@ class AssertionRewriter(ast.NodeVisitor): else: self.enable_assertion_pass_hook = False self.source = source - self.overwrite: Dict[str, str] = {} + self.variables_overwrite: Dict[str, str] = {} def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -982,7 +982,7 @@ class AssertionRewriter(ast.NodeVisitor): self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner if isinstance(v, ast.Compare): - if isinstance(v.left, ast.NamedExpr) and ( + if isinstance(v.left, namedExpr) and ( v.left.target.id in [ ast_expr.id @@ -992,7 +992,7 @@ class AssertionRewriter(ast.NodeVisitor): or v.left.target.id == pytest_temp ): pytest_temp = f"pytest_{v.left.target.id}_temp" - self.overwrite[v.left.target.id] = pytest_temp + self.variables_overwrite[v.left.target.id] = pytest_temp v.left.target.id = pytest_temp elif isinstance(v.left, ast.Name) and ( @@ -1075,8 +1075,8 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() - if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite: - comp.left.id = self.overwrite[comp.left.id] + if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: + comp.left.id = self.variables_overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1089,12 +1089,12 @@ class AssertionRewriter(ast.NodeVisitor): results = [left_res] for i, op, next_operand in it: if ( - isinstance(next_operand, ast.NamedExpr) + isinstance(next_operand, namedExpr) and isinstance(left_res, ast.Name) and next_operand.target.id == left_res.id ): next_operand.target.id = f"pytest_{left_res.id}_temp" - self.overwrite[left_res.id] = next_operand.target.id + self.variables_overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index fcede242f..8d9441403 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1418,6 +1418,23 @@ class TestIssue10743: assert result.ret == 1 result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + def test_assertion_walrus_operator_value_changes_cleared_after_each_test( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_value(): + a = True + assert (a := None) is None + + def test_walrus_operator_not_override_value(): + a = True + assert a is True + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"