Spaces:
Build error
Build error
from torch.utils.tensorboard import SummaryWriter | |
from PIL.JpegImagePlugin import JpegImageFile | |
import matplotlib.pyplot as plt | |
from typing import * | |
from math import * | |
import numpy as np | |
import random | |
import torch | |
import os | |
# use a style with no grid | |
plt.style.use("_mpl-gallery-nogrid") | |
def visualize_images(images_dict: Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]], | |
log_directory: str = "fake_face_logs", | |
n_images: int = 40, | |
figsize = (15, 15), | |
seed: Union[int, None] = None, | |
show: bool = True | |
): | |
"""Visualize some images from a dictionary | |
Args: | |
images_dict (Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]]): The dictionary of the images with key indicating the tag | |
log_directory (str, optional): The tensorboard log directory. Defaults to "fake_face_logs". | |
n_images (int, optional): The number of images. Defaults to 40. | |
figsize (tuple, optional): The figure size. Defaults to (15, 15). | |
seed (Union[int, None], optional): The seed. Defaults to None. | |
show (bool): Indicate if we want to show the figure. Defaults to True. | |
""" | |
assert len(images_dict) > 0 | |
assert isinstance(images_dict, dict) | |
# add seed | |
random.seed(seed) | |
# verify if we must add a title for each image | |
add_titles = len(images_dict) > 1 | |
images_ = [] | |
# modify the dictionary to obtain a tuple of images with their corresponding tags | |
for key in images_dict: | |
for image in images_dict[key]: | |
images_.append((key, image)) | |
# we take the number of images in the list if n_images is larger | |
if n_images > len(images_): n_images = len(images_) | |
# choose random images | |
images = random.choices(images_, k = n_images) | |
if isinstance(images[0], JpegImageFile): | |
images = [np.array(image[1]) for image in images if type(image[1]) in [JpegImageFile, torch.Tensor, np.ndarray]] | |
# calculate the number of rows and columns | |
n_rows = ceil(sqrt(n_images)) | |
fig, axs = plt.subplots(nrows=n_rows, ncols=n_rows, figsize = figsize) | |
# flat the axes | |
axs = axs.flat | |
# trace images | |
for i in range(n_images): | |
axs[i].imshow(images[i][1], interpolation = "nearest") | |
if add_titles: axs[i].set_title(images[i][0]) | |
axs[i].axis('off') | |
# add padding to the figure | |
fig.tight_layout() | |
# deleting no necessary plots | |
[fig.delaxes(axs[i]) for i in range(n_images, n_rows * n_rows)] | |
# add figure to tensorboard | |
with SummaryWriter(os.path.join(log_directory, "images")) as writer: | |
# identify the tag | |
tag = "_".join(list(images_dict)) if add_titles else list(images_dict.keys())[0] | |
writer.add_figure(tag = tag, figure = fig) | |
if show: return fig | |