Merge 2f12594279
into feaae2fb35
This commit is contained in:
commit
40109d7cd1
|
@ -0,0 +1 @@
|
||||||
|
Added hash comparison for pyc cache files.
|
|
@ -94,6 +94,7 @@ class AssertionState:
|
||||||
def __init__(self, config: Config, mode) -> None:
|
def __init__(self, config: Config, mode) -> None:
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.trace = config.trace.root.get("assertion")
|
self.trace = config.trace.root.get("assertion")
|
||||||
|
self.invalidation_mode = config.known_args_namespace.invalidationmode
|
||||||
self.hook: Optional[rewrite.AssertionRewritingHook] = None
|
self.hook: Optional[rewrite.AssertionRewritingHook] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import IO
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from typing import Iterator
|
from typing import Iterator
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
@ -30,6 +31,8 @@ from typing import Tuple
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
import _imp
|
||||||
|
|
||||||
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
||||||
from _pytest._io.saferepr import saferepr
|
from _pytest._io.saferepr import saferepr
|
||||||
from _pytest._version import version
|
from _pytest._version import version
|
||||||
|
@ -166,11 +169,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
co = _read_pyc(fn, pyc, state.trace)
|
co = _read_pyc(fn, pyc, state.trace)
|
||||||
if co is None:
|
if co is None:
|
||||||
state.trace(f"rewriting {fn!r}")
|
state.trace(f"rewriting {fn!r}")
|
||||||
source_stat, co = _rewrite_test(fn, self.config)
|
source_stat, source_hash, co = _rewrite_test(fn, self.config)
|
||||||
if write:
|
if write:
|
||||||
self._writing_pyc = True
|
self._writing_pyc = True
|
||||||
try:
|
try:
|
||||||
_write_pyc(state, co, source_stat, pyc)
|
_write_pyc(state, co, source_stat, source_hash, pyc)
|
||||||
finally:
|
finally:
|
||||||
self._writing_pyc = False
|
self._writing_pyc = False
|
||||||
else:
|
else:
|
||||||
|
@ -299,20 +302,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
|
||||||
|
|
||||||
|
|
||||||
def _write_pyc_fp(
|
def _write_pyc_fp(
|
||||||
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
|
fp: IO[bytes],
|
||||||
|
source_stat: os.stat_result,
|
||||||
|
source_hash: bytes,
|
||||||
|
co: types.CodeType,
|
||||||
|
invalidation_mode: Literal["timestamp", "checked-hash"],
|
||||||
) -> None:
|
) -> None:
|
||||||
# Technically, we don't have to have the same pyc format as
|
# Technically, we don't have to have the same pyc format as
|
||||||
# (C)Python, since these "pycs" should never be seen by builtin
|
# (C)Python, since these "pycs" should never be seen by builtin
|
||||||
# import. However, there's little reason to deviate.
|
# import. However, there's little reason to deviate.
|
||||||
fp.write(importlib.util.MAGIC_NUMBER)
|
fp.write(importlib.util.MAGIC_NUMBER)
|
||||||
# https://www.python.org/dev/peps/pep-0552/
|
# https://www.python.org/dev/peps/pep-0552/
|
||||||
flags = b"\x00\x00\x00\x00"
|
if invalidation_mode == "timestamp":
|
||||||
fp.write(flags)
|
flags = b"\x00\x00\x00\x00"
|
||||||
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
|
fp.write(flags)
|
||||||
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
|
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
|
||||||
size = source_stat.st_size & 0xFFFFFFFF
|
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
|
||||||
# "<LL" stands for 2 unsigned longs, little-endian.
|
size = source_stat.st_size & 0xFFFFFFFF
|
||||||
fp.write(struct.pack("<LL", mtime, size))
|
# "<LL" stands for 2 unsigned longs, little-endian.
|
||||||
|
fp.write(struct.pack("<LL", mtime, size))
|
||||||
|
elif invalidation_mode == "checked-hash":
|
||||||
|
flags = b"\x03\x00\x00\x00"
|
||||||
|
fp.write(flags)
|
||||||
|
# 64-bit source file hash
|
||||||
|
source_hash = source_hash[:8]
|
||||||
|
fp.write(source_hash)
|
||||||
fp.write(marshal.dumps(co))
|
fp.write(marshal.dumps(co))
|
||||||
|
|
||||||
|
|
||||||
|
@ -320,12 +334,13 @@ def _write_pyc(
|
||||||
state: "AssertionState",
|
state: "AssertionState",
|
||||||
co: types.CodeType,
|
co: types.CodeType,
|
||||||
source_stat: os.stat_result,
|
source_stat: os.stat_result,
|
||||||
|
source_hash: bytes,
|
||||||
pyc: Path,
|
pyc: Path,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
proc_pyc = f"{pyc}.{os.getpid()}"
|
proc_pyc = f"{pyc}.{os.getpid()}"
|
||||||
try:
|
try:
|
||||||
with open(proc_pyc, "wb") as fp:
|
with open(proc_pyc, "wb") as fp:
|
||||||
_write_pyc_fp(fp, source_stat, co)
|
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
|
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
|
||||||
return False
|
return False
|
||||||
|
@ -341,15 +356,18 @@ def _write_pyc(
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
|
def _rewrite_test(
|
||||||
|
fn: Path, config: Config
|
||||||
|
) -> Tuple[os.stat_result, bytes, types.CodeType]:
|
||||||
"""Read and rewrite *fn* and return the code object."""
|
"""Read and rewrite *fn* and return the code object."""
|
||||||
stat = os.stat(fn)
|
stat = os.stat(fn)
|
||||||
source = fn.read_bytes()
|
source = fn.read_bytes()
|
||||||
|
source_hash = importlib.util.source_hash(source)
|
||||||
strfn = str(fn)
|
strfn = str(fn)
|
||||||
tree = ast.parse(source, filename=strfn)
|
tree = ast.parse(source, filename=strfn)
|
||||||
rewrite_asserts(tree, source, strfn, config)
|
rewrite_asserts(tree, source, strfn, config)
|
||||||
co = compile(tree, strfn, "exec", dont_inherit=True)
|
co = compile(tree, strfn, "exec", dont_inherit=True)
|
||||||
return stat, co
|
return stat, source_hash, co
|
||||||
|
|
||||||
|
|
||||||
def _read_pyc(
|
def _read_pyc(
|
||||||
|
@ -379,17 +397,29 @@ def _read_pyc(
|
||||||
if data[:4] != importlib.util.MAGIC_NUMBER:
|
if data[:4] != importlib.util.MAGIC_NUMBER:
|
||||||
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
|
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
|
||||||
return None
|
return None
|
||||||
if data[4:8] != b"\x00\x00\x00\x00":
|
|
||||||
|
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
|
||||||
|
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
|
||||||
|
trace("_read_pyc(%s): timestamp based" % source)
|
||||||
|
mtime_data = data[8:12]
|
||||||
|
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
|
||||||
|
trace("_read_pyc(%s): out of date" % source)
|
||||||
|
return None
|
||||||
|
size_data = data[12:16]
|
||||||
|
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
|
||||||
|
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
|
||||||
|
return None
|
||||||
|
elif data[4:8] == b"\x03\x00\x00\x00":
|
||||||
|
trace("_read_pyc(%s): hash based" % source)
|
||||||
|
hash = data[8:16]
|
||||||
|
source_hash = importlib.util.source_hash(source.read_bytes())
|
||||||
|
if source_hash[:8] != hash:
|
||||||
|
trace("_read_pyc(%s): hash doesn't match" % source)
|
||||||
|
return None
|
||||||
|
else:
|
||||||
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
|
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
|
||||||
return None
|
return None
|
||||||
mtime_data = data[8:12]
|
|
||||||
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
|
|
||||||
trace("_read_pyc(%s): out of date" % source)
|
|
||||||
return None
|
|
||||||
size_data = data[12:16]
|
|
||||||
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
|
|
||||||
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
|
|
||||||
return None
|
|
||||||
try:
|
try:
|
||||||
co = marshal.load(fp)
|
co = marshal.load(fp)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
|
||||||
help="Prepend/append to sys.path when importing test modules and conftest "
|
help="Prepend/append to sys.path when importing test modules and conftest "
|
||||||
"files. Default: prepend.",
|
"files. Default: prepend.",
|
||||||
)
|
)
|
||||||
|
group.addoption(
|
||||||
|
"--invalidation-mode",
|
||||||
|
default="timestamp",
|
||||||
|
choices=["timestamp", "checked-hash"],
|
||||||
|
dest="invalidationmode",
|
||||||
|
help="Pytest pyc cache invalidation mode. Default: timestamp.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.addini(
|
parser.addini(
|
||||||
"consider_namespace_packages",
|
"consider_namespace_packages",
|
||||||
type="bool",
|
type="bool",
|
||||||
|
|
|
@ -4,6 +4,7 @@ import errno
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
|
from importlib.util import source_hash
|
||||||
import marshal
|
import marshal
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -21,6 +22,8 @@ from typing import Set
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
import zipfile
|
import zipfile
|
||||||
|
|
||||||
|
import _imp
|
||||||
|
|
||||||
import _pytest._code
|
import _pytest._code
|
||||||
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
|
||||||
from _pytest.assertion import util
|
from _pytest.assertion import util
|
||||||
|
@ -1043,12 +1046,14 @@ class TestAssertionRewriteHookDetails:
|
||||||
state = AssertionState(config, "rewrite")
|
state = AssertionState(config, "rewrite")
|
||||||
tmp_path.joinpath("source.py").touch()
|
tmp_path.joinpath("source.py").touch()
|
||||||
source_path = str(tmp_path)
|
source_path = str(tmp_path)
|
||||||
|
source_bytes = tmp_path.joinpath("source.py").read_bytes()
|
||||||
pycpath = tmp_path.joinpath("pyc")
|
pycpath = tmp_path.joinpath("pyc")
|
||||||
co = compile("1", "f.py", "single")
|
co = compile("1", "f.py", "single")
|
||||||
assert _write_pyc(state, co, os.stat(source_path), pycpath)
|
hash = source_hash(source_bytes)
|
||||||
|
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)
|
||||||
|
|
||||||
with mock.patch.object(os, "replace", side_effect=OSError):
|
with mock.patch.object(os, "replace", side_effect=OSError):
|
||||||
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
|
assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)
|
||||||
|
|
||||||
def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
|
def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -1121,10 +1126,41 @@ class TestAssertionRewriteHookDetails:
|
||||||
|
|
||||||
fn.write_text("def test(): assert True", encoding="utf-8")
|
fn.write_text("def test(): assert True", encoding="utf-8")
|
||||||
|
|
||||||
source_stat, co = _rewrite_test(fn, config)
|
source_stat, hash, co = _rewrite_test(fn, config)
|
||||||
_write_pyc(state, co, source_stat, pyc)
|
_write_pyc(state, co, source_stat, hash, pyc)
|
||||||
assert _read_pyc(fn, pyc, state.trace) is not None
|
assert _read_pyc(fn, pyc, state.trace) is not None
|
||||||
|
|
||||||
|
pyc_bytes = pyc.read_bytes()
|
||||||
|
assert pyc_bytes[4] == 0 # timestamp flag set
|
||||||
|
|
||||||
|
def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
|
||||||
|
from _pytest.assertion import AssertionState
|
||||||
|
from _pytest.assertion.rewrite import _read_pyc
|
||||||
|
from _pytest.assertion.rewrite import _rewrite_test
|
||||||
|
from _pytest.assertion.rewrite import _write_pyc
|
||||||
|
|
||||||
|
config = pytester.parseconfig("--invalidation-mode=checked-hash")
|
||||||
|
state = AssertionState(config, "rewrite")
|
||||||
|
|
||||||
|
fn = tmp_path / "source.py"
|
||||||
|
pyc = Path(str(fn) + "c")
|
||||||
|
|
||||||
|
# Test private attribute didn't change
|
||||||
|
assert getattr(_imp, "check_hash_based_pycs", None) in {
|
||||||
|
"default",
|
||||||
|
"always",
|
||||||
|
"never",
|
||||||
|
}
|
||||||
|
|
||||||
|
fn.write_text("def test(): assert True", encoding="utf-8")
|
||||||
|
source_stat, hash, co = _rewrite_test(fn, config)
|
||||||
|
_write_pyc(state, co, source_stat, hash, pyc)
|
||||||
|
assert _read_pyc(fn, pyc, state.trace) is not None
|
||||||
|
|
||||||
|
pyc_bytes = pyc.read_bytes()
|
||||||
|
assert pyc_bytes[4] == 3 # checked-hash flag set
|
||||||
|
assert pyc_bytes[8:16] == hash
|
||||||
|
|
||||||
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
|
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
|
||||||
from _pytest.assertion.rewrite import _read_pyc
|
from _pytest.assertion.rewrite import _read_pyc
|
||||||
|
|
||||||
|
@ -1170,6 +1206,50 @@ class TestAssertionRewriteHookDetails:
|
||||||
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
|
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
|
||||||
assert _read_pyc(source, pyc, print) is None
|
assert _read_pyc(source, pyc, print) is None
|
||||||
|
|
||||||
|
def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
|
||||||
|
from _pytest.assertion.rewrite import _read_pyc
|
||||||
|
|
||||||
|
source = tmp_path / "source.py"
|
||||||
|
pyc = tmp_path / "source.pyc"
|
||||||
|
|
||||||
|
source_bytes = b"def test(): pass\n"
|
||||||
|
source.write_bytes(source_bytes)
|
||||||
|
|
||||||
|
magic = importlib.util.MAGIC_NUMBER
|
||||||
|
|
||||||
|
flags = b"\x00\x00\x00\x00"
|
||||||
|
flags_hash = b"\x03\x00\x00\x00"
|
||||||
|
|
||||||
|
mtime = b"\x58\x3c\xb0\x5f"
|
||||||
|
mtime_int = int.from_bytes(mtime, "little")
|
||||||
|
os.utime(source, (mtime_int, mtime_int))
|
||||||
|
|
||||||
|
size = len(source_bytes).to_bytes(4, "little")
|
||||||
|
|
||||||
|
hash = source_hash(source_bytes)
|
||||||
|
hash = hash[:8]
|
||||||
|
|
||||||
|
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
|
||||||
|
|
||||||
|
# check_hash_based_pycs == "default" with hash based pyc file.
|
||||||
|
pyc.write_bytes(magic + flags_hash + hash + code)
|
||||||
|
assert _read_pyc(source, pyc, print) is not None
|
||||||
|
|
||||||
|
# check_hash_based_pycs == "always" with hash based pyc file.
|
||||||
|
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
|
||||||
|
pyc.write_bytes(magic + flags_hash + hash + code)
|
||||||
|
assert _read_pyc(source, pyc, print) is not None
|
||||||
|
|
||||||
|
# Bad hash.
|
||||||
|
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
|
||||||
|
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
|
||||||
|
assert _read_pyc(source, pyc, print) is None
|
||||||
|
|
||||||
|
# check_hash_based_pycs == "always" with timestamp based pyc file.
|
||||||
|
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
|
||||||
|
pyc.write_bytes(magic + flags + mtime + size + code)
|
||||||
|
assert _read_pyc(source, pyc, print) is None
|
||||||
|
|
||||||
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
|
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
|
||||||
"""Reloading a (collected) module after change picks up the change."""
|
"""Reloading a (collected) module after change picks up the change."""
|
||||||
pytester.makeini(
|
pytester.makeini(
|
||||||
|
|
Loading…
Reference in New Issue