Spaces:
Sleeping
Sleeping
Add new classes and features
Browse files- app.py +109 -57
- src/Nets.py +1 -38
- src/Roboto-Regular.ttf +0 -0
- src/cache/val_df.csv +0 -0
- src/examples/{false_predicted/squirrel.jpg β false/bee.jpg} +2 -2
- src/examples/{false_predicted/chimpanzee.jpg β false/coyote.jpg} +2 -2
- src/examples/{true_predicted/cat.jpg β false/donkey.jpg} +2 -2
- src/examples/false/goat.jpg +3 -0
- src/examples/false/hornbill.jpg +3 -0
- src/examples/false_predicted/starfish.jpg +0 -3
- src/examples/true/dolphin.jpg +3 -0
- src/examples/true/dragonfly.jpg +3 -0
- src/examples/{false_predicted β true}/koala.jpg +2 -2
- src/examples/{false_predicted β true}/sheep.jpg +2 -2
- src/examples/true/squid.jpg +3 -0
- src/examples/true_predicted/cockroach.jpg +0 -3
- src/examples/true_predicted/flamingo.jpg +0 -3
- src/examples/true_predicted/gorilla.jpg +0 -3
- src/examples/true_predicted/grasshopper.jpg +0 -3
- src/gradio_blocks.py +2 -2
- src/header.md +2 -2
- src/results/gradcam_video.mp4 +2 -2
- src/results/infer_image.png +2 -2
- src/results/models/best_model.pth +2 -2
- src/util.py +1 -75
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import copy
|
2 |
import os
|
3 |
import sys
|
@@ -16,31 +17,43 @@ import torch
|
|
16 |
from deep_translator import GoogleTranslator
|
17 |
from gradio_blocks import build_video_to_camvideo
|
18 |
from Nets import CustomResNet18
|
19 |
-
from PIL import Image
|
20 |
|
21 |
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
22 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
23 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
24 |
|
25 |
from tqdm import tqdm
|
26 |
-
import
|
27 |
-
|
|
|
28 |
|
29 |
-
util.ImageCache = CustomImageCache(60, False)
|
30 |
ffmpeg_path = shutil.which('ffmpeg')
|
31 |
mediapy.set_ffmpeg(ffmpeg_path)
|
32 |
|
33 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
34 |
IMAGES_PER_ROW = 5
|
35 |
|
36 |
-
MAXIMAL_FRAMES =
|
37 |
-
BATCHES_TO_PROCESS =
|
38 |
OUTPUT_FPS = 10
|
39 |
-
MAX_OUT_FRAMES =
|
40 |
|
41 |
-
MODEL = CustomResNet18(
|
42 |
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
CAM_METHODS = {
|
45 |
"GradCAM": GradCAM,
|
46 |
"GradCAM++": GradCAMPlusPlus,
|
@@ -87,16 +100,21 @@ def get_class_name(idx):
|
|
87 |
def get_class_idx(name):
|
88 |
return C_NAME_TO_NUM[name]
|
89 |
|
90 |
-
@lru_cache(maxsize=
|
91 |
-
def get_translated(to_translate):
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
if isinstance(image, dict):
|
98 |
-
# Its the image and a mask as pillow both -> Combine them to one image
|
99 |
-
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
100 |
image.save('src/results/infer_image.png')
|
101 |
image = transform(image)
|
102 |
image = image.unsqueeze(0)
|
@@ -105,11 +123,13 @@ def infer_image(image):
|
|
105 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
106 |
ret = defaultdict(float)
|
107 |
for idx, prob in enumerate(distribution[0]):
|
108 |
-
animal = f'{get_class_name(idx)}
|
|
|
|
|
109 |
ret[animal] = prob.item()
|
110 |
return ret
|
111 |
|
112 |
-
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
113 |
if image is None:
|
114 |
raise gr.Error("Please upload an image.")
|
115 |
|
@@ -123,8 +143,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
123 |
colormap = CV2_COLORMAPS[colormap]
|
124 |
|
125 |
image_width, image_height = image.size
|
126 |
-
if image_width >
|
127 |
-
raise gr.Error("The image is too big. The maximal size is
|
128 |
|
129 |
|
130 |
MODEL.eval()
|
@@ -135,6 +155,8 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
135 |
|
136 |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
137 |
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
|
|
|
|
|
138 |
|
139 |
grayscale_cam = grayscale_cam[0, :]
|
140 |
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
@@ -146,10 +168,25 @@ def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False,
|
|
146 |
else:
|
147 |
image = image / 255
|
148 |
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
152 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
|
|
153 |
if colormap not in CV2_COLORMAPS.keys():
|
154 |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
155 |
else:
|
@@ -159,8 +196,8 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
|
|
159 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
160 |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
161 |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
162 |
-
if width >
|
163 |
-
raise gr.Error("The video is too big. The maximal size is
|
164 |
print(f'FPS: {fps}, Width: {width}, Height: {height}')
|
165 |
|
166 |
frames = list()
|
@@ -213,21 +250,21 @@ def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=Fal
|
|
213 |
def load_examples():
|
214 |
folder_name_to_header = {
|
215 |
"AI_Generated": "AI Generated Images",
|
216 |
-
"
|
217 |
-
"
|
218 |
"others": "Other interesting images from the internet"
|
219 |
}
|
220 |
|
221 |
images_description = {
|
222 |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
|
223 |
-
"
|
224 |
-
"
|
225 |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
|
226 |
}
|
227 |
|
228 |
loaded_images = defaultdict(list)
|
229 |
|
230 |
-
for image_type in ["AI_Generated", "
|
231 |
# for image_type in os.listdir(IMAGE_PATH):
|
232 |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
|
233 |
gr.Markdown(f'## {folder_name_to_header[image_type]}')
|
@@ -239,7 +276,7 @@ def load_examples():
|
|
239 |
for j in range(IMAGES_PER_ROW):
|
240 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
241 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
242 |
-
name = f"{image.split('.')[0]}
|
243 |
image = Image.open(os.path.join(full_path, image))
|
244 |
# scale so that the longest side is 600px
|
245 |
scale = 600 / max(image.size)
|
@@ -273,7 +310,15 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
273 |
with gr.Column(scale=1):
|
274 |
pil_logo = Image.open('animals.png')
|
275 |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
# -------------------------------------------
|
278 |
# INPUT IMAGE
|
279 |
# -------------------------------------------
|
@@ -282,7 +327,6 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
282 |
user_image = gr.Image(
|
283 |
type="pil",
|
284 |
label="Upload Your Own Image",
|
285 |
-
tool="sketch",
|
286 |
interactive=True,
|
287 |
)
|
288 |
|
@@ -301,8 +345,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
301 |
info="Top three predicted classes and their confidences.",
|
302 |
scale=5,
|
303 |
)
|
304 |
-
|
305 |
-
|
|
|
306 |
|
307 |
# -------------------------------------------
|
308 |
# EXPLAIN
|
@@ -348,20 +393,28 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
348 |
scale=2,
|
349 |
info=_info
|
350 |
)
|
|
|
351 |
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
with gr.Row():
|
367 |
_info = """
|
@@ -371,7 +424,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
371 |
colormap = gr.Dropdown(
|
372 |
choices=list(CV2_COLORMAPS.keys()),
|
373 |
label="Colormap",
|
374 |
-
value="
|
375 |
interactive=True,
|
376 |
scale=2,
|
377 |
info=_info
|
@@ -410,15 +463,16 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
410 |
|
411 |
|
412 |
with gr.Column():
|
|
|
413 |
output_cam = gr.Image(
|
414 |
type="pil",
|
415 |
label="GradCAM",
|
416 |
info="GradCAM visualization",
|
417 |
-
|
|
|
418 |
)
|
419 |
-
|
420 |
-
gradcam_mode_button
|
421 |
-
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
|
422 |
|
423 |
# -------------------------------------------
|
424 |
# Video CAM
|
@@ -434,11 +488,9 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
434 |
loaded_images = load_examples()
|
435 |
for k in loaded_images.keys():
|
436 |
for image in loaded_images[k]:
|
437 |
-
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image])
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
|
442 |
if __name__ == "__main__":
|
443 |
demo.queue()
|
444 |
-
|
|
|
|
1 |
+
from concurrent.futures import ThreadPoolExecutor
|
2 |
import copy
|
3 |
import os
|
4 |
import sys
|
|
|
17 |
from deep_translator import GoogleTranslator
|
18 |
from gradio_blocks import build_video_to_camvideo
|
19 |
from Nets import CustomResNet18
|
20 |
+
from PIL import Image, ImageDraw, ImageFont
|
21 |
|
22 |
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
23 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
24 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
25 |
|
26 |
from tqdm import tqdm
|
27 |
+
from util import transform
|
28 |
+
|
29 |
+
font = ImageFont.truetype("src/Roboto-Regular.ttf", 16)
|
30 |
|
|
|
31 |
ffmpeg_path = shutil.which('ffmpeg')
|
32 |
mediapy.set_ffmpeg(ffmpeg_path)
|
33 |
|
34 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
35 |
IMAGES_PER_ROW = 5
|
36 |
|
37 |
+
MAXIMAL_FRAMES = 700
|
38 |
+
BATCHES_TO_PROCESS = 20
|
39 |
OUTPUT_FPS = 10
|
40 |
+
MAX_OUT_FRAMES = 70
|
41 |
|
42 |
+
MODEL = CustomResNet18(111).eval()
|
43 |
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
44 |
|
45 |
+
LANGUAGES_TO_SELECT = {
|
46 |
+
"None": None,
|
47 |
+
"German": "de",
|
48 |
+
"French": "fr",
|
49 |
+
"Spanish": "es",
|
50 |
+
"Italian": "it",
|
51 |
+
"Finnish": "fi",
|
52 |
+
"Ukrainian": "uk",
|
53 |
+
"Japanese": "ja",
|
54 |
+
"Hebrew": "iw"
|
55 |
+
}
|
56 |
+
|
57 |
CAM_METHODS = {
|
58 |
"GradCAM": GradCAM,
|
59 |
"GradCAM++": GradCAMPlusPlus,
|
|
|
100 |
def get_class_idx(name):
|
101 |
return C_NAME_TO_NUM[name]
|
102 |
|
103 |
+
@lru_cache(maxsize=len(LANGUAGES_TO_SELECT.keys())*111)
|
104 |
+
def get_translated(to_translate, target_language="German"):
|
105 |
+
target_language = LANGUAGES_TO_SELECT[target_language] if target_language in LANGUAGES_TO_SELECT else target_language
|
106 |
+
if target_language == "en": return to_translate
|
107 |
+
if target_language not in LANGUAGES_TO_SELECT.values(): raise gr.Error(f'Language {target_language} not found.')
|
108 |
+
return GoogleTranslator(source="en", target=target_language).translate(to_translate)
|
109 |
+
# for idx in range(111): get_translated(get_class_name(idx))
|
110 |
+
with ThreadPoolExecutor(max_workers=30) as executor:
|
111 |
+
# give the executor the list of images and args (in this case, the target language)
|
112 |
+
# and let the executor map the function to the list of images
|
113 |
+
for language in tqdm(LANGUAGES_TO_SELECT.keys(), desc='Preloading translations'):
|
114 |
+
executor.map(get_translated, ALL_CLASSES, [language] * len(ALL_CLASSES))
|
115 |
|
116 |
+
def infer_image(image, target_language):
|
117 |
+
if image is None: raise gr.Error("Please upload an image.")
|
|
|
|
|
|
|
118 |
image.save('src/results/infer_image.png')
|
119 |
image = transform(image)
|
120 |
image = image.unsqueeze(0)
|
|
|
123 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
124 |
ret = defaultdict(float)
|
125 |
for idx, prob in enumerate(distribution[0]):
|
126 |
+
animal = f'{get_class_name(idx)}'
|
127 |
+
if target_language is not None and target_language != "None":
|
128 |
+
animal += f' ({get_translated(get_class_name(idx), target_language)})'
|
129 |
ret[animal] = prob.item()
|
130 |
return ret
|
131 |
|
132 |
+
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"):
|
133 |
if image is None:
|
134 |
raise gr.Error("Please upload an image.")
|
135 |
|
|
|
143 |
colormap = CV2_COLORMAPS[colormap]
|
144 |
|
145 |
image_width, image_height = image.size
|
146 |
+
if image_width > 6000 or image_height > 6000:
|
147 |
+
raise gr.Error("The image is too big. The maximal size is 6000x6000.")
|
148 |
|
149 |
|
150 |
MODEL.eval()
|
|
|
155 |
|
156 |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
157 |
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
|
158 |
+
if label_image:
|
159 |
+
predicted_animal = get_class_name(np.argmax(cam.outputs.cpu().data.numpy(), axis=-1)[0])
|
160 |
|
161 |
grayscale_cam = grayscale_cam[0, :]
|
162 |
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
|
|
168 |
else:
|
169 |
image = image / 255
|
170 |
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
171 |
+
|
172 |
+
if label_image:
|
173 |
+
# add alpha channel to visualization
|
174 |
+
visualization = np.concatenate([visualization, np.ones((image_height, image_width, 1), dtype=np.uint8) * 255], axis=-1)
|
175 |
+
plt_image = Image.fromarray(visualization, mode="RGBA")
|
176 |
+
draw = ImageDraw.Draw(plt_image)
|
177 |
+
draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100))
|
178 |
+
animal = predicted_animal.capitalize()
|
179 |
+
if target_lang is not None and target_lang != "None":
|
180 |
+
animal += f' ({get_translated(animal, target_lang)})'
|
181 |
+
draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255))
|
182 |
+
visualization = np.array(plt_image)
|
183 |
+
|
184 |
+
out_image = Image.fromarray(visualization)
|
185 |
+
return out_image
|
186 |
|
187 |
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
188 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
189 |
+
if video is None: raise gr.Error("Please upload a video.")
|
190 |
if colormap not in CV2_COLORMAPS.keys():
|
191 |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
192 |
else:
|
|
|
196 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
197 |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
|
198 |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
199 |
+
if width > 2000 or height > 2000:
|
200 |
+
raise gr.Error("The video is too big. The maximal size is 2000x2000.")
|
201 |
print(f'FPS: {fps}, Width: {width}, Height: {height}')
|
202 |
|
203 |
frames = list()
|
|
|
250 |
def load_examples():
|
251 |
folder_name_to_header = {
|
252 |
"AI_Generated": "AI Generated Images",
|
253 |
+
"true": "True Predicted Images (Validation Set)",
|
254 |
+
"false": "False Predicted Images (Validation Set)",
|
255 |
"others": "Other interesting images from the internet"
|
256 |
}
|
257 |
|
258 |
images_description = {
|
259 |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.",
|
260 |
+
"true": "These images are from the validation set and the model predicted them correctly.",
|
261 |
+
"false": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)",
|
262 |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals."
|
263 |
}
|
264 |
|
265 |
loaded_images = defaultdict(list)
|
266 |
|
267 |
+
for image_type in ["AI_Generated", "true", "false", "others"]:
|
268 |
# for image_type in os.listdir(IMAGE_PATH):
|
269 |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/')
|
270 |
gr.Markdown(f'## {folder_name_to_header[image_type]}')
|
|
|
276 |
for j in range(IMAGES_PER_ROW):
|
277 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
278 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
279 |
+
name = f"{image.split('.')[0]}"
|
280 |
image = Image.open(os.path.join(full_path, image))
|
281 |
# scale so that the longest side is 600px
|
282 |
scale = 600 / max(image.size)
|
|
|
310 |
with gr.Column(scale=1):
|
311 |
pil_logo = Image.open('animals.png')
|
312 |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo")
|
313 |
+
|
314 |
+
animal_translation_target_language = gr.Dropdown(
|
315 |
+
choices=LANGUAGES_TO_SELECT.keys(),
|
316 |
+
label="Translation language for animals",
|
317 |
+
value="German",
|
318 |
+
interactive=True,
|
319 |
+
scale=2,
|
320 |
+
)
|
321 |
+
|
322 |
# -------------------------------------------
|
323 |
# INPUT IMAGE
|
324 |
# -------------------------------------------
|
|
|
327 |
user_image = gr.Image(
|
328 |
type="pil",
|
329 |
label="Upload Your Own Image",
|
|
|
330 |
interactive=True,
|
331 |
)
|
332 |
|
|
|
345 |
info="Top three predicted classes and their confidences.",
|
346 |
scale=5,
|
347 |
)
|
348 |
+
with gr.Row():
|
349 |
+
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=6)
|
350 |
+
predict_mode_button.click(fn=infer_image, inputs=[user_image, animal_translation_target_language], outputs=output, queue=True)
|
351 |
|
352 |
# -------------------------------------------
|
353 |
# EXPLAIN
|
|
|
393 |
scale=2,
|
394 |
info=_info
|
395 |
)
|
396 |
+
with gr.Row():
|
397 |
|
398 |
+
_info = """
|
399 |
+
Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
|
400 |
+
If you choose a specific class the GradCAM visualization will be based on this class.
|
401 |
+
For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
|
402 |
+
"""
|
403 |
+
animal_to_explain = gr.Dropdown(
|
404 |
+
choices=["Predicted Class"] + ALL_CLASSES,
|
405 |
+
label="Animal",
|
406 |
+
value="Predicted Class",
|
407 |
+
interactive=True,
|
408 |
+
scale=4,
|
409 |
+
info=_info
|
410 |
+
)
|
411 |
+
|
412 |
+
show_predicted_class = gr.Checkbox(
|
413 |
+
label="Show Predicted Class",
|
414 |
+
value=True,
|
415 |
+
interactive=True,
|
416 |
+
scale=1,
|
417 |
+
)
|
418 |
|
419 |
with gr.Row():
|
420 |
_info = """
|
|
|
424 |
colormap = gr.Dropdown(
|
425 |
choices=list(CV2_COLORMAPS.keys()),
|
426 |
label="Colormap",
|
427 |
+
value="Inferno",
|
428 |
interactive=True,
|
429 |
scale=2,
|
430 |
info=_info
|
|
|
463 |
|
464 |
|
465 |
with gr.Column():
|
466 |
+
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
|
467 |
output_cam = gr.Image(
|
468 |
type="pil",
|
469 |
label="GradCAM",
|
470 |
info="GradCAM visualization",
|
471 |
+
show_label=False,
|
472 |
+
scale=7,
|
473 |
)
|
474 |
+
_inputs = [user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain, show_predicted_class, animal_translation_target_language]
|
475 |
+
gradcam_mode_button.click(fn=gradcam, inputs=_inputs, outputs=output_cam, queue=True)
|
|
|
476 |
|
477 |
# -------------------------------------------
|
478 |
# Video CAM
|
|
|
488 |
loaded_images = load_examples()
|
489 |
for k in loaded_images.keys():
|
490 |
for image in loaded_images[k]:
|
491 |
+
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image], queue=True, scroll_to_output=True)
|
|
|
|
|
|
|
492 |
|
493 |
if __name__ == "__main__":
|
494 |
demo.queue()
|
495 |
+
print("Starting Gradio server...")
|
496 |
+
demo.launch(show_tips=True)
|
src/Nets.py
CHANGED
@@ -1,47 +1,10 @@
|
|
1 |
-
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
from torchvision import models
|
5 |
|
6 |
-
class SimpleCNN(nn.Module):
|
7 |
-
def __init__(self, k_size=3, pool_size=2, num_classes=1):
|
8 |
-
super(SimpleCNN, self).__init__()
|
9 |
-
self.relu = nn.ReLU()
|
10 |
-
# First Convolutional Layer
|
11 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=k_size, padding=1)
|
12 |
-
self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=k_size, stride=1, padding=1)
|
13 |
-
self.pool1 = nn.MaxPool2d(kernel_size=pool_size)
|
14 |
-
|
15 |
-
# Second Convolutional Layer
|
16 |
-
self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=k_size, stride=1, padding=1)
|
17 |
-
self.conv4 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=k_size, stride=1, padding=1)
|
18 |
-
self.pool2 = nn.MaxPool2d(kernel_size=pool_size)
|
19 |
-
|
20 |
-
self.conv5 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
21 |
-
self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
22 |
-
self.pool3 = nn.MaxPool2d(kernel_size=pool_size)
|
23 |
-
|
24 |
-
self.conv7 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
25 |
-
self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=k_size, stride=1, padding=1)
|
26 |
-
self.pool4 = nn.MaxPool2d(kernel_size=pool_size)
|
27 |
-
|
28 |
-
# Fully Connected Layers
|
29 |
-
self.fc = nn.Linear(64*14*14, num_classes) # Adjust the input features based on your input image size
|
30 |
-
|
31 |
-
def forward(self, x):
|
32 |
-
x = self.pool1(self.relu(self.conv2(self.relu(self.conv1(x)))))
|
33 |
-
x = self.pool2(self.relu(self.conv4(self.relu(self.conv3(x)))))
|
34 |
-
x = self.pool3(self.relu(self.conv6(self.relu(self.conv5(x)))))
|
35 |
-
x = self.pool4(self.relu(self.conv8(self.relu(self.conv7(x)))))
|
36 |
-
# print(x.shape)
|
37 |
-
x = x.view(x.size(0), -1)
|
38 |
-
x = self.fc(x)
|
39 |
-
return x
|
40 |
-
|
41 |
class CustomResNet18(nn.Module):
|
42 |
def __init__(self, num_classes=11):
|
43 |
super(CustomResNet18, self).__init__()
|
44 |
-
self.resnet = models.
|
45 |
num_features = self.resnet.fc.in_features
|
46 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
47 |
|
|
|
|
|
1 |
import torch.nn as nn
|
|
|
2 |
from torchvision import models
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
class CustomResNet18(nn.Module):
|
5 |
def __init__(self, num_classes=11):
|
6 |
super(CustomResNet18, self).__init__()
|
7 |
+
self.resnet = models.resnet18(pretrained=True)
|
8 |
num_features = self.resnet.fc.in_features
|
9 |
self.resnet.fc = nn.Linear(num_features, num_classes)
|
10 |
|
src/Roboto-Regular.ttf
ADDED
Binary file (515 kB). View file
|
|
src/cache/val_df.csv
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
src/examples/{false_predicted/squirrel.jpg β false/bee.jpg}
RENAMED
File without changes
|
src/examples/{false_predicted/chimpanzee.jpg β false/coyote.jpg}
RENAMED
File without changes
|
src/examples/{true_predicted/cat.jpg β false/donkey.jpg}
RENAMED
File without changes
|
src/examples/false/goat.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/false/hornbill.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/false_predicted/starfish.jpg
DELETED
Git LFS Details
|
src/examples/true/dolphin.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/true/dragonfly.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/{false_predicted β true}/koala.jpg
RENAMED
File without changes
|
src/examples/{false_predicted β true}/sheep.jpg
RENAMED
File without changes
|
src/examples/true/squid.jpg
ADDED
![]() |
Git LFS Details
|
src/examples/true_predicted/cockroach.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/flamingo.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/gorilla.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/grasshopper.jpg
DELETED
Git LFS Details
|
src/gradio_blocks.py
CHANGED
@@ -29,7 +29,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
|
|
29 |
)
|
30 |
|
31 |
video_layer = gr.Radio(
|
32 |
-
|
33 |
label="Layer",
|
34 |
value="layer4",
|
35 |
interactive=True,
|
@@ -48,7 +48,7 @@ def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gra
|
|
48 |
colormap = gr.Dropdown(
|
49 |
choices=list(CV2_COLORMAPS.keys()),
|
50 |
label="Colormap",
|
51 |
-
value="
|
52 |
interactive=True,
|
53 |
scale=2,
|
54 |
)
|
|
|
29 |
)
|
30 |
|
31 |
video_layer = gr.Radio(
|
32 |
+
[f"layer{i}" for i in range(1, 5)],
|
33 |
label="Layer",
|
34 |
value="layer4",
|
35 |
interactive=True,
|
|
|
48 |
colormap = gr.Dropdown(
|
49 |
choices=list(CV2_COLORMAPS.keys()),
|
50 |
label="Colormap",
|
51 |
+
value="Inferno",
|
52 |
interactive=True,
|
53 |
scale=2,
|
54 |
)
|
src/header.md
CHANGED
@@ -2,9 +2,9 @@
|
|
2 |
|
3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
4 |
|
5 |
-
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
|
6 |
|
7 |
-
The employed model is
|
8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
9 |
|
10 |
## Usage π¦
|
|
|
2 |
|
3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
4 |
|
5 |
+
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. To add a little more animals to the data, we added an additional 21 unique classes, so we were now working with our own 111-animals dataset. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
|
6 |
|
7 |
+
The employed model is ResNet18, which was trained on the dataset using transfer learning techniques.
|
8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
9 |
|
10 |
## Usage π¦
|
src/results/gradcam_video.mp4
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d88ec14ff35116bf5d8bd65454616aba242d8f79bde4dcbd717aabbcc910670a
|
3 |
+
size 917687
|
src/results/infer_image.png
CHANGED
![]() |
Git LFS Details
|
![]() |
Git LFS Details
|
src/results/models/best_model.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6a3f852efacebef8dee4ba74c0a73a7f33bf2180c4272dbf233a5c6157d7531
|
3 |
+
size 45015274
|
src/util.py
CHANGED
@@ -1,83 +1,9 @@
|
|
1 |
import torchvision.transforms as transforms
|
2 |
-
from torch.utils.data import DataLoader, Dataset
|
3 |
-
from sklearn.preprocessing import LabelEncoder
|
4 |
-
from tqdm import tqdm
|
5 |
-
from PIL import Image
|
6 |
import torch
|
7 |
-
import imagehash
|
8 |
-
ImageCache = None
|
9 |
|
10 |
-
class AnimalDataset(Dataset):
|
11 |
-
def __init__(self, df, transform=None):
|
12 |
-
self.paths = df["path"].values
|
13 |
-
self.targets = df["target"].values
|
14 |
-
self.encoded_target = df['encoded_target'].values
|
15 |
-
self.transform = transform
|
16 |
-
self.images = []
|
17 |
-
for path in tqdm(self.paths):
|
18 |
-
self.images.append(Image.open(path).convert("RGB").resize((224, 224)))
|
19 |
-
|
20 |
-
def __len__(self):
|
21 |
-
return len(self.paths)
|
22 |
-
|
23 |
-
def __getitem__(self, idx):
|
24 |
-
img = self.images[idx]
|
25 |
-
if self.transform:
|
26 |
-
img = self.transform(img)
|
27 |
-
target = self.targets[idx]
|
28 |
-
encoded_target = torch.tensor(self.encoded_target[idx]).type(torch.LongTensor)
|
29 |
-
return img, encoded_target, target
|
30 |
-
|
31 |
-
train_transform = transforms.Compose([
|
32 |
-
transforms.Resize((224,224)),
|
33 |
-
transforms.RandomHorizontalFlip(),
|
34 |
-
transforms.RandomRotation(10),
|
35 |
-
transforms.ToTensor(),
|
36 |
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
37 |
-
])
|
38 |
# Define the transformation pipeline
|
39 |
transform = transforms.Compose([
|
40 |
transforms.Resize((224,224)),
|
41 |
transforms.ToTensor(), # Convert the images to PyTorch tensors
|
42 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
43 |
-
])
|
44 |
-
|
45 |
-
class CustomImageCache:
|
46 |
-
def __init__(self, cache_size=50, debug=False):
|
47 |
-
self.cache = dict()
|
48 |
-
self.cache_size = 50
|
49 |
-
self.debug = debug
|
50 |
-
self.cache_hits = 0
|
51 |
-
self.cache_misses = 0
|
52 |
-
|
53 |
-
def __getitem__(self, image):
|
54 |
-
if isinstance(image, dict):
|
55 |
-
# Its the image and a mask as pillow both -> Combine them to one image
|
56 |
-
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
57 |
-
key = imagehash.average_hash(image)
|
58 |
-
|
59 |
-
if key in self.cache:
|
60 |
-
if self.debug: print("Cache hit!")
|
61 |
-
self.cache_hits += 1
|
62 |
-
return self.cache[key]
|
63 |
-
else:
|
64 |
-
if self.debug: print("Cache miss!")
|
65 |
-
self.cache_misses += 1
|
66 |
-
if len(self.cache.keys()) >= self.cache_size:
|
67 |
-
if self.debug: print("Cache full, popping item!")
|
68 |
-
self.cache.popitem()
|
69 |
-
self.cache[key] = image
|
70 |
-
return self.cache[key]
|
71 |
-
|
72 |
-
def __len__(self):
|
73 |
-
return len(self.cache.keys())
|
74 |
-
|
75 |
-
def print_info(self):
|
76 |
-
print(f"Cache size: {len(self)}")
|
77 |
-
print(f"Cache hits: {self.cache_hits}")
|
78 |
-
print(f"Cache misses: {self.cache_misses}")
|
79 |
-
|
80 |
-
def imageCacheWrapper(fn):
|
81 |
-
def wrapper(image):
|
82 |
-
return fn(ImageCache[image])
|
83 |
-
return wrapper
|
|
|
1 |
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
|
2 |
import torch
|
|
|
|
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
# Define the transformation pipeline
|
5 |
transform = transforms.Compose([
|
6 |
transforms.Resize((224,224)),
|
7 |
transforms.ToTensor(), # Convert the images to PyTorch tensors
|
8 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
9 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|