From ac98ff571b6731d715979690792451f010f22d2b Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 9 Sep 2023 01:15:46 +0200 Subject: [PATCH 1/2] Add hash comparison for pyc cache files --- changelog/11418.improvement.rst | 1 + src/_pytest/assertion/rewrite.py | 37 +++++++++++++++++++++----------- testing/test_assertrewrite.py | 36 ++++++++++++++++++++----------- 3 files changed, 50 insertions(+), 24 deletions(-) create mode 100644 changelog/11418.improvement.rst diff --git a/changelog/11418.improvement.rst b/changelog/11418.improvement.rst new file mode 100644 index 000000000..1bdb0b606 --- /dev/null +++ b/changelog/11418.improvement.rst @@ -0,0 +1 @@ +Added hash comparison for pyc cache files. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 678471ee9..39a57fb71 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -166,11 +166,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) co = _read_pyc(fn, pyc, state.trace) if co is None: state.trace(f"rewriting {fn!r}") - source_stat, co = _rewrite_test(fn, self.config) + source_stat, source_hash, co = _rewrite_test(fn, self.config) if write: self._writing_pyc = True try: - _write_pyc(state, co, source_stat, pyc) + _write_pyc(state, co, source_stat, source_hash, pyc) finally: self._writing_pyc = False else: @@ -299,7 +299,7 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) def _write_pyc_fp( - fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType + fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType ) -> None: # Technically, we don't have to have the same pyc format as # (C)Python, since these "pycs" should never be seen by builtin @@ -311,8 +311,11 @@ def _write_pyc_fp( # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) mtime = int(source_stat.st_mtime) & 0xFFFFFFFF size = source_stat.st_size & 0xFFFFFFFF + # 64-bit source file hash + source_hash = source_hash[:8] # " bool: proc_pyc = f"{pyc}.{os.getpid()}" try: with open(proc_pyc, "wb") as fp: - _write_pyc_fp(fp, source_stat, co) + _write_pyc_fp(fp, source_stat, source_hash, co) except OSError as e: state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") return False @@ -341,15 +345,18 @@ def _write_pyc( return True -def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: +def _rewrite_test( + fn: Path, config: Config +) -> Tuple[os.stat_result, bytes, types.CodeType]: """Read and rewrite *fn* and return the code object.""" stat = os.stat(fn) source = fn.read_bytes() + source_hash = importlib.util.source_hash(source) strfn = str(fn) tree = ast.parse(source, filename=strfn) rewrite_asserts(tree, source, strfn, config) co = compile(tree, strfn, "exec", dont_inherit=True) - return stat, co + return stat, source_hash, co def _read_pyc( @@ -368,12 +375,12 @@ def _read_pyc( stat_result = os.stat(source) mtime = int(stat_result.st_mtime) size = stat_result.st_size - data = fp.read(16) + data = fp.read(24) except OSError as e: trace(f"_read_pyc({source}): OSError {e}") return None # Check for invalid or out of date pyc file. - if len(data) != (16): + if len(data) != (24): trace("_read_pyc(%s): invalid pyc (too short)" % source) return None if data[:4] != importlib.util.MAGIC_NUMBER: @@ -382,14 +389,20 @@ def _read_pyc( if data[4:8] != b"\x00\x00\x00\x00": trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) return None - mtime_data = data[8:12] - if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: - trace("_read_pyc(%s): out of date" % source) - return None size_data = data[12:16] if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) return None + mtime_data = data[8:12] + if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: + trace("_read_pyc(%s): out of date" % source) + hash = data[16:24] + source_hash = importlib.util.source_hash(source.read_bytes()) + if source_hash[:8] == hash: + trace("_read_pyc(%s): source hash match (no change detected)" % source) + else: + trace("_read_pyc(%s): hash doesn't match" % source) + return None try: co = marshal.load(fp) except Exception as e: diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 7acc8cdf1..3a87b44fe 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,6 +4,7 @@ import errno from functools import partial import glob import importlib +from importlib.util import source_hash import marshal import os from pathlib import Path @@ -1043,12 +1044,14 @@ class TestAssertionRewriteHookDetails: state = AssertionState(config, "rewrite") tmp_path.joinpath("source.py").touch() source_path = str(tmp_path) + source_bytes = tmp_path.joinpath("source.py").read_bytes() pycpath = tmp_path.joinpath("pyc") co = compile("1", "f.py", "single") - assert _write_pyc(state, co, os.stat(source_path), pycpath) + hash = source_hash(source_bytes) + assert _write_pyc(state, co, os.stat(source_path), hash, pycpath) with mock.patch.object(os, "replace", side_effect=OSError): - assert not _write_pyc(state, co, os.stat(source_path), pycpath) + assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath) def test_resources_provider_for_loader(self, pytester: Pytester) -> None: """ @@ -1121,8 +1124,15 @@ class TestAssertionRewriteHookDetails: fn.write_text("def test(): assert True", encoding="utf-8") - source_stat, co = _rewrite_test(fn, config) - _write_pyc(state, co, source_stat, pyc) + source_stat, hash, co = _rewrite_test(fn, config) + _write_pyc(state, co, source_stat, hash, pyc) + assert _read_pyc(fn, pyc, state.trace) is not None + + # pyc read should still work if only the mtime changed + # Fallback to hash comparison + new_mtime = source_stat.st_mtime + 1.2 + os.utime(fn, (new_mtime, new_mtime)) + assert source_stat.st_mtime != os.stat(fn).st_mtime assert _read_pyc(fn, pyc, state.trace) is not None def test_read_pyc_more_invalid(self, tmp_path: Path) -> None: @@ -1143,11 +1153,13 @@ class TestAssertionRewriteHookDetails: os.utime(source, (mtime_int, mtime_int)) size = len(source_bytes).to_bytes(4, "little") + hash = source_hash(source_bytes) + hash = hash[:8] code = marshal.dumps(compile(source_bytes, str(source), "exec")) # Good header. - pyc.write_bytes(magic + flags + mtime + size + code) + pyc.write_bytes(magic + flags + mtime + size + hash + code) assert _read_pyc(source, pyc, print) is not None # Too short. @@ -1155,19 +1167,19 @@ class TestAssertionRewriteHookDetails: assert _read_pyc(source, pyc, print) is None # Bad magic. - pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code) + pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code) assert _read_pyc(source, pyc, print) is None # Unsupported flags. - pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code) - assert _read_pyc(source, pyc, print) is None - - # Bad mtime. - pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code) + pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code) assert _read_pyc(source, pyc, print) is None # Bad size. - pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code) + pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code) + assert _read_pyc(source, pyc, print) is None + + # Bad mtime + bad hash. + pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code) assert _read_pyc(source, pyc, print) is None def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: From 2f12594279b1e33e6855ee9863d77312fbe76bd4 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sun, 10 Sep 2023 23:48:15 +0200 Subject: [PATCH 2/2] Add invalidation-mode option --- src/_pytest/assertion/__init__.py | 1 + src/_pytest/assertion/rewrite.py | 73 ++++++++++++++--------- src/_pytest/main.py | 8 +++ testing/test_assertrewrite.py | 96 ++++++++++++++++++++++++++----- 4 files changed, 136 insertions(+), 42 deletions(-) diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index 21dd4a4a4..42d0de250 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -94,6 +94,7 @@ class AssertionState: def __init__(self, config: Config, mode) -> None: self.mode = mode self.trace = config.trace.root.get("assertion") + self.invalidation_mode = config.known_args_namespace.invalidationmode self.hook: Optional[rewrite.AssertionRewritingHook] = None diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 39a57fb71..42b6096cf 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -23,6 +23,7 @@ from typing import IO from typing import Iterable from typing import Iterator from typing import List +from typing import Literal from typing import Optional from typing import Sequence from typing import Set @@ -30,6 +31,8 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union +import _imp + from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import saferepr from _pytest._version import version @@ -299,23 +302,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) def _write_pyc_fp( - fp: IO[bytes], source_stat: os.stat_result, source_hash: bytes, co: types.CodeType + fp: IO[bytes], + source_stat: os.stat_result, + source_hash: bytes, + co: types.CodeType, + invalidation_mode: Literal["timestamp", "checked-hash"], ) -> None: # Technically, we don't have to have the same pyc format as # (C)Python, since these "pycs" should never be seen by builtin # import. However, there's little reason to deviate. fp.write(importlib.util.MAGIC_NUMBER) # https://www.python.org/dev/peps/pep-0552/ - flags = b"\x00\x00\x00\x00" - fp.write(flags) - # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) - mtime = int(source_stat.st_mtime) & 0xFFFFFFFF - size = source_stat.st_size & 0xFFFFFFFF - # 64-bit source file hash - source_hash = source_hash[:8] - # " None: help="Prepend/append to sys.path when importing test modules and conftest " "files. Default: prepend.", ) + group.addoption( + "--invalidation-mode", + default="timestamp", + choices=["timestamp", "checked-hash"], + dest="invalidationmode", + help="Pytest pyc cache invalidation mode. Default: timestamp.", + ) + parser.addini( "consider_namespace_packages", type="bool", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 3a87b44fe..0f57264f4 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -22,6 +22,8 @@ from typing import Set from unittest import mock import zipfile +import _imp + import _pytest._code from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest.assertion import util @@ -1128,13 +1130,37 @@ class TestAssertionRewriteHookDetails: _write_pyc(state, co, source_stat, hash, pyc) assert _read_pyc(fn, pyc, state.trace) is not None - # pyc read should still work if only the mtime changed - # Fallback to hash comparison - new_mtime = source_stat.st_mtime + 1.2 - os.utime(fn, (new_mtime, new_mtime)) - assert source_stat.st_mtime != os.stat(fn).st_mtime + pyc_bytes = pyc.read_bytes() + assert pyc_bytes[4] == 0 # timestamp flag set + + def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None: + from _pytest.assertion import AssertionState + from _pytest.assertion.rewrite import _read_pyc + from _pytest.assertion.rewrite import _rewrite_test + from _pytest.assertion.rewrite import _write_pyc + + config = pytester.parseconfig("--invalidation-mode=checked-hash") + state = AssertionState(config, "rewrite") + + fn = tmp_path / "source.py" + pyc = Path(str(fn) + "c") + + # Test private attribute didn't change + assert getattr(_imp, "check_hash_based_pycs", None) in { + "default", + "always", + "never", + } + + fn.write_text("def test(): assert True", encoding="utf-8") + source_stat, hash, co = _rewrite_test(fn, config) + _write_pyc(state, co, source_stat, hash, pyc) assert _read_pyc(fn, pyc, state.trace) is not None + pyc_bytes = pyc.read_bytes() + assert pyc_bytes[4] == 3 # checked-hash flag set + assert pyc_bytes[8:16] == hash + def test_read_pyc_more_invalid(self, tmp_path: Path) -> None: from _pytest.assertion.rewrite import _read_pyc @@ -1153,13 +1179,11 @@ class TestAssertionRewriteHookDetails: os.utime(source, (mtime_int, mtime_int)) size = len(source_bytes).to_bytes(4, "little") - hash = source_hash(source_bytes) - hash = hash[:8] code = marshal.dumps(compile(source_bytes, str(source), "exec")) # Good header. - pyc.write_bytes(magic + flags + mtime + size + hash + code) + pyc.write_bytes(magic + flags + mtime + size + code) assert _read_pyc(source, pyc, print) is not None # Too short. @@ -1167,20 +1191,64 @@ class TestAssertionRewriteHookDetails: assert _read_pyc(source, pyc, print) is None # Bad magic. - pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + hash + code) + pyc.write_bytes(b"\x12\x34\x56\x78" + flags + mtime + size + code) assert _read_pyc(source, pyc, print) is None # Unsupported flags. - pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + hash + code) + pyc.write_bytes(magic + b"\x00\xff\x00\x00" + mtime + size + code) + assert _read_pyc(source, pyc, print) is None + + # Bad mtime. + pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + code) assert _read_pyc(source, pyc, print) is None # Bad size. - pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + hash + code) + pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code) assert _read_pyc(source, pyc, print) is None - # Bad mtime + bad hash. - pyc.write_bytes(magic + flags + b"\x58\x3d\xb0\x5f" + size + b"\x00" * 8 + code) - assert _read_pyc(source, pyc, print) is None + def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None: + from _pytest.assertion.rewrite import _read_pyc + + source = tmp_path / "source.py" + pyc = tmp_path / "source.pyc" + + source_bytes = b"def test(): pass\n" + source.write_bytes(source_bytes) + + magic = importlib.util.MAGIC_NUMBER + + flags = b"\x00\x00\x00\x00" + flags_hash = b"\x03\x00\x00\x00" + + mtime = b"\x58\x3c\xb0\x5f" + mtime_int = int.from_bytes(mtime, "little") + os.utime(source, (mtime_int, mtime_int)) + + size = len(source_bytes).to_bytes(4, "little") + + hash = source_hash(source_bytes) + hash = hash[:8] + + code = marshal.dumps(compile(source_bytes, str(source), "exec")) + + # check_hash_based_pycs == "default" with hash based pyc file. + pyc.write_bytes(magic + flags_hash + hash + code) + assert _read_pyc(source, pyc, print) is not None + + # check_hash_based_pycs == "always" with hash based pyc file. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags_hash + hash + code) + assert _read_pyc(source, pyc, print) is not None + + # Bad hash. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code) + assert _read_pyc(source, pyc, print) is None + + # check_hash_based_pycs == "always" with timestamp based pyc file. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags + mtime + size + code) + assert _read_pyc(source, pyc, print) is None def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: """Reloading a (collected) module after change picks up the change."""