diff --git a/_pytest/assertion.py b/_pytest/assertion.py index d40981c32..2d760381f 100644 --- a/_pytest/assertion.py +++ b/_pytest/assertion.py @@ -2,9 +2,19 @@ support for presented detailed information in failing assertions. """ import py +import imp +import marshal +import struct import sys from _pytest.monkeypatch import monkeypatch +try: + from _pytest.assertrewrite import rewrite_asserts +except ImportError: + rewrite_asserts = None +else: + import ast + def pytest_addoption(parser): group = parser.getgroup("debugconfig") group._addoption('--no-assert', action="store_true", default=False, @@ -12,6 +22,7 @@ def pytest_addoption(parser): help="disable python assert expression reinterpretation."), def pytest_configure(config): + global rewrite_asserts # The _reprcompare attribute on the py.code module is used by # py._code._assertionnew to detect this plugin was loaded and in # turn call the hooks defined here as part of the @@ -29,6 +40,51 @@ def pytest_configure(config): m.setattr(py.builtin.builtins, 'AssertionError', py.code._AssertionError) m.setattr(py.code, '_reprcompare', callbinrepr) + else: + rewrite_asserts = None + +def pytest_pycollect_before_module_import(mod): + if rewrite_asserts is None: + return + # Some deep magic: load the source, rewrite the asserts, and write a + # fake pyc, so that it'll be loaded further down this function. + source = mod.fspath.read() + try: + tree = ast.parse(source) + except SyntaxError: + # Let this pop up again in the real import. + return + rewrite_asserts(tree) + try: + co = compile(tree, str(mod.fspath), "exec") + except SyntaxError: + # It's possible that this error is from some bug in the assertion + # rewriting, but I don't know of a fast way to tell. + return + if hasattr(imp, "cache_from_source"): + # Handle PEP 3147 pycs. + pyc = py.path(imp.cache_from_source(mod.fspath)) + pyc.dirname.ensure(dir=True) + else: + pyc = mod.fspath + "c" + mod._pyc = pyc + mtime = int(mod.fspath.mtime()) + fp = pyc.open("wb") + try: + fp.write(imp.get_magic()) + fp.write(struct.pack("