Merge pull request #5468 from asottile/switch_importlib_to_imp
Switch from deprecated imp to importlib
This commit is contained in:
		
						commit
						61dcb84f0d
					
				| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Switch from ``imp`` to ``importlib``.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Honor PEP 235 on case-insensitive file systems.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Test module is no longer double-imported when using ``--pyargs``.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Prevent "already imported" warnings from assertion rewriter when invoking pytest in-process multiple times.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Fix assertion rewriting in packages (``__init__.py``).
 | 
			
		||||
| 
						 | 
				
			
			@ -1,18 +1,16 @@
 | 
			
		|||
"""Rewrite assertion AST to produce nice error messages"""
 | 
			
		||||
import ast
 | 
			
		||||
import errno
 | 
			
		||||
import imp
 | 
			
		||||
import importlib.machinery
 | 
			
		||||
import importlib.util
 | 
			
		||||
import itertools
 | 
			
		||||
import marshal
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import struct
 | 
			
		||||
import sys
 | 
			
		||||
import types
 | 
			
		||||
from importlib.util import spec_from_file_location
 | 
			
		||||
 | 
			
		||||
import atomicwrites
 | 
			
		||||
import py
 | 
			
		||||
 | 
			
		||||
from _pytest._io.saferepr import saferepr
 | 
			
		||||
from _pytest.assertion import util
 | 
			
		||||
| 
						 | 
				
			
			@ -23,23 +21,13 @@ from _pytest.pathlib import fnmatch_ex
 | 
			
		|||
from _pytest.pathlib import PurePath
 | 
			
		||||
 | 
			
		||||
# pytest caches rewritten pycs in __pycache__.
 | 
			
		||||
if hasattr(imp, "get_tag"):
 | 
			
		||||
    PYTEST_TAG = imp.get_tag() + "-PYTEST"
 | 
			
		||||
else:
 | 
			
		||||
    if hasattr(sys, "pypy_version_info"):
 | 
			
		||||
        impl = "pypy"
 | 
			
		||||
    else:
 | 
			
		||||
        impl = "cpython"
 | 
			
		||||
    ver = sys.version_info
 | 
			
		||||
    PYTEST_TAG = "{}-{}{}-PYTEST".format(impl, ver[0], ver[1])
 | 
			
		||||
    del ver, impl
 | 
			
		||||
 | 
			
		||||
PYTEST_TAG = "{}-PYTEST".format(sys.implementation.cache_tag)
 | 
			
		||||
PYC_EXT = ".py" + (__debug__ and "c" or "o")
 | 
			
		||||
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AssertionRewritingHook:
 | 
			
		||||
    """PEP302 Import hook which rewrites asserts."""
 | 
			
		||||
    """PEP302/PEP451 import hook which rewrites asserts."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        self.config = config
 | 
			
		||||
| 
						 | 
				
			
			@ -48,7 +36,6 @@ class AssertionRewritingHook:
 | 
			
		|||
        except ValueError:
 | 
			
		||||
            self.fnpats = ["test_*.py", "*_test.py"]
 | 
			
		||||
        self.session = None
 | 
			
		||||
        self.modules = {}
 | 
			
		||||
        self._rewritten_names = set()
 | 
			
		||||
        self._must_rewrite = set()
 | 
			
		||||
        # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
 | 
			
		||||
| 
						 | 
				
			
			@ -62,55 +49,51 @@ class AssertionRewritingHook:
 | 
			
		|||
        self.session = session
 | 
			
		||||
        self._session_paths_checked = False
 | 
			
		||||
 | 
			
		||||
    def _imp_find_module(self, name, path=None):
 | 
			
		||||
        """Indirection so we can mock calls to find_module originated from the hook during testing"""
 | 
			
		||||
        return imp.find_module(name, path)
 | 
			
		||||
    # Indirection so we can mock calls to find_spec originated from the hook during testing
 | 
			
		||||
    _find_spec = importlib.machinery.PathFinder.find_spec
 | 
			
		||||
 | 
			
		||||
    def find_module(self, name, path=None):
 | 
			
		||||
    def find_spec(self, name, path=None, target=None):
 | 
			
		||||
        if self._writing_pyc:
 | 
			
		||||
            return None
 | 
			
		||||
        state = self.config._assertstate
 | 
			
		||||
        if self._early_rewrite_bailout(name, state):
 | 
			
		||||
            return None
 | 
			
		||||
        state.trace("find_module called for: %s" % name)
 | 
			
		||||
        names = name.rsplit(".", 1)
 | 
			
		||||
        lastname = names[-1]
 | 
			
		||||
        pth = None
 | 
			
		||||
        if path is not None:
 | 
			
		||||
            # Starting with Python 3.3, path is a _NamespacePath(), which
 | 
			
		||||
            # causes problems if not converted to list.
 | 
			
		||||
            path = list(path)
 | 
			
		||||
            if len(path) == 1:
 | 
			
		||||
                pth = path[0]
 | 
			
		||||
        if pth is None:
 | 
			
		||||
            try:
 | 
			
		||||
                fd, fn, desc = self._imp_find_module(lastname, path)
 | 
			
		||||
            except ImportError:
 | 
			
		||||
                return None
 | 
			
		||||
            if fd is not None:
 | 
			
		||||
                fd.close()
 | 
			
		||||
            tp = desc[2]
 | 
			
		||||
            if tp == imp.PY_COMPILED:
 | 
			
		||||
                if hasattr(imp, "source_from_cache"):
 | 
			
		||||
                    try:
 | 
			
		||||
                        fn = imp.source_from_cache(fn)
 | 
			
		||||
                    except ValueError:
 | 
			
		||||
                        # Python 3 doesn't like orphaned but still-importable
 | 
			
		||||
                        # .pyc files.
 | 
			
		||||
                        fn = fn[:-1]
 | 
			
		||||
                else:
 | 
			
		||||
                    fn = fn[:-1]
 | 
			
		||||
            elif tp != imp.PY_SOURCE:
 | 
			
		||||
                # Don't know what this is.
 | 
			
		||||
 | 
			
		||||
        spec = self._find_spec(name, path)
 | 
			
		||||
        if (
 | 
			
		||||
            # the import machinery could not find a file to import
 | 
			
		||||
            spec is None
 | 
			
		||||
            # this is a namespace package (without `__init__.py`)
 | 
			
		||||
            # there's nothing to rewrite there
 | 
			
		||||
            # python3.5 - python3.6: `namespace`
 | 
			
		||||
            # python3.7+: `None`
 | 
			
		||||
            or spec.origin in {None, "namespace"}
 | 
			
		||||
            # if the file doesn't exist, we can't rewrite it
 | 
			
		||||
            or not os.path.exists(spec.origin)
 | 
			
		||||
        ):
 | 
			
		||||
            return None
 | 
			
		||||
        else:
 | 
			
		||||
            fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
 | 
			
		||||
            fn = spec.origin
 | 
			
		||||
 | 
			
		||||
        fn_pypath = py.path.local(fn)
 | 
			
		||||
        if not self._should_rewrite(name, fn_pypath, state):
 | 
			
		||||
        if not self._should_rewrite(name, fn, state):
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        self._rewritten_names.add(name)
 | 
			
		||||
        return importlib.util.spec_from_file_location(
 | 
			
		||||
            name,
 | 
			
		||||
            fn,
 | 
			
		||||
            loader=self,
 | 
			
		||||
            submodule_search_locations=spec.submodule_search_locations,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def create_module(self, spec):
 | 
			
		||||
        return None  # default behaviour is fine
 | 
			
		||||
 | 
			
		||||
    def exec_module(self, module):
 | 
			
		||||
        fn = module.__spec__.origin
 | 
			
		||||
        state = self.config._assertstate
 | 
			
		||||
 | 
			
		||||
        self._rewritten_names.add(module.__name__)
 | 
			
		||||
 | 
			
		||||
        # The requested module looks like a test file, so rewrite it. This is
 | 
			
		||||
        # the most magical part of the process: load the source, rewrite the
 | 
			
		||||
| 
						 | 
				
			
			@ -121,7 +104,7 @@ class AssertionRewritingHook:
 | 
			
		|||
        # cached pyc is always a complete, valid pyc. Operations on it must be
 | 
			
		||||
        # atomic. POSIX's atomic rename comes in handy.
 | 
			
		||||
        write = not sys.dont_write_bytecode
 | 
			
		||||
        cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
 | 
			
		||||
        cache_dir = os.path.join(os.path.dirname(fn), "__pycache__")
 | 
			
		||||
        if write:
 | 
			
		||||
            try:
 | 
			
		||||
                os.mkdir(cache_dir)
 | 
			
		||||
| 
						 | 
				
			
			@ -132,26 +115,23 @@ class AssertionRewritingHook:
 | 
			
		|||
                    # common case) or it's blocked by a non-dir node. In the
 | 
			
		||||
                    # latter case, we'll ignore it in _write_pyc.
 | 
			
		||||
                    pass
 | 
			
		||||
                elif e in [errno.ENOENT, errno.ENOTDIR]:
 | 
			
		||||
                elif e in {errno.ENOENT, errno.ENOTDIR}:
 | 
			
		||||
                    # One of the path components was not a directory, likely
 | 
			
		||||
                    # because we're in a zip file.
 | 
			
		||||
                    write = False
 | 
			
		||||
                elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
 | 
			
		||||
                    state.trace("read only directory: %r" % fn_pypath.dirname)
 | 
			
		||||
                elif e in {errno.EACCES, errno.EROFS, errno.EPERM}:
 | 
			
		||||
                    state.trace("read only directory: %r" % os.path.dirname(fn))
 | 
			
		||||
                    write = False
 | 
			
		||||
                else:
 | 
			
		||||
                    raise
 | 
			
		||||
        cache_name = fn_pypath.basename[:-3] + PYC_TAIL
 | 
			
		||||
        cache_name = os.path.basename(fn)[:-3] + PYC_TAIL
 | 
			
		||||
        pyc = os.path.join(cache_dir, cache_name)
 | 
			
		||||
        # Notice that even if we're in a read-only directory, I'm going
 | 
			
		||||
        # to check for a cached pyc. This may not be optimal...
 | 
			
		||||
        co = _read_pyc(fn_pypath, pyc, state.trace)
 | 
			
		||||
        co = _read_pyc(fn, pyc, state.trace)
 | 
			
		||||
        if co is None:
 | 
			
		||||
            state.trace("rewriting {!r}".format(fn))
 | 
			
		||||
            source_stat, co = _rewrite_test(self.config, fn_pypath)
 | 
			
		||||
            if co is None:
 | 
			
		||||
                # Probably a SyntaxError in the test.
 | 
			
		||||
                return None
 | 
			
		||||
            source_stat, co = _rewrite_test(fn)
 | 
			
		||||
            if write:
 | 
			
		||||
                self._writing_pyc = True
 | 
			
		||||
                try:
 | 
			
		||||
| 
						 | 
				
			
			@ -160,13 +140,11 @@ class AssertionRewritingHook:
 | 
			
		|||
                    self._writing_pyc = False
 | 
			
		||||
        else:
 | 
			
		||||
            state.trace("found cached rewritten pyc for {!r}".format(fn))
 | 
			
		||||
        self.modules[name] = co, pyc
 | 
			
		||||
        return self
 | 
			
		||||
        exec(co, module.__dict__)
 | 
			
		||||
 | 
			
		||||
    def _early_rewrite_bailout(self, name, state):
 | 
			
		||||
        """
 | 
			
		||||
        This is a fast way to get out of rewriting modules. Profiling has
 | 
			
		||||
        shown that the call to imp.find_module (inside of the find_module
 | 
			
		||||
        """This is a fast way to get out of rewriting modules. Profiling has
 | 
			
		||||
        shown that the call to PathFinder.find_spec (inside of the find_spec
 | 
			
		||||
        from this class) is a major slowdown, so, this method tries to
 | 
			
		||||
        filter what we're sure won't be rewritten before getting to it.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -201,10 +179,9 @@ class AssertionRewritingHook:
 | 
			
		|||
        state.trace("early skip of rewriting module: {}".format(name))
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def _should_rewrite(self, name, fn_pypath, state):
 | 
			
		||||
    def _should_rewrite(self, name, fn, state):
 | 
			
		||||
        # always rewrite conftest files
 | 
			
		||||
        fn = str(fn_pypath)
 | 
			
		||||
        if fn_pypath.basename == "conftest.py":
 | 
			
		||||
        if os.path.basename(fn) == "conftest.py":
 | 
			
		||||
            state.trace("rewriting conftest file: {!r}".format(fn))
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -217,8 +194,9 @@ class AssertionRewritingHook:
 | 
			
		|||
 | 
			
		||||
        # modules not passed explicitly on the command line are only
 | 
			
		||||
        # rewritten if they match the naming convention for test files
 | 
			
		||||
        fn_path = PurePath(fn)
 | 
			
		||||
        for pat in self.fnpats:
 | 
			
		||||
            if fn_pypath.fnmatch(pat):
 | 
			
		||||
            if fnmatch_ex(pat, fn_path):
 | 
			
		||||
                state.trace("matched test file {!r}".format(fn))
 | 
			
		||||
                return True
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -249,9 +227,10 @@ class AssertionRewritingHook:
 | 
			
		|||
            set(names).intersection(sys.modules).difference(self._rewritten_names)
 | 
			
		||||
        )
 | 
			
		||||
        for name in already_imported:
 | 
			
		||||
            mod = sys.modules[name]
 | 
			
		||||
            if not AssertionRewriter.is_rewrite_disabled(
 | 
			
		||||
                sys.modules[name].__doc__ or ""
 | 
			
		||||
            ):
 | 
			
		||||
                mod.__doc__ or ""
 | 
			
		||||
            ) and not isinstance(mod.__loader__, type(self)):
 | 
			
		||||
                self._warn_already_imported(name)
 | 
			
		||||
        self._must_rewrite.update(names)
 | 
			
		||||
        self._marked_for_rewrite_cache.clear()
 | 
			
		||||
| 
						 | 
				
			
			@ -268,45 +247,8 @@ class AssertionRewritingHook:
 | 
			
		|||
            stacklevel=5,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def load_module(self, name):
 | 
			
		||||
        co, pyc = self.modules.pop(name)
 | 
			
		||||
        if name in sys.modules:
 | 
			
		||||
            # If there is an existing module object named 'fullname' in
 | 
			
		||||
            # sys.modules, the loader must use that existing module. (Otherwise,
 | 
			
		||||
            # the reload() builtin will not work correctly.)
 | 
			
		||||
            mod = sys.modules[name]
 | 
			
		||||
        else:
 | 
			
		||||
            # I wish I could just call imp.load_compiled here, but __file__ has to
 | 
			
		||||
            # be set properly. In Python 3.2+, this all would be handled correctly
 | 
			
		||||
            # by load_compiled.
 | 
			
		||||
            mod = sys.modules[name] = imp.new_module(name)
 | 
			
		||||
        try:
 | 
			
		||||
            mod.__file__ = co.co_filename
 | 
			
		||||
            # Normally, this attribute is 3.2+.
 | 
			
		||||
            mod.__cached__ = pyc
 | 
			
		||||
            mod.__loader__ = self
 | 
			
		||||
            # Normally, this attribute is 3.4+
 | 
			
		||||
            mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
 | 
			
		||||
            exec(co, mod.__dict__)
 | 
			
		||||
        except:  # noqa
 | 
			
		||||
            if name in sys.modules:
 | 
			
		||||
                del sys.modules[name]
 | 
			
		||||
            raise
 | 
			
		||||
        return sys.modules[name]
 | 
			
		||||
 | 
			
		||||
    def is_package(self, name):
 | 
			
		||||
        try:
 | 
			
		||||
            fd, fn, desc = self._imp_find_module(name)
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            return False
 | 
			
		||||
        if fd is not None:
 | 
			
		||||
            fd.close()
 | 
			
		||||
        tp = desc[2]
 | 
			
		||||
        return tp == imp.PKG_DIRECTORY
 | 
			
		||||
 | 
			
		||||
    def get_data(self, pathname):
 | 
			
		||||
        """Optional PEP302 get_data API.
 | 
			
		||||
        """
 | 
			
		||||
        """Optional PEP302 get_data API."""
 | 
			
		||||
        with open(pathname, "rb") as f:
 | 
			
		||||
            return f.read()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -314,15 +256,13 @@ class AssertionRewritingHook:
 | 
			
		|||
def _write_pyc(state, co, source_stat, pyc):
 | 
			
		||||
    # Technically, we don't have to have the same pyc format as
 | 
			
		||||
    # (C)Python, since these "pycs" should never be seen by builtin
 | 
			
		||||
    # import. However, there's little reason deviate, and I hope
 | 
			
		||||
    # sometime to be able to use imp.load_compiled to load them. (See
 | 
			
		||||
    # the comment in load_module above.)
 | 
			
		||||
    # import. However, there's little reason deviate.
 | 
			
		||||
    try:
 | 
			
		||||
        with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
 | 
			
		||||
            fp.write(imp.get_magic())
 | 
			
		||||
            fp.write(importlib.util.MAGIC_NUMBER)
 | 
			
		||||
            # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
 | 
			
		||||
            mtime = int(source_stat.mtime) & 0xFFFFFFFF
 | 
			
		||||
            size = source_stat.size & 0xFFFFFFFF
 | 
			
		||||
            mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
 | 
			
		||||
            size = source_stat.st_size & 0xFFFFFFFF
 | 
			
		||||
            # "<LL" stands for 2 unsigned longs, little-ending
 | 
			
		||||
            fp.write(struct.pack("<LL", mtime, size))
 | 
			
		||||
            fp.write(marshal.dumps(co))
 | 
			
		||||
| 
						 | 
				
			
			@ -335,35 +275,14 @@ def _write_pyc(state, co, source_stat, pyc):
 | 
			
		|||
    return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
RN = b"\r\n"
 | 
			
		||||
N = b"\n"
 | 
			
		||||
 | 
			
		||||
cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
 | 
			
		||||
BOM_UTF8 = "\xef\xbb\xbf"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rewrite_test(config, fn):
 | 
			
		||||
    """Try to read and rewrite *fn* and return the code object."""
 | 
			
		||||
    state = config._assertstate
 | 
			
		||||
    try:
 | 
			
		||||
        stat = fn.stat()
 | 
			
		||||
        source = fn.read("rb")
 | 
			
		||||
    except EnvironmentError:
 | 
			
		||||
        return None, None
 | 
			
		||||
    try:
 | 
			
		||||
        tree = ast.parse(source, filename=fn.strpath)
 | 
			
		||||
    except SyntaxError:
 | 
			
		||||
        # Let this pop up again in the real import.
 | 
			
		||||
        state.trace("failed to parse: {!r}".format(fn))
 | 
			
		||||
        return None, None
 | 
			
		||||
    rewrite_asserts(tree, fn, config)
 | 
			
		||||
    try:
 | 
			
		||||
        co = compile(tree, fn.strpath, "exec", dont_inherit=True)
 | 
			
		||||
    except SyntaxError:
 | 
			
		||||
        # It's possible that this error is from some bug in the
 | 
			
		||||
        # assertion rewriting, but I don't know of a fast way to tell.
 | 
			
		||||
        state.trace("failed to compile: {!r}".format(fn))
 | 
			
		||||
        return None, None
 | 
			
		||||
def _rewrite_test(fn):
 | 
			
		||||
    """read and rewrite *fn* and return the code object."""
 | 
			
		||||
    stat = os.stat(fn)
 | 
			
		||||
    with open(fn, "rb") as f:
 | 
			
		||||
        source = f.read()
 | 
			
		||||
    tree = ast.parse(source, filename=fn)
 | 
			
		||||
    rewrite_asserts(tree, fn)
 | 
			
		||||
    co = compile(tree, fn, "exec", dont_inherit=True)
 | 
			
		||||
    return stat, co
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -378,8 +297,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
 | 
			
		|||
        return None
 | 
			
		||||
    with fp:
 | 
			
		||||
        try:
 | 
			
		||||
            mtime = int(source.mtime())
 | 
			
		||||
            size = source.size()
 | 
			
		||||
            stat_result = os.stat(source)
 | 
			
		||||
            mtime = int(stat_result.st_mtime)
 | 
			
		||||
            size = stat_result.st_size
 | 
			
		||||
            data = fp.read(12)
 | 
			
		||||
        except EnvironmentError as e:
 | 
			
		||||
            trace("_read_pyc({}): EnvironmentError {}".format(source, e))
 | 
			
		||||
| 
						 | 
				
			
			@ -387,7 +307,7 @@ def _read_pyc(source, pyc, trace=lambda x: None):
 | 
			
		|||
        # Check for invalid or out of date pyc file.
 | 
			
		||||
        if (
 | 
			
		||||
            len(data) != 12
 | 
			
		||||
            or data[:4] != imp.get_magic()
 | 
			
		||||
            or data[:4] != importlib.util.MAGIC_NUMBER
 | 
			
		||||
            or struct.unpack("<LL", data[4:]) != (mtime & 0xFFFFFFFF, size & 0xFFFFFFFF)
 | 
			
		||||
        ):
 | 
			
		||||
            trace("_read_pyc(%s): invalid or out of date pyc" % source)
 | 
			
		||||
| 
						 | 
				
			
			@ -403,9 +323,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
 | 
			
		|||
        return co
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rewrite_asserts(mod, module_path=None, config=None):
 | 
			
		||||
def rewrite_asserts(mod, module_path=None):
 | 
			
		||||
    """Rewrite the assert statements in mod."""
 | 
			
		||||
    AssertionRewriter(module_path, config).run(mod)
 | 
			
		||||
    AssertionRewriter(module_path).run(mod)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _saferepr(obj):
 | 
			
		||||
| 
						 | 
				
			
			@ -586,10 +506,9 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
			
		|||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, module_path, config):
 | 
			
		||||
    def __init__(self, module_path):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.module_path = module_path
 | 
			
		||||
        self.config = config
 | 
			
		||||
 | 
			
		||||
    def run(self, mod):
 | 
			
		||||
        """Find all assert statements in *mod* and rewrite them."""
 | 
			
		||||
| 
						 | 
				
			
			@ -758,7 +677,7 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
			
		|||
                    "assertion is always true, perhaps remove parentheses?"
 | 
			
		||||
                ),
 | 
			
		||||
                category=None,
 | 
			
		||||
                filename=str(self.module_path),
 | 
			
		||||
                filename=self.module_path,
 | 
			
		||||
                lineno=assert_.lineno,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -817,7 +736,7 @@ class AssertionRewriter(ast.NodeVisitor):
 | 
			
		|||
        AST_NONE = ast.parse("None").body[0].value
 | 
			
		||||
        val_is_none = ast.Compare(node, [ast.Is()], [AST_NONE])
 | 
			
		||||
        send_warning = ast.parse(
 | 
			
		||||
            """
 | 
			
		||||
            """\
 | 
			
		||||
from _pytest.warning_types import PytestAssertRewriteWarning
 | 
			
		||||
from warnings import warn_explicit
 | 
			
		||||
warn_explicit(
 | 
			
		||||
| 
						 | 
				
			
			@ -827,7 +746,7 @@ warn_explicit(
 | 
			
		|||
    lineno={lineno},
 | 
			
		||||
)
 | 
			
		||||
            """.format(
 | 
			
		||||
                filename=module_path.strpath, lineno=lineno
 | 
			
		||||
                filename=module_path, lineno=lineno
 | 
			
		||||
            )
 | 
			
		||||
        ).body
 | 
			
		||||
        return ast.If(val_is_none, send_warning, [])
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -2,8 +2,8 @@
 | 
			
		|||
import enum
 | 
			
		||||
import fnmatch
 | 
			
		||||
import functools
 | 
			
		||||
import importlib
 | 
			
		||||
import os
 | 
			
		||||
import pkgutil
 | 
			
		||||
import sys
 | 
			
		||||
import warnings
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -630,21 +630,15 @@ class Session(nodes.FSCollector):
 | 
			
		|||
    def _tryconvertpyarg(self, x):
 | 
			
		||||
        """Convert a dotted module name to path."""
 | 
			
		||||
        try:
 | 
			
		||||
            loader = pkgutil.find_loader(x)
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            spec = importlib.util.find_spec(x)
 | 
			
		||||
        except (ValueError, ImportError):
 | 
			
		||||
            return x
 | 
			
		||||
        if loader is None:
 | 
			
		||||
        if spec is None or spec.origin in {None, "namespace"}:
 | 
			
		||||
            return x
 | 
			
		||||
        # This method is sometimes invoked when AssertionRewritingHook, which
 | 
			
		||||
        # does not define a get_filename method, is already in place:
 | 
			
		||||
        try:
 | 
			
		||||
            path = loader.get_filename(x)
 | 
			
		||||
        except AttributeError:
 | 
			
		||||
            # Retrieve path from AssertionRewritingHook:
 | 
			
		||||
            path = loader.modules[x][0].co_filename
 | 
			
		||||
        if loader.is_package(x):
 | 
			
		||||
            path = os.path.dirname(path)
 | 
			
		||||
        return path
 | 
			
		||||
        elif spec.submodule_search_locations:
 | 
			
		||||
            return os.path.dirname(spec.origin)
 | 
			
		||||
        else:
 | 
			
		||||
            return spec.origin
 | 
			
		||||
 | 
			
		||||
    def _parsearg(self, arg):
 | 
			
		||||
        """ return (fspath, names) tuple after checking the file exists. """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -294,6 +294,8 @@ def fnmatch_ex(pattern, path):
 | 
			
		|||
        name = path.name
 | 
			
		||||
    else:
 | 
			
		||||
        name = str(path)
 | 
			
		||||
        if path.is_absolute() and not os.path.isabs(pattern):
 | 
			
		||||
            pattern = "*{}{}".format(os.sep, pattern)
 | 
			
		||||
    return fnmatch.fnmatch(name, pattern)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
"""(disabled by default) support for testing pytest and pytest plugins."""
 | 
			
		||||
import gc
 | 
			
		||||
import importlib
 | 
			
		||||
import os
 | 
			
		||||
import platform
 | 
			
		||||
import re
 | 
			
		||||
| 
						 | 
				
			
			@ -16,7 +17,6 @@ import py
 | 
			
		|||
import pytest
 | 
			
		||||
from _pytest._code import Source
 | 
			
		||||
from _pytest._io.saferepr import saferepr
 | 
			
		||||
from _pytest.assertion.rewrite import AssertionRewritingHook
 | 
			
		||||
from _pytest.capture import MultiCapture
 | 
			
		||||
from _pytest.capture import SysCapture
 | 
			
		||||
from _pytest.main import ExitCode
 | 
			
		||||
| 
						 | 
				
			
			@ -787,6 +787,11 @@ class Testdir:
 | 
			
		|||
 | 
			
		||||
        :return: a :py:class:`HookRecorder` instance
 | 
			
		||||
        """
 | 
			
		||||
        # (maybe a cpython bug?) the importlib cache sometimes isn't updated
 | 
			
		||||
        # properly between file creation and inline_run (especially if imports
 | 
			
		||||
        # are interspersed with file creation)
 | 
			
		||||
        importlib.invalidate_caches()
 | 
			
		||||
 | 
			
		||||
        plugins = list(plugins)
 | 
			
		||||
        finalizers = []
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -796,18 +801,6 @@ class Testdir:
 | 
			
		|||
                mp_run.setenv(k, v)
 | 
			
		||||
            finalizers.append(mp_run.undo)
 | 
			
		||||
 | 
			
		||||
            # When running pytest inline any plugins active in the main test
 | 
			
		||||
            # process are already imported.  So this disables the warning which
 | 
			
		||||
            # will trigger to say they can no longer be rewritten, which is
 | 
			
		||||
            # fine as they have already been rewritten.
 | 
			
		||||
            orig_warn = AssertionRewritingHook._warn_already_imported
 | 
			
		||||
 | 
			
		||||
            def revert_warn_already_imported():
 | 
			
		||||
                AssertionRewritingHook._warn_already_imported = orig_warn
 | 
			
		||||
 | 
			
		||||
            finalizers.append(revert_warn_already_imported)
 | 
			
		||||
            AssertionRewritingHook._warn_already_imported = lambda *a: None
 | 
			
		||||
 | 
			
		||||
            # Any sys.module or sys.path changes done while running pytest
 | 
			
		||||
            # inline should be reverted after the test run completes to avoid
 | 
			
		||||
            # clashing with later inline tests run within the same pytest test,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -633,6 +633,19 @@ class TestInvocationVariants:
 | 
			
		|||
 | 
			
		||||
        result.stdout.fnmatch_lines(["collected*0*items*/*1*errors"])
 | 
			
		||||
 | 
			
		||||
    def test_pyargs_only_imported_once(self, testdir):
 | 
			
		||||
        pkg = testdir.mkpydir("foo")
 | 
			
		||||
        pkg.join("test_foo.py").write("print('hello from test_foo')\ndef test(): pass")
 | 
			
		||||
        pkg.join("conftest.py").write(
 | 
			
		||||
            "def pytest_configure(config): print('configuring')"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        result = testdir.runpytest("--pyargs", "foo.test_foo", "-s", syspathinsert=True)
 | 
			
		||||
        # should only import once
 | 
			
		||||
        assert result.outlines.count("hello from test_foo") == 1
 | 
			
		||||
        # should only configure once
 | 
			
		||||
        assert result.outlines.count("configuring") == 1
 | 
			
		||||
 | 
			
		||||
    def test_cmdline_python_package(self, testdir, monkeypatch):
 | 
			
		||||
        import warnings
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -137,8 +137,8 @@ class TestImportHookInstallation:
 | 
			
		|||
            "hamster.py": "",
 | 
			
		||||
            "test_foo.py": """\
 | 
			
		||||
                def test_foo(pytestconfig):
 | 
			
		||||
                    assert pytestconfig.pluginmanager.rewrite_hook.find_module('ham') is not None
 | 
			
		||||
                    assert pytestconfig.pluginmanager.rewrite_hook.find_module('hamster') is None
 | 
			
		||||
                    assert pytestconfig.pluginmanager.rewrite_hook.find_spec('ham') is not None
 | 
			
		||||
                    assert pytestconfig.pluginmanager.rewrite_hook.find_spec('hamster') is None
 | 
			
		||||
            """,
 | 
			
		||||
        }
 | 
			
		||||
        testdir.makepyfile(**contents)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,6 @@
 | 
			
		|||
import ast
 | 
			
		||||
import glob
 | 
			
		||||
import importlib
 | 
			
		||||
import os
 | 
			
		||||
import py_compile
 | 
			
		||||
import stat
 | 
			
		||||
| 
						 | 
				
			
			@ -117,6 +118,37 @@ class TestAssertionRewrite:
 | 
			
		|||
        result = testdir.runpytest_subprocess()
 | 
			
		||||
        assert "warnings" not in "".join(result.outlines)
 | 
			
		||||
 | 
			
		||||
    def test_rewrites_plugin_as_a_package(self, testdir):
 | 
			
		||||
        pkgdir = testdir.mkpydir("plugin")
 | 
			
		||||
        pkgdir.join("__init__.py").write(
 | 
			
		||||
            "import pytest\n"
 | 
			
		||||
            "@pytest.fixture\n"
 | 
			
		||||
            "def special_asserter():\n"
 | 
			
		||||
            "    def special_assert(x, y):\n"
 | 
			
		||||
            "        assert x == y\n"
 | 
			
		||||
            "    return special_assert\n"
 | 
			
		||||
        )
 | 
			
		||||
        testdir.makeconftest('pytest_plugins = ["plugin"]')
 | 
			
		||||
        testdir.makepyfile("def test(special_asserter): special_asserter(1, 2)\n")
 | 
			
		||||
        result = testdir.runpytest()
 | 
			
		||||
        result.stdout.fnmatch_lines(["*assert 1 == 2*"])
 | 
			
		||||
 | 
			
		||||
    def test_honors_pep_235(self, testdir, monkeypatch):
 | 
			
		||||
        # note: couldn't make it fail on macos with a single `sys.path` entry
 | 
			
		||||
        # note: these modules are named `test_*` to trigger rewriting
 | 
			
		||||
        testdir.tmpdir.join("test_y.py").write("x = 1")
 | 
			
		||||
        xdir = testdir.tmpdir.join("x").ensure_dir()
 | 
			
		||||
        xdir.join("test_Y").ensure_dir().join("__init__.py").write("x = 2")
 | 
			
		||||
        testdir.makepyfile(
 | 
			
		||||
            "import test_y\n"
 | 
			
		||||
            "import test_Y\n"
 | 
			
		||||
            "def test():\n"
 | 
			
		||||
            "    assert test_y.x == 1\n"
 | 
			
		||||
            "    assert test_Y.x == 2\n"
 | 
			
		||||
        )
 | 
			
		||||
        monkeypatch.syspath_prepend(xdir)
 | 
			
		||||
        testdir.runpytest().assert_outcomes(passed=1)
 | 
			
		||||
 | 
			
		||||
    def test_name(self, request):
 | 
			
		||||
        def f():
 | 
			
		||||
            assert False
 | 
			
		||||
| 
						 | 
				
			
			@ -831,8 +863,9 @@ def test_rewritten():
 | 
			
		|||
        monkeypatch.setattr(
 | 
			
		||||
            hook, "_warn_already_imported", lambda code, msg: warnings.append(msg)
 | 
			
		||||
        )
 | 
			
		||||
        hook.find_module("test_remember_rewritten_modules")
 | 
			
		||||
        hook.load_module("test_remember_rewritten_modules")
 | 
			
		||||
        spec = hook.find_spec("test_remember_rewritten_modules")
 | 
			
		||||
        module = importlib.util.module_from_spec(spec)
 | 
			
		||||
        hook.exec_module(module)
 | 
			
		||||
        hook.mark_rewrite("test_remember_rewritten_modules")
 | 
			
		||||
        hook.mark_rewrite("test_remember_rewritten_modules")
 | 
			
		||||
        assert warnings == []
 | 
			
		||||
| 
						 | 
				
			
			@ -872,33 +905,6 @@ def test_rewritten():
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class TestAssertionRewriteHookDetails:
 | 
			
		||||
    def test_loader_is_package_false_for_module(self, testdir):
 | 
			
		||||
        testdir.makepyfile(
 | 
			
		||||
            test_fun="""
 | 
			
		||||
            def test_loader():
 | 
			
		||||
                assert not __loader__.is_package(__name__)
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        result = testdir.runpytest()
 | 
			
		||||
        result.stdout.fnmatch_lines(["* 1 passed*"])
 | 
			
		||||
 | 
			
		||||
    def test_loader_is_package_true_for_package(self, testdir):
 | 
			
		||||
        testdir.makepyfile(
 | 
			
		||||
            test_fun="""
 | 
			
		||||
            def test_loader():
 | 
			
		||||
                assert not __loader__.is_package(__name__)
 | 
			
		||||
 | 
			
		||||
            def test_fun():
 | 
			
		||||
                assert __loader__.is_package('fun')
 | 
			
		||||
 | 
			
		||||
            def test_missing():
 | 
			
		||||
                assert not __loader__.is_package('pytest_not_there')
 | 
			
		||||
            """
 | 
			
		||||
        )
 | 
			
		||||
        testdir.mkpydir("fun")
 | 
			
		||||
        result = testdir.runpytest()
 | 
			
		||||
        result.stdout.fnmatch_lines(["* 3 passed*"])
 | 
			
		||||
 | 
			
		||||
    def test_sys_meta_path_munged(self, testdir):
 | 
			
		||||
        testdir.makepyfile(
 | 
			
		||||
            """
 | 
			
		||||
| 
						 | 
				
			
			@ -917,7 +923,7 @@ class TestAssertionRewriteHookDetails:
 | 
			
		|||
        state = AssertionState(config, "rewrite")
 | 
			
		||||
        source_path = tmpdir.ensure("source.py")
 | 
			
		||||
        pycpath = tmpdir.join("pyc").strpath
 | 
			
		||||
        assert _write_pyc(state, [1], source_path.stat(), pycpath)
 | 
			
		||||
        assert _write_pyc(state, [1], os.stat(source_path.strpath), pycpath)
 | 
			
		||||
 | 
			
		||||
        @contextmanager
 | 
			
		||||
        def atomic_write_failed(fn, mode="r", overwrite=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -979,7 +985,7 @@ class TestAssertionRewriteHookDetails:
 | 
			
		|||
        assert len(contents) > strip_bytes
 | 
			
		||||
        pyc.write(contents[:strip_bytes], mode="wb")
 | 
			
		||||
 | 
			
		||||
        assert _read_pyc(source, str(pyc)) is None  # no error
 | 
			
		||||
        assert _read_pyc(str(source), str(pyc)) is None  # no error
 | 
			
		||||
 | 
			
		||||
    def test_reload_is_same(self, testdir):
 | 
			
		||||
        # A file that will be picked up during collecting.
 | 
			
		||||
| 
						 | 
				
			
			@ -1186,14 +1192,17 @@ def test_rewrite_infinite_recursion(testdir, pytestconfig, monkeypatch):
 | 
			
		|||
        # make a note that we have called _write_pyc
 | 
			
		||||
        write_pyc_called.append(True)
 | 
			
		||||
        # try to import a module at this point: we should not try to rewrite this module
 | 
			
		||||
        assert hook.find_module("test_bar") is None
 | 
			
		||||
        assert hook.find_spec("test_bar") is None
 | 
			
		||||
        return original_write_pyc(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    monkeypatch.setattr(rewrite, "_write_pyc", spy_write_pyc)
 | 
			
		||||
    monkeypatch.setattr(sys, "dont_write_bytecode", False)
 | 
			
		||||
 | 
			
		||||
    hook = AssertionRewritingHook(pytestconfig)
 | 
			
		||||
    assert hook.find_module("test_foo") is not None
 | 
			
		||||
    spec = hook.find_spec("test_foo")
 | 
			
		||||
    assert spec is not None
 | 
			
		||||
    module = importlib.util.module_from_spec(spec)
 | 
			
		||||
    hook.exec_module(module)
 | 
			
		||||
    assert len(write_pyc_called) == 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1201,11 +1210,11 @@ class TestEarlyRewriteBailout:
 | 
			
		|||
    @pytest.fixture
 | 
			
		||||
    def hook(self, pytestconfig, monkeypatch, testdir):
 | 
			
		||||
        """Returns a patched AssertionRewritingHook instance so we can configure its initial paths and track
 | 
			
		||||
        if imp.find_module has been called.
 | 
			
		||||
        if PathFinder.find_spec has been called.
 | 
			
		||||
        """
 | 
			
		||||
        import imp
 | 
			
		||||
        import importlib.machinery
 | 
			
		||||
 | 
			
		||||
        self.find_module_calls = []
 | 
			
		||||
        self.find_spec_calls = []
 | 
			
		||||
        self.initial_paths = set()
 | 
			
		||||
 | 
			
		||||
        class StubSession:
 | 
			
		||||
| 
						 | 
				
			
			@ -1214,22 +1223,22 @@ class TestEarlyRewriteBailout:
 | 
			
		|||
            def isinitpath(self, p):
 | 
			
		||||
                return p in self._initialpaths
 | 
			
		||||
 | 
			
		||||
        def spy_imp_find_module(name, path):
 | 
			
		||||
            self.find_module_calls.append(name)
 | 
			
		||||
            return imp.find_module(name, path)
 | 
			
		||||
        def spy_find_spec(name, path):
 | 
			
		||||
            self.find_spec_calls.append(name)
 | 
			
		||||
            return importlib.machinery.PathFinder.find_spec(name, path)
 | 
			
		||||
 | 
			
		||||
        hook = AssertionRewritingHook(pytestconfig)
 | 
			
		||||
        # use default patterns, otherwise we inherit pytest's testing config
 | 
			
		||||
        hook.fnpats[:] = ["test_*.py", "*_test.py"]
 | 
			
		||||
        monkeypatch.setattr(hook, "_imp_find_module", spy_imp_find_module)
 | 
			
		||||
        monkeypatch.setattr(hook, "_find_spec", spy_find_spec)
 | 
			
		||||
        hook.set_session(StubSession())
 | 
			
		||||
        testdir.syspathinsert()
 | 
			
		||||
        return hook
 | 
			
		||||
 | 
			
		||||
    def test_basic(self, testdir, hook):
 | 
			
		||||
        """
 | 
			
		||||
        Ensure we avoid calling imp.find_module when we know for sure a certain module will not be rewritten
 | 
			
		||||
        to optimize assertion rewriting (#3918).
 | 
			
		||||
        Ensure we avoid calling PathFinder.find_spec when we know for sure a certain
 | 
			
		||||
        module will not be rewritten to optimize assertion rewriting (#3918).
 | 
			
		||||
        """
 | 
			
		||||
        testdir.makeconftest(
 | 
			
		||||
            """
 | 
			
		||||
| 
						 | 
				
			
			@ -1244,24 +1253,24 @@ class TestEarlyRewriteBailout:
 | 
			
		|||
        self.initial_paths.add(foobar_path)
 | 
			
		||||
 | 
			
		||||
        # conftest files should always be rewritten
 | 
			
		||||
        assert hook.find_module("conftest") is not None
 | 
			
		||||
        assert self.find_module_calls == ["conftest"]
 | 
			
		||||
        assert hook.find_spec("conftest") is not None
 | 
			
		||||
        assert self.find_spec_calls == ["conftest"]
 | 
			
		||||
 | 
			
		||||
        # files matching "python_files" mask should always be rewritten
 | 
			
		||||
        assert hook.find_module("test_foo") is not None
 | 
			
		||||
        assert self.find_module_calls == ["conftest", "test_foo"]
 | 
			
		||||
        assert hook.find_spec("test_foo") is not None
 | 
			
		||||
        assert self.find_spec_calls == ["conftest", "test_foo"]
 | 
			
		||||
 | 
			
		||||
        # file does not match "python_files": early bailout
 | 
			
		||||
        assert hook.find_module("bar") is None
 | 
			
		||||
        assert self.find_module_calls == ["conftest", "test_foo"]
 | 
			
		||||
        assert hook.find_spec("bar") is None
 | 
			
		||||
        assert self.find_spec_calls == ["conftest", "test_foo"]
 | 
			
		||||
 | 
			
		||||
        # file is an initial path (passed on the command-line): should be rewritten
 | 
			
		||||
        assert hook.find_module("foobar") is not None
 | 
			
		||||
        assert self.find_module_calls == ["conftest", "test_foo", "foobar"]
 | 
			
		||||
        assert hook.find_spec("foobar") is not None
 | 
			
		||||
        assert self.find_spec_calls == ["conftest", "test_foo", "foobar"]
 | 
			
		||||
 | 
			
		||||
    def test_pattern_contains_subdirectories(self, testdir, hook):
 | 
			
		||||
        """If one of the python_files patterns contain subdirectories ("tests/**.py") we can't bailout early
 | 
			
		||||
        because we need to match with the full path, which can only be found by calling imp.find_module.
 | 
			
		||||
        because we need to match with the full path, which can only be found by calling PathFinder.find_spec
 | 
			
		||||
        """
 | 
			
		||||
        p = testdir.makepyfile(
 | 
			
		||||
            **{
 | 
			
		||||
| 
						 | 
				
			
			@ -1273,8 +1282,8 @@ class TestEarlyRewriteBailout:
 | 
			
		|||
        )
 | 
			
		||||
        testdir.syspathinsert(p.dirpath())
 | 
			
		||||
        hook.fnpats[:] = ["tests/**.py"]
 | 
			
		||||
        assert hook.find_module("file") is not None
 | 
			
		||||
        assert self.find_module_calls == ["file"]
 | 
			
		||||
        assert hook.find_spec("file") is not None
 | 
			
		||||
        assert self.find_spec_calls == ["file"]
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.skipif(
 | 
			
		||||
        sys.platform.startswith("win32"), reason="cannot remove cwd on Windows"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
import os.path
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
import py
 | 
			
		||||
| 
						 | 
				
			
			@ -53,6 +54,10 @@ class TestPort:
 | 
			
		|||
    def test_matching(self, match, pattern, path):
 | 
			
		||||
        assert match(pattern, path)
 | 
			
		||||
 | 
			
		||||
    def test_matching_abspath(self, match):
 | 
			
		||||
        abspath = os.path.abspath(os.path.join("tests/foo.py"))
 | 
			
		||||
        assert match("tests/foo.py", abspath)
 | 
			
		||||
 | 
			
		||||
    @pytest.mark.parametrize(
 | 
			
		||||
        "pattern, path",
 | 
			
		||||
        [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue