GabrielML commited on
Commit
b426c59
Β·
1 Parent(s): 6b09016

Add new classes and features

Browse files
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import copy
2
  import os
3
  import sys
@@ -16,31 +17,43 @@ import torch
16
  from deep_translator import GoogleTranslator
17
  from gradio_blocks import build_video_to_camvideo
18
  from Nets import CustomResNet18
19
- from PIL import Image
20
 
21
  from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
22
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
23
  from pytorch_grad_cam.utils.image import show_cam_on_image
24
 
25
  from tqdm import tqdm
26
- import util
27
- from util import transform, CustomImageCache, imageCacheWrapper
 
28
 
29
- util.ImageCache = CustomImageCache(60, False)
30
  ffmpeg_path = shutil.which('ffmpeg')
31
  mediapy.set_ffmpeg(ffmpeg_path)
32
 
33
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
34
  IMAGES_PER_ROW = 5
35
 
36
- MAXIMAL_FRAMES = 1000
37
- BATCHES_TO_PROCESS = 10
38
  OUTPUT_FPS = 10
39
- MAX_OUT_FRAMES = 60
40
 
41
- MODEL = CustomResNet18(90).eval()
42
  MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
43
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  CAM_METHODS = {
45
  "GradCAM": GradCAM,
46
  "GradCAM++": GradCAMPlusPlus,
@@ -87,16 +100,21 @@ def get_class_name(idx):
87
  def get_class_idx(name):
88
  return C_NAME_TO_NUM[name]
89
 
90
- @lru_cache(maxsize=100)
91
- def get_translated(to_translate):
92
- return GoogleTranslator(source="en", target="de").translate(to_translate)
93
- for idx in range(90): get_translated(get_class_name(idx))
 
 
 
 
 
 
 
 
94
 
95
- @imageCacheWrapper
96
- def infer_image(image):
97
- if isinstance(image, dict):
98
- # Its the image and a mask as pillow both -> Combine them to one image
99
- image = Image.blend(image["image"], image["mask"], alpha=0.5)
100
  image.save('src/results/infer_image.png')
101
  image = transform(image)
102
  image = image.unsqueeze(0)
@@ -105,11 +123,13 @@ def infer_image(image):
105
  distribution = torch.nn.functional.softmax(output, dim=1)
106
  ret = defaultdict(float)
107
  for idx, prob in enumerate(distribution[0]):
108
- animal = f'{get_class_name(idx)} ({get_translated(get_class_name(idx))})'
 
 
109
  ret[animal] = prob.item()
110
  return ret
111
 
112
- def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
113
  if image is None:
114
  raise gr.Error("Please upload an image.")
115
 
@@ -123,8 +143,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
123
  colormap = CV2_COLORMAPS[colormap]
124
 
125
  image_width, image_height = image.size
126
- if image_width > 4000 or image_height > 4000:
127
- raise gr.Error("The image is too big. The maximal size is 4000x4000.")
128
 
129
 
130
  MODEL.eval()
@@ -135,6 +155,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
135
 
136
  with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
137
  grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
 
 
138
 
139
  grayscale_cam = grayscale_cam[0, :]
140
  grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
@@ -146,10 +168,25 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
146
  else:
147
  image = image / 255
148
  visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
149
- return Image.fromarray(visualization)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
152
  global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
 
153
  if colormap not in CV2_COLORMAPS.keys():
154
  raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
155
  else:
@@ -159,8 +196,8 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
159
  if OUTPUT_FPS == -1: OUTPUT_FPS = fps
160
  width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
161
  height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
162
- if width > 3000 or height > 3000:
163
- raise gr.Error("The video is too big. The maximal size is 3000x3000.")
164
  print(f'FPS: {fps}, Width: {width}, Height: {height}')
165
 
166
  frames = list()
@@ -213,21 +250,21 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
213
  def load_examples():
214
  folder_name_to_header = {
215
  "AI_Generated": "AI Generated Images",
216
- "true_predicted": "True Predicted Images (Validation Set)",
217
- "false_predicted": "False Predicted Images (Validation Set)",
218
  "others": "Other interesting images from the internet"
219
  }
220
 
221
  images_description = {
222
  "AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
223
- "true_predicted": "These images are from the validation set and the model predicted them correctly.",
224
- "false_predicted": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)",
225
  "others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
226
  }
227
 
228
  loaded_images = defaultdict(list)
229
 
230
- for image_type in ["AI_Generated", "true_predicted", "false_predicted", "others"]:
231
  # for image_type in os.listdir(IMAGE_PATH):
232
  full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
233
  gr.Markdown(f'## {folder_name_to_header[image_type]}')
@@ -239,7 +276,7 @@ def load_examples():
239
  for j in range(IMAGES_PER_ROW):
240
  if i * IMAGES_PER_ROW + j >= len(images_to_load): break
241
  image = images_to_load[i * IMAGES_PER_ROW + j]
242
- name = f"{image.split('.')[0]} ({get_translated(image.split('.')[0])})"
243
  image = Image.open(os.path.join(full_path, image))
244
  # scale so that the longest side is 600px
245
  scale = 600 / max(image.size)
@@ -273,7 +310,15 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
273
  with gr.Column(scale=1):
274
  pil_logo = Image.open('animals.png')
275
  logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
276
-
 
 
 
 
 
 
 
 
277
  # -------------------------------------------
278
  # INPUT IMAGE
279
  # -------------------------------------------
@@ -282,7 +327,6 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
282
  user_image = gr.Image(
283
  type="pil",
284
  label="Upload Your Own Image",
285
- tool="sketch",
286
  interactive=True,
287
  )
288
 
@@ -301,8 +345,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
301
  info="Top three predicted classes and their confidences.",
302
  scale=5,
303
  )
304
- predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
305
- predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True)
 
306
 
307
  # -------------------------------------------
308
  # EXPLAIN
@@ -348,20 +393,28 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
348
  scale=2,
349
  info=_info
350
  )
 
351
 
352
- _info = """
353
- Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
354
- If you choose a specific class the GradCAM visualization will be based on this class.
355
- For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
356
- """
357
- animal_to_explain = gr.Dropdown(
358
- choices=["Predicted Class"] + ALL_CLASSES,
359
- label="Animal",
360
- value="Predicted Class",
361
- interactive=True,
362
- scale=2,
363
- info=_info
364
- )
 
 
 
 
 
 
 
365
 
366
  with gr.Row():
367
  _info = """
@@ -371,7 +424,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
371
  colormap = gr.Dropdown(
372
  choices=list(CV2_COLORMAPS.keys()),
373
  label="Colormap",
374
- value="Jet",
375
  interactive=True,
376
  scale=2,
377
  info=_info
@@ -410,15 +463,16 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
410
 
411
 
412
  with gr.Column():
 
413
  output_cam = gr.Image(
414
  type="pil",
415
  label="GradCAM",
416
  info="GradCAM visualization",
417
- scale=5,
 
418
  )
419
-
420
- gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
421
- gradcam_mode_button.click(fn=gradcam, inputs=[user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
422
 
423
  # -------------------------------------------
424
  # Video CAM
@@ -434,11 +488,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
434
  loaded_images = load_examples()
435
  for k in loaded_images.keys():
436
  for image in loaded_images[k]:
437
- image.select(fn=lambda x: x, inputs=[image], outputs=[user_image])
438
-
439
-
440
-
441
 
442
  if __name__ == "__main__":
443
  demo.queue()
444
- demo.launch()
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
  import copy
3
  import os
4
  import sys
 
17
  from deep_translator import GoogleTranslator
18
  from gradio_blocks import build_video_to_camvideo
19
  from Nets import CustomResNet18
20
+ from PIL import Image, ImageDraw, ImageFont
21
 
22
  from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
23
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
24
  from pytorch_grad_cam.utils.image import show_cam_on_image
25
 
26
  from tqdm import tqdm
27
+ from util import transform
28
+
29
+ font = ImageFont.truetype("src/Roboto-Regular.ttf", 16)
30
 
 
31
  ffmpeg_path = shutil.which('ffmpeg')
32
  mediapy.set_ffmpeg(ffmpeg_path)
33
 
34
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
35
  IMAGES_PER_ROW = 5
36
 
37
+ MAXIMAL_FRAMES = 700
38
+ BATCHES_TO_PROCESS = 20
39
  OUTPUT_FPS = 10
40
+ MAX_OUT_FRAMES = 70
41
 
42
+ MODEL = CustomResNet18(111).eval()
43
  MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
44
 
45
+ LANGUAGES_TO_SELECT = {
46
+ "None": None,
47
+ "German": "de",
48
+ "French": "fr",
49
+ "Spanish": "es",
50
+ "Italian": "it",
51
+ "Finnish": "fi",
52
+ "Ukrainian": "uk",
53
+ "Japanese": "ja",
54
+ "Hebrew": "iw"
55
+ }
56
+
57
  CAM_METHODS = {
58
  "GradCAM": GradCAM,
59
  "GradCAM++": GradCAMPlusPlus,
 
100
  def get_class_idx(name):
101
  return C_NAME_TO_NUM[name]
102
 
103
+ @lru_cache(maxsize=len(LANGUAGES_TO_SELECT.keys())*111)
104
+ def get_translated(to_translate, target_language="German"):
105
+ target_language = LANGUAGES_TO_SELECT[target_language] if target_language in LANGUAGES_TO_SELECT else target_language
106
+ if target_language == "en": return to_translate
107
+ if target_language not in LANGUAGES_TO_SELECT.values(): raise gr.Error(f'Language {target_language} not found.')
108
+ return GoogleTranslator(source="en", target=target_language).translate(to_translate)
109
+ # for idx in range(111): get_translated(get_class_name(idx))
110
+ with ThreadPoolExecutor(max_workers=30) as executor:
111
+ # give the executor the list of images and args (in this case, the target language)
112
+ # and let the executor map the function to the list of images
113
+ for language in tqdm(LANGUAGES_TO_SELECT.keys(), desc='Preloading translations'):
114
+ executor.map(get_translated, ALL_CLASSES, [language] * len(ALL_CLASSES))
115
 
116
+ def infer_image(image, target_language):
117
+ if image is None: raise gr.Error("Please upload an image.")
 
 
 
118
  image.save('src/results/infer_image.png')
119
  image = transform(image)
120
  image = image.unsqueeze(0)
 
123
  distribution = torch.nn.functional.softmax(output, dim=1)
124
  ret = defaultdict(float)
125
  for idx, prob in enumerate(distribution[0]):
126
+ animal = f'{get_class_name(idx)}'
127
+ if target_language is not None and target_language != "None":
128
+ animal += f' ({get_translated(get_class_name(idx), target_language)})'
129
  ret[animal] = prob.item()
130
  return ret
131
 
132
+ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"):
133
  if image is None:
134
  raise gr.Error("Please upload an image.")
135
 
 
143
  colormap = CV2_COLORMAPS[colormap]
144
 
145
  image_width, image_height = image.size
146
+ if image_width > 6000 or image_height > 6000:
147
+ raise gr.Error("The image is too big. The maximal size is 6000x6000.")
148
 
149
 
150
  MODEL.eval()
 
155
 
156
  with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
157
  grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
158
+ if label_image:
159
+ predicted_animal = get_class_name(np.argmax(cam.outputs.cpu().data.numpy(), axis=-1)[0])
160
 
161
  grayscale_cam = grayscale_cam[0, :]
162
  grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
 
168
  else:
169
  image = image / 255
170
  visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
171
+
172
+ if label_image:
173
+ # add alpha channel to visualization
174
+ visualization = np.concatenate([visualization, np.ones((image_height, image_width, 1), dtype=np.uint8) * 255], axis=-1)
175
+ plt_image = Image.fromarray(visualization, mode="RGBA")
176
+ draw = ImageDraw.Draw(plt_image)
177
+ draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100))
178
+ animal = predicted_animal.capitalize()
179
+ if target_lang is not None and target_lang != "None":
180
+ animal += f' ({get_translated(animal, target_lang)})'
181
+ draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255))
182
+ visualization = np.array(plt_image)
183
+
184
+ out_image = Image.fromarray(visualization)
185
+ return out_image
186
 
187
  def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
188
  global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
189
+ if video is None: raise gr.Error("Please upload a video.")
190
  if colormap not in CV2_COLORMAPS.keys():
191
  raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
192
  else:
 
196
  if OUTPUT_FPS == -1: OUTPUT_FPS = fps
197
  width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
198
  height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
199
+ if width > 2000 or height > 2000:
200
+ raise gr.Error("The video is too big. The maximal size is 2000x2000.")
201
  print(f'FPS: {fps}, Width: {width}, Height: {height}')
202
 
203
  frames = list()
 
250
  def load_examples():
251
  folder_name_to_header = {
252
  "AI_Generated": "AI Generated Images",
253
+ "true": "True Predicted Images (Validation Set)",
254
+ "false": "False Predicted Images (Validation Set)",
255
  "others": "Other interesting images from the internet"
256
  }
257
 
258
  images_description = {
259
  "AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
260
+ "true": "These images are from the validation set and the model predicted them correctly.",
261
+ "false": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)",
262
  "others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
263
  }
264
 
265
  loaded_images = defaultdict(list)
266
 
267
+ for image_type in ["AI_Generated", "true", "false", "others"]:
268
  # for image_type in os.listdir(IMAGE_PATH):
269
  full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
270
  gr.Markdown(f'## {folder_name_to_header[image_type]}')
 
276
  for j in range(IMAGES_PER_ROW):
277
  if i * IMAGES_PER_ROW + j >= len(images_to_load): break
278
  image = images_to_load[i * IMAGES_PER_ROW + j]
279
+ name = f"{image.split('.')[0]}"
280
  image = Image.open(os.path.join(full_path, image))
281
  # scale so that the longest side is 600px
282
  scale = 600 / max(image.size)
 
310
  with gr.Column(scale=1):
311
  pil_logo = Image.open('animals.png')
312
  logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
313
+
314
+ animal_translation_target_language = gr.Dropdown(
315
+ choices=LANGUAGES_TO_SELECT.keys(),
316
+ label="Translation language for animals",
317
+ value="German",
318
+ interactive=True,
319
+ scale=2,
320
+ )
321
+
322
  # -------------------------------------------
323
  # INPUT IMAGE
324
  # -------------------------------------------
 
327
  user_image = gr.Image(
328
  type="pil",
329
  label="Upload Your Own Image",
 
330
  interactive=True,
331
  )
332
 
 
345
  info="Top three predicted classes and their confidences.",
346
  scale=5,
347
  )
348
+ with gr.Row():
349
+ predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=6)
350
+ predict_mode_button.click(fn=infer_image, inputs=[user_image, animal_translation_target_language], outputs=output, queue=True)
351
 
352
  # -------------------------------------------
353
  # EXPLAIN
 
393
  scale=2,
394
  info=_info
395
  )
396
+ with gr.Row():
397
 
398
+ _info = """
399
+ Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
400
+ If you choose a specific class the GradCAM visualization will be based on this class.
401
+ For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
402
+ """
403
+ animal_to_explain = gr.Dropdown(
404
+ choices=["Predicted Class"] + ALL_CLASSES,
405
+ label="Animal",
406
+ value="Predicted Class",
407
+ interactive=True,
408
+ scale=4,
409
+ info=_info
410
+ )
411
+
412
+ show_predicted_class = gr.Checkbox(
413
+ label="Show Predicted Class",
414
+ value=True,
415
+ interactive=True,
416
+ scale=1,
417
+ )
418
 
419
  with gr.Row():
420
  _info = """
 
424
  colormap = gr.Dropdown(
425
  choices=list(CV2_COLORMAPS.keys()),
426
  label="Colormap",
427
+ value="Inferno",
428
  interactive=True,
429
  scale=2,
430
  info=_info
 
463
 
464
 
465
  with gr.Column():
466
+ gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
467
  output_cam = gr.Image(
468
  type="pil",
469
  label="GradCAM",
470
  info="GradCAM visualization",
471
+ show_label=False,
472
+ scale=7,
473
  )
474
+ _inputs = [user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain, show_predicted_class, animal_translation_target_language]
475
+ gradcam_mode_button.click(fn=gradcam, inputs=_inputs, outputs=output_cam, queue=True)
 
476
 
477
  # -------------------------------------------
478
  # Video CAM
 
488
  loaded_images = load_examples()
489
  for k in loaded_images.keys():
490
  for image in loaded_images[k]:
491
+ image.select(fn=lambda x: x, inputs=[image], outputs=[user_image], queue=True, scroll_to_output=True)
 
 
 
492
 
493
  if __name__ == "__main__":
494
  demo.queue()
495
+ print("Starting Gradio server...")
496
+ demo.launch(show_tips=True)
src/Nets.py CHANGED
@@ -1,47 +1,10 @@
1
- import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
  from torchvision import models
5
 
6
- class SimpleCNN(nn.Module):
7
- def __init__(self, k_size=3, pool_size=2, num_classes=1):
8
- super(SimpleCNN, self).__init__()
9
- self.relu = nn.ReLU()
10
- # First Convolutional Layer
11
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=k_size, padding=1)
12
- self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=k_size, stride=1, padding=1)
13
- self.pool1 = nn.MaxPool2d(kernel_size=pool_size)
14
-
15
- # Second Convolutional Layer
16
- self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=k_size, stride=1, padding=1)
17
- self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=k_size, stride=1, padding=1)
18
- self.pool2 = nn.MaxPool2d(kernel_size=pool_size)
19
-
20
- self.conv5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=k_size, stride=1, padding=1)
21
- self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
22
- self.pool3 = nn.MaxPool2d(kernel_size=pool_size)
23
-
24
- self.conv7 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
25
- self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
26
- self.pool4 = nn.MaxPool2d(kernel_size=pool_size)
27
-
28
- # Fully Connected Layers
29
- self.fc = nn.Linear(64*14*14, num_classes) # Adjust the input features based on your input image size
30
-
31
- def forward(self, x):
32
- x = self.pool1(self.relu(self.conv2(self.relu(self.conv1(x)))))
33
- x = self.pool2(self.relu(self.conv4(self.relu(self.conv3(x)))))
34
- x = self.pool3(self.relu(self.conv6(self.relu(self.conv5(x)))))
35
- x = self.pool4(self.relu(self.conv8(self.relu(self.conv7(x)))))
36
- # print(x.shape)
37
- x = x.view(x.size(0), -1)
38
- x = self.fc(x)
39
- return x
40
-
41
  class CustomResNet18(nn.Module):
42
  def __init__(self, num_classes=11):
43
  super(CustomResNet18, self).__init__()
44
- self.resnet = models.resnet50(pretrained=True)
45
  num_features = self.resnet.fc.in_features
46
  self.resnet.fc = nn.Linear(num_features, num_classes)
47
 
 
 
1
  import torch.nn as nn
 
2
  from torchvision import models
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class CustomResNet18(nn.Module):
5
  def __init__(self, num_classes=11):
6
  super(CustomResNet18, self).__init__()
7
+ self.resnet = models.resnet18(pretrained=True)
8
  num_features = self.resnet.fc.in_features
9
  self.resnet.fc = nn.Linear(num_features, num_classes)
10
 
src/Roboto-Regular.ttf ADDED
Binary file (515 kB). View file
 
src/cache/val_df.csv CHANGED
The diff for this file is too large to render. See raw diff
 
src/examples/{false_predicted/squirrel.jpg β†’ false/bee.jpg} RENAMED
File without changes
src/examples/{false_predicted/chimpanzee.jpg β†’ false/coyote.jpg} RENAMED
File without changes
src/examples/{true_predicted/cat.jpg β†’ false/donkey.jpg} RENAMED
File without changes
src/examples/false/goat.jpg ADDED

Git LFS Details

  • SHA256: 5407753c8df3d0a5c901215b2ebcf378d4209f332e6a33cdb30a5006bfbf8d09
  • Pointer size: 131 Bytes
  • Size of remote file: 457 kB
src/examples/false/hornbill.jpg ADDED

Git LFS Details

  • SHA256: 3ccfc55aa247b4eff0483adb1683d2d1d4dd0790dcaff81b3e243e4659dd1bf0
  • Pointer size: 130 Bytes
  • Size of remote file: 83.5 kB
src/examples/false_predicted/starfish.jpg DELETED

Git LFS Details

  • SHA256: d87e919ecb6d94c51affd428927457c26683b7ab32140418337403ebe26a0d45
  • Pointer size: 131 Bytes
  • Size of remote file: 179 kB
src/examples/true/dolphin.jpg ADDED

Git LFS Details

  • SHA256: 729e4bfab228c912f14733ef32107583ed4cdaa2e0f197ff259f6981f24772ac
  • Pointer size: 131 Bytes
  • Size of remote file: 118 kB
src/examples/true/dragonfly.jpg ADDED

Git LFS Details

  • SHA256: 8a33acb02f7e9686f4642f3a878cbf7728c875cedda9e1d3ff4db0831189d5df
  • Pointer size: 131 Bytes
  • Size of remote file: 139 kB
src/examples/{false_predicted β†’ true}/koala.jpg RENAMED
File without changes
src/examples/{false_predicted β†’ true}/sheep.jpg RENAMED
File without changes
src/examples/true/squid.jpg ADDED

Git LFS Details

  • SHA256: 1be527b0a94f05e7d5b5178b187ec95bd2d3fb992ecc38af50afbfef775d65cb
  • Pointer size: 130 Bytes
  • Size of remote file: 19.8 kB
src/examples/true_predicted/cockroach.jpg DELETED

Git LFS Details

  • SHA256: 04e21d254a0e49c8a868b47de902e7eb6571ea28d9413940f817e175a87f3275
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
src/examples/true_predicted/flamingo.jpg DELETED

Git LFS Details

  • SHA256: e3c90ae9176e11b1dcc73f5d9c81b94433bd9d8228919340db80f404e9f6ced4
  • Pointer size: 131 Bytes
  • Size of remote file: 619 kB
src/examples/true_predicted/gorilla.jpg DELETED

Git LFS Details

  • SHA256: 27ce1f1437356309406de0341c4680ea9d6f72d90fd49133db3843c0af272fc8
  • Pointer size: 130 Bytes
  • Size of remote file: 12.1 kB
src/examples/true_predicted/grasshopper.jpg DELETED

Git LFS Details

  • SHA256: 76c372432e4bc478c35b157444cb7923d7c3827f8d3fb7cea8f5625a4b94ac51
  • Pointer size: 130 Bytes
  • Size of remote file: 10.7 kB
src/gradio_blocks.py CHANGED
@@ -29,7 +29,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
29
  )
30
 
31
  video_layer = gr.Radio(
32
- LAYERS.keys(),
33
  label="Layer",
34
  value="layer4",
35
  interactive=True,
@@ -48,7 +48,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
48
  colormap = gr.Dropdown(
49
  choices=list(CV2_COLORMAPS.keys()),
50
  label="Colormap",
51
- value="Jet",
52
  interactive=True,
53
  scale=2,
54
  )
 
29
  )
30
 
31
  video_layer = gr.Radio(
32
+ [f"layer{i}" for i in range(1, 5)],
33
  label="Layer",
34
  value="layer4",
35
  interactive=True,
 
48
  colormap = gr.Dropdown(
49
  choices=list(CV2_COLORMAPS.keys()),
50
  label="Colormap",
51
+ value="Inferno",
52
  interactive=True,
53
  scale=2,
54
  )
src/header.md CHANGED
@@ -2,9 +2,9 @@
2
 
3
  This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
4
 
5
- The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
6
 
7
- The employed model is ResNet50, which was trained on the dataset using transfer learning techniques.
8
  Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
9
 
10
  ## Usage 🦎
 
2
 
3
  This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
4
 
5
+ The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. To add a little more animals to the data, we added an additional 21 unique classes, so we were now working with our own 111-animals dataset. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
6
 
7
+ The employed model is ResNet18, which was trained on the dataset using transfer learning techniques.
8
  Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
9
 
10
  ## Usage 🦎
src/results/gradcam_video.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a9617d53ad717194350c99f6b1d2a172f01e712e4109c76b16fe3f70f32c4570
3
- size 772080
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d88ec14ff35116bf5d8bd65454616aba242d8f79bde4dcbd717aabbcc910670a
3
+ size 917687
src/results/infer_image.png CHANGED

Git LFS Details

  • SHA256: 8a1d8cf8974330c3e6fe91b98860ca140fb46edfb6a1f5c8448c8d5e2ed479c7
  • Pointer size: 131 Bytes
  • Size of remote file: 339 kB

Git LFS Details

  • SHA256: 5fb27d68a14ee2dd5d2f99e5b24cda08ea7245ffb06731108036937eed56b9b5
  • Pointer size: 131 Bytes
  • Size of remote file: 424 kB
src/results/models/best_model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bd7be6abdcf8f64be68324d3b6d82cc4f5e02a12e6462b63b2c190d5a0a4182a
3
- size 95091582
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6a3f852efacebef8dee4ba74c0a73a7f33bf2180c4272dbf233a5c6157d7531
3
+ size 45015274
src/util.py CHANGED
@@ -1,83 +1,9 @@
1
  import torchvision.transforms as transforms
2
- from torch.utils.data import DataLoader, Dataset
3
- from sklearn.preprocessing import LabelEncoder
4
- from tqdm import tqdm
5
- from PIL import Image
6
  import torch
7
- import imagehash
8
- ImageCache = None
9
 
10
- class AnimalDataset(Dataset):
11
- def __init__(self, df, transform=None):
12
- self.paths = df["path"].values
13
- self.targets = df["target"].values
14
- self.encoded_target = df['encoded_target'].values
15
- self.transform = transform
16
- self.images = []
17
- for path in tqdm(self.paths):
18
- self.images.append(Image.open(path).convert("RGB").resize((224, 224)))
19
-
20
- def __len__(self):
21
- return len(self.paths)
22
-
23
- def __getitem__(self, idx):
24
- img = self.images[idx]
25
- if self.transform:
26
- img = self.transform(img)
27
- target = self.targets[idx]
28
- encoded_target = torch.tensor(self.encoded_target[idx]).type(torch.LongTensor)
29
- return img, encoded_target, target
30
-
31
- train_transform = transforms.Compose([
32
- transforms.Resize((224,224)),
33
- transforms.RandomHorizontalFlip(),
34
- transforms.RandomRotation(10),
35
- transforms.ToTensor(),
36
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
- ])
38
  # Define the transformation pipeline
39
  transform = transforms.Compose([
40
  transforms.Resize((224,224)),
41
  transforms.ToTensor(), # Convert the images to PyTorch tensors
42
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
43
- ])
44
-
45
- class CustomImageCache:
46
- def __init__(self, cache_size=50, debug=False):
47
- self.cache = dict()
48
- self.cache_size = 50
49
- self.debug = debug
50
- self.cache_hits = 0
51
- self.cache_misses = 0
52
-
53
- def __getitem__(self, image):
54
- if isinstance(image, dict):
55
- # Its the image and a mask as pillow both -> Combine them to one image
56
- image = Image.blend(image["image"], image["mask"], alpha=0.5)
57
- key = imagehash.average_hash(image)
58
-
59
- if key in self.cache:
60
- if self.debug: print("Cache hit!")
61
- self.cache_hits += 1
62
- return self.cache[key]
63
- else:
64
- if self.debug: print("Cache miss!")
65
- self.cache_misses += 1
66
- if len(self.cache.keys()) >= self.cache_size:
67
- if self.debug: print("Cache full, popping item!")
68
- self.cache.popitem()
69
- self.cache[key] = image
70
- return self.cache[key]
71
-
72
- def __len__(self):
73
- return len(self.cache.keys())
74
-
75
- def print_info(self):
76
- print(f"Cache size: {len(self)}")
77
- print(f"Cache hits: {self.cache_hits}")
78
- print(f"Cache misses: {self.cache_misses}")
79
-
80
- def imageCacheWrapper(fn):
81
- def wrapper(image):
82
- return fn(ImageCache[image])
83
- return wrapper
 
1
  import torchvision.transforms as transforms
 
 
 
 
2
  import torch
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  # Define the transformation pipeline
5
  transform = transforms.Compose([
6
  transforms.Resize((224,224)),
7
  transforms.ToTensor(), # Convert the images to PyTorch tensors
8
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
9
+ ])