Add invalidation-mode option

This commit is contained in:
Marc Mueller 2023-09-10 23:48:15 +02:00
parent ac98ff571b
commit 2f12594279
4 changed files with 136 additions and 42 deletions

View File

@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None

View File

@ -23,6 +23,7 @@ from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Set
@ -30,6 +31,8 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import _imp
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
@ -299,23 +302,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(source_hash)
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)
fp.write(marshal.dumps(co))
@ -329,7 +340,7 @@ def _write_pyc(
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, source_hash, co)
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
@ -375,34 +386,40 @@ def _read_pyc(
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(24)
data = fp.read(16)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (24):
if len(data) != (16):
trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
hash = data[16:24]
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
source_hash = importlib.util.source_hash(source.read_bytes())
if source_hash[:8] == hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)
else:
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
try:
co = marshal.load(fp)
except Exception as e:

View File

@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)
parser.addini(
"consider_namespace_packages",
type="bool",

View File

@ -22,6 +22,8 @@ from typing import Set
from unittest import mock
import zipfile
import _imp
import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util
@ -1128,13 +1130,37 @@ class TestAssertionRewriteHookDetails:
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
# pyc read should still work if only the mtime changed
# Fallback to hash comparison
new_mtime = source_stat.st_mtime + 1.2
os.utime(fn, (new_mtime, new_mtime))
assert source_stat.st_mtime != os.stat(fn).st_mtime
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set
def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc
config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")
fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")
# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
@ -1153,13 +1179,11 @@ class TestAssertionRewriteHookDetails:
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# Good header.
pyc.write_bytes(magic + flags + mtime + size + hash + code)
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is not None
# Too short.
@ -1167,20 +1191,64 @@ class TestAssertionRewriteHookDetails:
assert _read_pyc(source, pyc, print) is None
# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime + bad hash.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"
source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)
magic = importlib.util.MAGIC_NUMBER
flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"
mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change."""