diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 42664add4..548a3fce3 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -666,7 +666,7 @@ class AssertionRewriter(ast.NodeVisitor): if doc is not None and self.is_rewrite_disabled(doc): return pos = 0 - lineno = 1 + item = None for item in mod.body: if ( expect_docstring @@ -937,6 +937,17 @@ class AssertionRewriter(ast.NodeVisitor): ast.copy_location(node, assert_) return self.statements + def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]: + # Display the repr of the target name if it's a local variable or + # _should_repr_global_name() thinks it's acceptable. + locs = ast.Call(self.builtin("locals"), [], []) + target_id = name.target.id # type: ignore[attr-defined] + inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) + dorepr = self.helper("_should_repr_global_name", name) + test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) + expr = ast.IfExp(test, self.display(name), ast.Str(target_id)) + return name, self.explanation_param(expr) + def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. @@ -1050,7 +1061,7 @@ class AssertionRewriter(ast.NodeVisitor): results = [left_res] for i, op, next_operand in it: next_res, next_expl = self.visit(next_operand) - if isinstance(next_operand, (ast.Compare, ast.BoolOp)): + if isinstance(next_operand, (ast.Compare, ast.BoolOp, ast.NamedExpr)): next_expl = f"({next_expl})" results.append(next_res) sym = BINOP_MAP[op.__class__] @@ -1072,6 +1083,7 @@ class AssertionRewriter(ast.NodeVisitor): res: ast.expr = ast.BoolOp(ast.And(), load_names) else: res = load_names[0] + return res, self.explanation_param(self.pop_format_context(expl_call)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 3c98392ed..89f2f8f6e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1265,6 +1265,160 @@ class TestIssue2121: result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) +class TestIssue10743: + def test_assertion_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + 'PYTEST_DONT_REWRITE' + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion_dont_rewrite(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_inline(): + a = "Hello" + assert not my_func(a, a := a.lower()) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_reverse(): + a = "Hello" + assert my_func(a := a.lower(), a) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_no_variable_name_conflict( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_no_conflict(): + a = "Hello" + assert a == (b := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) + + def test_assertion_walrus_operator_true_assertion_and_changes_variable_value( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_succeed(): + a = "Hello" + assert a != (a := a.lower()) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_fails(): + a = "Hello" + assert a == (a := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"]) + + def test_assertion_walrus_operator_boolean_composite( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None) + + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_compare_boolean_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := False) is False)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*assert not (False)"]) + + def test_assertion_walrus_operator_boolean_none_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := None) is None)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*assert not (None)"]) + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )