diff --git a/changelog/11418.improvement.rst b/changelog/11418.improvement.rst new file mode 100644 index 000000000..1bdb0b606 --- /dev/null +++ b/changelog/11418.improvement.rst @@ -0,0 +1 @@ +Added hash comparison for pyc cache files. diff --git a/src/_pytest/assertion/__init__.py b/src/_pytest/assertion/__init__.py index 21dd4a4a4..42d0de250 100644 --- a/src/_pytest/assertion/__init__.py +++ b/src/_pytest/assertion/__init__.py @@ -94,6 +94,7 @@ class AssertionState: def __init__(self, config: Config, mode) -> None: self.mode = mode self.trace = config.trace.root.get("assertion") + self.invalidation_mode = config.known_args_namespace.invalidationmode self.hook: Optional[rewrite.AssertionRewritingHook] = None diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 678471ee9..42b6096cf 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -23,6 +23,7 @@ from typing import IO from typing import Iterable from typing import Iterator from typing import List +from typing import Literal from typing import Optional from typing import Sequence from typing import Set @@ -30,6 +31,8 @@ from typing import Tuple from typing import TYPE_CHECKING from typing import Union +import _imp + from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest._io.saferepr import saferepr from _pytest._version import version @@ -166,11 +169,11 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) co = _read_pyc(fn, pyc, state.trace) if co is None: state.trace(f"rewriting {fn!r}") - source_stat, co = _rewrite_test(fn, self.config) + source_stat, source_hash, co = _rewrite_test(fn, self.config) if write: self._writing_pyc = True try: - _write_pyc(state, co, source_stat, pyc) + _write_pyc(state, co, source_stat, source_hash, pyc) finally: self._writing_pyc = False else: @@ -299,20 +302,31 @@ class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader) def _write_pyc_fp( - fp: IO[bytes], source_stat: os.stat_result, co: types.CodeType + fp: IO[bytes], + source_stat: os.stat_result, + source_hash: bytes, + co: types.CodeType, + invalidation_mode: Literal["timestamp", "checked-hash"], ) -> None: # Technically, we don't have to have the same pyc format as # (C)Python, since these "pycs" should never be seen by builtin # import. However, there's little reason to deviate. fp.write(importlib.util.MAGIC_NUMBER) # https://www.python.org/dev/peps/pep-0552/ - flags = b"\x00\x00\x00\x00" - fp.write(flags) - # as of now, bytecode header expects 32-bit numbers for size and mtime (#4903) - mtime = int(source_stat.st_mtime) & 0xFFFFFFFF - size = source_stat.st_size & 0xFFFFFFFF - # " bool: proc_pyc = f"{pyc}.{os.getpid()}" try: with open(proc_pyc, "wb") as fp: - _write_pyc_fp(fp, source_stat, co) + _write_pyc_fp(fp, source_stat, source_hash, co, state.invalidation_mode) except OSError as e: state.trace(f"error writing pyc file at {proc_pyc}: errno={e.errno}") return False @@ -341,15 +356,18 @@ def _write_pyc( return True -def _rewrite_test(fn: Path, config: Config) -> Tuple[os.stat_result, types.CodeType]: +def _rewrite_test( + fn: Path, config: Config +) -> Tuple[os.stat_result, bytes, types.CodeType]: """Read and rewrite *fn* and return the code object.""" stat = os.stat(fn) source = fn.read_bytes() + source_hash = importlib.util.source_hash(source) strfn = str(fn) tree = ast.parse(source, filename=strfn) rewrite_asserts(tree, source, strfn, config) co = compile(tree, strfn, "exec", dont_inherit=True) - return stat, co + return stat, source_hash, co def _read_pyc( @@ -379,17 +397,29 @@ def _read_pyc( if data[:4] != importlib.util.MAGIC_NUMBER: trace("_read_pyc(%s): invalid pyc (bad magic number)" % source) return None - if data[4:8] != b"\x00\x00\x00\x00": + + hash_based = getattr(_imp, "check_hash_based_pycs", "default") == "always" + if data[4:8] == b"\x00\x00\x00\x00" and not hash_based: + trace("_read_pyc(%s): timestamp based" % source) + mtime_data = data[8:12] + if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: + trace("_read_pyc(%s): out of date" % source) + return None + size_data = data[12:16] + if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: + trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) + return None + elif data[4:8] == b"\x03\x00\x00\x00": + trace("_read_pyc(%s): hash based" % source) + hash = data[8:16] + source_hash = importlib.util.source_hash(source.read_bytes()) + if source_hash[:8] != hash: + trace("_read_pyc(%s): hash doesn't match" % source) + return None + else: trace("_read_pyc(%s): invalid pyc (unsupported flags)" % source) return None - mtime_data = data[8:12] - if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF: - trace("_read_pyc(%s): out of date" % source) - return None - size_data = data[12:16] - if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF: - trace("_read_pyc(%s): invalid pyc (incorrect size)" % source) - return None + try: co = marshal.load(fp) except Exception as e: diff --git a/src/_pytest/main.py b/src/_pytest/main.py index 716d5cf78..56716eb59 100644 --- a/src/_pytest/main.py +++ b/src/_pytest/main.py @@ -223,6 +223,14 @@ def pytest_addoption(parser: Parser) -> None: help="Prepend/append to sys.path when importing test modules and conftest " "files. Default: prepend.", ) + group.addoption( + "--invalidation-mode", + default="timestamp", + choices=["timestamp", "checked-hash"], + dest="invalidationmode", + help="Pytest pyc cache invalidation mode. Default: timestamp.", + ) + parser.addini( "consider_namespace_packages", type="bool", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 7acc8cdf1..0f57264f4 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -4,6 +4,7 @@ import errno from functools import partial import glob import importlib +from importlib.util import source_hash import marshal import os from pathlib import Path @@ -21,6 +22,8 @@ from typing import Set from unittest import mock import zipfile +import _imp + import _pytest._code from _pytest._io.saferepr import DEFAULT_REPR_MAX_SIZE from _pytest.assertion import util @@ -1043,12 +1046,14 @@ class TestAssertionRewriteHookDetails: state = AssertionState(config, "rewrite") tmp_path.joinpath("source.py").touch() source_path = str(tmp_path) + source_bytes = tmp_path.joinpath("source.py").read_bytes() pycpath = tmp_path.joinpath("pyc") co = compile("1", "f.py", "single") - assert _write_pyc(state, co, os.stat(source_path), pycpath) + hash = source_hash(source_bytes) + assert _write_pyc(state, co, os.stat(source_path), hash, pycpath) with mock.patch.object(os, "replace", side_effect=OSError): - assert not _write_pyc(state, co, os.stat(source_path), pycpath) + assert not _write_pyc(state, co, os.stat(source_path), hash, pycpath) def test_resources_provider_for_loader(self, pytester: Pytester) -> None: """ @@ -1121,10 +1126,41 @@ class TestAssertionRewriteHookDetails: fn.write_text("def test(): assert True", encoding="utf-8") - source_stat, co = _rewrite_test(fn, config) - _write_pyc(state, co, source_stat, pyc) + source_stat, hash, co = _rewrite_test(fn, config) + _write_pyc(state, co, source_stat, hash, pyc) assert _read_pyc(fn, pyc, state.trace) is not None + pyc_bytes = pyc.read_bytes() + assert pyc_bytes[4] == 0 # timestamp flag set + + def test_read_pyc_success_hash(self, tmp_path: Path, pytester: Pytester) -> None: + from _pytest.assertion import AssertionState + from _pytest.assertion.rewrite import _read_pyc + from _pytest.assertion.rewrite import _rewrite_test + from _pytest.assertion.rewrite import _write_pyc + + config = pytester.parseconfig("--invalidation-mode=checked-hash") + state = AssertionState(config, "rewrite") + + fn = tmp_path / "source.py" + pyc = Path(str(fn) + "c") + + # Test private attribute didn't change + assert getattr(_imp, "check_hash_based_pycs", None) in { + "default", + "always", + "never", + } + + fn.write_text("def test(): assert True", encoding="utf-8") + source_stat, hash, co = _rewrite_test(fn, config) + _write_pyc(state, co, source_stat, hash, pyc) + assert _read_pyc(fn, pyc, state.trace) is not None + + pyc_bytes = pyc.read_bytes() + assert pyc_bytes[4] == 3 # checked-hash flag set + assert pyc_bytes[8:16] == hash + def test_read_pyc_more_invalid(self, tmp_path: Path) -> None: from _pytest.assertion.rewrite import _read_pyc @@ -1170,6 +1206,50 @@ class TestAssertionRewriteHookDetails: pyc.write_bytes(magic + flags + mtime + b"\x99\x00\x00\x00" + code) assert _read_pyc(source, pyc, print) is None + def test_read_pyc_more_invalid_hash(self, tmp_path: Path) -> None: + from _pytest.assertion.rewrite import _read_pyc + + source = tmp_path / "source.py" + pyc = tmp_path / "source.pyc" + + source_bytes = b"def test(): pass\n" + source.write_bytes(source_bytes) + + magic = importlib.util.MAGIC_NUMBER + + flags = b"\x00\x00\x00\x00" + flags_hash = b"\x03\x00\x00\x00" + + mtime = b"\x58\x3c\xb0\x5f" + mtime_int = int.from_bytes(mtime, "little") + os.utime(source, (mtime_int, mtime_int)) + + size = len(source_bytes).to_bytes(4, "little") + + hash = source_hash(source_bytes) + hash = hash[:8] + + code = marshal.dumps(compile(source_bytes, str(source), "exec")) + + # check_hash_based_pycs == "default" with hash based pyc file. + pyc.write_bytes(magic + flags_hash + hash + code) + assert _read_pyc(source, pyc, print) is not None + + # check_hash_based_pycs == "always" with hash based pyc file. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags_hash + hash + code) + assert _read_pyc(source, pyc, print) is not None + + # Bad hash. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags_hash + b"\x00" * 8 + code) + assert _read_pyc(source, pyc, print) is None + + # check_hash_based_pycs == "always" with timestamp based pyc file. + with mock.patch.object(_imp, "check_hash_based_pycs", "always"): + pyc.write_bytes(magic + flags + mtime + size + code) + assert _read_pyc(source, pyc, print) is None + def test_reload_is_same_and_reloads(self, pytester: Pytester) -> None: """Reloading a (collected) module after change picks up the change.""" pytester.makeini(