Source code for jammy.image.imgio

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : imgio.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/19/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.


import os as os
import os.path as osp
import numpy as np
import matplotlib.pyplot as plt

from . import backend
from .imgproc import dimshuffle


__all__ = [
    "imread",
    "imwrite",
    "imshow",
    "plt2pil",
    "plt2nd",
    "nd2pil",
    "pil2nd",
    "imgstack",
    "savefig",
    "ndimgs_in_row",
]


[docs]def imread(path, *, shuffle=False): if not osp.exists(path): return None i = backend.imread(path) if i is None: return None if shuffle: return dimshuffle(i, "channel_first") return i
[docs]def imwrite(path, img, *, shuffle=False): if shuffle: img = dimshuffle(img, "channel_last") backend.imwrite(path, img)
[docs]def imshow(title, img, *, shuffle=False): if shuffle: img = dimshuffle(img, "channel_last") backend.imshow(title, img)
[docs]def plt2pil(fig): """Convert a Matplotlib figure to a PIL Image and return it""" import io from PIL import Image buf = io.BytesIO() fig.savefig(buf) buf.seek(0) img = Image.open(buf) return img
[docs]def plt2nd(fig): return np.array(plt2pil(fig))
[docs]def savefig(fig, fig_name): fig_path = fig_name.split("/") if len(fig_path) > 1: save_path = "/".join(fig_path[:-1]) if not osp.isdir(save_path): os.makedirs(save_path, exist_ok=True) fig.savefig(fig_name)
[docs]def imgstack(imgs, dpi=128): import matplotlib import matplotlib.pyplot as plt num_img = len(imgs) img_size = imgs[0].size with plt.style.context("img"): fig, axs = plt.subplots( num_img, 1, figsize=(1 * img_size[0] / dpi, num_img * img_size[1] / dpi), ) for cur_img, cur_ax in zip(imgs, axs): cur_ax.imshow(cur_img) return fig
[docs]def ndimgs_in_row(list_x, n, dpi=128, img_size=400): """ args: list_x: list of elements that imshow can display n: number of the element in a row """ length = len(list_x) idxes = np.linspace(0, length - 1, n, dtype=int) with plt.style.context("img"): fig, axs = plt.subplots( 1, n, figsize=(n * img_size / dpi, img_size / dpi), dpi=dpi ) for j, idx in enumerate(idxes): axs[j].imshow(list_x[idx]) axs[j].set_title(idx) return fig
nd2pil=backend.pil_nd2img pil2nd=backend.pil_img2nd