This commit is contained in:
Marc Mueller 2024-04-29 19:18:24 +02:00 committed by GitHub
commit 40109d7cd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 146 additions and 26 deletions

View File

@ -0,0 +1 @@
Added hash comparison for pyc cache files.

View File

@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None:
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None

View File

@ -23,6 +23,7 @@ from typing import IO
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Literal
from typing import Optional
from typing import Sequence
from typing import Set
@ -30,6 +31,8 @@ from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
import _imp
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr
from _pytest._version import version
@ -166,11 +169,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config)
source_stat, source_hash, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
_write_pyc(state, co, source_stat, source_hash, pyc)
finally:
self._writing_pyc = False
else:
@ -299,13 +302,18 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/
if invalidation_mode == "timestamp":
flags = b"\x00\x00\x00\x00"
fp.write(flags)
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
@ -313,6 +321,12 @@ def _write_pyc_fp(
size = source_stat.st_size & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)
fp.write(marshal.dumps(co))
@ -320,12 +334,13 @@ def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
source_hash: bytes,
pyc: Path,
) -> bool:
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, co)
_write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
@ -341,15 +356,18 @@ def _write_pyc(
return True
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
def _rewrite_test(
fn: Path, config: Config
) -> Tuple[os.stat_result, bytes, types.CodeType]:
"""Read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
source = fn.read_bytes()
source_hash = importlib.util.source_hash(source)
strfn = str(fn)
tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co
return stat, source_hash, co
def _read_pyc(
@ -379,9 +397,10 @@ def _read_pyc(
if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
@ -390,6 +409,17 @@ def _read_pyc(
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
source_hash = importlib.util.source_hash(source.read_bytes())
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
try:
co = marshal.load(fp)
except Exception as e:

View File

@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.",
)
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)
parser.addini(
"consider_namespace_packages",
type="bool",

View File

@ -4,6 +4,7 @@ import errno
from functools import partial
import glob
import importlib
from importlib.util import source_hash
import marshal
import os
from pathlib import Path
@ -21,6 +22,8 @@ from typing import Set
from unittest import mock
import zipfile
import _imp
import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util
@ -1043,12 +1046,14 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite")
tmp_path.joinpath("source.py").touch()
source_path = str(tmp_path)
source_bytes = tmp_path.joinpath("source.py").read_bytes()
pycpath = tmp_path.joinpath("pyc")
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
hash = source_hash(source_bytes)
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)
with mock.patch.object(os, "replace", side_effect=OSError):
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)
def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
"""
@ -1121,10 +1126,41 @@ class TestAssertionRewriteHookDetails:
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, pyc)
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set
def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc
config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")
fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")
# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
@ -1170,6 +1206,50 @@ class TestAssertionRewriteHookDetails:
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None
def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"
source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)
magic = importlib.util.MAGIC_NUMBER
flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"
mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change."""
pytester.makeini(