File size: 6,808 Bytes
41bd4da
 
 
 
19aa1d3
41bd4da
 
 
 
 
 
 
 
 
19aa1d3
 
 
41bd4da
 
 
19aa1d3
41bd4da
 
 
19aa1d3
41bd4da
19aa1d3
 
 
 
 
 
 
41bd4da
19aa1d3
 
cb5b79f
19aa1d3
 
 
 
 
 
41bd4da
 
 
 
 
 
 
 
19aa1d3
 
 
 
 
 
 
 
 
 
 
 
41bd4da
19aa1d3
 
41bd4da
 
 
 
 
 
 
 
 
 
735e1d8
 
41bd4da
735e1d8
 
41bd4da
 
 
 
 
735e1d8
 
 
19aa1d3
735e1d8
41bd4da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19aa1d3
 
 
381c7f9
41bd4da
735e1d8
 
41bd4da
 
 
 
 
19aa1d3
 
 
 
 
 
 
 
 
41bd4da
 
 
 
 
 
 
19aa1d3
 
 
 
 
 
 
 
 
 
41bd4da
 
19aa1d3
 
41bd4da
19aa1d3
 
41bd4da
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# gradioMisClassGradCAMimageInputter
import os
import torch
import torchvision
from torchvision import datasets, transforms, utils
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torchmetrics.functional import accuracy
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import misclas_helper
import gradcam_helper
import lightningmodel
import trainsave_loadtest
from misclas_helper import display_cifar_misclassified_data
from gradcam_helper import display_gradcam_output
from misclas_helper import get_misclassified_data2
from misclas_helper import classify_images
from lightningmodel import LitResnet
from trainsave_loadtest import ts_lt
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import datasets, transforms, utils

save1_or_load0 = False

model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "weights_92.ckpt") # Train and Save Vs Load and Test
'''
ts_lt(save1_or_load0, # decision maker for training Vs testing
          Epochs = 1, # argument for training
          wt_fname = "/content/weights.ckpt" # argument for testing
          )
'''
targets = None
device = torch.device("cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')


device = torch.device("cpu")

# Get the misclassified data from test dataset
misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20)



################################################################################################




fileName = None

inv_normalize = transforms.Normalize(
  mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
  std=[1/0.23, 1/0.23, 1/0.23]
)

def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
  if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
    fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
    return Image.open(fileName)
  else:
    return None
misClass_demo = gr.Interface(
    fn = hello,
    inputs=[ gr.Textbox(label="Do you want to show misClassified images ?", placeholder="Yes / yes / YES", lines=1),
             gr.Slider(0, 20, step=5, label = "How many images ?")],
    outputs=['image'],
    title="Misclassified Images",
    description="If your answer to the question is yes, then only it works !",
)


############

targets = None
device = torch.device("cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck', 'No Image')


def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
  if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):
    if(WhichLayer == -1):
      target_layers = [model.model.resNetLayer2Part2[-1]]
    elif(WhichLayer == -2):
      target_layers = [model.model.resNetLayer2Part1[-1]]
    elif(WhichLayer == -3):
      target_layers = [model.model.Layer3[-1]]
    fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70)
    return Image.open(fileName)

gradCAM_demo = gr.Interface(
    fn=inference,
        #DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
    inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1),
             gr.Slider(0, 20, step=5, label = "How many images ?"),
             gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
             gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
    outputs=['image'],
    title="GradCammed Images",
    description="If your answer to the question is yes, then only it works !",
)


############

def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9):
  list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9]
  classified_data = classify_images(list_images, model.model, device)
  img_out = []
  pred_out = []
  for img, pred in classified_data:
    img_out.append(img)
    pred_out.append(pred)
  return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9]

imageInputter_demo = gr.Interface(
    ImageInputter,
    [
        "image","image","image","image","image","image","image","image","image","image"
    ],
    [
        gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"),
        gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"),
        gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"),
        gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"),
        gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"),
        gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"),
        gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"),
        gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"),
        gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"),
        gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9")
    ],
    examples=[
        ["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"],
        [None, None, None, None, "truck.jpg", None, None, None, None],
    ],
    title="Max 10 images input Classifier",
    description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.",
)


############


demo = gr.TabbedInterface(
    interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo],
    tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"]
)

demo.launch(debug=True)