Add invalidation-mode option

This commit is contained in:
Marc Mueller 2023-09-10 23:48:15 +02:00
parent ac98ff571b
commit 2f12594279
4 changed files with 136 additions and 42 deletions

View File

@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None: def __init__(self, config: Config, mode) -> None:
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None self.hook: Optional[rewrite.AssertionRewritingHook] = None

View File

@ -23,6 +23,7 @@ from typing import IO
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Literal
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set from typing import Set
@ -30,6 +31,8 @@ from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
import _imp
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest._version import version from _pytest._version import version
@ -299,22 +302,30 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def _write_pyc_fp( def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None: ) -> None:
# Technically, we don't have to have the same pyc format as # Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin # (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate. # import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER) fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/ # https://www.python.org/dev/peps/pep-0552/
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00" flags = b"\x00\x00\x00\x00"
fp.write(flags) fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian. # "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size)) fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash) fp.write(source_hash)
fp.write(marshal.dumps(co)) fp.write(marshal.dumps(co))
@ -329,7 +340,7 @@ def _write_pyc(
proc_pyc = f"{pyc}.{os.getpid()}" proc_pyc = f"{pyc}.{os.getpid()}"
try: try:
with open(proc_pyc, "wb") as fp: with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, source_hash, co) _write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e: except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False return False
@ -375,34 +386,40 @@ def _read_pyc(
stat_result = os.stat(source) stat_result = os.stat(source)
mtime = int(stat_result.st_mtime) mtime = int(stat_result.st_mtime)
size = stat_result.st_size size = stat_result.st_size
data = fp.read(24) data = fp.read(16)
except OSError as e: except OSError as e:
trace(f"_read_pyc({source}): OSError {e}") trace(f"_read_pyc({source}): OSError {e}")
return None return None
# Check for invalid or out of date pyc file. # Check for invalid or out of date pyc file.
if len(data) != (24): if len(data) != (16):
trace("_read_pyc(%s): invalid pyc (too short)" % source) trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None return None
if data[:4] != importlib.util.MAGIC_NUMBER: if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None return None
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None return None
size_data = data[12:16] size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None return None
mtime_data = data[8:12] elif data[4:8] == b"\x03\x00\x00\x00":
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: trace("_read_pyc(%s): hash based" % source)
trace("_read_pyc(%s): out of date" % source) hash = data[8:16]
hash = data[16:24]
source_hash = importlib.util.source_hash(source.read_bytes()) source_hash = importlib.util.source_hash(source.read_bytes())
if source_hash[:8] == hash: if source_hash[:8] != hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)
else:
trace("_read_pyc(%s): hash doesn't match" % source) trace("_read_pyc(%s): hash doesn't match" % source)
return None return None
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
try: try:
co = marshal.load(fp) co = marshal.load(fp)
except Exception as e: except Exception as e:

View File

@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest " help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.", "files. Default: prepend.",
) )
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)
parser.addini( parser.addini(
"consider_namespace_packages", "consider_namespace_packages",
type="bool", type="bool",

View File

@ -22,6 +22,8 @@ from typing import Set
from unittest import mock from unittest import mock
import zipfile import zipfile
import _imp
import _pytest._code import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util from _pytest.assertion import util
@ -1128,13 +1130,37 @@ class TestAssertionRewriteHookDetails:
_write_pyc(state, co, source_stat, hash, pyc) _write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None assert _read_pyc(fn, pyc, state.trace) is not None
# pyc read should still work if only the mtime changed pyc_bytes = pyc.read_bytes()
# Fallback to hash comparison assert pyc_bytes[4] == 0 # timestamp flag set
new_mtime = source_stat.st_mtime + 1.2
os.utime(fn, (new_mtime, new_mtime)) def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
assert source_stat.st_mtime != os.stat(fn).st_mtime from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc
config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")
fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")
# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None: def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc from _pytest.assertion.rewrite import _read_pyc
@ -1153,13 +1179,11 @@ class TestAssertionRewriteHookDetails:
os.utime(source, (mtime_int, mtime_int)) os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little") size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec")) code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# Good header. # Good header.
pyc.write_bytes(magic + flags + mtime + size + hash + code) pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is not None assert _read_pyc(source, pyc, print) is not None
# Too short. # Too short.
@ -1167,19 +1191,63 @@ class TestAssertionRewriteHookDetails:
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
# Bad magic. # Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code) pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
# Unsupported flags. # Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code) pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
# Bad size. # Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code) pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
# Bad mtime + bad hash. def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code) from _pytest.assertion.rewrite import _read_pyc
source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"
source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)
magic = importlib.util.MAGIC_NUMBER
flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"
mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: