Spaces:
Runtime error
Runtime error
update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,8 @@
|
|
1 |
# gradioMisClassGradCAMimageInputter
|
2 |
import os
|
3 |
-
import math
|
4 |
-
import numpy as np
|
5 |
-
import pandas as pd
|
6 |
import torch
|
7 |
-
import torch.nn as nn
|
8 |
-
import torch.nn.functional as F
|
9 |
import torchvision
|
10 |
-
import
|
11 |
from pl_bolts.datamodules import CIFAR10DataModule
|
12 |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
13 |
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
@@ -17,40 +12,71 @@ from pytorch_lightning.loggers import CSVLogger
|
|
17 |
from torch.optim.lr_scheduler import OneCycleLR
|
18 |
from torch.optim.swa_utils import AveragedModel, update_bn
|
19 |
from torchmetrics.functional import accuracy
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
from
|
25 |
-
import gradio as gr
|
26 |
import misclas_helper
|
27 |
import gradcam_helper
|
28 |
import lightningmodel
|
|
|
29 |
from misclas_helper import display_cifar_misclassified_data
|
30 |
from gradcam_helper import display_gradcam_output
|
31 |
from misclas_helper import get_misclassified_data2
|
|
|
32 |
from lightningmodel import LitResnet
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
|
|
|
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
targets = None
|
37 |
device = torch.device("cpu")
|
38 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
39 |
'dog', 'frog', 'horse', 'ship', 'truck')
|
40 |
|
41 |
-
model = LitResnet(lr=0.05).load_from_checkpoint("weights_92.ckpt")
|
42 |
|
43 |
device = torch.device("cpu")
|
44 |
|
45 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
inv_normalize = transforms.Normalize(
|
47 |
-
|
48 |
-
|
49 |
)
|
50 |
|
51 |
-
# Get the misclassified data from test dataset
|
52 |
-
misclassified_data = get_misclassified_data2(model, device, 20)
|
53 |
-
|
54 |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
|
55 |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
|
56 |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
|
@@ -72,7 +98,7 @@ misClass_demo = gr.Interface(
|
|
72 |
targets = None
|
73 |
device = torch.device("cpu")
|
74 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
75 |
-
'dog', 'frog', 'horse', 'ship', 'truck')
|
76 |
|
77 |
|
78 |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
|
@@ -89,9 +115,9 @@ def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transp
|
|
89 |
gradCAM_demo = gr.Interface(
|
90 |
fn=inference,
|
91 |
#DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
|
92 |
-
inputs=[ gr.Textbox(label="Do you want to show
|
93 |
-
gr.Slider(0, 20, step=5, label = "How many images ?"),
|
94 |
-
gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
|
95 |
gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
|
96 |
outputs=['image'],
|
97 |
title="GradCammed Images",
|
@@ -101,8 +127,15 @@ gradCAM_demo = gr.Interface(
|
|
101 |
|
102 |
############
|
103 |
|
104 |
-
def ImageInputter(img1, img2, img3, img4, img5, img6, img7, img8, img9
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
imageInputter_demo = gr.Interface(
|
108 |
ImageInputter,
|
@@ -110,25 +143,23 @@ imageInputter_demo = gr.Interface(
|
|
110 |
"image","image","image","image","image","image","image","image","image","image"
|
111 |
],
|
112 |
[
|
113 |
-
gr.Image("image", label = "
|
114 |
-
gr.Image("image", label = "
|
115 |
-
gr.Image("image", label = "
|
116 |
-
gr.Image("image", label = "
|
117 |
-
gr.Image("image", label = "
|
118 |
-
gr.Image("image", label = "
|
119 |
-
gr.Image("image", label = "
|
120 |
-
gr.Image("image", label = "
|
121 |
-
gr.Image("image", label = "
|
122 |
-
gr.Image("image", label = "
|
123 |
],
|
124 |
examples=[
|
125 |
-
["bird.jpg", "car.jpg", "cat.jpg"],
|
126 |
-
[
|
127 |
-
["horse.jpg", "plane.jpg", "ship.jpg"],
|
128 |
-
[None, "truck.jpg", None],
|
129 |
],
|
130 |
-
title="Max 10 images input",
|
131 |
-
description="Here's a sample image inputter. Allows you to feed in 10 images and display them. You may drag and drop images from bottom examples to the input feeders",
|
132 |
)
|
133 |
|
134 |
|
|
|
1 |
# gradioMisClassGradCAMimageInputter
|
2 |
import os
|
|
|
|
|
|
|
3 |
import torch
|
|
|
|
|
4 |
import torchvision
|
5 |
+
from torchvision import datasets, transforms, utils
|
6 |
from pl_bolts.datamodules import CIFAR10DataModule
|
7 |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
8 |
from pytorch_lightning import LightningModule, Trainer, seed_everything
|
|
|
12 |
from torch.optim.lr_scheduler import OneCycleLR
|
13 |
from torch.optim.swa_utils import AveragedModel, update_bn
|
14 |
from torchmetrics.functional import accuracy
|
15 |
+
import pandas as pd
|
16 |
+
import seaborn as sn
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
# from IPython.core.display import display
|
|
|
20 |
import misclas_helper
|
21 |
import gradcam_helper
|
22 |
import lightningmodel
|
23 |
+
import trainsave_loadtest
|
24 |
from misclas_helper import display_cifar_misclassified_data
|
25 |
from gradcam_helper import display_gradcam_output
|
26 |
from misclas_helper import get_misclassified_data2
|
27 |
+
from misclas_helper import classify_images
|
28 |
from lightningmodel import LitResnet
|
29 |
+
from trainsave_loadtest import ts_lt
|
30 |
+
import numpy as np
|
31 |
+
import gradio as gr
|
32 |
+
from PIL import Image
|
33 |
+
from pytorch_grad_cam import GradCAM
|
34 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
35 |
+
from torchvision import datasets, transforms, utils
|
36 |
|
37 |
+
save1_or_load0 = False
|
38 |
+
|
39 |
+
!mkdir -p /content # This path will be used for saving the weights file after training
|
40 |
|
41 |
+
!git clone https://github.com/rajayourfriend/EraV2/
|
42 |
+
!cp EraV2/S14/*.py .
|
43 |
+
!cp EraV2/S14/*.jpg .
|
44 |
+
if save1_or_load0 == False:
|
45 |
+
!cp /content/EraV2/S14/weights_92.ckpt /content/weights.ckpt
|
46 |
+
wt_fname = "/content/weights.ckpt" # weights file name to load
|
47 |
+
|
48 |
+
model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "/content/weights.ckpt") # Train and Save Vs Load and Test
|
49 |
+
'''
|
50 |
+
ts_lt(save1_or_load0, # decision maker for training Vs testing
|
51 |
+
Epochs = 1, # argument for training
|
52 |
+
wt_fname = "/content/weights.ckpt" # argument for testing
|
53 |
+
)
|
54 |
+
'''
|
55 |
targets = None
|
56 |
device = torch.device("cpu")
|
57 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
58 |
'dog', 'frog', 'horse', 'ship', 'truck')
|
59 |
|
|
|
60 |
|
61 |
device = torch.device("cpu")
|
62 |
|
63 |
+
# Get the misclassified data from test dataset
|
64 |
+
misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
################################################################################################
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
fileName = None
|
74 |
+
|
75 |
inv_normalize = transforms.Normalize(
|
76 |
+
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
|
77 |
+
std=[1/0.23, 1/0.23, 1/0.23]
|
78 |
)
|
79 |
|
|
|
|
|
|
|
80 |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
|
81 |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
|
82 |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
|
|
|
98 |
targets = None
|
99 |
device = torch.device("cpu")
|
100 |
classes = ('plane', 'car', 'bird', 'cat', 'deer',
|
101 |
+
'dog', 'frog', 'horse', 'ship', 'truck', 'No Image')
|
102 |
|
103 |
|
104 |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
|
|
|
115 |
gradCAM_demo = gr.Interface(
|
116 |
fn=inference,
|
117 |
#DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
|
118 |
+
inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1),
|
119 |
+
gr.Slider(0, 20, step=5, label = "How many images ?"),
|
120 |
+
gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
|
121 |
gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
|
122 |
outputs=['image'],
|
123 |
title="GradCammed Images",
|
|
|
127 |
|
128 |
############
|
129 |
|
130 |
+
def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9):
|
131 |
+
list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9]
|
132 |
+
classified_data = classify_images(list_images, model.model, device)
|
133 |
+
img_out = []
|
134 |
+
pred_out = []
|
135 |
+
for img, pred in classified_data:
|
136 |
+
img_out.append(img)
|
137 |
+
pred_out.append(pred)
|
138 |
+
return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9]
|
139 |
|
140 |
imageInputter_demo = gr.Interface(
|
141 |
ImageInputter,
|
|
|
143 |
"image","image","image","image","image","image","image","image","image","image"
|
144 |
],
|
145 |
[
|
146 |
+
gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"),
|
147 |
+
gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"),
|
148 |
+
gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"),
|
149 |
+
gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"),
|
150 |
+
gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"),
|
151 |
+
gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"),
|
152 |
+
gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"),
|
153 |
+
gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"),
|
154 |
+
gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"),
|
155 |
+
gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9")
|
156 |
],
|
157 |
examples=[
|
158 |
+
["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"],
|
159 |
+
[None, None, None, None, "truck.jpg", None, None, None, None],
|
|
|
|
|
160 |
],
|
161 |
+
title="Max 10 images input Classifier",
|
162 |
+
description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.",
|
163 |
)
|
164 |
|
165 |
|