add test on variables that have been overwritten are cleared after each test

This commit is contained in:
Alessio Izzo 2023-03-01 01:06:43 +01:00
parent 78f3a963cc
commit 52f818d2be
No known key found for this signature in database
GPG Key ID: 2B5983EE2D924936
2 changed files with 24 additions and 7 deletions

View File

@ -657,7 +657,7 @@ class AssertionRewriter(ast.NodeVisitor):
else:
self.enable_assertion_pass_hook = False
self.source = source
self.overwrite: Dict[str, str] = {}
self.variables_overwrite: Dict[str, str] = {}
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
@ -982,7 +982,7 @@ class AssertionRewriter(ast.NodeVisitor):
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
if isinstance(v, ast.Compare):
if isinstance(v.left, ast.NamedExpr) and (
if isinstance(v.left, namedExpr) and (
v.left.target.id
in [
ast_expr.id
@ -992,7 +992,7 @@ class AssertionRewriter(ast.NodeVisitor):
or v.left.target.id == pytest_temp
):
pytest_temp = f"pytest_{v.left.target.id}_temp"
self.overwrite[v.left.target.id] = pytest_temp
self.variables_overwrite[v.left.target.id] = pytest_temp
v.left.target.id = pytest_temp
elif isinstance(v.left, ast.Name) and (
@ -1075,8 +1075,8 @@ class AssertionRewriter(ast.NodeVisitor):
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
self.push_format_context()
if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite:
comp.left.id = self.overwrite[comp.left.id]
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
comp.left.id = self.variables_overwrite[comp.left.id]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
@ -1089,12 +1089,12 @@ class AssertionRewriter(ast.NodeVisitor):
results = [left_res]
for i, op, next_operand in it:
if (
isinstance(next_operand, ast.NamedExpr)
isinstance(next_operand, namedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
next_operand.target.id = f"pytest_{left_res.id}_temp"
self.overwrite[left_res.id] = next_operand.target.id
self.variables_overwrite[left_res.id] = next_operand.target.id
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
next_expl = f"({next_expl})"

View File

@ -1418,6 +1418,23 @@ class TestIssue10743:
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_value():
a = True
assert (a := None) is None
def test_walrus_operator_not_override_value():
a = True
assert a is True
"""
)
result = pytester.runpytest()
assert result.ret == 0
@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"