Introduce pytest.register_assert_rewrite()
Plugins can now explicitly mark modules to be re-written. By default only the modules containing the plugin entrypoint are re-written.
This commit is contained in:
		
							parent
							
								
									944da5b98a
								
							
						
					
					
						commit
						743f59afb2
					
				| 
						 | 
				
			
			@ -7,6 +7,7 @@ import sys
 | 
			
		|||
 | 
			
		||||
from _pytest.monkeypatch import monkeypatch
 | 
			
		||||
from _pytest.assertion import util
 | 
			
		||||
from _pytest.assertion import rewrite
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_addoption(parser):
 | 
			
		||||
| 
						 | 
				
			
			@ -26,6 +27,34 @@ def pytest_addoption(parser):
 | 
			
		|||
                            provide assert expression information. """)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def pytest_namespace():
 | 
			
		||||
    return {'register_assert_rewrite': register_assert_rewrite}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def register_assert_rewrite(*names):
 | 
			
		||||
    """Register a module name to be rewritten on import.
 | 
			
		||||
 | 
			
		||||
    This function will make sure that the module will get it's assert
 | 
			
		||||
    statements rewritten when it is imported.  Thus you should make
 | 
			
		||||
    sure to call this before the module is actually imported, usually
 | 
			
		||||
    in your __init__.py if you are a plugin using a package.
 | 
			
		||||
    """
 | 
			
		||||
    for hook in sys.meta_path:
 | 
			
		||||
        if isinstance(hook, rewrite.AssertionRewritingHook):
 | 
			
		||||
            importhook = hook
 | 
			
		||||
            break
 | 
			
		||||
    else:
 | 
			
		||||
        importhook = DummyRewriteHook()
 | 
			
		||||
    importhook.mark_rewrite(*names)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DummyRewriteHook(object):
 | 
			
		||||
    """A no-op import hook for when rewriting is disabled."""
 | 
			
		||||
 | 
			
		||||
    def mark_rewrite(self, *names):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AssertionState:
 | 
			
		||||
    """State for the assertion plugin."""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -163,8 +163,8 @@ class AssertionRewritingHook(object):
 | 
			
		|||
                    self.session = session
 | 
			
		||||
                    del session
 | 
			
		||||
        else:
 | 
			
		||||
            toplevel_name = name.split('.', 1)[0]
 | 
			
		||||
            if toplevel_name in self._must_rewrite:
 | 
			
		||||
            for marked in self._must_rewrite:
 | 
			
		||||
                if marked.startswith(name):
 | 
			
		||||
                    return True
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -11,6 +11,7 @@ import py
 | 
			
		|||
import sys, os
 | 
			
		||||
import _pytest._code
 | 
			
		||||
import _pytest.hookspec  # the extension point definitions
 | 
			
		||||
import _pytest.assertion
 | 
			
		||||
from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker
 | 
			
		||||
 | 
			
		||||
hookimpl = HookimplMarker("pytest")
 | 
			
		||||
| 
						 | 
				
			
			@ -154,6 +155,9 @@ class PytestPluginManager(PluginManager):
 | 
			
		|||
            self.trace.root.setwriter(err.write)
 | 
			
		||||
            self.enable_tracing()
 | 
			
		||||
 | 
			
		||||
        # Config._consider_importhook will set a real object if required.
 | 
			
		||||
        self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
 | 
			
		||||
 | 
			
		||||
    def addhooks(self, module_or_class):
 | 
			
		||||
        """
 | 
			
		||||
        .. deprecated:: 2.8
 | 
			
		||||
| 
						 | 
				
			
			@ -362,7 +366,9 @@ class PytestPluginManager(PluginManager):
 | 
			
		|||
        self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
 | 
			
		||||
 | 
			
		||||
    def consider_module(self, mod):
 | 
			
		||||
        self._import_plugin_specs(getattr(mod, "pytest_plugins", None))
 | 
			
		||||
        plugins = getattr(mod, 'pytest_plugins', [])
 | 
			
		||||
        self.rewrite_hook.mark_rewrite(*plugins)
 | 
			
		||||
        self._import_plugin_specs(plugins)
 | 
			
		||||
 | 
			
		||||
    def _import_plugin_specs(self, spec):
 | 
			
		||||
        if spec:
 | 
			
		||||
| 
						 | 
				
			
			@ -926,15 +932,13 @@ class Config(object):
 | 
			
		|||
        and find all the installed plugins to mark them for re-writing
 | 
			
		||||
        by the importhook.
 | 
			
		||||
        """
 | 
			
		||||
        import _pytest.assertion
 | 
			
		||||
        ns, unknown_args = self._parser.parse_known_and_unknown_args(args)
 | 
			
		||||
        mode = ns.assertmode
 | 
			
		||||
        if ns.noassert or ns.nomagic:
 | 
			
		||||
            mode = "plain"
 | 
			
		||||
        self._warn_about_missing_assertion(mode)
 | 
			
		||||
        if mode != 'plain':
 | 
			
		||||
            hook = _pytest.assertion.install_importhook(self, mode)
 | 
			
		||||
            if hook:
 | 
			
		||||
                self.pluginmanager.rewrite_hook = hook
 | 
			
		||||
                for entrypoint in pkg_resources.iter_entry_points('pytest11'):
 | 
			
		||||
                    for entry in entrypoint.dist._get_metadata('RECORD'):
 | 
			
		||||
                        fn = entry.split(',')[0]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,12 +63,44 @@ class TestImportHookInstallation:
 | 
			
		|||
            assert 0
 | 
			
		||||
        result.stdout.fnmatch_lines([expected])
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
 | 
			
		||||
    def test_pytest_plugins_rewrite(self, testdir, mode):
 | 
			
		||||
        contents = {
 | 
			
		||||
            'conftest.py': """
 | 
			
		||||
                pytest_plugins = ['ham']
 | 
			
		||||
            """,
 | 
			
		||||
            'ham.py': """
 | 
			
		||||
                import pytest
 | 
			
		||||
                @pytest.fixture
 | 
			
		||||
                def check_first():
 | 
			
		||||
                    def check(values, value):
 | 
			
		||||
                        assert values.pop(0) == value
 | 
			
		||||
                    return check
 | 
			
		||||
            """,
 | 
			
		||||
            'test_foo.py': """
 | 
			
		||||
                def test_foo(check_first):
 | 
			
		||||
                    check_first([10, 30], 30)
 | 
			
		||||
            """,
 | 
			
		||||
        }
 | 
			
		||||
        testdir.makepyfile(**contents)
 | 
			
		||||
        result = testdir.runpytest_subprocess('--assert=%s' % mode)
 | 
			
		||||
        if mode == 'plain':
 | 
			
		||||
            expected = 'E       AssertionError'
 | 
			
		||||
        elif mode == 'rewrite':
 | 
			
		||||
            expected = '*assert 10 == 30*'
 | 
			
		||||
        elif mode == 'reinterp':
 | 
			
		||||
            expected = '*AssertionError:*was re-run*'
 | 
			
		||||
        else:
 | 
			
		||||
            assert 0
 | 
			
		||||
        result.stdout.fnmatch_lines([expected])
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
 | 
			
		||||
    def test_installed_plugin_rewrite(self, testdir, mode):
 | 
			
		||||
        # Make sure the hook is installed early enough so that plugins
 | 
			
		||||
        # installed via setuptools are re-written.
 | 
			
		||||
        ham = testdir.tmpdir.join('hampkg').ensure(dir=1)
 | 
			
		||||
        ham.join('__init__.py').write("""
 | 
			
		||||
        testdir.tmpdir.join('hampkg').ensure(dir=1)
 | 
			
		||||
        contents = {
 | 
			
		||||
            'hampkg/__init__.py': """
 | 
			
		||||
                import pytest
 | 
			
		||||
 | 
			
		||||
                @pytest.fixture
 | 
			
		||||
| 
						 | 
				
			
			@ -76,9 +108,8 @@ def check_first2():
 | 
			
		|||
                    def check(values, value):
 | 
			
		||||
                        assert values.pop(0) == value
 | 
			
		||||
                    return check
 | 
			
		||||
        """)
 | 
			
		||||
        testdir.makepyfile(
 | 
			
		||||
            spamplugin="""
 | 
			
		||||
            """,
 | 
			
		||||
            'spamplugin.py': """
 | 
			
		||||
            import pytest
 | 
			
		||||
            from hampkg import check_first2
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -88,7 +119,7 @@ def check_first2():
 | 
			
		|||
                    assert values.pop(0) == value
 | 
			
		||||
                return check
 | 
			
		||||
            """,
 | 
			
		||||
            mainwrapper="""
 | 
			
		||||
            'mainwrapper.py': """
 | 
			
		||||
            import pytest, pkg_resources
 | 
			
		||||
 | 
			
		||||
            class DummyDistInfo:
 | 
			
		||||
| 
						 | 
				
			
			@ -116,14 +147,15 @@ def check_first2():
 | 
			
		|||
            pkg_resources.iter_entry_points = iter_entry_points
 | 
			
		||||
            pytest.main()
 | 
			
		||||
            """,
 | 
			
		||||
            test_foo="""
 | 
			
		||||
            'test_foo.py': """
 | 
			
		||||
            def test(check_first):
 | 
			
		||||
                check_first([10, 30], 30)
 | 
			
		||||
 | 
			
		||||
            def test2(check_first2):
 | 
			
		||||
                check_first([10, 30], 30)
 | 
			
		||||
            """,
 | 
			
		||||
        )
 | 
			
		||||
        }
 | 
			
		||||
        testdir.makepyfile(**contents)
 | 
			
		||||
        result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode)
 | 
			
		||||
        if mode == 'plain':
 | 
			
		||||
            expected = 'E       AssertionError'
 | 
			
		||||
| 
						 | 
				
			
			@ -135,6 +167,47 @@ def check_first2():
 | 
			
		|||
            assert 0
 | 
			
		||||
        result.stdout.fnmatch_lines([expected])
 | 
			
		||||
 | 
			
		||||
    def test_rewrite_ast(self, testdir):
 | 
			
		||||
        testdir.tmpdir.join('pkg').ensure(dir=1)
 | 
			
		||||
        contents = {
 | 
			
		||||
            'pkg/__init__.py': """
 | 
			
		||||
                import pytest
 | 
			
		||||
                pytest.register_assert_rewrite('pkg.helper')
 | 
			
		||||
            """,
 | 
			
		||||
            'pkg/helper.py': """
 | 
			
		||||
                def tool():
 | 
			
		||||
                    a, b = 2, 3
 | 
			
		||||
                    assert a == b
 | 
			
		||||
            """,
 | 
			
		||||
            'pkg/plugin.py': """
 | 
			
		||||
                import pytest, pkg.helper
 | 
			
		||||
                @pytest.fixture
 | 
			
		||||
                def tool():
 | 
			
		||||
                    return pkg.helper.tool
 | 
			
		||||
            """,
 | 
			
		||||
            'pkg/other.py': """
 | 
			
		||||
                l = [3, 2]
 | 
			
		||||
                def tool():
 | 
			
		||||
                    assert l.pop() == 3
 | 
			
		||||
            """,
 | 
			
		||||
            'conftest.py': """
 | 
			
		||||
                pytest_plugins = ['pkg.plugin']
 | 
			
		||||
            """,
 | 
			
		||||
            'test_pkg.py': """
 | 
			
		||||
                import pkg.other
 | 
			
		||||
                def test_tool(tool):
 | 
			
		||||
                    tool()
 | 
			
		||||
                def test_other():
 | 
			
		||||
                    pkg.other.tool()
 | 
			
		||||
            """,
 | 
			
		||||
        }
 | 
			
		||||
        testdir.makepyfile(**contents)
 | 
			
		||||
        result = testdir.runpytest_subprocess('--assert=rewrite')
 | 
			
		||||
        result.stdout.fnmatch_lines(['>*assert a == b*',
 | 
			
		||||
                                     'E*assert 2 == 3*',
 | 
			
		||||
                                     '>*assert l.pop() == 3*',
 | 
			
		||||
                                     'E*AssertionError*re-run*'])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestBinReprIntegration:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue