Source code for jammy.utils.defaults

import threading
import inspect
import contextlib
import functools

from .meta import decorator_with_optional_args
from .naming import class_name_of_method

__all__ = [
    'defaults_manager', 'wrap_custom_as_default', 'gen_get_default', 'gen_set_default',
    'option_context',
    'ARGDEF', 'default_args'
]

class DefaultsManager(object):
    def __init__(self):
        self._is_local = dict()

        self._defaults_global = dict()
        self._defaults_local = threading.local()

    @decorator_with_optional_args(is_method=True)
    def wrap_custom_as_default(self, *, is_local=False):
        def wrapper(meth):
            identifier = class_name_of_method(meth)
            meth = contextlib.contextmanager(meth)
            self._is_local[identifier] = is_local
            defaults = self._get_defaults_registry(identifier)

            @contextlib.contextmanager
            @functools.wraps(meth)
            def wrapped_func(slf, *args, **kwargs):
                backup = defaults.get(identifier, None)
                defaults[identifier] = slf
                with meth(slf, *args, **kwargs):
                    yield
                defaults[identifier] = backup

            return wrapped_func
        return wrapper

    def gen_get_default(self, cls, default_getter=None):
        identifier = class_name_of_method(cls.as_default)

        def get_default(default=None):
            if default is None and default_getter is not None:
                default = default_getter()

            # NB(Jiayuan Mao): cannot use .get(identifier, default), because after calling as_default, the current
            #     default will be set to None.
            val = self._get_defaults_registry(identifier).get(identifier, None)
            if val is None:
                val = default
            return val
        return get_default

    def gen_set_default(self, cls):
        identifier = class_name_of_method(cls.as_default)

        def set_default(default):
            self._get_defaults_registry(identifier)[identifier] = default
        return set_default

    def set_default(self, cls, default):
        identifier = class_name_of_method(cls.as_default)
        self._get_defaults_registry(identifier)[identifier] = default

    def _get_defaults_registry(self, identifier):
        is_local = self._is_local.get(identifier, False)
        if is_local:
            if not hasattr(self._defaults_local, 'defaults'):
                self._defaults_local.defaults = dict()
            defaults = self._defaults_local.defaults
        else:
            defaults = self._defaults_global
        return defaults


defaults_manager = DefaultsManager()
wrap_custom_as_default = defaults_manager.wrap_custom_as_default
gen_get_default = defaults_manager.gen_get_default
gen_set_default = defaults_manager.gen_set_default

class _LocalObjectSimulator(object):
    __slots__ = ['ctx']


[docs]def option_context(name, is_local=True, **kwargs): class OptionContext(object): def __init__(self, **init_kwargs): for k, v in kwargs.items(): setattr(self, k, v) if hasattr(self.__class__, 'current_context') and self.__class__.current_context.ctx is not None: c = self.__class__.get_default() for k in kwargs: setattr(self, k, getattr(c, k)) for k, v in init_kwargs.items(): assert k in kwargs setattr(self, k, v) @classmethod def get_option(cls, name): getattr(cls.get_default(), name) @classmethod def set_default_option(cls, name, value): cls._create_default_context() setattr(cls.default_context.ctx, name, value) @classmethod def get_default(cls): cls._create_current_context() if cls.current_context.ctx is not None: return cls.current_context.ctx else: cls._create_default_context() return cls.default_context.ctx @contextlib.contextmanager def as_default(self): self.__class__._create_current_context() backup = self.__class__.current_context.ctx self.__class__.current_context.ctx = self yield self.__class__.current_context.ctx = backup @classmethod def _create_default_context(cls): if hasattr(cls, 'default_context'): return if is_local: cls.default_context = threading.local() else: cls.default_context = _LocalObjectSimulator() cls.default_context.ctx = cls(**kwargs) @classmethod def _create_current_context(cls): if hasattr(cls, 'current_context'): return if is_local: cls.current_context = threading.local() else: cls.current_context = _LocalObjectSimulator() cls.current_context.ctx = None OptionContext.__name__ = name return OptionContext
ARGDEF = object()
[docs]def default_args(func): def wrapper(func): sig = inspect.signature(func) @functools.wraps(func) def wrapped(*args, **kwargs): bounded = sig.bind(*args, **kwargs) bounded.apply_defaults() for k, v in bounded.arguments.items(): if v is ARGDEF: if k in sig.parameters: default_value = sig.parameters[k].default bounded.arguments[k] = default_value return func(*bounded.args, **bounded.kwargs) return wrapped return wrapper(func)