=
adding the best model to hugging face
b63fd37
from torch.utils.tensorboard import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
import matplotlib.pyplot as plt
from typing import *
from math import *
import numpy as np
import random
import torch
import os
# use a style with no grid
plt.style.use("_mpl-gallery-nogrid")
def visualize_images(images_dict: Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]],
log_directory: str = "fake_face_logs",
n_images: int = 40,
figsize = (15, 15),
seed: Union[int, None] = None,
show: bool = True
):
"""Visualize some images from a dictionary
Args:
images_dict (Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]]): The dictionary of the images with key indicating the tag
log_directory (str, optional): The tensorboard log directory. Defaults to "fake_face_logs".
n_images (int, optional): The number of images. Defaults to 40.
figsize (tuple, optional): The figure size. Defaults to (15, 15).
seed (Union[int, None], optional): The seed. Defaults to None.
show (bool): Indicate if we want to show the figure. Defaults to True.
"""
assert len(images_dict) > 0
assert isinstance(images_dict, dict)
# add seed
random.seed(seed)
# verify if we must add a title for each image
add_titles = len(images_dict) > 1
images_ = []
# modify the dictionary to obtain a tuple of images with their corresponding tags
for key in images_dict:
for image in images_dict[key]:
images_.append((key, image))
# we take the number of images in the list if n_images is larger
if n_images > len(images_): n_images = len(images_)
# choose random images
images = random.choices(images_, k = n_images)
if isinstance(images[0], JpegImageFile):
images = [np.array(image[1]) for image in images if type(image[1]) in [JpegImageFile, torch.Tensor, np.ndarray]]
# calculate the number of rows and columns
n_rows = ceil(sqrt(n_images))
fig, axs = plt.subplots(nrows=n_rows, ncols=n_rows, figsize = figsize)
# flat the axes
axs = axs.flat
# trace images
for i in range(n_images):
axs[i].imshow(images[i][1], interpolation = "nearest")
if add_titles: axs[i].set_title(images[i][0])
axs[i].axis('off')
# add padding to the figure
fig.tight_layout()
# deleting no necessary plots
[fig.delaxes(axs[i]) for i in range(n_images, n_rows * n_rows)]
# add figure to tensorboard
with SummaryWriter(os.path.join(log_directory, "images")) as writer:
# identify the tag
tag = "_".join(list(images_dict)) if add_titles else list(images_dict.keys())[0]
writer.add_figure(tag = tag, figure = fig)
if show: return fig