File size: 1,129 Bytes
06a7cdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import matplotlib.pyplot as plt
from torchvision import transforms


def plot_incorrect_preds(incorrect, classes, num_imgs):
    import random

    # num_imgs is a multiple of 5
    assert num_imgs % 5 == 0
    assert len(incorrect) >= num_imgs

    incorrect_inds = random.sample(range(len(incorrect)), num_imgs)

    # incorrect (data, target, pred, output)
    fig = plt.figure(figsize=(10, num_imgs // 2))
    plt.suptitle("Target | Predicted Label")
    for i in range(num_imgs):
        curr_incorrect = incorrect[incorrect_inds[i]]
        plt.subplot(num_imgs // 5, 5, i + 1, aspect="auto")

        # unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
        unnormalized = transforms.Normalize(
            (-1.98947368, -1.98436214, -1.71072797), (4.048583, 4.11522634, 3.83141762)
        )(curr_incorrect[0])
        plt.imshow(transforms.ToPILImage()(unnormalized))
        plt.title(
            f"{classes[curr_incorrect[1].item()]}|{classes[curr_incorrect[2].item()]}",
            # fontsize=8,
        )
        plt.xticks([])
        plt.yticks([])
    plt.tight_layout()
    return fig