This commit is contained in:
Marc Mueller 2024-04-29 19:18:24 +02:00 committed by GitHub
commit 40109d7cd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 146 additions and 26 deletions

View File

@ -0,0 +1 @@
Added hash comparison for pyc cache files.

View File

@ -94,6 +94,7 @@ class AssertionState:
def __init__(self, config: Config, mode) -> None: def __init__(self, config: Config, mode) -> None:
self.mode = mode self.mode = mode
self.trace = config.trace.root.get("assertion") self.trace = config.trace.root.get("assertion")
self.invalidation_mode = config.known_args_namespace.invalidationmode
self.hook: Optional[rewrite.AssertionRewritingHook] = None self.hook: Optional[rewrite.AssertionRewritingHook] = None

View File

@ -23,6 +23,7 @@ from typing import IO
from typing import Iterable from typing import Iterable
from typing import Iterator from typing import Iterator
from typing import List from typing import List
from typing import Literal
from typing import Optional from typing import Optional
from typing import Sequence from typing import Sequence
from typing import Set from typing import Set
@ -30,6 +31,8 @@ from typing import Tuple
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import Union from typing import Union
import _imp
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest._io.saferepr import saferepr from _pytest._io.saferepr import saferepr
from _pytest._version import version from _pytest._version import version
@ -166,11 +169,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
co = _read_pyc(fn, pyc, state.trace) co = _read_pyc(fn, pyc, state.trace)
if co is None: if co is None:
state.trace(f"rewriting {fn!r}") state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config) source_stat, source_hash, co = _rewrite_test(fn, self.config)
if write: if write:
self._writing_pyc = True self._writing_pyc = True
try: try:
_write_pyc(state, co, source_stat, pyc) _write_pyc(state, co, source_stat, source_hash, pyc)
finally: finally:
self._writing_pyc = False self._writing_pyc = False
else: else:
@ -299,20 +302,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def _write_pyc_fp( def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType fp: IO[bytes],
source_stat: os.stat_result,
source_hash: bytes,
co: types.CodeType,
invalidation_mode: Literal["timestamp", "checked-hash"],
) -> None: ) -> None:
# Technically, we don't have to have the same pyc format as # Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin # (C)Python, since these "pycs" should never be seen by builtin
# import. However, there's little reason to deviate. # import. However, there's little reason to deviate.
fp.write(importlib.util.MAGIC_NUMBER) fp.write(importlib.util.MAGIC_NUMBER)
# https://www.python.org/dev/peps/pep-0552/ # https://www.python.org/dev/peps/pep-0552/
flags = b"\x00\x00\x00\x00" if invalidation_mode == "timestamp":
fp.write(flags) flags = b"\x00\x00\x00\x00"
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) fp.write(flags)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
size = source_stat.st_size & 0xFFFFFFFF mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
# "<LL" stands for 2 unsigned longs, little-endian. size = source_stat.st_size & 0xFFFFFFFF
fp.write(struct.pack("<LL", mtime, size)) # "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
elif invalidation_mode == "checked-hash":
flags = b"\x03\x00\x00\x00"
fp.write(flags)
# 64-bit source file hash
source_hash = source_hash[:8]
fp.write(source_hash)
fp.write(marshal.dumps(co)) fp.write(marshal.dumps(co))
@ -320,12 +334,13 @@ def _write_pyc(
state: "AssertionState", state: "AssertionState",
co: types.CodeType, co: types.CodeType,
source_stat: os.stat_result, source_stat: os.stat_result,
source_hash: bytes,
pyc: Path, pyc: Path,
) -> bool: ) -> bool:
proc_pyc = f"{pyc}.{os.getpid()}" proc_pyc = f"{pyc}.{os.getpid()}"
try: try:
with open(proc_pyc, "wb") as fp: with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, co) _write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode)
except OSError as e: except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False return False
@ -341,15 +356,18 @@ def _write_pyc(
return True return True
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: def _rewrite_test(
fn: Path, config: Config
) -> Tuple[os.stat_result, bytes, types.CodeType]:
"""Read and rewrite *fn* and return the code object.""" """Read and rewrite *fn* and return the code object."""
stat = os.stat(fn) stat = os.stat(fn)
source = fn.read_bytes() source = fn.read_bytes()
source_hash = importlib.util.source_hash(source)
strfn = str(fn) strfn = str(fn)
tree = ast.parse(source, filename=strfn) tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config) rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True) co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co return stat, source_hash, co
def _read_pyc( def _read_pyc(
@ -379,17 +397,29 @@ def _read_pyc(
if data[:4] != importlib.util.MAGIC_NUMBER: if data[:4] != importlib.util.MAGIC_NUMBER:
trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) trace("_read_pyc(%s): invalid pyc (bad magic number)" % source)
return None return None
if data[4:8] != b"\x00\x00\x00\x00":
hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always"
if data[4:8] == b"\x00\x00\x00\x00" and not hash_based:
trace("_read_pyc(%s): timestamp based" % source)
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
elif data[4:8] == b"\x03\x00\x00\x00":
trace("_read_pyc(%s): hash based" % source)
hash = data[8:16]
source_hash = importlib.util.source_hash(source.read_bytes())
if source_hash[:8] != hash:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
else:
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
try: try:
co = marshal.load(fp) co = marshal.load(fp)
except Exception as e: except Exception as e:

View File

@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None:
help="Prepend/append to sys.path when importing test modules and conftest " help="Prepend/append to sys.path when importing test modules and conftest "
"files. Default: prepend.", "files. Default: prepend.",
) )
group.addoption(
"--invalidation-mode",
default="timestamp",
choices=["timestamp", "checked-hash"],
dest="invalidationmode",
help="Pytest pyc cache invalidation mode. Default: timestamp.",
)
parser.addini( parser.addini(
"consider_namespace_packages", "consider_namespace_packages",
type="bool", type="bool",

View File

@ -4,6 +4,7 @@ import errno
from functools import partial from functools import partial
import glob import glob
import importlib import importlib
from importlib.util import source_hash
import marshal import marshal
import os import os
from pathlib import Path from pathlib import Path
@ -21,6 +22,8 @@ from typing import Set
from unittest import mock from unittest import mock
import zipfile import zipfile
import _imp
import _pytest._code import _pytest._code
from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE
from _pytest.assertion import util from _pytest.assertion import util
@ -1043,12 +1046,14 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite") state = AssertionState(config, "rewrite")
tmp_path.joinpath("source.py").touch() tmp_path.joinpath("source.py").touch()
source_path = str(tmp_path) source_path = str(tmp_path)
source_bytes = tmp_path.joinpath("source.py").read_bytes()
pycpath = tmp_path.joinpath("pyc") pycpath = tmp_path.joinpath("pyc")
co = compile("1", "f.py", "single") co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath) hash = source_hash(source_bytes)
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)
with mock.patch.object(os, "replace", side_effect=OSError): with mock.patch.object(os, "replace", side_effect=OSError):
assert not _write_pyc(state, co, os.stat(source_path), pycpath) assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)
def test_resources_provider_for_loader(self, pytester: Pytester) -> None: def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
""" """
@ -1121,10 +1126,41 @@ class TestAssertionRewriteHookDetails:
fn.write_text("def test(): assert True", encoding="utf-8") fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, co = _rewrite_test(fn, config) source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, pyc) _write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 0 # timestamp flag set
def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None:
from _pytest.assertion import AssertionState
from _pytest.assertion.rewrite import _read_pyc
from _pytest.assertion.rewrite import _rewrite_test
from _pytest.assertion.rewrite import _write_pyc
config = pytester.parseconfig("--invalidation-mode=checked-hash")
state = AssertionState(config, "rewrite")
fn = tmp_path / "source.py"
pyc = Path(str(fn) + "c")
# Test private attribute didn't change
assert getattr(_imp, "check_hash_based_pycs", None) in {
"default",
"always",
"never",
}
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
pyc_bytes = pyc.read_bytes()
assert pyc_bytes[4] == 3 # checked-hash flag set
assert pyc_bytes[8:16] == hash
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None: def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc from _pytest.assertion.rewrite import _read_pyc
@ -1170,6 +1206,50 @@ class TestAssertionRewriteHookDetails:
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code) pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
assert _read_pyc(source, pyc, print) is None assert _read_pyc(source, pyc, print) is None
def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc
source = tmp_path / "source.py"
pyc = tmp_path / "source.pyc"
source_bytes = b"def test(): pass\n"
source.write_bytes(source_bytes)
magic = importlib.util.MAGIC_NUMBER
flags = b"\x00\x00\x00\x00"
flags_hash = b"\x03\x00\x00\x00"
mtime = b"\x58\x3c\xb0\x5f"
mtime_int = int.from_bytes(mtime, "little")
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# check_hash_based_pycs == "default" with hash based pyc file.
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# check_hash_based_pycs == "always" with hash based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + hash + code)
assert _read_pyc(source, pyc, print) is not None
# Bad hash.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
# check_hash_based_pycs == "always" with timestamp based pyc file.
with mock.patch.object(_imp, "check_hash_based_pycs", "always"):
pyc.write_bytes(magic + flags + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None:
"""Reloading a (collected) module after change picks up the change.""" """Reloading a (collected) module after change picks up the change."""
pytester.makeini( pytester.makeini(