Merge pull request #1641 from flub/rewrite-plugins

Rewrite plugins
This commit is contained in:
Floris Bruynooghe 2016-07-14 19:39:15 +01:00 committed by GitHub
commit 24fbbbef1f
8 changed files with 355 additions and 96 deletions

View File

@ -143,6 +143,9 @@ time or change existing behaviors in order to make them less surprising/more use
**Changes** **Changes**
* Plugins now benefit from assertion rewriting. Thanks
`@sober7`_, `@nicoddemus`_ and `@flub`_ for the PR.
* Fixtures marked with ``@pytest.fixture`` can now use ``yield`` statements exactly like * Fixtures marked with ``@pytest.fixture`` can now use ``yield`` statements exactly like
those marked with the ``@pytest.yield_fixture`` decorator. This change renders those marked with the ``@pytest.yield_fixture`` decorator. This change renders
``@pytest.yield_fixture`` deprecated and makes ``@pytest.fixture`` with ``yield`` statements ``@pytest.yield_fixture`` deprecated and makes ``@pytest.fixture`` with ``yield`` statements

View File

@ -5,9 +5,8 @@ import py
import os import os
import sys import sys
from _pytest.config import hookimpl
from _pytest.monkeypatch import MonkeyPatch
from _pytest.assertion import util from _pytest.assertion import util
from _pytest.assertion import rewrite
def pytest_addoption(parser): def pytest_addoption(parser):
@ -27,6 +26,34 @@ def pytest_addoption(parser):
provide assert expression information. """) provide assert expression information. """)
def pytest_namespace():
return {'register_assert_rewrite': register_assert_rewrite}
def register_assert_rewrite(*names):
"""Register a module name to be rewritten on import.
This function will make sure that the module will get it's assert
statements rewritten when it is imported. Thus you should make
sure to call this before the module is actually imported, usually
in your __init__.py if you are a plugin using a package.
"""
for hook in sys.meta_path:
if isinstance(hook, rewrite.AssertionRewritingHook):
importhook = hook
break
else:
importhook = DummyRewriteHook()
importhook.mark_rewrite(*names)
class DummyRewriteHook(object):
"""A no-op import hook for when rewriting is disabled."""
def mark_rewrite(self, *names):
pass
class AssertionState: class AssertionState:
"""State for the assertion plugin.""" """State for the assertion plugin."""
@ -35,10 +62,7 @@ class AssertionState:
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
@hookimpl(tryfirst=True) def install_importhook(config, mode):
def pytest_load_initial_conftests(early_config, parser, args):
ns, ns_unknown_args = parser.parse_known_and_unknown_args(args)
mode = ns.assertmode
if mode == "rewrite": if mode == "rewrite":
try: try:
import ast # noqa import ast # noqa
@ -51,37 +75,38 @@ def pytest_load_initial_conftests(early_config, parser, args):
sys.version_info[:3] == (2, 6, 0)): sys.version_info[:3] == (2, 6, 0)):
mode = "reinterp" mode = "reinterp"
early_config._assertstate = AssertionState(early_config, mode) config._assertstate = AssertionState(config, mode)
warn_about_missing_assertion(mode, early_config.pluginmanager)
if mode != "plain": _load_modules(mode)
_load_modules(mode) from _pytest.monkeypatch import MonkeyPatch
m = MonkeyPatch() m = MonkeyPatch()
early_config._cleanup.append(m.undo) config._cleanup.append(m.undo)
m.setattr(py.builtin.builtins, 'AssertionError', m.setattr(py.builtin.builtins, 'AssertionError',
reinterpret.AssertionError) # noqa reinterpret.AssertionError) # noqa
hook = None hook = None
if mode == "rewrite": if mode == "rewrite":
hook = rewrite.AssertionRewritingHook(early_config) # noqa hook = rewrite.AssertionRewritingHook(config) # noqa
sys.meta_path.insert(0, hook) sys.meta_path.insert(0, hook)
early_config._assertstate.hook = hook config._assertstate.hook = hook
early_config._assertstate.trace("configured with mode set to %r" % (mode,)) config._assertstate.trace("configured with mode set to %r" % (mode,))
def undo(): def undo():
hook = early_config._assertstate.hook hook = config._assertstate.hook
if hook is not None and hook in sys.meta_path: if hook is not None and hook in sys.meta_path:
sys.meta_path.remove(hook) sys.meta_path.remove(hook)
early_config.add_cleanup(undo) config.add_cleanup(undo)
return hook
def pytest_collection(session): def pytest_collection(session):
# this hook is only called when test modules are collected # this hook is only called when test modules are collected
# so for example not in the master process of pytest-xdist # so for example not in the master process of pytest-xdist
# (which does not collect test modules) # (which does not collect test modules)
hook = session.config._assertstate.hook assertstate = getattr(session.config, '_assertstate', None)
if hook is not None: if assertstate:
hook.set_session(session) if assertstate.hook is not None:
assertstate.hook.set_session(session)
def _running_on_ci(): def _running_on_ci():
@ -138,9 +163,10 @@ def pytest_runtest_teardown(item):
def pytest_sessionfinish(session): def pytest_sessionfinish(session):
hook = session.config._assertstate.hook assertstate = getattr(session.config, '_assertstate', None)
if hook is not None: if assertstate:
hook.session = None if assertstate.hook is not None:
assertstate.hook.set_session(None)
def _load_modules(mode): def _load_modules(mode):
@ -151,31 +177,5 @@ def _load_modules(mode):
from _pytest.assertion import rewrite # noqa from _pytest.assertion import rewrite # noqa
def warn_about_missing_assertion(mode, pluginmanager):
try:
assert False
except AssertionError:
pass
else:
if mode == "rewrite":
specifically = ("assertions which are not in test modules "
"will be ignored")
else:
specifically = "failing tests may report as passing"
# temporarily disable capture so we can print our warning
capman = pluginmanager.getplugin('capturemanager')
try:
out, err = capman.suspendcapture()
sys.stderr.write("WARNING: " + specifically +
" because assert statements are not executed "
"by the underlying Python interpreter "
"(are you using python -O?)\n")
finally:
capman.resumecapture()
sys.stdout.write(out)
sys.stderr.write(err)
# Expose this plugin's implementation for the pytest_assertrepr_compare hook # Expose this plugin's implementation for the pytest_assertrepr_compare hook
pytest_assertrepr_compare = util.assertrepr_compare pytest_assertrepr_compare = util.assertrepr_compare

View File

@ -51,6 +51,7 @@ class AssertionRewritingHook(object):
self.session = None self.session = None
self.modules = {} self.modules = {}
self._register_with_pkg_resources() self._register_with_pkg_resources()
self._must_rewrite = set()
def set_session(self, session): def set_session(self, session):
self.session = session self.session = session
@ -87,7 +88,7 @@ class AssertionRewritingHook(object):
fn = os.path.join(pth, name.rpartition(".")[2] + ".py") fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
fn_pypath = py.path.local(fn) fn_pypath = py.path.local(fn)
if not self._should_rewrite(fn_pypath, state): if not self._should_rewrite(name, fn_pypath, state):
return None return None
# The requested module looks like a test file, so rewrite it. This is # The requested module looks like a test file, so rewrite it. This is
@ -137,7 +138,7 @@ class AssertionRewritingHook(object):
self.modules[name] = co, pyc self.modules[name] = co, pyc
return self return self
def _should_rewrite(self, fn_pypath, state): def _should_rewrite(self, name, fn_pypath, state):
# always rewrite conftest files # always rewrite conftest files
fn = str(fn_pypath) fn = str(fn_pypath)
if fn_pypath.basename == 'conftest.py': if fn_pypath.basename == 'conftest.py':
@ -161,8 +162,29 @@ class AssertionRewritingHook(object):
finally: finally:
self.session = session self.session = session
del session del session
else:
for marked in self._must_rewrite:
if marked.startswith(name):
return True
return False return False
def mark_rewrite(self, *names):
"""Mark import names as needing to be re-written.
The named module or package as well as any nested modules will
be re-written on import.
"""
already_imported = set(names).intersection(set(sys.modules))
if already_imported:
self._warn_already_imported(already_imported)
self._must_rewrite.update(names)
def _warn_already_imported(self, names):
self.config.warn(
'P1',
'Modules are already imported so can not be re-written: %s' %
','.join(names))
def load_module(self, name): def load_module(self, name):
# If there is an existing module object named 'fullname' in # If there is an existing module object named 'fullname' in
# sys.modules, the loader must use that existing module. (Otherwise, # sys.modules, the loader must use that existing module. (Otherwise,

View File

@ -5,11 +5,13 @@ import traceback
import types import types
import warnings import warnings
import pkg_resources
import py import py
# DON't import pytest here because it causes import cycle troubles # DON't import pytest here because it causes import cycle troubles
import sys, os import sys, os
import _pytest._code import _pytest._code
import _pytest.hookspec # the extension point definitions import _pytest.hookspec # the extension point definitions
import _pytest.assertion
from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker from _pytest._pluggy import PluginManager, HookimplMarker, HookspecMarker
hookimpl = HookimplMarker("pytest") hookimpl = HookimplMarker("pytest")
@ -160,6 +162,9 @@ class PytestPluginManager(PluginManager):
self.trace.root.setwriter(err.write) self.trace.root.setwriter(err.write)
self.enable_tracing() self.enable_tracing()
# Config._consider_importhook will set a real object if required.
self.rewrite_hook = _pytest.assertion.DummyRewriteHook()
def addhooks(self, module_or_class): def addhooks(self, module_or_class):
""" """
.. deprecated:: 2.8 .. deprecated:: 2.8
@ -368,7 +373,9 @@ class PytestPluginManager(PluginManager):
self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS")) self._import_plugin_specs(os.environ.get("PYTEST_PLUGINS"))
def consider_module(self, mod): def consider_module(self, mod):
self._import_plugin_specs(getattr(mod, "pytest_plugins", None)) plugins = getattr(mod, 'pytest_plugins', [])
self.rewrite_hook.mark_rewrite(*plugins)
self._import_plugin_specs(plugins)
def _import_plugin_specs(self, spec): def _import_plugin_specs(self, spec):
if spec: if spec:
@ -925,14 +932,58 @@ class Config(object):
self._parser.addini('addopts', 'extra command line options', 'args') self._parser.addini('addopts', 'extra command line options', 'args')
self._parser.addini('minversion', 'minimally required pytest version') self._parser.addini('minversion', 'minimally required pytest version')
def _consider_importhook(self, args, entrypoint_name):
"""Install the PEP 302 import hook if using assertion re-writing.
Needs to parse the --assert=<mode> option from the commandline
and find all the installed plugins to mark them for re-writing
by the importhook.
"""
ns, unknown_args = self._parser.parse_known_and_unknown_args(args)
mode = ns.assertmode
self._warn_about_missing_assertion(mode)
if mode != 'plain':
hook = _pytest.assertion.install_importhook(self, mode)
if hook:
self.pluginmanager.rewrite_hook = hook
for entrypoint in pkg_resources.iter_entry_points('pytest11'):
for entry in entrypoint.dist._get_metadata('RECORD'):
fn = entry.split(',')[0]
is_simple_module = os.sep not in fn and fn.endswith('.py')
is_package = fn.count(os.sep) == 1 and fn.endswith('__init__.py')
if is_simple_module:
module_name, ext = os.path.splitext(fn)
hook.mark_rewrite(module_name)
elif is_package:
package_name = os.path.dirname(fn)
hook.mark_rewrite(package_name)
def _warn_about_missing_assertion(self, mode):
try:
assert False
except AssertionError:
pass
else:
if mode == "rewrite":
specifically = ("assertions not in test modules or plugins"
"will be ignored")
else:
specifically = "failing tests may report as passing"
sys.stderr.write("WARNING: " + specifically +
" because assert statements are not executed "
"by the underlying Python interpreter "
"(are you using python -O?)\n")
def _preparse(self, args, addopts=True): def _preparse(self, args, addopts=True):
self._initini(args) self._initini(args)
if addopts: if addopts:
args[:] = shlex.split(os.environ.get('PYTEST_ADDOPTS', '')) + args args[:] = shlex.split(os.environ.get('PYTEST_ADDOPTS', '')) + args
args[:] = self.getini("addopts") + args args[:] = self.getini("addopts") + args
self._checkversion() self._checkversion()
entrypoint_name = 'pytest11'
self._consider_importhook(args, entrypoint_name)
self.pluginmanager.consider_preparse(args) self.pluginmanager.consider_preparse(args)
self.pluginmanager.load_setuptools_entrypoints("pytest11") self.pluginmanager.load_setuptools_entrypoints(entrypoint_name)
self.pluginmanager.consider_env() self.pluginmanager.consider_env()
self.known_args_namespace = ns = self._parser.parse_known_args(args, namespace=self.option.copy()) self.known_args_namespace = ns = self._parser.parse_known_args(args, namespace=self.option.copy())
if self.known_args_namespace.confcutdir is None and self.inifile: if self.known_args_namespace.confcutdir is None and self.inifile:

View File

@ -16,6 +16,7 @@ from _pytest._code import Source
import py import py
import pytest import pytest
from _pytest.main import Session, EXIT_OK from _pytest.main import Session, EXIT_OK
from _pytest.assertion.rewrite import AssertionRewritingHook
def pytest_addoption(parser): def pytest_addoption(parser):
@ -685,8 +686,17 @@ class Testdir:
``pytest.main()`` instance should use. ``pytest.main()`` instance should use.
:return: A :py:class:`HookRecorder` instance. :return: A :py:class:`HookRecorder` instance.
""" """
# When running py.test inline any plugins active in the main
# test process are already imported. So this disables the
# warning which will trigger to say they can no longer be
# re-written, which is fine as they are already re-written.
orig_warn = AssertionRewritingHook._warn_already_imported
def revert():
AssertionRewritingHook._warn_already_imported = orig_warn
self.request.addfinalizer(revert)
AssertionRewritingHook._warn_already_imported = lambda *a: None
rec = [] rec = []
class Collect: class Collect:
def pytest_configure(x, config): def pytest_configure(x, config):

View File

@ -26,6 +26,189 @@ def mock_config():
def interpret(expr): def interpret(expr):
return reinterpret.reinterpret(expr, _pytest._code.Frame(sys._getframe(1))) return reinterpret.reinterpret(expr, _pytest._code.Frame(sys._getframe(1)))
class TestImportHookInstallation:
@pytest.mark.parametrize('initial_conftest', [True, False])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode):
"""Test that conftest files are using assertion rewrite on import.
(#1619)
"""
testdir.tmpdir.join('foo/tests').ensure(dir=1)
conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py'
contents = {
conftest_path: """
import pytest
@pytest.fixture
def check_first():
def check(values, value):
assert values.pop(0) == value
return check
""",
'foo/tests/test_foo.py': """
def test(check_first):
check_first([10, 30], 30)
"""
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
elif mode == 'rewrite':
expected = '*assert 10 == 30*'
elif mode == 'reinterp':
expected = '*AssertionError:*was re-run*'
else:
assert 0
result.stdout.fnmatch_lines([expected])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_pytest_plugins_rewrite(self, testdir, mode):
contents = {
'conftest.py': """
pytest_plugins = ['ham']
""",
'ham.py': """
import pytest
@pytest.fixture
def check_first():
def check(values, value):
assert values.pop(0) == value
return check
""",
'test_foo.py': """
def test_foo(check_first):
check_first([10, 30], 30)
""",
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
elif mode == 'rewrite':
expected = '*assert 10 == 30*'
elif mode == 'reinterp':
expected = '*AssertionError:*was re-run*'
else:
assert 0
result.stdout.fnmatch_lines([expected])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_installed_plugin_rewrite(self, testdir, mode):
# Make sure the hook is installed early enough so that plugins
# installed via setuptools are re-written.
testdir.tmpdir.join('hampkg').ensure(dir=1)
contents = {
'hampkg/__init__.py': """
import pytest
@pytest.fixture
def check_first2():
def check(values, value):
assert values.pop(0) == value
return check
""",
'spamplugin.py': """
import pytest
from hampkg import check_first2
@pytest.fixture
def check_first():
def check(values, value):
assert values.pop(0) == value
return check
""",
'mainwrapper.py': """
import pytest, pkg_resources
class DummyDistInfo:
project_name = 'spam'
version = '1.0'
def _get_metadata(self, name):
return ['spamplugin.py,sha256=abc,123',
'hampkg/__init__.py,sha256=abc,123']
class DummyEntryPoint:
name = 'spam'
module_name = 'spam.py'
attrs = ()
extras = None
dist = DummyDistInfo()
def load(self, require=True, *args, **kwargs):
import spamplugin
return spamplugin
def iter_entry_points(name):
yield DummyEntryPoint()
pkg_resources.iter_entry_points = iter_entry_points
pytest.main()
""",
'test_foo.py': """
def test(check_first):
check_first([10, 30], 30)
def test2(check_first2):
check_first([10, 30], 30)
""",
}
testdir.makepyfile(**contents)
result = testdir.run(sys.executable, 'mainwrapper.py', '-s', '--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
elif mode == 'rewrite':
expected = '*assert 10 == 30*'
elif mode == 'reinterp':
expected = '*AssertionError:*was re-run*'
else:
assert 0
result.stdout.fnmatch_lines([expected])
def test_rewrite_ast(self, testdir):
testdir.tmpdir.join('pkg').ensure(dir=1)
contents = {
'pkg/__init__.py': """
import pytest
pytest.register_assert_rewrite('pkg.helper')
""",
'pkg/helper.py': """
def tool():
a, b = 2, 3
assert a == b
""",
'pkg/plugin.py': """
import pytest, pkg.helper
@pytest.fixture
def tool():
return pkg.helper.tool
""",
'pkg/other.py': """
l = [3, 2]
def tool():
assert l.pop() == 3
""",
'conftest.py': """
pytest_plugins = ['pkg.plugin']
""",
'test_pkg.py': """
import pkg.other
def test_tool(tool):
tool()
def test_other():
pkg.other.tool()
""",
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=rewrite')
result.stdout.fnmatch_lines(['>*assert a == b*',
'E*assert 2 == 3*',
'>*assert l.pop() == 3*',
'E*AssertionError*re-run*'])
class TestBinReprIntegration: class TestBinReprIntegration:
def test_pytest_assertrepr_compare_called(self, testdir): def test_pytest_assertrepr_compare_called(self, testdir):

View File

@ -12,7 +12,7 @@ if sys.platform.startswith("java"):
import _pytest._code import _pytest._code
from _pytest.assertion import util from _pytest.assertion import util
from _pytest.assertion.rewrite import rewrite_asserts, PYTEST_TAG from _pytest.assertion.rewrite import rewrite_asserts, PYTEST_TAG, AssertionRewritingHook
from _pytest.main import EXIT_NOTESTSCOLLECTED from _pytest.main import EXIT_NOTESTSCOLLECTED
@ -524,6 +524,16 @@ def test_rewritten():
testdir.makepyfile("import a_package_without_init_py.module") testdir.makepyfile("import a_package_without_init_py.module")
assert testdir.runpytest().ret == EXIT_NOTESTSCOLLECTED assert testdir.runpytest().ret == EXIT_NOTESTSCOLLECTED
def test_rewrite_warning(self, pytestconfig, monkeypatch):
hook = AssertionRewritingHook(pytestconfig)
warnings = []
def mywarn(code, msg):
warnings.append((code, msg))
monkeypatch.setattr(hook.config, 'warn', mywarn)
hook.mark_rewrite('_pytest')
assert '_pytest' in warnings[0][1]
class TestAssertionRewriteHookDetails(object): class TestAssertionRewriteHookDetails(object):
def test_loader_is_package_false_for_module(self, testdir): def test_loader_is_package_false_for_module(self, testdir):
testdir.makepyfile(test_fun=""" testdir.makepyfile(test_fun="""
@ -704,40 +714,6 @@ class TestAssertionRewriteHookDetails(object):
result = testdir.runpytest() result = testdir.runpytest()
result.stdout.fnmatch_lines('*1 passed*') result.stdout.fnmatch_lines('*1 passed*')
@pytest.mark.parametrize('initial_conftest', [True, False])
@pytest.mark.parametrize('mode', ['plain', 'rewrite', 'reinterp'])
def test_conftest_assertion_rewrite(self, testdir, initial_conftest, mode):
"""Test that conftest files are using assertion rewrite on import.
(#1619)
"""
testdir.tmpdir.join('foo/tests').ensure(dir=1)
conftest_path = 'conftest.py' if initial_conftest else 'foo/conftest.py'
contents = {
conftest_path: """
import pytest
@pytest.fixture
def check_first():
def check(values, value):
assert values.pop(0) == value
return check
""",
'foo/tests/test_foo.py': """
def test(check_first):
check_first([10, 30], 30)
"""
}
testdir.makepyfile(**contents)
result = testdir.runpytest_subprocess('--assert=%s' % mode)
if mode == 'plain':
expected = 'E AssertionError'
elif mode == 'rewrite':
expected = '*assert 10 == 30*'
elif mode == 'reinterp':
expected = '*AssertionError:*was re-run*'
else:
assert 0
result.stdout.fnmatch_lines([expected])
def test_issue731(testdir): def test_issue731(testdir):
testdir.makepyfile(""" testdir.makepyfile("""

View File

@ -373,10 +373,14 @@ def test_preparse_ordering_with_setuptools(testdir, monkeypatch):
pkg_resources = pytest.importorskip("pkg_resources") pkg_resources = pytest.importorskip("pkg_resources")
def my_iter(name): def my_iter(name):
assert name == "pytest11" assert name == "pytest11"
class Dist:
project_name = 'spam'
version = '1.0'
def _get_metadata(self, name):
return ['foo.txt,sha256=abc,123']
class EntryPoint: class EntryPoint:
name = "mytestplugin" name = "mytestplugin"
class dist: dist = Dist()
pass
def load(self): def load(self):
class PseudoPlugin: class PseudoPlugin:
x = 42 x = 42
@ -396,9 +400,14 @@ def test_setuptools_importerror_issue1479(testdir, monkeypatch):
pkg_resources = pytest.importorskip("pkg_resources") pkg_resources = pytest.importorskip("pkg_resources")
def my_iter(name): def my_iter(name):
assert name == "pytest11" assert name == "pytest11"
class Dist:
project_name = 'spam'
version = '1.0'
def _get_metadata(self, name):
return ['foo.txt,sha256=abc,123']
class EntryPoint: class EntryPoint:
name = "mytestplugin" name = "mytestplugin"
dist = None dist = Dist()
def load(self): def load(self):
raise ImportError("Don't hide me!") raise ImportError("Don't hide me!")
return iter([EntryPoint()]) return iter([EntryPoint()])
@ -412,8 +421,14 @@ def test_plugin_preparse_prevents_setuptools_loading(testdir, monkeypatch):
pkg_resources = pytest.importorskip("pkg_resources") pkg_resources = pytest.importorskip("pkg_resources")
def my_iter(name): def my_iter(name):
assert name == "pytest11" assert name == "pytest11"
class Dist:
project_name = 'spam'
version = '1.0'
def _get_metadata(self, name):
return ['foo.txt,sha256=abc,123']
class EntryPoint: class EntryPoint:
name = "mytestplugin" name = "mytestplugin"
dist = Dist()
def load(self): def load(self):
assert 0, "should not arrive here" assert 0, "should not arrive here"
return iter([EntryPoint()]) return iter([EntryPoint()])
@ -505,7 +520,6 @@ def test_load_initial_conftest_last_ordering(testdir):
expected = [ expected = [
"_pytest.config", "_pytest.config",
'test_config', 'test_config',
'_pytest.assertion',
'_pytest.capture', '_pytest.capture',
] ]
assert [x.function.__module__ for x in l] == expected assert [x.function.__module__ for x in l] == expected