diff --git a/py/execnet/serializer.py b/py/execnet/serializer.py new file mode 100755 index 000000000..d7801f6bf --- /dev/null +++ b/py/execnet/serializer.py @@ -0,0 +1,291 @@ +""" +Simple marshal format (based on pickle) designed to work across Python versions. +""" + +import sys +import struct + +import py + +_INPY3 = _REALLY_PY3 = sys.version_info > (3, 0) + +class SerializeError(Exception): + pass + +class SerializationError(SerializeError): + """Error while serializing an object.""" + +class UnserializableType(SerializationError): + """Can't serialize a type.""" + +class UnserializationError(SerializeError): + """Error while unserializing an object.""" + +class VersionMismatch(UnserializationError): + """Data from a previous or later format.""" + +class Corruption(UnserializationError): + """The pickle format appears to have been corrupted.""" + +if _INPY3: + def b(s): + return s.encode("ascii") + _b = b + class _unicode(str): + pass + bytes = bytes +else: + class bytes(str): + pass + b = str + _b = bytes + _unicode = unicode + +FOUR_BYTE_INT_MAX = 2147483647 + +_int4_format = struct.Struct("!i") + +# Protocol constants +VERSION_NUMBER = 1 +VERSION = b(chr(VERSION_NUMBER)) +PY2STRING = b('s') +PY3STRING = b('t') +UNICODE = b('u') +BYTES = b('b') +NEWLIST = b('l') +BUILDTUPLE = b('T') +SETITEM = b('m') +NEWDICT = b('d') +INT = b('i') +STOP = b('S') + +class CrossVersionOptions(object): + pass + +class Serializer(object): + + def __init__(self, stream): + self.stream = stream + + def save(self, obj): + self.stream.write(VERSION) + self._save(obj) + self.stream.write(STOP) + + def _save(self, obj): + tp = type(obj) + try: + dispatch = self.dispatch[tp] + except KeyError: + raise UnserializableType("can't serialize %s" % (tp,)) + dispatch(self, obj) + + def save_bytes(self, bytes_): + self.stream.write(BYTES) + self._write_byte_sequence(bytes_) + + def save_unicode(self, s): + self.stream.write(UNICODE) + self._write_unicode_string(s) + + def save_string(self, s): + if _INPY3: + self.stream.write(PY3STRING) + self._write_unicode_string(s) + else: + # Case for tests + if _REALLY_PY3 and isinstance(s, str): + s = s.encode("latin-1") + self.stream.write(PY2STRING) + self._write_byte_sequence(s) + + def _write_unicode_string(self, s): + try: + as_bytes = s.encode("utf-8") + except UnicodeEncodeError: + raise SerializationError("strings must be utf-8 encodable") + self._write_byte_sequence(as_bytes) + + def _write_byte_sequence(self, bytes_): + self._write_int4(len(bytes_), "string is too long") + self.stream.write(bytes_) + + def save_int(self, i): + self.stream.write(INT) + self._write_int4(i) + + def _write_int4(self, i, error="int must be less than %i" % + (FOUR_BYTE_INT_MAX,)): + if i > FOUR_BYTE_INT_MAX: + raise SerializationError(error) + self.stream.write(_int4_format.pack(i)) + + def save_list(self, L): + self.stream.write(NEWLIST) + self._write_int4(len(L), "list is too long") + for i, item in enumerate(L): + self._write_setitem(i, item) + + def _write_setitem(self, key, value): + self._save(key) + self._save(value) + self.stream.write(SETITEM) + + def save_dict(self, d): + self.stream.write(NEWDICT) + for key, value in d.items(): + self._write_setitem(key, value) + + def save_tuple(self, tup): + for item in tup: + self._save(item) + self.stream.write(BUILDTUPLE) + self._write_int4(len(tup), "tuple is too long") + + +class _UnserializationOptions(object): + pass + +class _Py2UnserializationOptions(_UnserializationOptions): + + def __init__(self, py3_strings_as_str=False): + self.py3_strings_as_str = py3_strings_as_str + +class _Py3UnserializationOptions(_UnserializationOptions): + + def __init__(self, py2_strings_as_str=False): + self.py2_strings_as_str = py2_strings_as_str + + +_unchanging_dispatch = {} +for tp in (dict, list, tuple, int): + name = "save_%s" % (tp.__name__,) + _unchanging_dispatch[tp] = getattr(Serializer, name) +del tp, name + +def _setup_dispatch(): + dispatch = _unchanging_dispatch.copy() + # This is sutble. bytes is aliased to str in 2.6, so + # dispatch[bytes] is overwritten. Additionally, we alias unicode + # to str in 3.x, so dispatch[unicode] is overwritten with + # save_string. + dispatch[bytes] = Serializer.save_bytes + dispatch[unicode] = Serializer.save_unicode + dispatch[str] = Serializer.save_string + Serializer.dispatch = dispatch + +def _setup_version_dependent_constants(leave_unicode_alone=False): + global unicode, UnserializationOptions + if _INPY3: + unicode = str + UnserializationOptions = _Py3UnserializationOptions + else: + UnserializationOptions = _Py2UnserializationOptions + unicode = _unicode + _setup_dispatch() +_setup_version_dependent_constants() + + +class _Stop(Exception): + pass + +class Unserializer(object): + + def __init__(self, stream, options=None): + self.stream = stream + if options is None: + options = UnserializationOptions() + self.options = options + + def load(self): + self.stack = [] + version = ord(self.stream.read(1)) + if version != VERSION_NUMBER: + raise VersionMismatch("%i != %i" % (version, VERSION_NUMBER)) + try: + while True: + opcode = self.stream.read(1) + if not opcode: + raise EOFError + try: + loader = self.opcodes[opcode] + except KeyError: + raise Corruption("unkown opcode %s" % (opcode,)) + loader(self) + except _Stop: + if len(self.stack) != 1: + raise UnserializationError("internal unserialization error") + return self.stack[0] + else: + raise Corruption("didn't get STOP") + + opcodes = {} + + def load_int(self): + i = self._read_int4() + self.stack.append(i) + opcodes[INT] = load_int + + def _read_int4(self): + return _int4_format.unpack(self.stream.read(4))[0] + + def _read_byte_string(self): + length = self._read_int4() + as_bytes = self.stream.read(length) + return as_bytes + + def load_py3string(self): + as_bytes = self._read_byte_string() + if (not _INPY3 and self.options.py3_strings_as_str) and not _REALLY_PY3: + # XXX Should we try to decode into latin-1? + self.stack.append(as_bytes) + else: + self.stack.append(as_bytes.decode("utf-8")) + opcodes[PY3STRING] = load_py3string + + def load_py2string(self): + as_bytes = self._read_byte_string() + if (_INPY3 and self.options.py2_strings_as_str) or \ + (_REALLY_PY3 and not _INPY3): + s = as_bytes.decode("latin-1") + else: + s = as_bytes + self.stack.append(s) + opcodes[PY2STRING] = load_py2string + + def load_bytes(self): + s = bytes(self._read_byte_string()) + self.stack.append(s) + opcodes[BYTES] = load_bytes + + def load_unicode(self): + self.stack.append(self._read_byte_string().decode("utf-8")) + opcodes[UNICODE] = load_unicode + + def load_newlist(self): + length = self._read_int4() + self.stack.append([None] * length) + opcodes[NEWLIST] = load_newlist + + def load_setitem(self): + if len(self.stack) < 3: + raise Corruption("not enough items for setitem") + value = self.stack.pop() + key = self.stack.pop() + self.stack[-1][key] = value + opcodes[SETITEM] = load_setitem + + def load_newdict(self): + self.stack.append({}) + opcodes[NEWDICT] = load_newdict + + def load_buildtuple(self): + length = self._read_int4() + tup = tuple(self.stack[-length:]) + del self.stack[-length:] + self.stack.append(tup) + opcodes[BUILDTUPLE] = load_buildtuple + + def load_stop(self): + raise _Stop + opcodes[STOP] = load_stop diff --git a/testing/execnet/test_serializer.py b/testing/execnet/test_serializer.py new file mode 100755 index 000000000..5e192da4d --- /dev/null +++ b/testing/execnet/test_serializer.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +import shutil +import py +from py.__.execnet import serializer + +def setup_module(mod): + mod._save_python3 = serializer._INPY3 + +def teardown_module(mod): + serializer._setup_version_dependent_constants() + +def _dump(obj): + stream = py.io.BytesIO() + saver = serializer.Serializer(stream) + saver.save(obj) + return stream.getvalue() + +def _load(serialized, str_coerion): + stream = py.io.BytesIO(serialized) + opts = serializer.UnserializationOptions(str_coerion) + unserializer = serializer.Unserializer(stream, opts) + return unserializer.load() + +def _run_in_version(is_py3, func, *args): + serializer._INPY3 = is_py3 + serializer._setup_version_dependent_constants() + try: + return func(*args) + finally: + serializer._INPY3 = _save_python3 + +def dump_py2(obj): + return _run_in_version(False, _dump, obj) + +def dump_py3(obj): + return _run_in_version(True, _dump, obj) + +def load_py2(serialized, str_coercion=False): + return _run_in_version(False, _load, serialized, str_coercion) + +def load_py3(serialized, str_coercion=False): + return _run_in_version(True, _load, serialized, str_coercion) + +try: + bytes +except NameError: + bytes = str + + +def pytest_funcarg__py2(request): + return _py2_wrapper + +def pytest_funcarg__py3(request): + return _py3_wrapper + +class TestSerializer: + + def test_int(self): + for dump in dump_py2, dump_py3: + p = dump_py2(4) + for load in load_py2, load_py3: + i = load(p) + assert isinstance(i, int) + assert i == 4 + py.test.raises(serializer.SerializationError, dump, 123456678900) + + def test_bytes(self): + for dump in dump_py2, dump_py3: + p = dump(serializer._b('hi')) + for load in load_py2, load_py3: + s = load(p) + assert isinstance(s, serializer.bytes) + assert s == serializer._b('hi') + + def check_sequence(self, seq): + for dump in dump_py2, dump_py3: + p = dump(seq) + for load in load_py2, load_py3: + l = load(p) + assert l == seq + + def test_list(self): + self.check_sequence([1, 2, 3]) + + @py.test.mark.xfail + # I'm not sure if we need the complexity. + def test_recursive_list(self): + l = [1, 2, 3] + l.append(l) + self.check_sequence(l) + + def test_tuple(self): + self.check_sequence((1, 2, 3)) + + def test_dict(self): + for dump in dump_py2, dump_py3: + p = dump({"hi" : 2, (1, 2, 3) : 32}) + for load in load_py2, load_py3: + d = load(p, True) + assert d == {"hi" : 2, (1, 2, 3) : 32} + + def test_string(self): + py.test.skip("will rewrite") + p = dump_py2("xyz") + s = load_py2(p) + assert isinstance(s, str) + assert s == "xyz" + s = load_py3(p) + assert isinstance(s, bytes) + assert s == serializer.b("xyz") + p = dump_py2("xyz") + s = load_py3(p, True) + assert isinstance(s, serializer._unicode) + assert s == serializer.unicode("xyz") + p = dump_py3("xyz") + s = load_py2(p, True) + assert isinstance(s, str) + assert s == "xyz" + + def test_unicode(self): + py.test.skip("will rewrite") + for dump, uni in (dump_py2, serializer._unicode), (dump_py3, str): + p = dump(uni("xyz")) + for load in load_py2, load_py3: + s = load(p) + assert isinstance(s, serializer._unicode) + assert s == serializer._unicode("xyz")