#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : packing.py
# Author : Jiayuan Mao, Qsh.Zh
# Email : qsh.zh27@gmail.com
# Date : 11/17/2021
#
# Qinsheng modifies based on Jacinle.
# Distributed under terms of the MIT license.
import functools
import os
from jammy.utils.enum import JamEnum
from jammy.utils.registry import CallbackRegistry, RegistryGroup
__all__ = [
"check_pickle",
"loadb_pickle",
"dumpb_pickle",
"check_msgpack",
"loadb_msgpack",
"dumpb_msgpack",
"check_pyarrow",
"loadb_pyarrow",
"dumpb_pyarrow",
"loadb",
"dumpb",
"get_available_backends",
"get_default_backend",
"set_default_backend",
]
import pickle
loadb_pickle = pickle.loads
dumpb_pickle = pickle.dumps
try:
import msgpack
import msgpack_numpy
msgpack_numpy.patch()
dumpb_msgpack = functools.partial(msgpack.dumps, use_bin_type=True)
loadb_msgpack = msgpack.loads
except ImportError:
dumpb_msgpack = loadb_msgpack = None
# pylint: disable=unnecessary-lambda, invalid-envvar-default
try:
import pyarrow
dumpb_pyarrow = lambda obj: pyarrow.serialize(obj).to_buffer()
loadb_pyarrow = lambda buffer: pyarrow.deserialize(buffer)
except ImportError:
dumpb_pyarrow = loadb_pyarrow = None
class _PackingFunctionRegistryGroup(RegistryGroup):
__base_class__ = CallbackRegistry
def dispatch(self, registry_name, entry, *args, **kwargs):
return self[registry_name].dispatch(entry, *args, **kwargs)
_packing_function_registry = _PackingFunctionRegistryGroup()
[docs]def check_pickle():
return True
[docs]def check_msgpack():
return dumpb_msgpack is not None
[docs]def check_pyarrow():
return dumpb_pyarrow is not None
class _PackingBackend(JamEnum):
PICKLE = "pickle"
MSGPACK = "msgpack"
PYARROW = "pyarrow"
_packing_function_registry.register("check", _PackingBackend.PICKLE, lambda: True)
_packing_function_registry.register("check", _PackingBackend.MSGPACK, check_msgpack)
_packing_function_registry.register("check", _PackingBackend.PYARROW, check_pyarrow)
_packing_function_registry.register("loadb", _PackingBackend.PICKLE, loadb_pickle)
_packing_function_registry.register("dumpb", _PackingBackend.PICKLE, dumpb_pickle)
_packing_function_registry.register("loadb", _PackingBackend.MSGPACK, loadb_msgpack)
_packing_function_registry.register("dumpb", _PackingBackend.MSGPACK, dumpb_msgpack)
_packing_function_registry.register("loadb", _PackingBackend.PYARROW, loadb_pyarrow)
_packing_function_registry.register("dumpb", _PackingBackend.PYARROW, dumpb_pyarrow)
_default_packing_backend = _PackingBackend.PICKLE
[docs]def get_default_backend():
return _default_packing_backend.name
[docs]def get_available_backends():
return [
obj.name
for obj in _PackingBackend.choice_objs()
if _packing_function_registry.dispatch("check", obj)
]
[docs]def set_default_backend(backend):
global _default_packing_backend # pylint: disable=global-statement
_default_packing_backend = _PackingBackend.from_string(backend)
assert (
_default_packing_backend.name in get_available_backends()
), 'Unsupported backend on your machine: "{}".'.format(
_default_packing_backend.name
)
[docs]def loadb(bstr, *args, backend=None, **kwargs):
backend = backend or _default_packing_backend
return _packing_function_registry.dispatch("loadb", backend, bstr, *args, **kwargs)
[docs]def dumpb(obj, *args, backend=None, **kwargs):
backend = backend or _default_packing_backend
return _packing_function_registry.dispatch("dumpb", backend, obj, *args, **kwargs)
def _initialize_backend():
set_default_backend(os.getenv("JAM_PACKING_BACKEND", _PackingBackend.PICKLE))
_initialize_backend()