Source code for jammy.io.fs

import contextlib
import glob
import gzip
import json
import os
import os.path as osp
import pickle
import platform
import shutil
from zipfile import ZIP_DEFLATED, ZipFile

import numpy as np
import scipy.io as sio
import six
import yaml
from filelock import FileLock

from jammy.image import imread, imwrite
from jammy.logging import get_logger
from jammy.utils.enum import JamEnum
from jammy.utils.registry import CallbackRegistry, RegistryGroup

from .common import get_ext

logger = get_logger()

__all__ = [
    "as_file_descriptor",
    "fs_verbose",
    "set_fs_verbose",
    "open",
    "open_txt",
    "open_h5",
    "open_gz",
    "load",
    "load_txt",
    "load_h5",
    "load_pkl",
    "load_pklgz",
    "load_npy",
    "load_npz",
    "load_mat",
    "load_pth",
    "load_yaml",
    "load_json",
    "load_img",
    "dump",
    "dump_pkl",
    "dump_pklgz",
    "dump_npy",
    "dump_npz",
    "dump_mat",
    "dump_pth",
    "dump_json",
    "dump_img",
    "safe_dump",
    "compress",
    "compress_zip",
    "extract",
    "extract",
    "link",
    "mkdir",
    "lsdir",
    "remove",
    "locate_newest_file",
    "move",
    "copy",
    "replace",
    "io_function_registry",
    "latest_time",
]

sys_open = open


[docs]def as_file_descriptor(fd_or_fname, mode="r"): if isinstance(fd_or_fname, str): return sys_open(fd_or_fname, mode) # pylint: disable=consider-using-with return fd_or_fname
[docs]def open_h5(file, mode, **kwargs): import h5py return h5py.File(file, mode, **kwargs)
[docs]def open_txt(file, mode, **kwargs): return sys_open(file, mode, **kwargs) # pylint: disable=consider-using-with
[docs]def open_gz(file, mode): return gzip.open(file, mode)
def extract_zip(file, *args, **kwargs): with ZipFile(file, "r") as zip_ref: zip_ref.extractall(*args, **kwargs)
[docs]def load_pkl(file, **kwargs): with as_file_descriptor(file, "rb") as f: try: return pickle.load(f, **kwargs) except UnicodeDecodeError: if "encoding" in kwargs: raise return pickle.load(f, encoding="latin1", **kwargs)
# pylint: disable=unused-argument
[docs]def load_pklgz(file, **kwargs): with open_gz(file, "rb") as f: return load_pkl(f)
[docs]def load_h5(file, **kwargs): return open_h5(file, "r", **kwargs)
[docs]def load_txt(file, **kwargs): with sys_open(file, "r", **kwargs) as f: return f.readlines()
[docs]def load_npy(file, **kwargs): return np.load(file, **kwargs)
[docs]def load_npz(file, **kwargs): return np.load(file, **kwargs)
[docs]def load_mat(file, **kwargs): return sio.loadmat(file, **kwargs)
[docs]def load_pth(file, **kwargs): import torch return torch.load(file, **kwargs)
[docs]def load_json(file, **kwargs): with sys_open(file, "r") as fp: return json.load(fp)
[docs]def load_img(file, **kwargs): return imread(file)
[docs]def load_yaml(file, **kwargs): with sys_open(file, "r") as yamlfile: return yaml.load(yamlfile)
[docs]def dump_pkl(file, obj, **kwargs): with as_file_descriptor(file, "wb") as f: return pickle.dump(obj, f, **kwargs)
[docs]def dump_pklgz(file, obj, **kwargs): with open_gz(file, "wb") as f: return pickle.dump(obj, f)
def dump_yaml(file, obj, **kwargs): with sys_open(file, "w") as f: return yaml.dump(obj, f)
[docs]def dump_json(file, obj, **kwargs): with sys_open(file, "w") as f: return json.dump(obj, f, **kwargs)
[docs]def dump_img(file, obj, **kwargs): imwrite(file, obj)
[docs]def dump_npy(file, obj, **kwargs): return np.save(file, obj)
[docs]def dump_npz(file, obj, **kwargs): return np.savez(file, obj)
[docs]def dump_mat(file, obj, **kwargs): return sio.savemat(file, obj, **kwargs)
[docs]def dump_pth(file, obj, **kwargs): import torch return torch.save(obj, file)
[docs]def compress_zip( # pylint: disable=inconsistent-return-statements file, file_list, verbose=True, **kwargs ): from jammy.cli import yes_or_no with ZipFile(file, "w", ZIP_DEFLATED) as cur_zip: for l_file in file_list: try: cur_zip.write(l_file) except FileNotFoundError: is_continue = yes_or_no(f"Missing {l_file}, continue?") if is_continue: pass else: return None
class _IOFunctionRegistryGroup(RegistryGroup): __base_class__ = CallbackRegistry def dispatch(self, registry_name, file, *args, **kwargs): entry = get_ext(file) callback = self.lookup( registry_name, entry, fallback=True, default=_default_io_fallback ) return callback(file, *args, **kwargs) def _default_io_fallback(file, *args, **kwargs): raise ValueError('Unknown file extension: "{}".'.format(file)) io_function_registry = _IOFunctionRegistryGroup() io_function_registry.register("open", ".txt", open_txt) io_function_registry.register("open", ".h5", open_h5) io_function_registry.register("open", ".gz", open_gz) io_function_registry.register("open", "__fallback__", sys_open) io_function_registry.register("load", ".pkl", load_pkl) io_function_registry.register("load", ".pklgz", load_pklgz) io_function_registry.register("load", ".txt", load_txt) io_function_registry.register("load", ".h5", load_h5) io_function_registry.register("load", ".npy", load_npy) io_function_registry.register("load", ".npz", load_npz) io_function_registry.register("load", ".mat", load_mat) io_function_registry.register("load", ".pth", load_pth) io_function_registry.register("load", ".pt", load_pth) io_function_registry.register("load", ".ckpt", load_pth) io_function_registry.register("load", ".pt", load_pth) io_function_registry.register("load", ".cfg", load_pkl) io_function_registry.register("load", ".yaml", load_yaml) io_function_registry.register("load", ".yml", load_yaml) io_function_registry.register("load", ".json", load_json) io_function_registry.register("load", ".jpg", load_img) io_function_registry.register("load", ".png", load_img) io_function_registry.register("load", ".jepg", load_img) io_function_registry.register("dump", ".pkl", dump_pkl) io_function_registry.register("dump", ".pklgz", dump_pklgz) io_function_registry.register("dump", ".npy", dump_npy) io_function_registry.register("dump", ".npz", dump_npz) io_function_registry.register("dump", ".mat", dump_mat) io_function_registry.register("dump", ".pth", dump_pth) io_function_registry.register("dump", ".pt", dump_pth) io_function_registry.register("dump", ".ckpt", dump_pth) io_function_registry.register("dump", ".cfg", dump_pkl) io_function_registry.register("dump", ".yaml", dump_yaml) io_function_registry.register("dump", ".yml", dump_yaml) io_function_registry.register("dump", ".json", dump_json) io_function_registry.register("dump", ".jpg", dump_img) io_function_registry.register("dump", ".png", dump_img) io_function_registry.register("dump", ".jepg", dump_img) io_function_registry.register("extract", ".zip", extract_zip) io_function_registry.register("compress", ".zip", compress_zip) _fs_verbose = False # pylint: disable=global-statement
[docs]@contextlib.contextmanager def fs_verbose(mode=True): global _fs_verbose _fs_verbose, mode = mode, _fs_verbose yield _fs_verbose = mode
[docs]def set_fs_verbose(mode=True): global _fs_verbose _fs_verbose = mode
[docs]def open(file, mode, **kwargs): # pylint: disable=redefined-builtin if _fs_verbose and isinstance(file, six.string_types): logger.info('Opening file: "{}", mode={}.'.format(file, mode)) return io_function_registry.dispatch("open", file, mode, **kwargs)
[docs]def load(file, **kwargs): if _fs_verbose and isinstance(file, six.string_types): logger.info('Loading data from file: "{}".'.format(file)) return io_function_registry.dispatch("load", file, **kwargs)
[docs]def dump(file, obj, **kwargs): if _fs_verbose and isinstance(file, six.string_types): logger.info('Dumping data to file: "{}".'.format(file)) return io_function_registry.dispatch("dump", file, obj, **kwargs)
[docs]def compress(file, obj, **kwargs): if _fs_verbose and isinstance(file, six.string_types): logger.info('compress data to file: "{}".'.format(file)) return io_function_registry.dispatch("compress", file, obj, **kwargs)
[docs]def extract(file, **kwargs): if _fs_verbose and isinstance(file, six.string_types): logger.info('extract data to file: "{}".'.format(file)) return io_function_registry.dispatch("extract", file, **kwargs)
[docs]def safe_dump(fname, data, use_lock=True, use_temp=True, lock_timeout=10): temp_fname = "temp." + fname lock_fname = "lock." + fname def safe_dump_inner(): if use_temp: dump(temp_fname, data) os.replace(temp_fname, fname) return True else: return dump(temp_fname, data) if use_lock: with FileLock(lock_fname, lock_timeout) as flock: if flock.is_locked: return safe_dump_inner() else: logger.warning("Cannot lock the file: {}.".format(fname)) return False else: return safe_dump_inner()
[docs]def mkdir(path): return os.makedirs(path, exist_ok=True)
class LSDirectoryReturnType(JamEnum): BASE = "base" NAME = "name" REL = "rel" FULL = "full" REAL = "real"
[docs]def lsdir(dirname, pattern=None, return_type="full"): assert "*" in dirname or "?" in dirname or osp.isdir(dirname) return_type = LSDirectoryReturnType.from_string(return_type) if pattern is not None: files = glob.glob(osp.join(dirname, pattern), recursive=True) elif "*" in dirname: files = glob.glob(dirname) else: files = os.listdir(dirname) if return_type is LSDirectoryReturnType.BASE: return [osp.basename(f) for f in files] elif return_type is LSDirectoryReturnType.NAME: return [osp.splitext(osp.basename(f))[0] for f in files] elif return_type is LSDirectoryReturnType.REL: assert ( "*" not in dirname and "?" not in dirname ), "Cannot use * or ? for relative paths." return [osp.relpath(f, dirname) for f in files] elif return_type is LSDirectoryReturnType.FULL: return files elif return_type is LSDirectoryReturnType.REAL: return [osp.realpath(osp.join(dirname, f)) for f in files] else: raise ValueError("Unknown lsdir return type: {}.".format(return_type))
[docs]def remove(file): if osp.exists(file): if osp.isdir(file): shutil.rmtree(file, ignore_errors=True) if osp.isfile(file): os.remove(file)
[docs]def copy(src, dst): if osp.exists(src): if osp.isdir(src): _copy = shutil.copytree if osp.isfile(src): _copy = shutil.copyfile _copy(src, dst)
[docs]def move(src, dst): if osp.exists(src): os.rename(src, dst)
[docs]def replace(src, dst): if osp.exists(src): if osp.exists(dst): remove(dst) os.replace(src, dst)
[docs]def locate_newest_file(dirname, pattern): files = lsdir(dirname, pattern, return_type="full") if len(files) == 0: return None return max(files, key=osp.getmtime)
[docs]def latest_time(fname): import datetime if platform.system() == "Windows": ftime = os.path.getctime(fname) else: stat = os.stat(fname) try: ftime = stat.st_birthtime except AttributeError: # probably on Linux. ftime = stat.st_mtime return datetime.datetime.fromtimestamp(ftime)