mawady commited on
Commit
ebde07d
·
1 Parent(s): 9a6ce60

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -0
app.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ warnings.filterwarnings("ignore")
4
+ import torch
5
+ import cv2
6
+ import numpy as np
7
+ from torchvision import transforms
8
+ from torchvision.models import resnet18, ResNet18_Weights
9
+ import urllib.request
10
+ from pytorch_grad_cam import GradCAMPlusPlus
11
+ from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
12
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
+ import gradio as gr
14
+
15
+ IMG_SIZE = 224
16
+ CLASSES = ResNet18_Weights.IMAGENET1K_V1.meta["categories"]
17
+ TOP_NUM_CLASSES = 3
18
+
19
+ url = (
20
+ "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
21
+ )
22
+ path_input = "./cat.jpg"
23
+ urllib.request.urlretrieve(url, filename=path_input)
24
+
25
+
26
+ url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
27
+ path_input = "./dog.jpg"
28
+ urllib.request.urlretrieve(url, filename=path_input)
29
+
30
+ device = "cpu"
31
+ if torch.cuda.is_available():
32
+ device = "cuda"
33
+
34
+ model = resnet18(pretrained=True)
35
+
36
+ data_transforms = transforms.Compose(
37
+ [
38
+ transforms.Resize(IMG_SIZE),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
41
+ ]
42
+ )
43
+
44
+
45
+ def grad_campp(img, cls_ids):
46
+ img_rz = cv2.resize(np.array(img), (IMG_SIZE, IMG_SIZE))
47
+ img = np.float32(img_rz) / 255
48
+ input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).to(
49
+ device
50
+ )
51
+ # mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
52
+
53
+ # Set target layers
54
+ target_layers = [model.layer4[-1]]
55
+
56
+ # Set target classes
57
+ # targets = [ClassifierOutputTarget(cls_id) for cls_id in cls_ids]
58
+
59
+ # GradCAM++
60
+ gradcampp = GradCAMPlusPlus(model=model, target_layers=target_layers)
61
+
62
+ lst_gradcam = []
63
+ for i in range(TOP_NUM_CLASSES):
64
+ targets = [ClassifierOutputTarget(cls_ids[i])]
65
+ grayscale_gradcampp = gradcampp(
66
+ input_tensor=input_tensor,
67
+ targets=targets,
68
+ eigen_smooth=False,
69
+ aug_smooth=False,
70
+ )
71
+ grayscale_gradcampp = grayscale_gradcampp[0, :]
72
+ gradcampp_image = show_cam_on_image(img, grayscale_gradcampp, use_rgb=True)
73
+ lst_gradcam.append(gradcampp_image)
74
+
75
+ return img_rz, lst_gradcam
76
+
77
+
78
+ def do_inference(img):
79
+ img_t = data_transforms(img)
80
+ batch_t = torch.unsqueeze(img_t, 0)
81
+ model.eval()
82
+ # We don't need gradients for test, so wrap in
83
+ # no_grad to save memory
84
+ with torch.no_grad():
85
+ batch_t = batch_t.to(device)
86
+ # forward propagation
87
+ output = model(batch_t)
88
+ # get prediction
89
+ probs = torch.nn.functional.softmax(output, dim=1)
90
+ cls_ids = (
91
+ torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int)
92
+ )[:TOP_NUM_CLASSES]
93
+ probs = probs.cpu().numpy()[0]
94
+ probs = probs[cls_ids]
95
+ labels = np.array(CLASSES)[cls_ids]
96
+ img_rz, lst_gradcam = grad_campp(img, cls_ids)
97
+ return (
98
+ {labels[i]: round(float(probs[i]), 2) for i in range(len(labels))},
99
+ img_rz,
100
+ lst_gradcam[0],
101
+ lst_gradcam[1],
102
+ lst_gradcam[2],
103
+ )
104
+
105
+
106
+ im = gr.inputs.Image(
107
+ shape=None, image_mode="RGB", invert_colors=False, source="upload", type="pil"
108
+ )
109
+
110
+ title = "Explainable AI - PyTorch"
111
+ description = "Playground: GradCam Inferernce of Object Classification using ResNet18 model. Libraries: PyTorch, Gradio, Grad-Cam"
112
+ examples = [["./cat.jpg"], ["./dog.jpg"]]
113
+ article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
114
+ iface = gr.Interface(
115
+ do_inference,
116
+ im,
117
+ outputs=[
118
+ gr.outputs.Label(num_top_classes=TOP_NUM_CLASSES),
119
+ gr.outputs.Image(label="Output image", type="pil"),
120
+ gr.outputs.Image(label="Output image", type="pil"),
121
+ gr.outputs.Image(label="Output image", type="pil"),
122
+ gr.outputs.Image(label="Output image", type="pil"),
123
+ ],
124
+ live=False,
125
+ interpretation=None,
126
+ title=title,
127
+ description=description,
128
+ examples=examples,
129
+ )
130
+
131
+ # iface.test_launch()
132
+
133
+ iface.launch()