add check and replace on already visited variables that have been changing by the walrus operator
This commit is contained in:
parent
3e90bf573f
commit
78f3a963cc
|
@ -657,6 +657,7 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
else:
|
else:
|
||||||
self.enable_assertion_pass_hook = False
|
self.enable_assertion_pass_hook = False
|
||||||
self.source = source
|
self.source = source
|
||||||
|
self.overwrite: Dict[str, str] = {}
|
||||||
|
|
||||||
def run(self, mod: ast.Module) -> None:
|
def run(self, mod: ast.Module) -> None:
|
||||||
"""Find all assert statements in *mod* and rewrite them."""
|
"""Find all assert statements in *mod* and rewrite them."""
|
||||||
|
@ -943,8 +944,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
return self.statements
|
return self.statements
|
||||||
|
|
||||||
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
|
||||||
# Display the repr of the target name if it's a local variable or
|
# This method handles the 'walrus operator' repr of the target
|
||||||
# _should_repr_global_name() thinks it's acceptable.
|
# name if it's a local variable or _should_repr_global_name() thinks it's acceptable.
|
||||||
locs = ast.Call(self.builtin("locals"), [], [])
|
locs = ast.Call(self.builtin("locals"), [], [])
|
||||||
target_id = name.target.id # type: ignore[attr-defined]
|
target_id = name.target.id # type: ignore[attr-defined]
|
||||||
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
|
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
|
||||||
|
@ -973,12 +974,32 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
levels = len(boolop.values) - 1
|
levels = len(boolop.values) - 1
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
# Process each operand, short-circuiting if needed.
|
# Process each operand, short-circuiting if needed.
|
||||||
|
pytest_temp = None
|
||||||
for i, v in enumerate(boolop.values):
|
for i, v in enumerate(boolop.values):
|
||||||
if i:
|
if i:
|
||||||
fail_inner: List[ast.stmt] = []
|
fail_inner: List[ast.stmt] = []
|
||||||
# cond is set in a prior loop iteration below
|
# cond is set in a prior loop iteration below
|
||||||
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
|
||||||
self.expl_stmts = fail_inner
|
self.expl_stmts = fail_inner
|
||||||
|
if isinstance(v, ast.Compare):
|
||||||
|
if isinstance(v.left, ast.NamedExpr) and (
|
||||||
|
v.left.target.id
|
||||||
|
in [
|
||||||
|
ast_expr.id
|
||||||
|
for ast_expr in boolop.values[:i]
|
||||||
|
if hasattr(ast_expr, "id")
|
||||||
|
]
|
||||||
|
or v.left.target.id == pytest_temp
|
||||||
|
):
|
||||||
|
pytest_temp = f"pytest_{v.left.target.id}_temp"
|
||||||
|
self.overwrite[v.left.target.id] = pytest_temp
|
||||||
|
v.left.target.id = pytest_temp
|
||||||
|
|
||||||
|
elif isinstance(v.left, ast.Name) and (
|
||||||
|
pytest_temp is not None
|
||||||
|
and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp")
|
||||||
|
):
|
||||||
|
v.left.id = pytest_temp
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
res, expl = self.visit(v)
|
res, expl = self.visit(v)
|
||||||
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
|
||||||
|
@ -1054,6 +1075,8 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
|
|
||||||
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
|
||||||
self.push_format_context()
|
self.push_format_context()
|
||||||
|
if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite:
|
||||||
|
comp.left.id = self.overwrite[comp.left.id]
|
||||||
left_res, left_expl = self.visit(comp.left)
|
left_res, left_expl = self.visit(comp.left)
|
||||||
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
|
||||||
left_expl = f"({left_expl})"
|
left_expl = f"({left_expl})"
|
||||||
|
@ -1065,6 +1088,13 @@ class AssertionRewriter(ast.NodeVisitor):
|
||||||
syms = []
|
syms = []
|
||||||
results = [left_res]
|
results = [left_res]
|
||||||
for i, op, next_operand in it:
|
for i, op, next_operand in it:
|
||||||
|
if (
|
||||||
|
isinstance(next_operand, ast.NamedExpr)
|
||||||
|
and isinstance(left_res, ast.Name)
|
||||||
|
and next_operand.target.id == left_res.id
|
||||||
|
):
|
||||||
|
next_operand.target.id = f"pytest_{left_res.id}_temp"
|
||||||
|
self.overwrite[left_res.id] = next_operand.target.id
|
||||||
next_res, next_expl = self.visit(next_operand)
|
next_res, next_expl = self.visit(next_operand)
|
||||||
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
|
||||||
next_expl = f"({next_expl})"
|
next_expl = f"({next_expl})"
|
||||||
|
|
|
@ -1374,8 +1374,7 @@ class TestIssue10743:
|
||||||
)
|
)
|
||||||
result = pytester.runpytest()
|
result = pytester.runpytest()
|
||||||
assert result.ret == 1
|
assert result.ret == 1
|
||||||
# This is not optimal as error message but it depends on how the rewrite is structured
|
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
|
||||||
result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"])
|
|
||||||
|
|
||||||
def test_assertion_walrus_operator_boolean_composite(
|
def test_assertion_walrus_operator_boolean_composite(
|
||||||
self, pytester: Pytester
|
self, pytester: Pytester
|
||||||
|
@ -1385,7 +1384,7 @@ class TestIssue10743:
|
||||||
def test_walrus_operator_change_boolean_value():
|
def test_walrus_operator_change_boolean_value():
|
||||||
a = True
|
a = True
|
||||||
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
|
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
|
||||||
|
assert a is None
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result = pytester.runpytest()
|
result = pytester.runpytest()
|
||||||
|
@ -1403,8 +1402,7 @@ class TestIssue10743:
|
||||||
)
|
)
|
||||||
result = pytester.runpytest()
|
result = pytester.runpytest()
|
||||||
assert result.ret == 1
|
assert result.ret == 1
|
||||||
# This is not optimal as error message but it depends on how the rewrite is structured
|
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
|
||||||
result.stdout.fnmatch_lines(["*assert not (False)"])
|
|
||||||
|
|
||||||
def test_assertion_walrus_operator_boolean_none_fails(
|
def test_assertion_walrus_operator_boolean_none_fails(
|
||||||
self, pytester: Pytester
|
self, pytester: Pytester
|
||||||
|
@ -1418,8 +1416,7 @@ class TestIssue10743:
|
||||||
)
|
)
|
||||||
result = pytester.runpytest()
|
result = pytester.runpytest()
|
||||||
assert result.ret == 1
|
assert result.ret == 1
|
||||||
# This is not optimal as error message but it depends on how the rewrite is structured
|
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
|
||||||
result.stdout.fnmatch_lines(["*assert not (None)"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
|
|
Loading…
Reference in New Issue