diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index f661fe947..33e2ef6cc 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -685,8 +685,8 @@ class AssertionRewriter(ast.NodeVisitor): # Nothing to do. return - # Insert some special imports at the top of the module but after any - # docstrings and __future__ imports. + # We'll insert some special imports at the top of the module, but after any + # docstrings and __future__ imports, so first figure out where that is. doc = getattr(mod, "docstring", None) expect_docstring = doc is None if doc is not None and self.is_rewrite_disabled(doc): @@ -718,6 +718,7 @@ class AssertionRewriter(ast.NodeVisitor): lineno = item.decorator_list[0].lineno else: lineno = item.lineno + # Now actually insert the special imports. if sys.version_info >= (3, 10): aliases = [ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), @@ -737,6 +738,7 @@ class AssertionRewriter(ast.NodeVisitor): ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ] mod.body[pos:pos] = imports + # Collect asserts. nodes: List[ast.AST] = [mod] while nodes: