GabrielML commited on
Commit
8850972
·
1 Parent(s): c426221
app.py CHANGED
@@ -1,49 +1,80 @@
1
  import copy
2
  import os
3
  import sys
 
4
  sys.path.append('src')
 
5
  from collections import defaultdict
6
  from functools import lru_cache
 
 
7
  import gradio as gr
8
- import matplotlib.pyplot as plt
9
  import numpy as np
10
  import pandas as pd
11
  import torch
12
  from deep_translator import GoogleTranslator
 
13
  from Nets import CustomResNet18
14
  from PIL import Image
15
- from torchcam.methods import GradCAM, GradCAMpp, SmoothGradCAMpp, XGradCAM
16
- from torchcam.utils import overlay_mask
17
- from torchvision.transforms.functional import to_pil_image
 
 
18
  from tqdm import tqdm
19
- from util import transform
20
- from gradio_blocks import build_video_to_camvideo
21
- import cv2
22
- import ffmpeg
23
- import shutil
24
- import mediapy
25
 
 
26
  ffmpeg_path = shutil.which('ffmpeg')
27
  mediapy.set_ffmpeg(ffmpeg_path)
28
 
29
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
30
- IMAGES_PER_ROW = 10
31
 
32
  MAXIMAL_FRAMES = 1000
33
- BATCHES_TO_PROCESS = 15
34
  OUTPUT_FPS = 10
35
- MAX_OUT_FRAMES = 70
 
 
 
36
 
37
  CAM_METHODS = {
38
  "GradCAM": GradCAM,
39
- "GradCAM++": GradCAMpp,
40
  "XGradCAM": XGradCAM,
41
- "SmoothGradCAM++": SmoothGradCAMpp,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
 
44
- model = CustomResNet18(90).eval()
45
- model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
46
- cam_model = copy.deepcopy(model)
47
  data_df = pd.read_csv('src/cache/val_df.csv')
48
 
49
  C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict()
@@ -58,16 +89,19 @@ def get_class_idx(name):
58
 
59
  @lru_cache(maxsize=100)
60
  def get_translated(to_translate):
61
- # return "ssss"
62
  return GoogleTranslator(source="en", target="de").translate(to_translate)
63
  for idx in range(90): get_translated(get_class_name(idx))
64
 
65
- def infer_image(image, image_sketch):
66
- image = image if image is not None else image_sketch
 
 
 
 
67
  image = transform(image)
68
  image = image.unsqueeze(0)
69
  with torch.no_grad():
70
- output = model(image)
71
  distribution = torch.nn.functional.softmax(output, dim=1)
72
  ret = defaultdict(float)
73
  for idx, prob in enumerate(distribution[0]):
@@ -75,32 +109,51 @@ def infer_image(image, image_sketch):
75
  ret[animal] = prob.item()
76
  return ret
77
 
78
- def gradcam(image, image_sketch=None, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
79
- image = image if image is not None else image_sketch
80
- if layer == 'layer1': layers = [model.resnet.layer1]
81
- elif layer == 'layer2': layers = [model.resnet.layer2]
82
- elif layer == 'layer3': layers = [model.resnet.layer3]
83
- elif layer == 'layer4': layers = [model.resnet.layer4]
84
- else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
85
 
86
- model.eval()
87
- img_tensor = transform(image).unsqueeze(0)
88
- cam = CAM_METHODS[cam_method](model, target_layer=layers)
89
- output = model(img_tensor)
90
- class_to_explain = output.squeeze(0).argmax().item() if specific_class == "Predicted Class" else get_class_idx(specific_class)
91
- activation_map = cam(class_to_explain, output)
92
- result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha)
93
- cam.remove_hooks()
94
 
95
- # # height maximal 300px
96
- # if result.size[1] > 300:
97
- # ratio = 300 / result.size[1]
98
- # result = result.resize((int(result.size[0] * ratio), 300))
99
- return result
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
103
  global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
 
 
 
 
104
  video = cv2.VideoCapture(video)
105
  fps = int(video.get(cv2.CAP_PROP_FPS))
106
  if OUTPUT_FPS == -1: OUTPUT_FPS = fps
@@ -127,36 +180,32 @@ def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_cla
127
  print(f'Frames to process: {len(frames)}')
128
 
129
  processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
130
- # generate lists in lists for the images for batch processing. 10 images per inner list..
131
  batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)]
132
 
133
- model.eval()
134
- if layer == 'layer1': layers = [model.resnet.layer1]
135
- elif layer == 'layer2': layers = [model.resnet.layer2]
136
- elif layer == 'layer3': layers = [model.resnet.layer3]
137
- elif layer == 'layer4': layers = [model.resnet.layer4]
138
- else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
139
- cam = CAM_METHODS[cam_method](model, target_layer=layers)
140
  results = list()
141
- for i, batch in enumerate(tqdm(batched)):
142
- images_tensor = torch.stack([transform(image) for image in batch])
143
- outputs = model(images_tensor)
144
- out_classes = [output.argmax().item() for output in outputs]
145
- classes_to_explain = out_classes if specific_class == "Predicted Class" else [get_class_idx(specific_class)] * len(out_classes)
146
- activation_maps = cam(classes_to_explain, outputs)
147
- for j, activation_map in enumerate(activation_maps[0]):
148
- result = overlay_mask(batch[j], to_pil_image(activation_map, mode='F'), alpha=alpha)
149
- results.append(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR))
150
- cam.remove_hooks()
 
 
 
 
 
 
 
 
151
 
152
  # save video
153
- # fourcc = cv2.VideoWriter_fourcc(*'AVC1')
154
- # fourcc = cv2.VideoWriter_fourcc(*'MP4V')
155
- # fourcc = cv2.VideoWriter_fourcc(*'XVID')
156
- # size = (results[0].shape[1], results[0].shape[0])
157
- # video = cv2.VideoWriter('src/results/gradcam_video.mp4', fourcc, OUTPUT_FPS, size)
158
- # for frame in results:
159
- # video.write(frame)
160
  mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS)
161
  video.release()
162
  return 'src/results/gradcam_video.mp4'
@@ -190,10 +239,15 @@ def load_examples():
190
  for j in range(IMAGES_PER_ROW):
191
  if i * IMAGES_PER_ROW + j >= len(images_to_load): break
192
  image = images_to_load[i * IMAGES_PER_ROW + j]
 
 
 
 
 
193
  loaded_images[image_type].append(
194
  gr.Image(
195
- value=os.path.join(full_path, image),
196
- label=f"image ({get_translated(image.split('.')[0])})",
197
  type="pil",
198
  interactive=False,
199
  elem_classes=["selectable_images"],
@@ -224,22 +278,13 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
224
  # INPUT IMAGE
225
  # -------------------------------------------
226
  with gr.Row():
227
- with gr.Tab("Upload Image"):
228
- with gr.Row(variant="panel", equal_height=True):
229
- user_image = gr.Image(
230
- type="pil",
231
- label="Upload Your Own Image",
232
- info="You can also upload your own image for prediction.",
233
- )
234
- with gr.Tab("Draw Image"):
235
- with gr.Row(variant="panel", equal_height=True):
236
- user_image_sketched = gr.Image(
237
- type="pil",
238
- source="canvas",
239
- tool="color-sketch",
240
- label="Draw Your Own Image",
241
- info="You can also draw your own image for prediction.",
242
- )
243
 
244
  # -------------------------------------------
245
  # TOOLS
@@ -257,7 +302,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
257
  scale=5,
258
  )
259
  predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
260
- predict_mode_button.click(fn=infer_image, inputs=[user_image, user_image_sketched], outputs=output, queue=True)
261
 
262
  # -------------------------------------------
263
  # EXPLAIN
@@ -265,16 +310,20 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
265
  with gr.Tab("Explain Image"):
266
  with gr.Row():
267
  with gr.Column():
 
268
  cam_method = gr.Radio(
269
  list(CAM_METHODS.keys()),
270
  label="GradCAM Method",
 
271
  value="GradCAM",
272
  interactive=True,
273
  scale=2,
274
  )
275
- cam_method.description = "Here you can choose the GradCAM method."
276
- cam_method.description_place = "left"
277
 
 
 
 
 
278
  alpha = gr.Slider(
279
  minimum=.1,
280
  maximum=.9,
@@ -283,46 +332,99 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
283
  step=.1,
284
  label="Alpha",
285
  scale=1,
 
286
  )
287
- alpha.description = "Here you can choose the alpha value."
288
- alpha.description_place = "left"
289
 
 
 
 
 
 
290
  layer = gr.Radio(
291
- ["layer1", "layer2", "layer3", "layer4", "all"],
292
  label="Layer",
293
  value="layer4",
294
  interactive=True,
295
  scale=2,
 
296
  )
297
- layer.description = "Here you can choose the layer to visualize."
298
- layer.description_place = "left"
299
 
 
 
 
 
 
300
  animal_to_explain = gr.Dropdown(
301
  choices=["Predicted Class"] + ALL_CLASSES,
302
  label="Animal",
303
  value="Predicted Class",
304
  interactive=True,
305
  scale=2,
 
306
  )
307
- animal_to_explain.description = "Here you can choose the animal to explain. If you choose 'Predicted Class' the method will explain the predicted class."
308
- animal_to_explain.description_place = "center"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
 
 
 
 
 
 
 
 
 
310
  with gr.Column():
311
  output_cam = gr.Image(
312
  type="pil",
313
  label="GradCAM",
314
- info="GradCAM visualization"
315
-
316
  )
317
 
318
  gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
319
- gradcam_mode_button.click(fn=gradcam, inputs=[user_image, user_image_sketched, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
320
 
321
  # -------------------------------------------
322
  # Video CAM
323
  # -------------------------------------------
324
  with gr.Tab("Explain Video"):
325
- build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video)
326
 
327
  # -------------------------------------------
328
  # EXAMPLES
 
1
  import copy
2
  import os
3
  import sys
4
+
5
  sys.path.append('src')
6
+ import shutil
7
  from collections import defaultdict
8
  from functools import lru_cache
9
+
10
+ import cv2
11
  import gradio as gr
12
+ import mediapy
13
  import numpy as np
14
  import pandas as pd
15
  import torch
16
  from deep_translator import GoogleTranslator
17
+ from gradio_blocks import build_video_to_camvideo
18
  from Nets import CustomResNet18
19
  from PIL import Image
20
+
21
+ from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
22
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
23
+ from pytorch_grad_cam.utils.image import show_cam_on_image
24
+
25
  from tqdm import tqdm
26
+ import util
27
+ from util import transform, CustomImageCache, imageCacheWrapper
 
 
 
 
28
 
29
+ util.ImageCache = CustomImageCache(60, False)
30
  ffmpeg_path = shutil.which('ffmpeg')
31
  mediapy.set_ffmpeg(ffmpeg_path)
32
 
33
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
34
+ IMAGES_PER_ROW = 5
35
 
36
  MAXIMAL_FRAMES = 1000
37
+ BATCHES_TO_PROCESS = 20
38
  OUTPUT_FPS = 10
39
+ MAX_OUT_FRAMES = 60
40
+
41
+ MODEL = CustomResNet18(90).eval()
42
+ MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
43
 
44
  CAM_METHODS = {
45
  "GradCAM": GradCAM,
46
+ "GradCAM++": GradCAMPlusPlus,
47
  "XGradCAM": XGradCAM,
48
+ "HiResCAM": HiResCAM,
49
+ "EigenCAM": EigenCAM
50
+ }
51
+
52
+ LAYERS = {
53
+ 'layer1': MODEL.resnet.layer1,
54
+ 'layer2': MODEL.resnet.layer2,
55
+ 'layer3': MODEL.resnet.layer3,
56
+ 'layer4': MODEL.resnet.layer4,
57
+ 'all': [MODEL.resnet.layer1, MODEL.resnet.layer2, MODEL.resnet.layer3, MODEL.resnet.layer4],
58
+ 'layer3+4': [MODEL.resnet.layer3, MODEL.resnet.layer4]
59
+ }
60
+
61
+ CV2_COLORMAPS = {
62
+ "Autumn": cv2.COLORMAP_AUTUMN,
63
+ "Bone": cv2.COLORMAP_BONE,
64
+ "Jet": cv2.COLORMAP_JET,
65
+ "Winter": cv2.COLORMAP_WINTER,
66
+ "Rainbow": cv2.COLORMAP_RAINBOW,
67
+ "Ocean": cv2.COLORMAP_OCEAN,
68
+ "Summer": cv2.COLORMAP_SUMMER,
69
+ "Pink": cv2.COLORMAP_PINK,
70
+ "Hot": cv2.COLORMAP_HOT,
71
+ "Magma": cv2.COLORMAP_MAGMA,
72
+ "Inferno": cv2.COLORMAP_INFERNO,
73
+ "Plasma": cv2.COLORMAP_PLASMA,
74
+ "Twilight": cv2.COLORMAP_TWILIGHT,
75
  }
76
 
77
+ # cam_model = copy.deepcopy(model)
 
 
78
  data_df = pd.read_csv('src/cache/val_df.csv')
79
 
80
  C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict()
 
89
 
90
  @lru_cache(maxsize=100)
91
  def get_translated(to_translate):
 
92
  return GoogleTranslator(source="en", target="de").translate(to_translate)
93
  for idx in range(90): get_translated(get_class_name(idx))
94
 
95
+ @imageCacheWrapper
96
+ def infer_image(image):
97
+ if isinstance(image, dict):
98
+ # Its the image and a mask as pillow both -> Combine them to one image
99
+ image = Image.blend(image["image"], image["mask"], alpha=0.5)
100
+ image.save('src/results/infer_image.png')
101
  image = transform(image)
102
  image = image.unsqueeze(0)
103
  with torch.no_grad():
104
+ output = MODEL(image)
105
  distribution = torch.nn.functional.softmax(output, dim=1)
106
  ret = defaultdict(float)
107
  for idx, prob in enumerate(distribution[0]):
 
109
  ret[animal] = prob.item()
110
  return ret
111
 
112
+ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
113
+ if image is None:
114
+ raise gr.Error("Please upload an image.")
 
 
 
 
115
 
116
+ if isinstance(image, dict):
117
+ # Its the image and a mask as pillow both -> Combine them to one image
118
+ image = Image.blend(image["image"], image["mask"], alpha=0.5)
 
 
 
 
 
119
 
120
+ if colormap not in CV2_COLORMAPS.keys():
121
+ raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
122
+ else:
123
+ colormap = CV2_COLORMAPS[colormap]
124
+
125
+ image_width, image_height = image.size
126
+ if image_width > 4000 or image_height > 4000:
127
+ raise gr.Error("The image is too big. The maximal size is 4000x4000.")
128
+
129
+
130
+ MODEL.eval()
131
+ layers = LAYERS[layer]
132
+
133
+ image_tensor = transform(image).unsqueeze(0)
134
+ targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None
135
 
136
+ with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
137
+ grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
138
+
139
+ grayscale_cam = grayscale_cam[0, :]
140
+ grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
141
+ image = np.float32(image)
142
+ visualization = None
143
+ if BWHighlight:
144
+ image = image * grayscale_cam[..., np.newaxis]
145
+ visualization = image.astype(np.uint8)
146
+ else:
147
+ image = image / 255
148
+ visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
149
+ return Image.fromarray(visualization)
150
 
151
+ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
152
  global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
153
+ if colormap not in CV2_COLORMAPS.keys():
154
+ raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
155
+ else:
156
+ colormap = CV2_COLORMAPS[colormap]
157
  video = cv2.VideoCapture(video)
158
  fps = int(video.get(cv2.CAP_PROP_FPS))
159
  if OUTPUT_FPS == -1: OUTPUT_FPS = fps
 
180
  print(f'Frames to process: {len(frames)}')
181
 
182
  processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
183
+ # generate lists in lists for the images for batch processing. BATCHES_TO_PROCESS images per inner list
184
  batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)]
185
 
186
+ MODEL.eval()
187
+ layers = LAYERS[layer]
 
 
 
 
 
188
  results = list()
189
+ targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None
190
+ with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
191
+ for i, batch in enumerate(tqdm(batched)):
192
+ images_tensor = torch.stack([transform(image) for image in batch])
193
+
194
+ grayscale_cam = cam(input_tensor=images_tensor, targets=targets, aug_smooth=False, eigen_smooth=use_eigen_smooth)
195
+ for i, image in enumerate(batch):
196
+ _grayscale_cam = grayscale_cam[i, :]
197
+ _grayscale_cam = cv2.resize(_grayscale_cam, (width, height), interpolation=cv2.INTER_LINEAR)
198
+ image = np.float32(image)
199
+ visualization = None
200
+ if BWHighlight:
201
+ image = image * _grayscale_cam[..., np.newaxis]
202
+ visualization = image.astype(np.uint8)
203
+ else:
204
+ image = image / 255
205
+ visualization = show_cam_on_image(image, _grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
206
+ results.append(visualization)
207
 
208
  # save video
 
 
 
 
 
 
 
209
  mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS)
210
  video.release()
211
  return 'src/results/gradcam_video.mp4'
 
239
  for j in range(IMAGES_PER_ROW):
240
  if i * IMAGES_PER_ROW + j >= len(images_to_load): break
241
  image = images_to_load[i * IMAGES_PER_ROW + j]
242
+ name = f"{image.split('.')[0]} ({get_translated(image.split('.')[0])})"
243
+ image = Image.open(os.path.join(full_path, image))
244
+ # scale so that the longest side is 600px
245
+ scale = 600 / max(image.size)
246
+ image = image.resize((int(image.size[0] * scale), int(image.size[1] * scale)))
247
  loaded_images[image_type].append(
248
  gr.Image(
249
+ value=image,
250
+ label=name,
251
  type="pil",
252
  interactive=False,
253
  elem_classes=["selectable_images"],
 
278
  # INPUT IMAGE
279
  # -------------------------------------------
280
  with gr.Row():
281
+ with gr.Row(variant="panel", equal_height=True):
282
+ user_image = gr.Image(
283
+ type="pil",
284
+ label="Upload Your Own Image",
285
+ tool="sketch",
286
+ interactive=True,
287
+ )
 
 
 
 
 
 
 
 
 
288
 
289
  # -------------------------------------------
290
  # TOOLS
 
302
  scale=5,
303
  )
304
  predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
305
+ predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True)
306
 
307
  # -------------------------------------------
308
  # EXPLAIN
 
310
  with gr.Tab("Explain Image"):
311
  with gr.Row():
312
  with gr.Column():
313
+ _info = "There are different GradCAM methods. You can read more about them here: (https://github.com/jacobgil/pytorch-grad-cam#references)."
314
  cam_method = gr.Radio(
315
  list(CAM_METHODS.keys()),
316
  label="GradCAM Method",
317
+ info=_info,
318
  value="GradCAM",
319
  interactive=True,
320
  scale=2,
321
  )
 
 
322
 
323
+ _info = """
324
+ The alpha value is used to blend the original image with the GradCAM visualization. If you choose a value of 0.5 the original image and the GradCAM visualization will be blended equally.
325
+ If you choose a value of 0.1 the original image will be barely visible and if you choose a value of 0.9 the GradCAM visualization will be barely visible.
326
+ """
327
  alpha = gr.Slider(
328
  minimum=.1,
329
  maximum=.9,
 
332
  step=.1,
333
  label="Alpha",
334
  scale=1,
335
+ info=_info
336
  )
 
 
337
 
338
+ _info = """
339
+ The layer is used to choose the layer of the ResNet50 model. The GradCAM visualization will be based on this layer.
340
+ Best to choose is the last layer (layer4) because it is the layer with the most information before the final prediction. This makes the GradCAM visualization the most meaningful.
341
+ If all layers are chosen the GradCAM visualization will be averaged over all layers.
342
+ """
343
  layer = gr.Radio(
344
+ LAYERS.keys(),
345
  label="Layer",
346
  value="layer4",
347
  interactive=True,
348
  scale=2,
349
+ info=_info
350
  )
 
 
351
 
352
+ _info = """
353
+ Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
354
+ If you choose a specific class the GradCAM visualization will be based on this class.
355
+ For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
356
+ """
357
  animal_to_explain = gr.Dropdown(
358
  choices=["Predicted Class"] + ALL_CLASSES,
359
  label="Animal",
360
  value="Predicted Class",
361
  interactive=True,
362
  scale=2,
363
+ info=_info
364
  )
365
+
366
+ with gr.Row():
367
+ _info = """
368
+ Here you can choose the colormap. Instead of a colormap you can also choose "BW Highlight" to just keep the original image and highlight the important parts of the image.
369
+ If you select "BW Highlight" the colormap will be ignored.
370
+ """
371
+ colormap = gr.Dropdown(
372
+ choices=list(CV2_COLORMAPS.keys()),
373
+ label="Colormap",
374
+ value="Jet",
375
+ interactive=True,
376
+ scale=2,
377
+ info=_info
378
+ )
379
+
380
+ bw_highlight = gr.Checkbox(
381
+ label="BW Highlight",
382
+ value=False,
383
+ interactive=True,
384
+ scale=1,
385
+ )
386
+ bw_highlight.description = "Here you can choose if you want to highlight the important parts of the image in black and white."
387
+
388
+ with gr.Row():
389
+ _info = """
390
+ The Eigen Smooth is a method to smooth the GradCAM visualization.
391
+ """
392
+ use_eigen_smooth = gr.Checkbox(
393
+ label="Eigen Smooth",
394
+ value=False,
395
+ interactive=True,
396
+ scale=1,
397
+ info=_info
398
+ )
399
+ _info = """
400
+ The Aug Smooth is also a method to smooth the GradCAM visualization. But this method needs a lot of performance and is therefore slow.
401
+ """
402
 
403
+ use_aug_smooth = gr.Checkbox(
404
+ label="Aug Smooth",
405
+ value=False,
406
+ interactive=True,
407
+ scale=1,
408
+ info=_info
409
+ )
410
+
411
+
412
  with gr.Column():
413
  output_cam = gr.Image(
414
  type="pil",
415
  label="GradCAM",
416
+ info="GradCAM visualization",
417
+ scale=5,
418
  )
419
 
420
  gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
421
+ gradcam_mode_button.click(fn=gradcam, inputs=[user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
422
 
423
  # -------------------------------------------
424
  # Video CAM
425
  # -------------------------------------------
426
  with gr.Tab("Explain Video"):
427
+ build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video)
428
 
429
  # -------------------------------------------
430
  # EXAMPLES
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
requirements_old.txt ADDED
Binary file (4.01 kB). View file
 
src/example_videos/jellyfish_-_110877 (360p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:53e45f6bd34aaeefdabfeddc9d70a4d7670a183137f2469db42f4f90e73ea296
3
- size 797977
 
 
 
 
src/example_videos/monarch_-_327 (360p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d0c163fd1ca2ec280c13430f8606d14e8c490980d66b63610ccf3e5af581138e
3
- size 428449
 
 
 
 
src/example_videos/pexels-zlatin-georgiev-5607745 (240p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:02d82cff5c3c4eb51edd6f07e64e33116cd6889ad45dfb9521aa89406022c539
3
- size 470993
 
 
 
 
src/example_videos/pexels-zlatin-georgiev-7173031 (240p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6931230df61b43d063905f96e313ea47ac3b3ca16fd5d749bcd167b07d83ee69
3
- size 268559
 
 
 
 
src/example_videos/pexels_videos_2556839 (240p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:065667e0d791af82df46c2b7667f17d5a0687fce606fc7e5c216a0e9c3045f76
3
- size 402447
 
 
 
 
src/example_videos/squirrel_on_a_wood (360p).mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9af343138b0a589ccee2c9469f97ce55cbee720ffec9084429f33ae1e37e4f12
3
- size 2006082
 
 
 
 
src/examples/AI_Generated/goat (2).png DELETED

Git LFS Details

  • SHA256: e1346e0c71f880f274de273ed6eee5e8ad2bc7b2d767560459192d5de9bec8b8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
src/examples/AI_Generated/koala.png DELETED

Git LFS Details

  • SHA256: 9ac3ad3beab6456614f297254589507af1df3191db8bfcbad26924c5d996e831
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
src/examples/AI_Generated/rabbit.png DELETED

Git LFS Details

  • SHA256: d376c67d72e403102393e685459084d45d6b804caac42348aa712dc646075796
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
src/examples/AI_Generated/rhinoceros.png DELETED

Git LFS Details

  • SHA256: 2aac1990b1b6a835100ae9f75824fb84f4a8e322aa8e7f73d67f89506c3330d2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.75 MB
src/examples/AI_Generated/swan.png DELETED

Git LFS Details

  • SHA256: 5b0aa528e07a8a78f3bc4c610a40f46b769e7b540c9f228d4027d3288a2edda1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
src/examples/AI_Generated/woodpecker.png DELETED

Git LFS Details

  • SHA256: 98eae49bd0c23630df30be211d899b1cc39fd6877c7dcb8dfc8d1cf1b6f1e8b9
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
src/examples/false_predicted/boar.jpg DELETED

Git LFS Details

  • SHA256: a5ebd92e69975a55c6568f50f642fef14609efaad4fc3e32978b2057ced39f96
  • Pointer size: 130 Bytes
  • Size of remote file: 21.6 kB
src/examples/false_predicted/dolphin.jpg DELETED

Git LFS Details

  • SHA256: cc0c0d4ba9df1df29c21e04b642f2d5cb597489bcb62b90c3b73bed01286de2c
  • Pointer size: 129 Bytes
  • Size of remote file: 3.39 kB
src/examples/false_predicted/horse.jpg DELETED

Git LFS Details

  • SHA256: ab44868a2558e57b36c83af01cecf2c427b9403f9eec550898f7c767a7c4af1c
  • Pointer size: 131 Bytes
  • Size of remote file: 138 kB
src/examples/false_predicted/sparrow.jpg DELETED

Git LFS Details

  • SHA256: 4b09f1b5809696b10847e0c8b5c0e2d4febbe8cc8eb81f4493316c7c7e91d048
  • Pointer size: 129 Bytes
  • Size of remote file: 7.96 kB
src/examples/others/Tiger-fuers-Wohnzimmer-In-Hybridkatzen-steckt-ein-Stueck-Wildnis-2.jpg DELETED

Git LFS Details

  • SHA256: a94bef5b6320925e4d750dc7109b608ca5eb512736078dfaeeac48f6cb070216
  • Pointer size: 131 Bytes
  • Size of remote file: 193 kB
src/examples/true_predicted/dragonfly.jpg DELETED

Git LFS Details

  • SHA256: d98a75d0aa449d80f180f561bc075b03b61b7a2a22d071e100144012431032f6
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
src/examples/true_predicted/goat.jpg DELETED

Git LFS Details

  • SHA256: 6c718b96f73fc30732401210e5b0b586ccb22d366c8d623a57e3bc3eccbceb86
  • Pointer size: 129 Bytes
  • Size of remote file: 9.4 kB
src/examples/true_predicted/panda.jpg DELETED

Git LFS Details

  • SHA256: 66e8398d083eb047a6b92eefee2bbd113786e6858a2f5b5e491e820603ad8b8d
  • Pointer size: 130 Bytes
  • Size of remote file: 14.5 kB
src/examples/true_predicted/rat.jpg DELETED

Git LFS Details

  • SHA256: a77b75dcd4151cd85fb5c57e134cb45c03632d83e6517d81862f95659786e437
  • Pointer size: 130 Bytes
  • Size of remote file: 71.4 kB
src/examples/true_predicted/wombat.jpg DELETED

Git LFS Details

  • SHA256: aed677c70afef65f65950969e652f7fc1ef9ccc10e413719a5a4e65f8f153e6e
  • Pointer size: 130 Bytes
  • Size of remote file: 14.8 kB
src/gradio_blocks.py CHANGED
@@ -3,12 +3,12 @@ import os
3
 
4
  VIDEOS_PER_ROW = 3
5
  VIDEO_EXAMPLES_PATH = "src/example_videos"
6
- def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
7
  with gr.Row():
8
  with gr.Column(scale=2):
9
  gr.Markdown("### Video to GradCAM-Video")
10
  gr.Markdown("Here you can upload a video and visualize the GradCAM.")
11
- gr.Markdown("Please note that this can take a while. Also currently only a maximum of 70 frames can be processed. The video will be cut to 70 frames if it is longer. Furthermore, the video can only consist of a maximum of 1000.")
12
  gr.Markdown("The more frames and fps the video has, the longer it takes to process and the result will be more choppy.")
13
  video_cam_method = gr.Radio(
14
  ["GradCAM", "GradCAM++"],
@@ -17,8 +17,6 @@ def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
17
  interactive=True,
18
  scale=2,
19
  )
20
- video_cam_method.description = "Here you can choose the GradCAM method."
21
- video_cam_method.description_place = "left"
22
 
23
  video_alpha = gr.Slider(
24
  minimum=.1,
@@ -29,35 +27,54 @@ def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
29
  label="Alpha",
30
  scale=1,
31
  )
32
- video_alpha.description = "Here you can choose the alpha value."
33
- video_alpha.description_place = "left"
34
 
35
  video_layer = gr.Radio(
36
- ["layer1", "layer2", "layer3", "layer4", "all"],
37
- label="Layer",
38
- value="layer4",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  interactive=True,
40
  scale=2,
41
  )
42
- video_layer.description = "Here you can choose the layer to visualize."
43
- video_layer.description_place = "left"
44
 
45
- video_animal_to_explain = gr.Dropdown(
46
- choices=["Predicted Class"] + ALL_CLASSES,
47
- label="Animal",
48
- value="Predicted Class",
49
  interactive=True,
50
- scale=2,
51
  )
52
- video_animal_to_explain.description = "Here you can choose the animal to explain. If you choose 'Predicted Class' the method will explain the predicted class."
53
- video_animal_to_explain.description_place = "center"
 
 
 
 
 
 
 
54
  with gr.Column(scale=1):
55
  with gr.Column():
56
  video_in = gr.Video(autoplay=False, include_audio=False)
57
  video_out = gr.Video(autoplay=False, include_audio=False)
58
 
59
  gif_cam_mode_button = gr.Button(value="Show GradCAM-Video", label="GradCAM", scale=1)
60
- gif_cam_mode_button.click(fn=gradcam_video, inputs=[video_in, video_alpha, video_cam_method, video_layer, video_animal_to_explain], outputs=[video_out], queue=True)
61
 
62
  with gr.Row():
63
  with gr.Column():
 
3
 
4
  VIDEOS_PER_ROW = 3
5
  VIDEO_EXAMPLES_PATH = "src/example_videos"
6
+ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video):
7
  with gr.Row():
8
  with gr.Column(scale=2):
9
  gr.Markdown("### Video to GradCAM-Video")
10
  gr.Markdown("Here you can upload a video and visualize the GradCAM.")
11
+ gr.Markdown("Please note that this can take a while. Also currently only a maximum of 60 frames can be processed. The video will be cut to 60 frames if it is longer. Furthermore, the video can only consist of a maximum of 1000.")
12
  gr.Markdown("The more frames and fps the video has, the longer it takes to process and the result will be more choppy.")
13
  video_cam_method = gr.Radio(
14
  ["GradCAM", "GradCAM++"],
 
17
  interactive=True,
18
  scale=2,
19
  )
 
 
20
 
21
  video_alpha = gr.Slider(
22
  minimum=.1,
 
27
  label="Alpha",
28
  scale=1,
29
  )
 
 
30
 
31
  video_layer = gr.Radio(
32
+ LAYERS.keys(),
33
+ label="Layer",
34
+ value="layer4",
35
+ interactive=True,
36
+ scale=2,
37
+ )
38
+
39
+ video_animal_to_explain = gr.Dropdown(
40
+ choices=["Predicted Class"] + ALL_CLASSES,
41
+ label="Animal",
42
+ value="Predicted Class",
43
+ interactive=True,
44
+ scale=2,
45
+ )
46
+
47
+ with gr.Row():
48
+ colormap = gr.Dropdown(
49
+ choices=list(CV2_COLORMAPS.keys()),
50
+ label="Colormap",
51
+ value="Jet",
52
  interactive=True,
53
  scale=2,
54
  )
 
 
55
 
56
+ bw_highlight = gr.Checkbox(
57
+ label="BW Highlight",
58
+ value=False,
 
59
  interactive=True,
60
+ scale=1,
61
  )
62
+
63
+ with gr.Row():
64
+ use_eigen_smooth = gr.Checkbox(
65
+ label="Eigen Smooth",
66
+ value=False,
67
+ interactive=True,
68
+ scale=1,
69
+ )
70
+
71
  with gr.Column(scale=1):
72
  with gr.Column():
73
  video_in = gr.Video(autoplay=False, include_audio=False)
74
  video_out = gr.Video(autoplay=False, include_audio=False)
75
 
76
  gif_cam_mode_button = gr.Button(value="Show GradCAM-Video", label="GradCAM", scale=1)
77
+ gif_cam_mode_button.click(fn=gradcam_video, inputs=[video_in, colormap, use_eigen_smooth, bw_highlight, video_alpha, video_cam_method, video_layer, video_animal_to_explain], outputs=[video_out], queue=True)
78
 
79
  with gr.Row():
80
  with gr.Column():
src/header.md CHANGED
@@ -2,15 +2,17 @@
2
 
3
  This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
4
 
5
- The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified.
6
 
7
- The employed model is ResNet18, which was trained on the dataset using transfer learning techniques.
8
  Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
9
 
10
  ## Usage 🦎
11
 
12
- **Predict:** In the "Predict" tab, the model can be applied to high-resolution images to predict the species among the 90 different animals.
13
 
14
- **Explain:** Under the "Explain" tab, the model can be applied to high-resolution images to obtain an explanation for the prediction. This explanation is generated using the [Grad-CAM](https://github.com/frgfm/torch-cam.git) method.
15
 
16
- **Example Images:** The "Example Images" section allows users to view sample images from the dataset. These images can be utilized as input by simply dragging and dropping them onto the interface. It is important to note that these example images were not part of the training data used for the model.
 
 
 
2
 
3
  This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
4
 
5
+ The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
6
 
7
+ The employed model is ResNet50, which was trained on the dataset using transfer learning techniques.
8
  Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
9
 
10
  ## Usage 🦎
11
 
12
+ **Predict:** In the "Predict" tab, the model can be applied to the uploaded image to obtain a prediction. This is also interessting to get the animal for the following explaination.
13
 
14
+ **Explain Image:** Under the "Explain Image" tab, you can get an explanation of the prediction in the form of a generated heatmap. We are using [this](https://github.com/jacobgil/pytorch-grad-cam) cool implementation of Grad-CAM to generate the heatmaps!
15
 
16
+ **Explain Video**: The same as above, but for short videos. The video is split into frames and the model is applied to each frame. The resulting heatmaps are then combined to a video again.
17
+
18
+ **Example Images:** The "Example Images" section allows users to view sample images from the dataset and another sources. Some of the images and videos are from [1](https://www.pexels.com/), [2](https://pixabay.com/) and [3](https://www.bing.com/create).
src/results/gradcam_video.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:935c594a0ecbc14565723ad3896989aaaa6021232d368bf1cda5f8e9c0bf9e74
3
- size 922461
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9617d53ad717194350c99f6b1d2a172f01e712e4109c76b16fe3f70f32c4570
3
+ size 772080
src/{example_videos/butterfly_-_38947 (360p).mp4 → results/infer_image.png} RENAMED
File without changes
src/util.py CHANGED
@@ -4,7 +4,8 @@ from sklearn.preprocessing import LabelEncoder
4
  from tqdm import tqdm
5
  from PIL import Image
6
  import torch
7
-
 
8
 
9
  class AnimalDataset(Dataset):
10
  def __init__(self, df, transform=None):
@@ -41,3 +42,42 @@ transform = transforms.Compose([
41
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
42
  ])
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from tqdm import tqdm
5
  from PIL import Image
6
  import torch
7
+ import imagehash
8
+ ImageCache = None
9
 
10
  class AnimalDataset(Dataset):
11
  def __init__(self, df, transform=None):
 
42
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43
  ])
44
 
45
+ class CustomImageCache:
46
+ def __init__(self, cache_size=50, debug=False):
47
+ self.cache = dict()
48
+ self.cache_size = 50
49
+ self.debug = debug
50
+ self.cache_hits = 0
51
+ self.cache_misses = 0
52
+
53
+ def __getitem__(self, image):
54
+ if isinstance(image, dict):
55
+ # Its the image and a mask as pillow both -> Combine them to one image
56
+ image = Image.blend(image["image"], image["mask"], alpha=0.5)
57
+ key = imagehash.average_hash(image)
58
+
59
+ if key in self.cache:
60
+ if self.debug: print("Cache hit!")
61
+ self.cache_hits += 1
62
+ return self.cache[key]
63
+ else:
64
+ if self.debug: print("Cache miss!")
65
+ self.cache_misses += 1
66
+ if len(self.cache.keys()) >= self.cache_size:
67
+ if self.debug: print("Cache full, popping item!")
68
+ self.cache.popitem()
69
+ self.cache[key] = image
70
+ return self.cache[key]
71
+
72
+ def __len__(self):
73
+ return len(self.cache.keys())
74
+
75
+ def print_info(self):
76
+ print(f"Cache size: {len(self)}")
77
+ print(f"Cache hits: {self.cache_hits}")
78
+ print(f"Cache misses: {self.cache_misses}")
79
+
80
+ def imageCacheWrapper(fn):
81
+ def wrapper(image):
82
+ return fn(ImageCache[image])
83
+ return wrapper