Source code for jammy.concurrency.packing

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : packing.py
# Author : Jiayuan Mao, Qsh.Zh
# Email  : qsh.zh27@gmail.com
# Date   : 11/17/2021
#
# Qinsheng modifies based on Jacinle.
# Distributed under terms of the MIT license.

import functools
import os

from jammy.utils.enum import JamEnum
from jammy.utils.registry import CallbackRegistry, RegistryGroup

__all__ = [
    "check_pickle",
    "loadb_pickle",
    "dumpb_pickle",
    "check_msgpack",
    "loadb_msgpack",
    "dumpb_msgpack",
    "check_pyarrow",
    "loadb_pyarrow",
    "dumpb_pyarrow",
    "loadb",
    "dumpb",
    "get_available_backends",
    "get_default_backend",
    "set_default_backend",
]

import pickle

loadb_pickle = pickle.loads
dumpb_pickle = pickle.dumps

try:
    import msgpack
    import msgpack_numpy

    msgpack_numpy.patch()
    dumpb_msgpack = functools.partial(msgpack.dumps, use_bin_type=True)
    loadb_msgpack = msgpack.loads
except ImportError:
    dumpb_msgpack = loadb_msgpack = None

# pylint: disable=unnecessary-lambda, invalid-envvar-default

try:
    import pyarrow

    dumpb_pyarrow = lambda obj: pyarrow.serialize(obj).to_buffer()
    loadb_pyarrow = lambda buffer: pyarrow.deserialize(buffer)
except ImportError:
    dumpb_pyarrow = loadb_pyarrow = None


class _PackingFunctionRegistryGroup(RegistryGroup):
    __base_class__ = CallbackRegistry

    def dispatch(self, registry_name, entry, *args, **kwargs):
        return self[registry_name].dispatch(entry, *args, **kwargs)


_packing_function_registry = _PackingFunctionRegistryGroup()


[docs]def check_pickle(): return True
[docs]def check_msgpack(): return dumpb_msgpack is not None
[docs]def check_pyarrow(): return dumpb_pyarrow is not None
class _PackingBackend(JamEnum): PICKLE = "pickle" MSGPACK = "msgpack" PYARROW = "pyarrow" _packing_function_registry.register("check", _PackingBackend.PICKLE, lambda: True) _packing_function_registry.register("check", _PackingBackend.MSGPACK, check_msgpack) _packing_function_registry.register("check", _PackingBackend.PYARROW, check_pyarrow) _packing_function_registry.register("loadb", _PackingBackend.PICKLE, loadb_pickle) _packing_function_registry.register("dumpb", _PackingBackend.PICKLE, dumpb_pickle) _packing_function_registry.register("loadb", _PackingBackend.MSGPACK, loadb_msgpack) _packing_function_registry.register("dumpb", _PackingBackend.MSGPACK, dumpb_msgpack) _packing_function_registry.register("loadb", _PackingBackend.PYARROW, loadb_pyarrow) _packing_function_registry.register("dumpb", _PackingBackend.PYARROW, dumpb_pyarrow) _default_packing_backend = _PackingBackend.PICKLE
[docs]def get_default_backend(): return _default_packing_backend.name
[docs]def get_available_backends(): return [ obj.name for obj in _PackingBackend.choice_objs() if _packing_function_registry.dispatch("check", obj) ]
[docs]def set_default_backend(backend): global _default_packing_backend # pylint: disable=global-statement _default_packing_backend = _PackingBackend.from_string(backend) assert ( _default_packing_backend.name in get_available_backends() ), 'Unsupported backend on your machine: "{}".'.format( _default_packing_backend.name )
[docs]def loadb(bstr, *args, backend=None, **kwargs): backend = backend or _default_packing_backend return _packing_function_registry.dispatch("loadb", backend, bstr, *args, **kwargs)
[docs]def dumpb(obj, *args, backend=None, **kwargs): backend = backend or _default_packing_backend return _packing_function_registry.dispatch("dumpb", backend, obj, *args, **kwargs)
def _initialize_backend(): set_default_backend(os.getenv("JAM_PACKING_BACKEND", _PackingBackend.PICKLE)) _initialize_backend()