From 3427d27d5a517a265502feba359df4f7c46ed611 Mon Sep 17 00:00:00 2001 From: Sviatoslav Abakumov Date: Wed, 25 Oct 2017 10:54:43 +0300 Subject: [PATCH 1/3] Try to get docstring from module node --- _pytest/assertion/rewrite.py | 11 ++++++++--- testing/test_assertrewrite.py | 34 +++++++++++++++++++++------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 992002b81..06687a0c8 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -595,15 +595,17 @@ class AssertionRewriter(ast.NodeVisitor): # docstrings and __future__ imports. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("_pytest.assertion.rewrite", "@pytest_ar")] - expect_docstring = True + doc = getattr(mod, "docstring", None) + expect_docstring = doc is None + if doc is not None and self.is_rewrite_disabled(doc): + return pos = 0 lineno = 0 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): doc = item.value.s - if "PYTEST_DONT_REWRITE" in doc: - # The module has disabled assertion rewriting. + if self.is_rewrite_disabled(doc): return lineno += len(doc) - 1 expect_docstring = False @@ -637,6 +639,9 @@ class AssertionRewriter(ast.NodeVisitor): not isinstance(field, ast.expr)): nodes.append(field) + def is_rewrite_disabled(self, docstring): + return "PYTEST_DONT_REWRITE" in docstring + def variable(self): """Get a new variable.""" # Use a character invalid in python identifiers to avoid clashing. diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2d61b7440..c935a7862 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -65,13 +65,15 @@ class TestAssertionRewrite(object): def test_place_initial_imports(self): s = """'Doc string'\nother = stuff""" m = rewrite(s) - assert isinstance(m.body[0], ast.Expr) - assert isinstance(m.body[0].value, ast.Str) - for imp in m.body[1:3]: + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + for imp in m.body[0:2]: assert isinstance(imp, ast.Import) assert imp.lineno == 2 assert imp.col_offset == 0 - assert isinstance(m.body[3], ast.Assign) + assert isinstance(m.body[2], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) @@ -82,14 +84,16 @@ class TestAssertionRewrite(object): assert isinstance(m.body[3], ast.Expr) s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) - assert isinstance(m.body[0], ast.Expr) - assert isinstance(m.body[0].value, ast.Str) - assert isinstance(m.body[1], ast.ImportFrom) - for imp in m.body[2:4]: + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:3]: assert isinstance(imp, ast.Import) assert imp.lineno == 3 assert imp.col_offset == 0 - assert isinstance(m.body[4], ast.Expr) + assert isinstance(m.body[3], ast.Expr) s = """from . import relative\nother_stuff""" m = rewrite(s) for imp in m.body[0:2]: @@ -101,10 +105,14 @@ class TestAssertionRewrite(object): def test_dont_rewrite(self): s = """'PYTEST_DONT_REWRITE'\nassert 14""" m = rewrite(s) - assert len(m.body) == 2 - assert isinstance(m.body[0].value, ast.Str) - assert isinstance(m.body[1], ast.Assert) - assert m.body[1].msg is None + if sys.version_info < (3, 7): + assert len(m.body) == 2 + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + else: + assert len(m.body) == 1 + assert m.body[0].msg is None def test_name(self): def f(): From fd7bfa30d0a6d2415f5669bdbe035dba1fcad5b2 Mon Sep 17 00:00:00 2001 From: Sviatoslav Abakumov Date: Wed, 25 Oct 2017 11:05:07 +0300 Subject: [PATCH 2/3] Put imports on the last line unless there are other exprs --- _pytest/assertion/rewrite.py | 5 +++-- testing/test_assertrewrite.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index 06687a0c8..c3966340a 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -600,20 +600,21 @@ class AssertionRewriter(ast.NodeVisitor): if doc is not None and self.is_rewrite_disabled(doc): return pos = 0 - lineno = 0 + lineno = 1 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): doc = item.value.s if self.is_rewrite_disabled(doc): return - lineno += len(doc) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or item.module != "__future__"): lineno = item.lineno break pos += 1 + else: + lineno = item.lineno imports = [ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases] mod.body[pos:pos] = imports diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index c935a7862..45c0c7b16 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -82,6 +82,17 @@ class TestAssertionRewrite(object): assert imp.lineno == 2 assert imp.col_offset == 0 assert isinstance(m.body[3], ast.Expr) + s = """'doc string'\nfrom __future__ import with_statement""" + m = rewrite(s) + if sys.version_info < (3, 7): + assert isinstance(m.body[0], ast.Expr) + assert isinstance(m.body[0].value, ast.Str) + del m.body[0] + assert isinstance(m.body[0], ast.ImportFrom) + for imp in m.body[1:3]: + assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) if sys.version_info < (3, 7): From 27bb2eceb43f9c2c3747cf3c0a2e999292d382fc Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Thu, 26 Oct 2017 20:15:05 -0200 Subject: [PATCH 3/3] Add comment about why we remove docstrings on test_assertrewrite As explained in pytest-dev/pytest#2870 --- testing/test_assertrewrite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 45c0c7b16..02270e157 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -65,6 +65,9 @@ class TestAssertionRewrite(object): def test_place_initial_imports(self): s = """'Doc string'\nother = stuff""" m = rewrite(s) + # Module docstrings in 3.7 are part of Module node, it's not in the body + # so we remove it so the following body items have the same indexes on + # all Python versions if sys.version_info < (3, 7): assert isinstance(m.body[0], ast.Expr) assert isinstance(m.body[0].value, ast.Str)