GabrielML commited on
Commit
79acef0
·
1 Parent(s): 8717d3d

Video feature

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. README.md +1 -1
  4. animals.png +3 -0
  5. app.py +257 -105
  6. requirements.txt +0 -0
  7. src/Nets.py +1 -1
  8. src/cache/val_df.csv +0 -0
  9. src/examples/AI_Generated/crow.png +3 -0
  10. src/examples/AI_Generated/donkey.png +3 -0
  11. src/examples/AI_Generated/eagle.png +3 -0
  12. src/examples/{dragonfly/353bd2bd65.jpg → AI_Generated/elephant (2).png} +2 -2
  13. src/examples/AI_Generated/elephant.png +3 -0
  14. src/examples/AI_Generated/fox.png +3 -0
  15. src/examples/AI_Generated/goat (2).png +3 -0
  16. src/examples/AI_Generated/goat.png +3 -0
  17. src/examples/AI_Generated/goldfish.png +3 -0
  18. src/examples/AI_Generated/jellyfish.png +3 -0
  19. src/examples/AI_Generated/koala.png +3 -0
  20. src/examples/AI_Generated/otter.png +3 -0
  21. src/examples/AI_Generated/panda.png +3 -0
  22. src/examples/AI_Generated/penguin.png +3 -0
  23. src/examples/AI_Generated/pigeon.png +3 -0
  24. src/examples/AI_Generated/rabbit.png +3 -0
  25. src/examples/AI_Generated/rhinoceros (2).png +3 -0
  26. src/examples/AI_Generated/rhinoceros.png +3 -0
  27. src/examples/AI_Generated/snake.png +3 -0
  28. src/examples/AI_Generated/swan.png +3 -0
  29. src/examples/AI_Generated/woodpecker.png +3 -0
  30. src/examples/antelope/1d556456dc.jpg +0 -3
  31. src/examples/badger/0836f4eb45.jpg +0 -3
  32. src/examples/badger/23bfad16a7.jpg +0 -3
  33. src/examples/badger/4c273d12a9.jpg +0 -3
  34. src/examples/badger/5bffbd51cf.jpg +0 -3
  35. src/examples/badger/87d1db4af3.jpg +0 -3
  36. src/examples/badger/89a8316cd4.jpg +0 -3
  37. src/examples/badger/99e296bf48.jpg +0 -3
  38. src/examples/bat/16f6af0091.jpg +0 -3
  39. src/examples/bat/1dd514de63.jpg +0 -3
  40. src/examples/bat/1fd53c0b98.jpg +0 -3
  41. src/examples/bat/2d028b789d.jpg +0 -3
  42. src/examples/bat/2f7c6c7cd5.jpg +0 -3
  43. src/examples/bat/330e4a8053.jpg +0 -3
  44. src/examples/bat/47d2c91d9b.jpg +0 -3
  45. src/examples/bat/513bb906a6.jpg +0 -3
  46. src/examples/bat/5e85312fa8.jpg +0 -3
  47. src/examples/bat/6b4b95f0c4.jpg +0 -3
  48. src/examples/bat/6da14f603d.jpg +0 -3
  49. src/examples/bat/741fa84ed0.jpg +0 -3
  50. src/examples/bear/116d9b7f88.jpg +0 -3
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  *.jpg filter=lfs diff=lfs merge=lfs -text
37
  *.png filter=lfs diff=lfs merge=lfs -text
38
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -158,3 +158,5 @@ src/results/plots/
158
  src/train_resnet.py
159
  src/visualize_gradcam.ipynb
160
  src/cache/data.csv
 
 
 
158
  src/train_resnet.py
159
  src/visualize_gradcam.ipynb
160
  src/cache/data.csv
161
+ .vscode/settings.json
162
+ src/backup
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Explain Animal CNN
3
- emoji: 💻
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
 
1
  ---
2
  title: Explain Animal CNN
3
+ emoji: 🐬
4
  colorFrom: pink
5
  colorTo: gray
6
  sdk: gradio
animals.png ADDED

Git LFS Details

  • SHA256: e7367d776ea2a31f2e41a228272cdc50ca49ee81b6ca8e0cc9b61fddd8313896
  • Pointer size: 132 Bytes
  • Size of remote file: 6.67 MB
app.py CHANGED
@@ -4,7 +4,6 @@ import sys
4
  sys.path.append('src')
5
  from collections import defaultdict
6
  from functools import lru_cache
7
-
8
  import gradio as gr
9
  import matplotlib.pyplot as plt
10
  import numpy as np
@@ -18,10 +17,18 @@ from torchcam.utils import overlay_mask
18
  from torchvision.transforms.functional import to_pil_image
19
  from tqdm import tqdm
20
  from util import transform
 
 
 
 
21
 
22
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
23
- RANDOM_IMAGES_TO_SHOW = 10
24
- IMAGES_PER_ROW = 5
 
 
 
 
25
 
26
  CAM_METHODS = {
27
  "GradCAM": GradCAM,
@@ -35,27 +42,24 @@ model.load_state_dict(torch.load('src/results/models/best_model.pth', map_locati
35
  cam_model = copy.deepcopy(model)
36
  data_df = pd.read_csv('src/cache/val_df.csv')
37
 
38
- def load_random_images():
39
- random_images = list()
40
- for i in range(RANDOM_IMAGES_TO_SHOW):
41
- idx = np.random.randint(0, len(data_df))
42
- p = os.path.join(IMAGE_PATH, data_df.iloc[idx]['path'])
43
- p = p.replace('\\', '/')
44
- p = p.replace('//', '/')
45
- animal = data_df.iloc[idx]['target']
46
- if os.path.exists(p):
47
- random_images.append((animal, Image.open(p)))
48
- return random_images
49
 
50
  def get_class_name(idx):
51
- return data_df[data_df['encoded_target'] == idx]['target'].values[0]
 
 
 
52
 
53
  @lru_cache(maxsize=100)
54
  def get_translated(to_translate):
55
- return GoogleTranslator(source="en", target="de").translate(to_translate)
56
- for idx in range(90): get_translated(get_class_name(idx))
 
57
 
58
- def infer_image(image):
 
59
  image = transform(image)
60
  image = image.unsqueeze(0)
61
  with torch.no_grad():
@@ -67,7 +71,8 @@ def infer_image(image):
67
  ret[animal] = prob.item()
68
  return ret
69
 
70
- def gradcam(image, alpha, cam_method, layer):
 
71
  if layer == 'layer1': layers = [model.resnet.layer1]
72
  elif layer == 'layer2': layers = [model.resnet.layer2]
73
  elif layer == 'layer3': layers = [model.resnet.layer3]
@@ -78,105 +83,252 @@ def gradcam(image, alpha, cam_method, layer):
78
  img_tensor = transform(image).unsqueeze(0)
79
  cam = CAM_METHODS[cam_method](model, target_layer=layers)
80
  output = model(img_tensor)
81
- activation_map = cam(output.squeeze(0).argmax().item(), output)
 
82
  result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha)
83
  cam.remove_hooks()
84
 
85
- # height maximal 300px
86
- if result.size[1] > 300:
87
- ratio = 300 / result.size[1]
88
- result = result.resize((int(result.size[0] * ratio), 300))
89
  return result
90
 
91
- with gr.Blocks() as demo:
92
- with open('src/header.md', 'r') as f:
93
- markdown_string = f.read()
94
- header = gr.Markdown(markdown_string)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- with gr.Row(variant="panel", equal_height=True):
97
- user_image = gr.Image(
98
- type="pil",
99
- label="Upload Your Own Image",
100
- info="You can also upload your own image for prediction.",
101
- scale=1,
102
- )
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- with gr.Tab("Predict"):
105
- with gr.Column():
106
- output = gr.Label(
107
- num_top_classes=3,
108
- label="Output",
109
- info="Top three predicted classes and their confidences.",
110
- scale=5,
111
- )
112
- predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
113
- predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True)
 
 
 
 
 
 
114
 
115
- with gr.Tab("Explain"):
116
- with gr.Row():
117
- with gr.Column():
118
- cam_method = gr.Radio(
119
- list(CAM_METHODS.keys()),
120
- label="GradCAM Method",
121
- value="GradCAM",
122
- interactive=True,
123
- scale=2,
124
- )
125
- cam_method.description = "Here you can choose the GradCAM method."
126
- cam_method.description_place = "left"
127
-
128
- alpha = gr.Slider(
129
- minimum=.1,
130
- maximum=.9,
131
- value=0.5,
132
- interactive=True,
133
- step=.1,
134
- label="Alpha",
135
- scale=1,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  )
137
- alpha.description = "Here you can choose the alpha value."
138
- alpha.description_place = "left"
139
-
140
- layer = gr.Radio(
141
- ["layer1", "layer2", "layer3", "layer4", "all"],
142
- label="Layer",
143
- value="layer4",
144
- interactive=True,
145
- scale=2,
146
  )
147
- layer.description = "Here you can choose the layer to visualize."
148
- layer.description_place = "left"
149
-
 
 
 
 
 
 
150
  with gr.Column():
151
- output_cam = gr.Image(
152
- type="pil",
153
- label="GradCAM",
154
- info="GradCAM visualization"
155
-
156
  )
157
-
158
- gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
159
- gradcam_mode_button.click(fn=gradcam, inputs=[user_image, alpha, cam_method, layer], outputs=output_cam, queue=True)
160
-
161
- with gr.Tab("Example Images"):
162
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  placeholder = gr.Markdown("## Example Images")
164
- showed_images = list()
165
- loaded_images = load_random_images()
166
- amount_rows = max(1, (len(loaded_images) // IMAGES_PER_ROW))
167
- if len(loaded_images) == 0:
168
- print(f"Could not find any images in {IMAGE_PATH}")
169
- amount_rows = 0
170
- for i in range(amount_rows):
171
- with gr.Row():
172
- for j in range(IMAGES_PER_ROW):
173
- animal, image = loaded_images[i * IMAGES_PER_ROW + j]
174
- showed_images.append(gr.Image(
175
- value=image,
176
- label=animal,
177
- type="pil",
178
- interactive=False,
179
- ))
180
 
181
 
182
 
 
4
  sys.path.append('src')
5
  from collections import defaultdict
6
  from functools import lru_cache
 
7
  import gradio as gr
8
  import matplotlib.pyplot as plt
9
  import numpy as np
 
17
  from torchvision.transforms.functional import to_pil_image
18
  from tqdm import tqdm
19
  from util import transform
20
+ from gradio_blocks import build_video_to_camvideo
21
+ import cv2
22
+ import ffmpeg
23
+
24
 
25
  IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
26
+ IMAGES_PER_ROW = 7
27
+
28
+ MAXIMAL_FRAMES = 1000
29
+ BATCHES_TO_PROCESS = 10
30
+ OUTPUT_FPS = 15
31
+ MAX_OUT_FRAMES = 60
32
 
33
  CAM_METHODS = {
34
  "GradCAM": GradCAM,
 
42
  cam_model = copy.deepcopy(model)
43
  data_df = pd.read_csv('src/cache/val_df.csv')
44
 
45
+ C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict()
46
+ C_NAME_TO_NUM = {v: k for k, v in C_NUM_TO_NAME.items()}
47
+ ALL_CLASSES = sorted(list(C_NUM_TO_NAME.values()), key=lambda x: x.lower())
 
 
 
 
 
 
 
 
48
 
49
  def get_class_name(idx):
50
+ return C_NUM_TO_NAME[idx]
51
+
52
+ def get_class_idx(name):
53
+ return C_NAME_TO_NUM[name]
54
 
55
  @lru_cache(maxsize=100)
56
  def get_translated(to_translate):
57
+ return "ssss"
58
+ # return GoogleTranslator(source="en", target="de").translate(to_translate)
59
+ # for idx in range(90): get_translated(get_class_name(idx))
60
 
61
+ def infer_image(image, image_sketch):
62
+ image = image if image is not None else image_sketch
63
  image = transform(image)
64
  image = image.unsqueeze(0)
65
  with torch.no_grad():
 
71
  ret[animal] = prob.item()
72
  return ret
73
 
74
+ def gradcam(image, image_sketch=None, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
75
+ image = image if image is not None else image_sketch
76
  if layer == 'layer1': layers = [model.resnet.layer1]
77
  elif layer == 'layer2': layers = [model.resnet.layer2]
78
  elif layer == 'layer3': layers = [model.resnet.layer3]
 
83
  img_tensor = transform(image).unsqueeze(0)
84
  cam = CAM_METHODS[cam_method](model, target_layer=layers)
85
  output = model(img_tensor)
86
+ class_to_explain = output.squeeze(0).argmax().item() if specific_class == "Predicted Class" else get_class_idx(specific_class)
87
+ activation_map = cam(class_to_explain, output)
88
  result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha)
89
  cam.remove_hooks()
90
 
91
+ # # height maximal 300px
92
+ # if result.size[1] > 300:
93
+ # ratio = 300 / result.size[1]
94
+ # result = result.resize((int(result.size[0] * ratio), 300))
95
  return result
96
 
97
+
98
+ def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
99
+ global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
100
+ video = cv2.VideoCapture(video)
101
+ fps = int(video.get(cv2.CAP_PROP_FPS))
102
+ if OUTPUT_FPS == -1: OUTPUT_FPS = fps
103
+ width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
104
+ height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
105
+ if width > 3000 or height > 3000:
106
+ raise gr.Error("The video is too big. The maximal size is 3000x3000.")
107
+ print(f'FPS: {fps}, Width: {width}, Height: {height}')
108
+
109
+ frames = list()
110
+ success, image = video.read()
111
+ while success:
112
+ frames.append(image)
113
+ success, image = video.read()
114
+ print(f'Frames: {len(frames)}')
115
+ if len(frames) == 0:
116
+ raise gr.Error("The video is empty.")
117
+ if len(frames) >= MAXIMAL_FRAMES:
118
+ raise gr.Error(f"The video is too long. The maximal length is {MAXIMAL_FRAMES} frames.")
119
+
120
+ if len(frames) > MAX_OUT_FRAMES:
121
+ frames = frames[::len(frames) // MAX_OUT_FRAMES]
122
+
123
+ print(f'Frames to process: {len(frames)}')
124
+
125
+ processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
126
+ # generate lists in lists for the images for batch processing. 10 images per inner list..
127
+ batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)]
128
 
129
+ model.eval()
130
+ if layer == 'layer1': layers = [model.resnet.layer1]
131
+ elif layer == 'layer2': layers = [model.resnet.layer2]
132
+ elif layer == 'layer3': layers = [model.resnet.layer3]
133
+ elif layer == 'layer4': layers = [model.resnet.layer4]
134
+ else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
135
+ cam = CAM_METHODS[cam_method](model, target_layer=layers)
136
+ results = list()
137
+ for i, batch in enumerate(tqdm(batched)):
138
+ images_tensor = torch.stack([transform(image) for image in batch])
139
+ outputs = model(images_tensor)
140
+ out_classes = [output.argmax().item() for output in outputs]
141
+ classes_to_explain = out_classes if specific_class == "Predicted Class" else [get_class_idx(specific_class)] * len(out_classes)
142
+ activation_maps = cam(classes_to_explain, outputs)
143
+ for j, activation_map in enumerate(activation_maps[0]):
144
+ result = overlay_mask(batch[j], to_pil_image(activation_map, mode='F'), alpha=alpha)
145
+ results.append(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR))
146
+ cam.remove_hooks()
147
 
148
+ # save video
149
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
150
+ size = (results[0].shape[1], results[0].shape[0])
151
+ video = cv2.VideoWriter('src/results/gradcam_video.mp4', fourcc, OUTPUT_FPS, size)
152
+ for frame in results:
153
+ video.write(frame)
154
+ video.release()
155
+ return 'src/results/gradcam_video.mp4'
156
+
157
+ def load_examples():
158
+ folder_name_to_header = {
159
+ "AI_Generated": "AI Generated Images",
160
+ "true_predicted": "True Predicted Images (Validation Set)",
161
+ "false_predicted": "False Predicted Images (Validation Set)",
162
+ "others": "Other interesting images from the internet"
163
+ }
164
 
165
+ images_description = {
166
+ "AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
167
+ "true_predicted": "These images are from the validation set and the model predicted them correctly.",
168
+ "false_predicted": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)",
169
+ "others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
170
+ }
171
+
172
+ loaded_images = defaultdict(list)
173
+
174
+ for image_type in ["AI_Generated", "true_predicted", "false_predicted", "others"]:
175
+ # for image_type in os.listdir(IMAGE_PATH):
176
+ full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
177
+ gr.Markdown(f'## {folder_name_to_header[image_type]}')
178
+ gr.Markdown(images_description[image_type])
179
+ images_to_load = os.listdir(full_path)
180
+ rows = (len(images_to_load) // IMAGES_PER_ROW) + 1
181
+ for i in range(rows):
182
+ with gr.Row(elem_classes=["row-example-images"], equal_height=False):
183
+ for j in range(IMAGES_PER_ROW):
184
+ if i * IMAGES_PER_ROW + j >= len(images_to_load): break
185
+ image = images_to_load[i * IMAGES_PER_ROW + j]
186
+ loaded_images[image_type].append(
187
+ gr.Image(
188
+ value=os.path.join(full_path, image),
189
+ label=f"image ({get_translated(image.split('.')[0])})",
190
+ type="pil",
191
+ interactive=False,
192
+ elem_classes=["selectable_images"],
193
+ )
194
+ )
195
+ return loaded_images
196
+
197
+ css = """
198
+ #logo {text-align: right;}
199
+ p {text-align: justify; text-justify: inter-word; font-size: 1.1em; line-height: 1.2em;}
200
+ .svelte-1btp92j.selectable {cursor: pointer !important; }
201
+ """
202
+
203
+
204
+
205
+ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
206
+ # -------------------------------------------
207
+ # HEADER WITH LOGO
208
+ # -------------------------------------------
209
+ with gr.Row():
210
+ with open('src/header.md', 'r', encoding='utf-8') as f:
211
+ markdown_string = f.read()
212
+ with gr.Column(scale=10):
213
+ header = gr.Markdown(markdown_string)
214
+ with gr.Column(scale=1):
215
+ pil_logo = Image.open('animals.png')
216
+ logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
217
+
218
+ # -------------------------------------------
219
+ # INPUT IMAGE
220
+ # -------------------------------------------
221
+ with gr.Row():
222
+ with gr.Tab("Upload Image"):
223
+ with gr.Row(variant="panel", equal_height=True):
224
+ user_image = gr.Image(
225
+ type="pil",
226
+ label="Upload Your Own Image",
227
+ info="You can also upload your own image for prediction.",
228
  )
229
+ with gr.Tab("Draw Image"):
230
+ with gr.Row(variant="panel", equal_height=True):
231
+ user_image_sketched = gr.Image(
232
+ type="pil",
233
+ source="canvas",
234
+ tool="color-sketch",
235
+ label="Draw Your Own Image",
236
+ info="You can also draw your own image for prediction.",
 
237
  )
238
+
239
+ # -------------------------------------------
240
+ # TOOLS
241
+ # -------------------------------------------
242
+ with gr.Row():
243
+ # -------------------------------------------
244
+ # PREDICT
245
+ # -------------------------------------------
246
+ with gr.Tab("Predict"):
247
  with gr.Column():
248
+ output = gr.Label(
249
+ num_top_classes=5,
250
+ label="Output",
251
+ info="Top three predicted classes and their confidences.",
252
+ scale=5,
253
  )
254
+ predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
255
+ predict_mode_button.click(fn=infer_image, inputs=[user_image, user_image_sketched], outputs=output, queue=True)
256
+
257
+ # -------------------------------------------
258
+ # EXPLAIN
259
+ # -------------------------------------------
260
+ with gr.Tab("Explain"):
261
+ with gr.Row():
262
+ with gr.Column():
263
+ cam_method = gr.Radio(
264
+ list(CAM_METHODS.keys()),
265
+ label="GradCAM Method",
266
+ value="GradCAM",
267
+ interactive=True,
268
+ scale=2,
269
+ )
270
+ cam_method.description = "Here you can choose the GradCAM method."
271
+ cam_method.description_place = "left"
272
+
273
+ alpha = gr.Slider(
274
+ minimum=.1,
275
+ maximum=.9,
276
+ value=0.5,
277
+ interactive=True,
278
+ step=.1,
279
+ label="Alpha",
280
+ scale=1,
281
+ )
282
+ alpha.description = "Here you can choose the alpha value."
283
+ alpha.description_place = "left"
284
+
285
+ layer = gr.Radio(
286
+ ["layer1", "layer2", "layer3", "layer4", "all"],
287
+ label="Layer",
288
+ value="layer4",
289
+ interactive=True,
290
+ scale=2,
291
+ )
292
+ layer.description = "Here you can choose the layer to visualize."
293
+ layer.description_place = "left"
294
+
295
+ animal_to_explain = gr.Dropdown(
296
+ choices=["Predicted Class"] + ALL_CLASSES,
297
+ label="Animal",
298
+ value="Predicted Class",
299
+ interactive=True,
300
+ scale=2,
301
+ )
302
+ animal_to_explain.description = "Here you can choose the animal to explain. If you choose 'Predicted Class' the method will explain the predicted class."
303
+ animal_to_explain.description_place = "center"
304
+
305
+ with gr.Column():
306
+ output_cam = gr.Image(
307
+ type="pil",
308
+ label="GradCAM",
309
+ info="GradCAM visualization"
310
+
311
+ )
312
+
313
+ gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
314
+ gradcam_mode_button.click(fn=gradcam, inputs=[user_image, user_image_sketched, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
315
+
316
+ # -------------------------------------------
317
+ # GIF CAM
318
+ # -------------------------------------------
319
+ with gr.Tab("Gif Cam"):
320
+ build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video)
321
+
322
+ # -------------------------------------------
323
+ # EXAMPLES
324
+ # -------------------------------------------
325
+ with gr.Tab("Example Images"):
326
  placeholder = gr.Markdown("## Example Images")
327
+ loaded_images = load_examples()
328
+ for k in loaded_images.keys():
329
+ for image in loaded_images[k]:
330
+ image.select(fn=lambda x: x, inputs=[image], outputs=[user_image])
331
+
 
 
 
 
 
 
 
 
 
 
 
332
 
333
 
334
 
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
 
src/Nets.py CHANGED
@@ -41,7 +41,7 @@ class SimpleCNN(nn.Module):
41
  class CustomResNet18(nn.Module):
42
  def __init__(self, num_classes=11):
43
  super(CustomResNet18, self).__init__()
44
- self.resnet = models.resnet18(pretrained=True)
45
  num_features = self.resnet.fc.in_features
46
  self.resnet.fc = nn.Linear(num_features, num_classes)
47
 
 
41
  class CustomResNet18(nn.Module):
42
  def __init__(self, num_classes=11):
43
  super(CustomResNet18, self).__init__()
44
+ self.resnet = models.resnet50(pretrained=True)
45
  num_features = self.resnet.fc.in_features
46
  self.resnet.fc = nn.Linear(num_features, num_classes)
47
 
src/cache/val_df.csv CHANGED
The diff for this file is too large to render. See raw diff
 
src/examples/AI_Generated/crow.png ADDED

Git LFS Details

  • SHA256: 7b0f826eb3f73af5f7028d6c5cf08283e8ecb3eb8e02aa24aaeea0d96db4985e
  • Pointer size: 132 Bytes
  • Size of remote file: 2.26 MB
src/examples/AI_Generated/donkey.png ADDED

Git LFS Details

  • SHA256: e3d19dc6df7a0d5de0b41c8ecb9d01386a11a9ca9ac073a60b50e6cef65f4429
  • Pointer size: 132 Bytes
  • Size of remote file: 2.68 MB
src/examples/AI_Generated/eagle.png ADDED

Git LFS Details

  • SHA256: 85a6d5d8f4172572d674082e6567ff03206ec48867a2917f04d3f738da9f7155
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
src/examples/{dragonfly/353bd2bd65.jpg → AI_Generated/elephant (2).png} RENAMED
File without changes
src/examples/AI_Generated/elephant.png ADDED

Git LFS Details

  • SHA256: 173a41c3434c1e90d4c0ba52c335fca122bb0defd796fe3711637c0514bb809d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.75 MB
src/examples/AI_Generated/fox.png ADDED

Git LFS Details

  • SHA256: e765dd8573867e5e9ba800d3be50dfd9212b870ece12f211898598dac6b86da1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.04 MB
src/examples/AI_Generated/goat (2).png ADDED

Git LFS Details

  • SHA256: e1346e0c71f880f274de273ed6eee5e8ad2bc7b2d767560459192d5de9bec8b8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
src/examples/AI_Generated/goat.png ADDED

Git LFS Details

  • SHA256: 12610f867fb465645ce55dfdc271787c60b6c521a6c7d2337c10188525c87ada
  • Pointer size: 132 Bytes
  • Size of remote file: 5.51 MB
src/examples/AI_Generated/goldfish.png ADDED

Git LFS Details

  • SHA256: ed6e82dcabb9453a70a7ef1b28db4151f2d8a292fac17b7b130ca9cad9d3b079
  • Pointer size: 132 Bytes
  • Size of remote file: 2.46 MB
src/examples/AI_Generated/jellyfish.png ADDED

Git LFS Details

  • SHA256: 747fd042535093a4b1b68435c756ff6095e4fefbd367fd51b74ca8ccfef60f71
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
src/examples/AI_Generated/koala.png ADDED

Git LFS Details

  • SHA256: 9ac3ad3beab6456614f297254589507af1df3191db8bfcbad26924c5d996e831
  • Pointer size: 132 Bytes
  • Size of remote file: 1.65 MB
src/examples/AI_Generated/otter.png ADDED

Git LFS Details

  • SHA256: e7d0bc5786dd3d31220810183f56cb4775d9f0be22739479d0c922dd4cc6ee00
  • Pointer size: 132 Bytes
  • Size of remote file: 2.42 MB
src/examples/AI_Generated/panda.png ADDED

Git LFS Details

  • SHA256: be4158fd9cbac2753909630ee15a8933cf41db4986dc3f2c05772abb992aaf51
  • Pointer size: 132 Bytes
  • Size of remote file: 2.28 MB
src/examples/AI_Generated/penguin.png ADDED

Git LFS Details

  • SHA256: a28dab5557db846c0e798790e3d541b8e4ac74401db067326077919cd99b277f
  • Pointer size: 132 Bytes
  • Size of remote file: 5.82 MB
src/examples/AI_Generated/pigeon.png ADDED

Git LFS Details

  • SHA256: d6b9dd55637ba97f1dd3c125d46babf3dc0a68d76f072f72927e9e25a08a9fa1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
src/examples/AI_Generated/rabbit.png ADDED

Git LFS Details

  • SHA256: d376c67d72e403102393e685459084d45d6b804caac42348aa712dc646075796
  • Pointer size: 132 Bytes
  • Size of remote file: 2.36 MB
src/examples/AI_Generated/rhinoceros (2).png ADDED

Git LFS Details

  • SHA256: ec2b0e3f1b90e69a6c2a6bc2eab8706848cacb4c69b8b6e063522800639fe7ff
  • Pointer size: 132 Bytes
  • Size of remote file: 2.71 MB
src/examples/AI_Generated/rhinoceros.png ADDED

Git LFS Details

  • SHA256: 2aac1990b1b6a835100ae9f75824fb84f4a8e322aa8e7f73d67f89506c3330d2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.75 MB
src/examples/AI_Generated/snake.png ADDED

Git LFS Details

  • SHA256: 3909dad360ca7a45761f778396180412327b4d3109f7c1a8fa05c5ac9a6218c5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
src/examples/AI_Generated/swan.png ADDED

Git LFS Details

  • SHA256: 5b0aa528e07a8a78f3bc4c610a40f46b769e7b540c9f228d4027d3288a2edda1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.1 MB
src/examples/AI_Generated/woodpecker.png ADDED

Git LFS Details

  • SHA256: 98eae49bd0c23630df30be211d899b1cc39fd6877c7dcb8dfc8d1cf1b6f1e8b9
  • Pointer size: 132 Bytes
  • Size of remote file: 2.16 MB
src/examples/antelope/1d556456dc.jpg DELETED

Git LFS Details

  • SHA256: e43ce922cf0ece6f10e00a98452c2cc9c1368098f20a3c283c0a5c55bb8b2aac
  • Pointer size: 131 Bytes
  • Size of remote file: 175 kB
src/examples/badger/0836f4eb45.jpg DELETED

Git LFS Details

  • SHA256: 26653074c2954440e1594da396c3884564d8eb98ec3ae668556d041aa2cf622b
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
src/examples/badger/23bfad16a7.jpg DELETED

Git LFS Details

  • SHA256: f6de1f22ef34e536e608734028a18e2f828ad66c588008a2a035d65343358494
  • Pointer size: 130 Bytes
  • Size of remote file: 23.8 kB
src/examples/badger/4c273d12a9.jpg DELETED

Git LFS Details

  • SHA256: d3684ec98d3f1a4524fa2116d34faf0d9002c3b76fc747e25e4e76a92d0f4255
  • Pointer size: 130 Bytes
  • Size of remote file: 11.1 kB
src/examples/badger/5bffbd51cf.jpg DELETED

Git LFS Details

  • SHA256: ddb4f9356747abe204444530b99051cade2920900007cecaa60c67a72c8c76eb
  • Pointer size: 130 Bytes
  • Size of remote file: 99.1 kB
src/examples/badger/87d1db4af3.jpg DELETED

Git LFS Details

  • SHA256: 3cb2a5b662288b8381fd7a8f0d0f4af2b6170a483980a973eedc1a917e4585e7
  • Pointer size: 130 Bytes
  • Size of remote file: 10.9 kB
src/examples/badger/89a8316cd4.jpg DELETED

Git LFS Details

  • SHA256: dbab117c3ea518712a4bea29ca5c7eaa4c272f763f0c15a41abb6cc9dff0c76e
  • Pointer size: 130 Bytes
  • Size of remote file: 14.2 kB
src/examples/badger/99e296bf48.jpg DELETED

Git LFS Details

  • SHA256: 54636f305b57972d7a33926cdbe89d1ea91c6df403e1520df8110d862c40fb4f
  • Pointer size: 131 Bytes
  • Size of remote file: 244 kB
src/examples/bat/16f6af0091.jpg DELETED

Git LFS Details

  • SHA256: 06a98aa308cc1d920924e1212b5059dddf7561306b2dd841191fd9ed6b731c19
  • Pointer size: 130 Bytes
  • Size of remote file: 14.3 kB
src/examples/bat/1dd514de63.jpg DELETED

Git LFS Details

  • SHA256: 50246cfb3ec00dfbe0b596e36498e7c2f4eafbc06227b3032cfddc4ed8108371
  • Pointer size: 131 Bytes
  • Size of remote file: 740 kB
src/examples/bat/1fd53c0b98.jpg DELETED

Git LFS Details

  • SHA256: 54705a491c58f1bf49d67c290fd2e216d6026725442dc113466abfb073df7a05
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
src/examples/bat/2d028b789d.jpg DELETED

Git LFS Details

  • SHA256: 72084453d6f52328a0f14c832a78e03b33afab9a82a17fea0ba53212412d1074
  • Pointer size: 131 Bytes
  • Size of remote file: 122 kB
src/examples/bat/2f7c6c7cd5.jpg DELETED

Git LFS Details

  • SHA256: 672966f774809bc0864e0748a7215698f8baf8d1580647784c59da99142814d8
  • Pointer size: 131 Bytes
  • Size of remote file: 559 kB
src/examples/bat/330e4a8053.jpg DELETED

Git LFS Details

  • SHA256: 5b5cc5fedb38de6cd9c6db67f75eb59983fd141b6c314b1839deb831c8a87e09
  • Pointer size: 130 Bytes
  • Size of remote file: 13.4 kB
src/examples/bat/47d2c91d9b.jpg DELETED

Git LFS Details

  • SHA256: 8f27cb329903dd0f3e47a2b343b458970b63b33325bb39946413bd293a8c6c68
  • Pointer size: 130 Bytes
  • Size of remote file: 11.4 kB
src/examples/bat/513bb906a6.jpg DELETED

Git LFS Details

  • SHA256: c445ff1955991389d62c588ed75e4fc9a6428fe1399cfb2a6cbeb56a55c11af4
  • Pointer size: 130 Bytes
  • Size of remote file: 71.4 kB
src/examples/bat/5e85312fa8.jpg DELETED

Git LFS Details

  • SHA256: 127136ab936ee97d0a41447cef34b76cdef7dc416b513580ec06a051d82dc167
  • Pointer size: 131 Bytes
  • Size of remote file: 229 kB
src/examples/bat/6b4b95f0c4.jpg DELETED

Git LFS Details

  • SHA256: 6d8d98088534879a2c073881b378dafc53948610b78395557375d83f04e5b5d3
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
src/examples/bat/6da14f603d.jpg DELETED

Git LFS Details

  • SHA256: 6cec4cbc05b24ae81e709348d7ea9d04fdc29532791b559a16b28b80785db3df
  • Pointer size: 129 Bytes
  • Size of remote file: 8.67 kB
src/examples/bat/741fa84ed0.jpg DELETED

Git LFS Details

  • SHA256: b61f463e9f1a9cc91f3eed960b9f8b5beb1fe3221277044104e529869160d272
  • Pointer size: 131 Bytes
  • Size of remote file: 379 kB
src/examples/bear/116d9b7f88.jpg DELETED

Git LFS Details

  • SHA256: 66005dc85e7b14b643892410f48ce9b2c44c1c8f8287dc63e2cca929f937661f
  • Pointer size: 130 Bytes
  • Size of remote file: 14.6 kB