Source code for jammy.utils.gpu

import gpustat
import numpy as np


[docs]def is_gpu_free(ids=0, mem_thres=0.2, util_thres=0.3): if isinstance(ids, (list, tuple)): return all(is_gpu_free(item, mem_thres, util_thres) for item in ids) if isinstance(ids, int): query = gpustat.new_query() fixed_gpu = query[ids] used_mem = 1.0 * fixed_gpu.memory_used / fixed_gpu.memory_total if used_mem < mem_thres and fixed_gpu.utilization < util_thres: return True return False if ids == "all": query = gpustat.new_query() return is_gpu_free(list(range(len(query))), mem_thres, util_thres) raise RuntimeError(f"{ids} not supprted")
[docs]def gpu_by_weight(mem_prior=1.0): mem_prior = np.clip(mem_prior, 0.0, 1.0) query = gpustat.new_query() if len(query) == 0: raise RuntimeError("gpu not available") if len(query) == 1: return query[0].entry["index"] mem_list, utils_list = get_mem_util() mem, utils = np.array(mem_list), np.array(utils_list) weight = mem * mem_prior + utils * (1 - mem_prior) ids = np.argsort(weight) return [query[id_item].entry["index"] for id_item in ids]
[docs]def gpu_by_util(): query = gpustat.new_query() _, utils_list = get_mem_util() ids = np.argsort(utils_list) return [query[id_item].entry["index"] for id_item in ids]
[docs]def get_mem_util(): query = gpustat.new_query() used_space_list = [1.0 * item.memory_used / item.memory_total for item in query] utils_list = [item.utilization for item in query] return used_space_list, utils_list