File size: 3,054 Bytes
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63fd37
 
 
 
 
 
 
 
 
 
 
 
 
783053f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b63fd37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from torch.utils.tensorboard  import SummaryWriter
from PIL.JpegImagePlugin import JpegImageFile
import matplotlib.pyplot as plt
from typing import *
from math import *
import numpy as np
import random
import torch
import os

# use a style with no grid
plt.style.use("_mpl-gallery-nogrid")

def visualize_images(images_dict: Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]], 
                     log_directory: str = "fake_face_logs",
                     n_images: int = 40,
                     figsize = (15, 15), 
                     seed: Union[int, None] = None,
                     show: bool = True
                     ):
    """Visualize some images from a dictionary

    Args:
        images_dict (Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]]): The dictionary of the images with key indicating the tag
        log_directory (str, optional): The tensorboard log directory. Defaults to "fake_face_logs".
        n_images (int, optional): The number of images. Defaults to 40.
        figsize (tuple, optional): The figure size. Defaults to (15, 15).
        seed (Union[int, None], optional): The seed. Defaults to None.
        show (bool): Indicate if we want to show the figure. Defaults to True.
    """
    
    assert len(images_dict) > 0
    
    assert isinstance(images_dict, dict) 
    
    # add seed
    random.seed(seed)
    
    # verify if we must add a title for each image
    add_titles = len(images_dict) > 1
    
    images_ = []
    
    # modify the dictionary to obtain a tuple of images with their corresponding tags
    for key in images_dict:
        
        for image in images_dict[key]:
            
            images_.append((key, image))
        
    # we take the number of images in the list if n_images is larger
    if n_images > len(images_): n_images = len(images_)
    
    # choose random images
    images = random.choices(images_, k = n_images)
    
    if isinstance(images[0], JpegImageFile):
        
        images = [np.array(image[1]) for image in images if type(image[1]) in [JpegImageFile, torch.Tensor, np.ndarray]]
    
    # calculate the number of rows and columns
    n_rows = ceil(sqrt(n_images))
    
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_rows, figsize = figsize)
    
    # flat the axes
    axs = axs.flat
    
    # trace images
    for i in range(n_images):
        
        axs[i].imshow(images[i][1], interpolation = "nearest")
        
        if add_titles: axs[i].set_title(images[i][0])
        
        axs[i].axis('off')
    
    # add padding to the figure
    fig.tight_layout()
    
    # deleting no necessary plots
    [fig.delaxes(axs[i]) for i in range(n_images, n_rows * n_rows)]
    
    # add figure to tensorboard
    with SummaryWriter(os.path.join(log_directory, "images")) as writer:
        
        # identify the tag
        tag = "_".join(list(images_dict)) if add_titles else list(images_dict.keys())[0]
        
        writer.add_figure(tag = tag, figure = fig)
    
    if show: return fig