Source code for jammy.random.rng
import os
import random as sys_random
import numpy as np
import numpy.random as npr
from jammy.utils.cache import cached_result
from jammy.utils.defaults import defaults_manager
from jammy.utils.env import jam_getenv
from jammy.utils.registry import Registry
__all__ = [
"JamRandomState",
"get_default_rng",
"gen_seed",
"gen_rng",
"reset_global_seed",
"jam_rng_seed",
]
[docs]class JamRandomState(npr.RandomState):
[docs] def choice_list(self, list_, size=1, replace=False, p=None):
"""Efficiently draw an element from an list,
if the rng is given, use it instead of the system one."""
if size == 1:
if type(list_) in (list, tuple):
return list_[self.choice(len(list_), p=p)]
return self.choice(list_, p=p)
else:
if type(list_) in (list, tuple):
inds = self.choice(len(list_), size=size, replace=replace, p=p)
return [list_[i] for i in inds]
return self.choice(list_, size=size, replace=replace, p=p)
[docs] def shuffle_list(self, list_):
if isinstance(list_, list):
sys_random.shuffle(list_, random=self.random_sample)
else:
self.shuffle(list_)
[docs] def shuffle_multi(self, *arrs):
length = len(arrs[0])
for cur_a in arrs:
assert (
len(cur_a) == length
), "non-compatible length when shuffling multiple arrays"
inds = np.arange(length)
self.shuffle(inds)
return tuple(map(lambda x: x[inds], arrs))
[docs] @defaults_manager.wrap_custom_as_default(is_local=True)
def as_default(self):
yield self
[docs]@cached_result
def jam_rng_seed():
return jam_getenv("RANDOM_SEED", 8, int)
_rng = JamRandomState()
get_default_rng = defaults_manager.gen_get_default(
JamRandomState, default_getter=lambda: _rng
)
class _RngRegistry(Registry):
def register(self, entry, value):
seed = jam_rng_seed()
if seed is not None:
value()(seed)
return super().register(entry, value)
[docs]def gen_seed():
return get_default_rng().randint(4294967296)
[docs]def gen_rng(seed=None):
if seed is None:
seed = jam_rng_seed()
return JamRandomState(seed)
global_rng_registry = _RngRegistry()
global_rng_registry.register("jammy", lambda: _rng.seed)
global_rng_registry.register("numpy", lambda: npr.seed)
global_rng_registry.register("sys", lambda: sys_random.seed)
[docs]def reset_global_seed(seed=None, verbose=True):
if seed is None:
seed = gen_seed()
seed = int(seed)
for k, seed_getter in global_rng_registry.items():
if verbose:
from jammy.logging import get_logger
logger = get_logger()
logger.critical(
"Reset random seed for: {} (pid={}, seed={}).".format(
k, os.getpid(), seed
)
)
seed_getter()(seed)