add visit_NamedExpr in assertrewrite

This commit is contained in:
Alessio Izzo 2023-02-21 00:02:15 +01:00
parent 9ccae9a8e3
commit 948f5e83ea
No known key found for this signature in database
GPG Key ID: 2B5983EE2D924936
2 changed files with 168 additions and 2 deletions

View File

@ -666,7 +666,7 @@ class AssertionRewriter(ast.NodeVisitor):
if doc is not None and self.is_rewrite_disabled(doc):
return
pos = 0
lineno = 1
item = None
for item in mod.body:
if (
expect_docstring
@ -937,6 +937,17 @@ class AssertionRewriter(ast.NodeVisitor):
ast.copy_location(node, assert_)
return self.statements
def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]:
# Display the repr of the target name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id # type: ignore[attr-defined]
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Str(target_id))
return name, self.explanation_param(expr)
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
# Display the repr of the name if it's a local variable or
# _should_repr_global_name() thinks it's acceptable.
@ -1050,7 +1061,7 @@ class AssertionRewriter(ast.NodeVisitor):
results = [left_res]
for i, op, next_operand in it:
next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
if isinstance(next_operand, (ast.Compare, ast.BoolOp, ast.NamedExpr)):
next_expl = f"({next_expl})"
results.append(next_res)
sym = BINOP_MAP[op.__class__]
@ -1072,6 +1083,7 @@ class AssertionRewriter(ast.NodeVisitor):
res: ast.expr = ast.BoolOp(ast.And(), load_names)
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))

View File

@ -1265,6 +1265,160 @@ class TestIssue2121:
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
class TestIssue10743:
def test_assertion_walrus_operator(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after
def change_value(value):
return value.lower()
def test_walrus_conversion():
a = "Hello"
assert not my_func(a, a := change_value(a))
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
'PYTEST_DONT_REWRITE'
def my_func(before, after):
return before == after
def change_value(value):
return value.lower()
def test_walrus_conversion_dont_rewrite():
a = "Hello"
assert not my_func(a, a := change_value(a))
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after
def test_walrus_conversion_inline():
a = "Hello"
assert not my_func(a, a := a.lower())
assert a == "hello"
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def my_func(before, after):
return before == after
def test_walrus_conversion_reverse():
a = "Hello"
assert my_func(a := a.lower(), a)
assert a == 'hello'
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_walrus_no_variable_name_conflict(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_no_conflict():
a = "Hello"
assert a == (b := a.lower())
"""
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
def test_assertion_walrus_operator_true_assertion_and_changes_variable_value(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_succeed():
a = "Hello"
assert a != (a := a.lower())
assert a == 'hello'
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None:
pytester.makepyfile(
"""
def test_walrus_conversion_fails():
a = "Hello"
assert a == (a := a.lower())
"""
)
result = pytester.runpytest()
assert result.ret == 1
# This is not optimal as error message but it depends on how the rewrite is structured
result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"])
def test_assertion_walrus_operator_boolean_composite(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
"""
)
result = pytester.runpytest()
assert result.ret == 0
def test_assertion_walrus_operator_compare_boolean_fails(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert not (a and ((a := False) is False))
"""
)
result = pytester.runpytest()
assert result.ret == 1
# This is not optimal as error message but it depends on how the rewrite is structured
result.stdout.fnmatch_lines(["*assert not (False)"])
def test_assertion_walrus_operator_boolean_none_fails(
self, pytester: Pytester
) -> None:
pytester.makepyfile(
"""
def test_walrus_operator_change_boolean_value():
a = True
assert not (a and ((a := None) is None))
"""
)
result = pytester.runpytest()
assert result.ret == 1
# This is not optimal as error message but it depends on how the rewrite is structured
result.stdout.fnmatch_lines(["*assert not (None)"])
@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)