From 9c4f6791e5a545c2fc53f68ef9e3ce031fa6842b Mon Sep 17 00:00:00 2001 From: Benjamin Peterson Date: Tue, 24 May 2011 17:21:58 -0500 Subject: [PATCH] give initial imports a reasonable lineno --- _pytest/assertrewrite.py | 7 +++++-- testing/test_assertrewrite.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/_pytest/assertrewrite.py b/_pytest/assertrewrite.py index f30e817d9..be49b2266 100644 --- a/_pytest/assertrewrite.py +++ b/_pytest/assertrewrite.py @@ -92,18 +92,21 @@ class AssertionRewriter(ast.NodeVisitor): aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"), ast.alias("py", "@pylib"), ast.alias("_pytest.assertrewrite", "@pytest_ar")] - imports = [ast.Import([alias], lineno=0, col_offset=0) - for alias in aliases] expect_docstring = True pos = 0 + lineno = 0 for item in mod.body: if (expect_docstring and isinstance(item, ast.Expr) and isinstance(item.value, ast.Str)): + lineno += len(item.value.s.splitlines()) - 1 expect_docstring = False elif (not isinstance(item, ast.ImportFrom) or item.level > 0 and item.identifier != "__future__"): + lineno = item.lineno break pos += 1 + imports = [ast.Import([alias], lineno=lineno, col_offset=0) + for alias in aliases] mod.body[pos:pos] = imports # Collect asserts. nodes = collections.deque([mod]) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index f6b74d97e..a3d831b22 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -54,12 +54,16 @@ class TestAssertionRewrite: assert isinstance(m.body[0].value, ast.Str) for imp in m.body[1:4]: assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 assert isinstance(m.body[4], ast.Assign) s = """from __future__ import with_statement\nother_stuff""" m = rewrite(s) assert isinstance(m.body[0], ast.ImportFrom) for imp in m.body[1:4]: assert isinstance(imp, ast.Import) + assert imp.lineno == 2 + assert imp.col_offset == 0 assert isinstance(m.body[4], ast.Expr) s = """'doc string'\nfrom __future__ import with_statement\nother""" m = rewrite(s) @@ -68,6 +72,8 @@ class TestAssertionRewrite: assert isinstance(m.body[1], ast.ImportFrom) for imp in m.body[2:5]: assert isinstance(imp, ast.Import) + assert imp.lineno == 3 + assert imp.col_offset == 0 assert isinstance(m.body[5], ast.Expr) def test_name(self):