diff --git a/_pytest/assertion/__init__.py b/_pytest/assertion/__init__.py index 746c810ee..271a2e7d5 100644 --- a/_pytest/assertion/__init__.py +++ b/_pytest/assertion/__init__.py @@ -7,6 +7,7 @@ import sys from _pytest.monkeypatch import monkeypatch from _pytest.assertion import util +from _pytest.assertion import rewrite def pytest_addoption(parser): @@ -26,6 +27,34 @@ def pytest_addoption(parser): provide assert expression information. """) +def pytest_namespace(): + return {'register_assert_rewrite': register_assert_rewrite} + + +def register_assert_rewrite(*names): + """Register a module name to be rewritten on import. + + This function will make sure that the module will get it's assert + statements rewritten when it is imported. Thus you should make + sure to call this before the module is actually imported, usually + in your __init__.py if you are a plugin using a package. + """ + for hook in sys.meta_path: + if isinstance(hook, rewrite.AssertionRewritingHook): + importhook = hook + break + else: + importhook = DummyRewriteHook() + importhook.mark_rewrite(*names) + + +class DummyRewriteHook(object): + """A no-op import hook for when rewriting is disabled.""" + + def mark_rewrite(self, *names): + pass + + class AssertionState: """State for the assertion plugin.""" diff --git a/_pytest/assertion/rewrite.py b/_pytest/assertion/rewrite.py index aa33f1352..50d8062ae 100644 --- a/_pytest/assertion/rewrite.py +++ b/_pytest/assertion/rewrite.py @@ -163,9 +163,9 @@ class AssertionRewritingHook(object): self.session = session del session else: - toplevel_name = name.split('.', 1)[0] - if toplevel_name in self._must_rewrite: - return True + for marked in self._must_rewrite: + if marked.startswith(name): + return True return False def mark_rewrite(self, *names): diff --git a/_pytest/config.py b/_pytest/config.py index 5ac120ab3..536c2fb34 100644 --- a/_pytest/config.py +++ b/_pytest/config.py @@ -11,6 +11,7 @@ import py import sys, os import _pytest._code import _pytest.hookspec # the extension point definitions +import _pytest.assertion from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker hookimpl = HookimplMarker("pytest") @@ -154,6 +155,9 @@ class PytestPluginManager(PluginManager): self.trace.root.setwriter(err.write) self.enable_tracing() + # Config._consider_importhook will set a real object if required. + self.rewrite_hook = _pytest.assertion.DummyRewriteHook() + def addhooks(self, module_or_class): """ .. deprecated:: 2.8 @@ -362,7 +366,9 @@ class PytestPluginManager(PluginManager): self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS")) def consider_module(self, mod): - self._import_plugin_specs(getattr(mod, "pytest_plugins", None)) + plugins = getattr(mod, 'pytest_plugins', []) + self.rewrite_hook.mark_rewrite(*plugins) + self._import_plugin_specs(plugins) def _import_plugin_specs(self, spec): if spec: @@ -926,15 +932,13 @@ class Config(object): and find all the installed plugins to mark them for re-writing by the importhook. """ - import _pytest.assertion ns, unknown_args = self._parser.parse_known_and_unknown_args(args) mode = ns.assertmode - if ns.noassert or ns.nomagic: - mode = "plain" self._warn_about_missing_assertion(mode) if mode != 'plain': hook = _pytest.assertion.install_importhook(self, mode) if hook: + self.pluginmanager.rewrite_hook = hook for entrypoint in pkg_resources.iter_entry_points('pytest11'): for entry in entrypoint.dist._get_metadata('RECORD'): fn = entry.split(',')[0] diff --git a/testing/test_assertion.py b/testing/test_assertion.py index 0346cb9a9..215d3e419 100644 --- a/testing/test_assertion.py +++ b/testing/test_assertion.py @@ -63,22 +63,53 @@ class TestImportHookInstallation: assert 0 result.stdout.fnmatch_lines([expected]) + @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) + def test_pytest_plugins_rewrite(self, testdir, mode): + contents = { + 'conftest.py': """ + pytest_plugins = ['ham'] + """, + 'ham.py': """ + import pytest + @pytest.fixture + def check_first(): + def check(values, value): + assert values.pop(0) == value + return check + """, + 'test_foo.py': """ + def test_foo(check_first): + check_first([10, 30], 30) + """, + } + testdir.makepyfile(**contents) + result = testdir.runpytest_subprocess('--assert=%s' % mode) + if mode == 'plain': + expected = 'E AssertionError' + elif mode == 'rewrite': + expected = '*assert 10 == 30*' + elif mode == 'reinterp': + expected = '*AssertionError:*was re-run*' + else: + assert 0 + result.stdout.fnmatch_lines([expected]) + @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp']) def test_installed_plugin_rewrite(self, testdir, mode): # Make sure the hook is installed early enough so that plugins # installed via setuptools are re-written. - ham = testdir.tmpdir.join('hampkg').ensure(dir=1) - ham.join('__init__.py').write(""" -import pytest + testdir.tmpdir.join('hampkg').ensure(dir=1) + contents = { + 'hampkg/__init__.py': """ + import pytest -@pytest.fixture -def check_first2(): - def check(values, value): - assert values.pop(0) == value - return check - """) - testdir.makepyfile( - spamplugin=""" + @pytest.fixture + def check_first2(): + def check(values, value): + assert values.pop(0) == value + return check + """, + 'spamplugin.py': """ import pytest from hampkg import check_first2 @@ -88,7 +119,7 @@ def check_first2(): assert values.pop(0) == value return check """, - mainwrapper=""" + 'mainwrapper.py': """ import pytest, pkg_resources class DummyDistInfo: @@ -116,14 +147,15 @@ def check_first2(): pkg_resources.iter_entry_points = iter_entry_points pytest.main() """, - test_foo=""" + 'test_foo.py': """ def test(check_first): check_first([10, 30], 30) def test2(check_first2): check_first([10, 30], 30) """, - ) + } + testdir.makepyfile(**contents) result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode) if mode == 'plain': expected = 'E AssertionError' @@ -135,6 +167,47 @@ def check_first2(): assert 0 result.stdout.fnmatch_lines([expected]) + def test_rewrite_ast(self, testdir): + testdir.tmpdir.join('pkg').ensure(dir=1) + contents = { + 'pkg/__init__.py': """ + import pytest + pytest.register_assert_rewrite('pkg.helper') + """, + 'pkg/helper.py': """ + def tool(): + a, b = 2, 3 + assert a == b + """, + 'pkg/plugin.py': """ + import pytest, pkg.helper + @pytest.fixture + def tool(): + return pkg.helper.tool + """, + 'pkg/other.py': """ + l = [3, 2] + def tool(): + assert l.pop() == 3 + """, + 'conftest.py': """ + pytest_plugins = ['pkg.plugin'] + """, + 'test_pkg.py': """ + import pkg.other + def test_tool(tool): + tool() + def test_other(): + pkg.other.tool() + """, + } + testdir.makepyfile(**contents) + result = testdir.runpytest_subprocess('--assert=rewrite') + result.stdout.fnmatch_lines(['>*assert a == b*', + 'E*assert 2 == 3*', + '>*assert l.pop() == 3*', + 'E*AssertionError*re-run*']) + class TestBinReprIntegration: