raja5259 commited on
Commit
19aa1d3
·
verified ·
1 Parent(s): c0502a3

update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -42
app.py CHANGED
@@ -1,13 +1,8 @@
1
  # gradioMisClassGradCAMimageInputter
2
  import os
3
- import math
4
- import numpy as np
5
- import pandas as pd
6
  import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
  import torchvision
10
- import matplotlib.pyplot as plt
11
  from pl_bolts.datamodules import CIFAR10DataModule
12
  from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
13
  from pytorch_lightning import LightningModule, Trainer, seed_everything
@@ -17,40 +12,71 @@ from pytorch_lightning.loggers import CSVLogger
17
  from torch.optim.lr_scheduler import OneCycleLR
18
  from torch.optim.swa_utils import AveragedModel, update_bn
19
  from torchmetrics.functional import accuracy
20
- from pytorch_lightning.callbacks import ModelCheckpoint
21
- from torchvision import datasets, transforms, utils
22
- from PIL import Image
23
- from pytorch_grad_cam import GradCAM
24
- from pytorch_grad_cam.utils.image import show_cam_on_image
25
- import gradio as gr
26
  import misclas_helper
27
  import gradcam_helper
28
  import lightningmodel
 
29
  from misclas_helper import display_cifar_misclassified_data
30
  from gradcam_helper import display_gradcam_output
31
  from misclas_helper import get_misclassified_data2
 
32
  from lightningmodel import LitResnet
 
 
 
 
 
 
 
33
 
34
- fileName = None
 
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  targets = None
37
  device = torch.device("cpu")
38
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
39
  'dog', 'frog', 'horse', 'ship', 'truck')
40
 
41
- model = LitResnet(lr=0.05).load_from_checkpoint("weights_92.ckpt")
42
 
43
  device = torch.device("cpu")
44
 
45
- # Denormalize the data using test mean and std deviation
 
 
 
 
 
 
 
 
 
 
 
46
  inv_normalize = transforms.Normalize(
47
- mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
48
- std=[1/0.23, 1/0.23, 1/0.23]
49
  )
50
 
51
- # Get the misclassified data from test dataset
52
- misclassified_data = get_misclassified_data2(model, device, 20)
53
-
54
  def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
55
  if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
56
  fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
@@ -72,7 +98,7 @@ misClass_demo = gr.Interface(
72
  targets = None
73
  device = torch.device("cpu")
74
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
75
- 'dog', 'frog', 'horse', 'ship', 'truck')
76
 
77
 
78
  def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
@@ -89,9 +115,9 @@ def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transp
89
  gradCAM_demo = gr.Interface(
90
  fn=inference,
91
  #DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
92
- inputs=[ gr.Textbox(label="Do you want to show GradCammed images ?", placeholder="Yes / yes / YES", lines=1),
93
- gr.Slider(0, 20, step=5, label = "How many images ?"),
94
- gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
95
  gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
96
  outputs=['image'],
97
  title="GradCammed Images",
@@ -101,8 +127,15 @@ gradCAM_demo = gr.Interface(
101
 
102
  ############
103
 
104
- def ImageInputter(img1, img2, img3, img4, img5, img6, img7, img8, img9, img10):
105
- return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10
 
 
 
 
 
 
 
106
 
107
  imageInputter_demo = gr.Interface(
108
  ImageInputter,
@@ -110,25 +143,23 @@ imageInputter_demo = gr.Interface(
110
  "image","image","image","image","image","image","image","image","image","image"
111
  ],
112
  [
113
- gr.Image("image", label = "output 1"),
114
- gr.Image("image", label = "output 2"),
115
- gr.Image("image", label = "output 3"),
116
- gr.Image("image", label = "output 4"),
117
- gr.Image("image", label = "output 5"),
118
- gr.Image("image", label = "output 6"),
119
- gr.Image("image", label = "output 7"),
120
- gr.Image("image", label = "output 8"),
121
- gr.Image("image", label = "output 9"),
122
- gr.Image("image", label = "output 10")
123
  ],
124
  examples=[
125
- ["bird.jpg", "car.jpg", "cat.jpg"],
126
- ["deer.jpg", "dog.jpg", "frog.jpg"],
127
- ["horse.jpg", "plane.jpg", "ship.jpg"],
128
- [None, "truck.jpg", None],
129
  ],
130
- title="Max 10 images input",
131
- description="Here's a sample image inputter. Allows you to feed in 10 images and display them. You may drag and drop images from bottom examples to the input feeders",
132
  )
133
 
134
 
 
1
  # gradioMisClassGradCAMimageInputter
2
  import os
 
 
 
3
  import torch
 
 
4
  import torchvision
5
+ from torchvision import datasets, transforms, utils
6
  from pl_bolts.datamodules import CIFAR10DataModule
7
  from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
8
  from pytorch_lightning import LightningModule, Trainer, seed_everything
 
12
  from torch.optim.lr_scheduler import OneCycleLR
13
  from torch.optim.swa_utils import AveragedModel, update_bn
14
  from torchmetrics.functional import accuracy
15
+ import pandas as pd
16
+ import seaborn as sn
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ # from IPython.core.display import display
 
20
  import misclas_helper
21
  import gradcam_helper
22
  import lightningmodel
23
+ import trainsave_loadtest
24
  from misclas_helper import display_cifar_misclassified_data
25
  from gradcam_helper import display_gradcam_output
26
  from misclas_helper import get_misclassified_data2
27
+ from misclas_helper import classify_images
28
  from lightningmodel import LitResnet
29
+ from trainsave_loadtest import ts_lt
30
+ import numpy as np
31
+ import gradio as gr
32
+ from PIL import Image
33
+ from pytorch_grad_cam import GradCAM
34
+ from pytorch_grad_cam.utils.image import show_cam_on_image
35
+ from torchvision import datasets, transforms, utils
36
 
37
+ save1_or_load0 = False
38
+
39
+ !mkdir -p /content # This path will be used for saving the weights file after training
40
 
41
+ !git clone https://github.com/rajayourfriend/EraV2/
42
+ !cp EraV2/S14/*.py .
43
+ !cp EraV2/S14/*.jpg .
44
+ if save1_or_load0 == False:
45
+ !cp /content/EraV2/S14/weights_92.ckpt /content/weights.ckpt
46
+ wt_fname = "/content/weights.ckpt" # weights file name to load
47
+
48
+ model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "/content/weights.ckpt") # Train and Save Vs Load and Test
49
+ '''
50
+ ts_lt(save1_or_load0, # decision maker for training Vs testing
51
+ Epochs = 1, # argument for training
52
+ wt_fname = "/content/weights.ckpt" # argument for testing
53
+ )
54
+ '''
55
  targets = None
56
  device = torch.device("cpu")
57
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
58
  'dog', 'frog', 'horse', 'ship', 'truck')
59
 
 
60
 
61
  device = torch.device("cpu")
62
 
63
+ # Get the misclassified data from test dataset
64
+ misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20)
65
+
66
+
67
+
68
+ ################################################################################################
69
+
70
+
71
+
72
+
73
+ fileName = None
74
+
75
  inv_normalize = transforms.Normalize(
76
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
77
+ std=[1/0.23, 1/0.23, 1/0.23]
78
  )
79
 
 
 
 
80
  def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
81
  if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
82
  fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
 
98
  targets = None
99
  device = torch.device("cpu")
100
  classes = ('plane', 'car', 'bird', 'cat', 'deer',
101
+ 'dog', 'frog', 'horse', 'ship', 'truck', 'No Image')
102
 
103
 
104
  def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
 
115
  gradCAM_demo = gr.Interface(
116
  fn=inference,
117
  #DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
118
+ inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1),
119
+ gr.Slider(0, 20, step=5, label = "How many images ?"),
120
+ gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
121
  gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
122
  outputs=['image'],
123
  title="GradCammed Images",
 
127
 
128
  ############
129
 
130
+ def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9):
131
+ list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9]
132
+ classified_data = classify_images(list_images, model.model, device)
133
+ img_out = []
134
+ pred_out = []
135
+ for img, pred in classified_data:
136
+ img_out.append(img)
137
+ pred_out.append(pred)
138
+ return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9]
139
 
140
  imageInputter_demo = gr.Interface(
141
  ImageInputter,
 
143
  "image","image","image","image","image","image","image","image","image","image"
144
  ],
145
  [
146
+ gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"),
147
+ gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"),
148
+ gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"),
149
+ gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"),
150
+ gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"),
151
+ gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"),
152
+ gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"),
153
+ gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"),
154
+ gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"),
155
+ gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9")
156
  ],
157
  examples=[
158
+ ["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"],
159
+ [None, None, None, None, "truck.jpg", None, None, None, None],
 
 
160
  ],
161
+ title="Max 10 images input Classifier",
162
+ description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.",
163
  )
164
 
165