Add hash comparison for pyc cache files

This commit is contained in:
Marc Mueller 2023-09-09 01:15:46 +02:00
parent fafab1dbfd
commit ac98ff571b
3 changed files with 50 additions and 24 deletions

View File

@ -0,0 +1 @@
Added hash comparison for pyc cache files.

View File

@ -166,11 +166,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace(f"rewriting {fn!r}")
source_stat, co = _rewrite_test(fn, self.config)
source_stat, source_hash, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
_write_pyc(state, co, source_stat, pyc)
_write_pyc(state, co, source_stat, source_hash, pyc)
finally:
self._writing_pyc = False
else:
@ -299,7 +299,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader)
def _write_pyc_fp(
fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType
fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType
) -> None:
# Technically, we don't have to have the same pyc format as
# (C)Python, since these "pycs" should never be seen by builtin
@ -311,8 +311,11 @@ def _write_pyc_fp(
# as of now, bytecode header expects 32-bit numbers for size and mtime (#4903)
mtime = int(source_stat.st_mtime) & 0xFFFFFFFF
size = source_stat.st_size & 0xFFFFFFFF
# 64-bit source file hash
source_hash = source_hash[:8]
# "<LL" stands for 2 unsigned longs, little-endian.
fp.write(struct.pack("<LL", mtime, size))
fp.write(source_hash)
fp.write(marshal.dumps(co))
@ -320,12 +323,13 @@ def _write_pyc(
state: "AssertionState",
co: types.CodeType,
source_stat: os.stat_result,
source_hash: bytes,
pyc: Path,
) -> bool:
proc_pyc = f"{pyc}.{os.getpid()}"
try:
with open(proc_pyc, "wb") as fp:
_write_pyc_fp(fp, source_stat, co)
_write_pyc_fp(fp, source_stat, source_hash, co)
except OSError as e:
state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}")
return False
@ -341,15 +345,18 @@ def _write_pyc(
return True
def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]:
def _rewrite_test(
fn: Path, config: Config
) -> Tuple[os.stat_result, bytes, types.CodeType]:
"""Read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
source = fn.read_bytes()
source_hash = importlib.util.source_hash(source)
strfn = str(fn)
tree = ast.parse(source, filename=strfn)
rewrite_asserts(tree, source, strfn, config)
co = compile(tree, strfn, "exec", dont_inherit=True)
return stat, co
return stat, source_hash, co
def _read_pyc(
@ -368,12 +375,12 @@ def _read_pyc(
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(16)
data = fp.read(24)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (16):
if len(data) != (24):
trace("_read_pyc(%s): invalid pyc (too short)" % source)
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
@ -382,14 +389,20 @@ def _read_pyc(
if data[4:8] != b"\x00\x00\x00\x00":
trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace("_read_pyc(%s): invalid pyc (incorrect size)" % source)
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace("_read_pyc(%s): out of date" % source)
hash = data[16:24]
source_hash = importlib.util.source_hash(source.read_bytes())
if source_hash[:8] == hash:
trace("_read_pyc(%s): source hash match (no change detected)" % source)
else:
trace("_read_pyc(%s): hash doesn't match" % source)
return None
try:
co = marshal.load(fp)
except Exception as e:

View File

@ -4,6 +4,7 @@ import errno
from functools import partial
import glob
import importlib
from importlib.util import source_hash
import marshal
import os
from pathlib import Path
@ -1043,12 +1044,14 @@ class TestAssertionRewriteHookDetails:
state = AssertionState(config, "rewrite")
tmp_path.joinpath("source.py").touch()
source_path = str(tmp_path)
source_bytes = tmp_path.joinpath("source.py").read_bytes()
pycpath = tmp_path.joinpath("pyc")
co = compile("1", "f.py", "single")
assert _write_pyc(state, co, os.stat(source_path), pycpath)
hash = source_hash(source_bytes)
assert _write_pyc(state, co, os.stat(source_path), hash, pycpath)
with mock.patch.object(os, "replace", side_effect=OSError):
assert not _write_pyc(state, co, os.stat(source_path), pycpath)
assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath)
def test_resources_provider_for_loader(self, pytester: Pytester) -> None:
"""
@ -1121,8 +1124,15 @@ class TestAssertionRewriteHookDetails:
fn.write_text("def test(): assert True", encoding="utf-8")
source_stat, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, pyc)
source_stat, hash, co = _rewrite_test(fn, config)
_write_pyc(state, co, source_stat, hash, pyc)
assert _read_pyc(fn, pyc, state.trace) is not None
# pyc read should still work if only the mtime changed
# Fallback to hash comparison
new_mtime = source_stat.st_mtime + 1.2
os.utime(fn, (new_mtime, new_mtime))
assert source_stat.st_mtime != os.stat(fn).st_mtime
assert _read_pyc(fn, pyc, state.trace) is not None
def test_read_pyc_more_invalid(self, tmp_path: Path) -> None:
@ -1143,11 +1153,13 @@ class TestAssertionRewriteHookDetails:
os.utime(source, (mtime_int, mtime_int))
size = len(source_bytes).to_bytes(4, "little")
hash = source_hash(source_bytes)
hash = hash[:8]
code = marshal.dumps(compile(source_bytes, str(source), "exec"))
# Good header.
pyc.write_bytes(magic + flags + mtime + size + code)
pyc.write_bytes(magic + flags + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is not None
# Too short.
@ -1155,19 +1167,19 @@ class TestAssertionRewriteHookDetails:
assert _read_pyc(source, pyc, print) is None
# Bad magic.
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code)
pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is None
# Unsupported flags.
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code)
pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code)
assert _read_pyc(source, pyc, print) is None
# Bad size.
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code)
pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code)
assert _read_pyc(source, pyc, print) is None
# Bad mtime + bad hash.
pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code)
assert _read_pyc(source, pyc, print) is None
def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: