From 5341b9cd67dad64d1d96dd0851861ff5f7f874dc Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 9 Sep 2023 14:09:31 +0200 Subject: [PATCH 1/2] Fix assert rewriting with assignment expressions (#11414) Fixes #11239 (cherry picked from commit 7259e8db9844f6f973c1d0c0ce46cc68c8248abb) --- AUTHORS | 1 + changelog/11239.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 54 +++++++++++++++++++++++--------- testing/test_assertrewrite.py | 21 +++++++++++++ 4 files changed, 63 insertions(+), 14 deletions(-) create mode 100644 changelog/11239.bugfix.rst diff --git a/AUTHORS b/AUTHORS index e8456d92b..6d860575f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -231,6 +231,7 @@ Maho Maik Figura Mandeep Bhutani Manuel Krebber +Marc Mueller Marc Schlaich Marcelo Duarte Trevisani Marcin Bachry diff --git a/changelog/11239.bugfix.rst b/changelog/11239.bugfix.rst new file mode 100644 index 000000000..a486224cd --- /dev/null +++ b/changelog/11239.bugfix.rst @@ -0,0 +1 @@ +Fixed ``:=`` in asserts impacting unrelated test cases. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index ab83fee32..fd2355297 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -13,6 +13,7 @@ import struct import sys import tokenize import types +from collections import defaultdict from pathlib import Path from pathlib import PurePath from typing import Callable @@ -56,6 +57,10 @@ else: astNum = ast.Num +class Sentinel: + pass + + assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -63,6 +68,9 @@ PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +# Special marker that denotes we have just left a scope definition +_SCOPE_END_MARKER = Sentinel() + class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -645,6 +653,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. + :scope: A tuple containing the current scope used for variables_overwrite. + :variables_overwrite: A dict filled with references to variables that change value within an assert. This happens when a variable is reassigned with the walrus operator @@ -666,7 +676,10 @@ class AssertionRewriter(ast.NodeVisitor): else: self.enable_assertion_pass_hook = False self.source = source - self.variables_overwrite: Dict[str, str] = {} + self.scope: tuple[ast.AST, ...] = () + self.variables_overwrite: defaultdict[ + tuple[ast.AST, ...], Dict[str, str] + ] = defaultdict(dict) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -732,9 +745,17 @@ class AssertionRewriter(ast.NodeVisitor): mod.body[pos:pos] = imports # Collect asserts. - nodes: List[ast.AST] = [mod] + self.scope = (mod,) + nodes: List[Union[ast.AST, Sentinel]] = [mod] while nodes: node = nodes.pop() + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self.scope = tuple((*self.scope, node)) + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: + self.scope = self.scope[:-1] + continue + assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): new: List[ast.AST] = [] @@ -1005,7 +1026,7 @@ class AssertionRewriter(ast.NodeVisitor): ] ): pytest_temp = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ v.left.target.id ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp @@ -1048,17 +1069,20 @@ class AssertionRewriter(ast.NodeVisitor): new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: - arg = self.variables_overwrite[arg.id] # type:ignore[assignment] + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( + self.scope, {} + ): + arg = self.variables_overwrite[self.scope][ + arg.id + ] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if ( - isinstance(keyword.value, ast.Name) - and keyword.value.id in self.variables_overwrite - ): - keyword.value = self.variables_overwrite[ + if isinstance( + keyword.value, ast.Name + ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): + keyword.value = self.variables_overwrite[self.scope][ keyword.value.id ] # type:ignore[assignment] res, expl = self.visit(keyword.value) @@ -1094,12 +1118,14 @@ class AssertionRewriter(ast.NodeVisitor): def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left = self.variables_overwrite[ + if isinstance( + comp.left, ast.Name + ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): + comp.left = self.variables_overwrite[self.scope][ comp.left.id ] # type:ignore[assignment] if isinstance(comp.left, namedExpr): - self.variables_overwrite[ + self.variables_overwrite[self.scope][ comp.left.target.id ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) @@ -1119,7 +1145,7 @@ class AssertionRewriter(ast.NodeVisitor): and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ left_res.id ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 778f843e6..63353438c 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1531,6 +1531,27 @@ class TestIssue11028: result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) +class TestIssue11239: + def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: + """Regression for (#11239) + + Walrus operator rewriting would leak to separate test cases if they used the same variables. + """ + pytester.makepyfile( + """ + def test_1(): + state = {"x": 2}.get("x") + assert state is not None + + def test_2(): + db = {"x": 2} + assert (state := db.get("x")) is not None + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" ) From 721a0881fb99c8f749701d1301cf19a948a9554f Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Sat, 9 Sep 2023 09:38:10 -0300 Subject: [PATCH 2/2] Skip test_assertion_walrus_different_test_cases on Python 3.7 --- testing/test_assertrewrite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 63353438c..fbf285495 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1532,6 +1532,7 @@ class TestIssue11028: class TestIssue11239: + @pytest.mark.skipif(sys.version_info[:2] <= (3, 7), reason="Only Python 3.8+") def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: """Regression for (#11239)