Spaces:
Sleeping
Sleeping
Fixes
Browse files- app.py +202 -100
- requirements.txt +0 -0
- requirements_old.txt +0 -0
- src/example_videos/jellyfish_-_110877 (360p).mp4 +0 -3
- src/example_videos/monarch_-_327 (360p).mp4 +0 -3
- src/example_videos/pexels-zlatin-georgiev-5607745 (240p).mp4 +0 -3
- src/example_videos/pexels-zlatin-georgiev-7173031 (240p).mp4 +0 -3
- src/example_videos/pexels_videos_2556839 (240p).mp4 +0 -3
- src/example_videos/squirrel_on_a_wood (360p).mp4 +0 -3
- src/examples/AI_Generated/goat (2).png +0 -3
- src/examples/AI_Generated/koala.png +0 -3
- src/examples/AI_Generated/rabbit.png +0 -3
- src/examples/AI_Generated/rhinoceros.png +0 -3
- src/examples/AI_Generated/swan.png +0 -3
- src/examples/AI_Generated/woodpecker.png +0 -3
- src/examples/false_predicted/boar.jpg +0 -3
- src/examples/false_predicted/dolphin.jpg +0 -3
- src/examples/false_predicted/horse.jpg +0 -3
- src/examples/false_predicted/sparrow.jpg +0 -3
- src/examples/others/Tiger-fuers-Wohnzimmer-In-Hybridkatzen-steckt-ein-Stueck-Wildnis-2.jpg +0 -3
- src/examples/true_predicted/dragonfly.jpg +0 -3
- src/examples/true_predicted/goat.jpg +0 -3
- src/examples/true_predicted/panda.jpg +0 -3
- src/examples/true_predicted/rat.jpg +0 -3
- src/examples/true_predicted/wombat.jpg +0 -3
- src/gradio_blocks.py +36 -19
- src/header.md +7 -5
- src/results/gradcam_video.mp4 +2 -2
- src/{example_videos/butterfly_-_38947 (360p).mp4 → results/infer_image.png} +2 -2
- src/util.py +41 -1
app.py
CHANGED
@@ -1,49 +1,80 @@
|
|
1 |
import copy
|
2 |
import os
|
3 |
import sys
|
|
|
4 |
sys.path.append('src')
|
|
|
5 |
from collections import defaultdict
|
6 |
from functools import lru_cache
|
|
|
|
|
7 |
import gradio as gr
|
8 |
-
import
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
11 |
import torch
|
12 |
from deep_translator import GoogleTranslator
|
|
|
13 |
from Nets import CustomResNet18
|
14 |
from PIL import Image
|
15 |
-
|
16 |
-
from
|
17 |
-
from
|
|
|
|
|
18 |
from tqdm import tqdm
|
19 |
-
|
20 |
-
from
|
21 |
-
import cv2
|
22 |
-
import ffmpeg
|
23 |
-
import shutil
|
24 |
-
import mediapy
|
25 |
|
|
|
26 |
ffmpeg_path = shutil.which('ffmpeg')
|
27 |
mediapy.set_ffmpeg(ffmpeg_path)
|
28 |
|
29 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
30 |
-
IMAGES_PER_ROW =
|
31 |
|
32 |
MAXIMAL_FRAMES = 1000
|
33 |
-
BATCHES_TO_PROCESS =
|
34 |
OUTPUT_FPS = 10
|
35 |
-
MAX_OUT_FRAMES =
|
|
|
|
|
|
|
36 |
|
37 |
CAM_METHODS = {
|
38 |
"GradCAM": GradCAM,
|
39 |
-
"GradCAM++":
|
40 |
"XGradCAM": XGradCAM,
|
41 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
}
|
43 |
|
44 |
-
|
45 |
-
model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
46 |
-
cam_model = copy.deepcopy(model)
|
47 |
data_df = pd.read_csv('src/cache/val_df.csv')
|
48 |
|
49 |
C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict()
|
@@ -58,16 +89,19 @@ def get_class_idx(name):
|
|
58 |
|
59 |
@lru_cache(maxsize=100)
|
60 |
def get_translated(to_translate):
|
61 |
-
# return "ssss"
|
62 |
return GoogleTranslator(source="en", target="de").translate(to_translate)
|
63 |
for idx in range(90): get_translated(get_class_name(idx))
|
64 |
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
image = transform(image)
|
68 |
image = image.unsqueeze(0)
|
69 |
with torch.no_grad():
|
70 |
-
output =
|
71 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
72 |
ret = defaultdict(float)
|
73 |
for idx, prob in enumerate(distribution[0]):
|
@@ -75,32 +109,51 @@ def infer_image(image, image_sketch):
|
|
75 |
ret[animal] = prob.item()
|
76 |
return ret
|
77 |
|
78 |
-
def gradcam(image,
|
79 |
-
|
80 |
-
|
81 |
-
elif layer == 'layer2': layers = [model.resnet.layer2]
|
82 |
-
elif layer == 'layer3': layers = [model.resnet.layer3]
|
83 |
-
elif layer == 'layer4': layers = [model.resnet.layer4]
|
84 |
-
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
output = model(img_tensor)
|
90 |
-
class_to_explain = output.squeeze(0).argmax().item() if specific_class == "Predicted Class" else get_class_idx(specific_class)
|
91 |
-
activation_map = cam(class_to_explain, output)
|
92 |
-
result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha)
|
93 |
-
cam.remove_hooks()
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
103 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
|
|
|
|
|
|
|
|
104 |
video = cv2.VideoCapture(video)
|
105 |
fps = int(video.get(cv2.CAP_PROP_FPS))
|
106 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
@@ -127,36 +180,32 @@ def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_cla
|
|
127 |
print(f'Frames to process: {len(frames)}')
|
128 |
|
129 |
processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
|
130 |
-
# generate lists in lists for the images for batch processing.
|
131 |
batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)]
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
elif layer == 'layer2': layers = [model.resnet.layer2]
|
136 |
-
elif layer == 'layer3': layers = [model.resnet.layer3]
|
137 |
-
elif layer == 'layer4': layers = [model.resnet.layer4]
|
138 |
-
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
|
139 |
-
cam = CAM_METHODS[cam_method](model, target_layer=layers)
|
140 |
results = list()
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
# save video
|
153 |
-
# fourcc = cv2.VideoWriter_fourcc(*'AVC1')
|
154 |
-
# fourcc = cv2.VideoWriter_fourcc(*'MP4V')
|
155 |
-
# fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
156 |
-
# size = (results[0].shape[1], results[0].shape[0])
|
157 |
-
# video = cv2.VideoWriter('src/results/gradcam_video.mp4', fourcc, OUTPUT_FPS, size)
|
158 |
-
# for frame in results:
|
159 |
-
# video.write(frame)
|
160 |
mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS)
|
161 |
video.release()
|
162 |
return 'src/results/gradcam_video.mp4'
|
@@ -190,10 +239,15 @@ def load_examples():
|
|
190 |
for j in range(IMAGES_PER_ROW):
|
191 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
192 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
|
|
|
|
|
|
|
|
|
|
193 |
loaded_images[image_type].append(
|
194 |
gr.Image(
|
195 |
-
value=
|
196 |
-
label=
|
197 |
type="pil",
|
198 |
interactive=False,
|
199 |
elem_classes=["selectable_images"],
|
@@ -224,22 +278,13 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
224 |
# INPUT IMAGE
|
225 |
# -------------------------------------------
|
226 |
with gr.Row():
|
227 |
-
with gr.
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
with gr.Tab("Draw Image"):
|
235 |
-
with gr.Row(variant="panel", equal_height=True):
|
236 |
-
user_image_sketched = gr.Image(
|
237 |
-
type="pil",
|
238 |
-
source="canvas",
|
239 |
-
tool="color-sketch",
|
240 |
-
label="Draw Your Own Image",
|
241 |
-
info="You can also draw your own image for prediction.",
|
242 |
-
)
|
243 |
|
244 |
# -------------------------------------------
|
245 |
# TOOLS
|
@@ -257,7 +302,7 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
257 |
scale=5,
|
258 |
)
|
259 |
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
|
260 |
-
predict_mode_button.click(fn=infer_image, inputs=[user_image
|
261 |
|
262 |
# -------------------------------------------
|
263 |
# EXPLAIN
|
@@ -265,16 +310,20 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
265 |
with gr.Tab("Explain Image"):
|
266 |
with gr.Row():
|
267 |
with gr.Column():
|
|
|
268 |
cam_method = gr.Radio(
|
269 |
list(CAM_METHODS.keys()),
|
270 |
label="GradCAM Method",
|
|
|
271 |
value="GradCAM",
|
272 |
interactive=True,
|
273 |
scale=2,
|
274 |
)
|
275 |
-
cam_method.description = "Here you can choose the GradCAM method."
|
276 |
-
cam_method.description_place = "left"
|
277 |
|
|
|
|
|
|
|
|
|
278 |
alpha = gr.Slider(
|
279 |
minimum=.1,
|
280 |
maximum=.9,
|
@@ -283,46 +332,99 @@ with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo:
|
|
283 |
step=.1,
|
284 |
label="Alpha",
|
285 |
scale=1,
|
|
|
286 |
)
|
287 |
-
alpha.description = "Here you can choose the alpha value."
|
288 |
-
alpha.description_place = "left"
|
289 |
|
|
|
|
|
|
|
|
|
|
|
290 |
layer = gr.Radio(
|
291 |
-
|
292 |
label="Layer",
|
293 |
value="layer4",
|
294 |
interactive=True,
|
295 |
scale=2,
|
|
|
296 |
)
|
297 |
-
layer.description = "Here you can choose the layer to visualize."
|
298 |
-
layer.description_place = "left"
|
299 |
|
|
|
|
|
|
|
|
|
|
|
300 |
animal_to_explain = gr.Dropdown(
|
301 |
choices=["Predicted Class"] + ALL_CLASSES,
|
302 |
label="Animal",
|
303 |
value="Predicted Class",
|
304 |
interactive=True,
|
305 |
scale=2,
|
|
|
306 |
)
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
310 |
with gr.Column():
|
311 |
output_cam = gr.Image(
|
312 |
type="pil",
|
313 |
label="GradCAM",
|
314 |
-
info="GradCAM visualization"
|
315 |
-
|
316 |
)
|
317 |
|
318 |
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
|
319 |
-
gradcam_mode_button.click(fn=gradcam, inputs=[user_image,
|
320 |
|
321 |
# -------------------------------------------
|
322 |
# Video CAM
|
323 |
# -------------------------------------------
|
324 |
with gr.Tab("Explain Video"):
|
325 |
-
build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video)
|
326 |
|
327 |
# -------------------------------------------
|
328 |
# EXAMPLES
|
|
|
1 |
import copy
|
2 |
import os
|
3 |
import sys
|
4 |
+
|
5 |
sys.path.append('src')
|
6 |
+
import shutil
|
7 |
from collections import defaultdict
|
8 |
from functools import lru_cache
|
9 |
+
|
10 |
+
import cv2
|
11 |
import gradio as gr
|
12 |
+
import mediapy
|
13 |
import numpy as np
|
14 |
import pandas as pd
|
15 |
import torch
|
16 |
from deep_translator import GoogleTranslator
|
17 |
+
from gradio_blocks import build_video_to_camvideo
|
18 |
from Nets import CustomResNet18
|
19 |
from PIL import Image
|
20 |
+
|
21 |
+
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
22 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
23 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
24 |
+
|
25 |
from tqdm import tqdm
|
26 |
+
import util
|
27 |
+
from util import transform, CustomImageCache, imageCacheWrapper
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
util.ImageCache = CustomImageCache(60, False)
|
30 |
ffmpeg_path = shutil.which('ffmpeg')
|
31 |
mediapy.set_ffmpeg(ffmpeg_path)
|
32 |
|
33 |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
|
34 |
+
IMAGES_PER_ROW = 5
|
35 |
|
36 |
MAXIMAL_FRAMES = 1000
|
37 |
+
BATCHES_TO_PROCESS = 20
|
38 |
OUTPUT_FPS = 10
|
39 |
+
MAX_OUT_FRAMES = 60
|
40 |
+
|
41 |
+
MODEL = CustomResNet18(90).eval()
|
42 |
+
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
|
43 |
|
44 |
CAM_METHODS = {
|
45 |
"GradCAM": GradCAM,
|
46 |
+
"GradCAM++": GradCAMPlusPlus,
|
47 |
"XGradCAM": XGradCAM,
|
48 |
+
"HiResCAM": HiResCAM,
|
49 |
+
"EigenCAM": EigenCAM
|
50 |
+
}
|
51 |
+
|
52 |
+
LAYERS = {
|
53 |
+
'layer1': MODEL.resnet.layer1,
|
54 |
+
'layer2': MODEL.resnet.layer2,
|
55 |
+
'layer3': MODEL.resnet.layer3,
|
56 |
+
'layer4': MODEL.resnet.layer4,
|
57 |
+
'all': [MODEL.resnet.layer1, MODEL.resnet.layer2, MODEL.resnet.layer3, MODEL.resnet.layer4],
|
58 |
+
'layer3+4': [MODEL.resnet.layer3, MODEL.resnet.layer4]
|
59 |
+
}
|
60 |
+
|
61 |
+
CV2_COLORMAPS = {
|
62 |
+
"Autumn": cv2.COLORMAP_AUTUMN,
|
63 |
+
"Bone": cv2.COLORMAP_BONE,
|
64 |
+
"Jet": cv2.COLORMAP_JET,
|
65 |
+
"Winter": cv2.COLORMAP_WINTER,
|
66 |
+
"Rainbow": cv2.COLORMAP_RAINBOW,
|
67 |
+
"Ocean": cv2.COLORMAP_OCEAN,
|
68 |
+
"Summer": cv2.COLORMAP_SUMMER,
|
69 |
+
"Pink": cv2.COLORMAP_PINK,
|
70 |
+
"Hot": cv2.COLORMAP_HOT,
|
71 |
+
"Magma": cv2.COLORMAP_MAGMA,
|
72 |
+
"Inferno": cv2.COLORMAP_INFERNO,
|
73 |
+
"Plasma": cv2.COLORMAP_PLASMA,
|
74 |
+
"Twilight": cv2.COLORMAP_TWILIGHT,
|
75 |
}
|
76 |
|
77 |
+
# cam_model = copy.deepcopy(model)
|
|
|
|
|
78 |
data_df = pd.read_csv('src/cache/val_df.csv')
|
79 |
|
80 |
C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict()
|
|
|
89 |
|
90 |
@lru_cache(maxsize=100)
|
91 |
def get_translated(to_translate):
|
|
|
92 |
return GoogleTranslator(source="en", target="de").translate(to_translate)
|
93 |
for idx in range(90): get_translated(get_class_name(idx))
|
94 |
|
95 |
+
@imageCacheWrapper
|
96 |
+
def infer_image(image):
|
97 |
+
if isinstance(image, dict):
|
98 |
+
# Its the image and a mask as pillow both -> Combine them to one image
|
99 |
+
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
100 |
+
image.save('src/results/infer_image.png')
|
101 |
image = transform(image)
|
102 |
image = image.unsqueeze(0)
|
103 |
with torch.no_grad():
|
104 |
+
output = MODEL(image)
|
105 |
distribution = torch.nn.functional.softmax(output, dim=1)
|
106 |
ret = defaultdict(float)
|
107 |
for idx, prob in enumerate(distribution[0]):
|
|
|
109 |
ret[animal] = prob.item()
|
110 |
return ret
|
111 |
|
112 |
+
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
113 |
+
if image is None:
|
114 |
+
raise gr.Error("Please upload an image.")
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
if isinstance(image, dict):
|
117 |
+
# Its the image and a mask as pillow both -> Combine them to one image
|
118 |
+
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
|
|
|
|
|
|
|
|
|
|
119 |
|
120 |
+
if colormap not in CV2_COLORMAPS.keys():
|
121 |
+
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
122 |
+
else:
|
123 |
+
colormap = CV2_COLORMAPS[colormap]
|
124 |
+
|
125 |
+
image_width, image_height = image.size
|
126 |
+
if image_width > 4000 or image_height > 4000:
|
127 |
+
raise gr.Error("The image is too big. The maximal size is 4000x4000.")
|
128 |
+
|
129 |
+
|
130 |
+
MODEL.eval()
|
131 |
+
layers = LAYERS[layer]
|
132 |
+
|
133 |
+
image_tensor = transform(image).unsqueeze(0)
|
134 |
+
targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None
|
135 |
|
136 |
+
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
137 |
+
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth)
|
138 |
+
|
139 |
+
grayscale_cam = grayscale_cam[0, :]
|
140 |
+
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
|
141 |
+
image = np.float32(image)
|
142 |
+
visualization = None
|
143 |
+
if BWHighlight:
|
144 |
+
image = image * grayscale_cam[..., np.newaxis]
|
145 |
+
visualization = image.astype(np.uint8)
|
146 |
+
else:
|
147 |
+
image = image / 255
|
148 |
+
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
149 |
+
return Image.fromarray(visualization)
|
150 |
|
151 |
+
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"):
|
152 |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES
|
153 |
+
if colormap not in CV2_COLORMAPS.keys():
|
154 |
+
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.")
|
155 |
+
else:
|
156 |
+
colormap = CV2_COLORMAPS[colormap]
|
157 |
video = cv2.VideoCapture(video)
|
158 |
fps = int(video.get(cv2.CAP_PROP_FPS))
|
159 |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps
|
|
|
180 |
print(f'Frames to process: {len(frames)}')
|
181 |
|
182 |
processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames]
|
183 |
+
# generate lists in lists for the images for batch processing. BATCHES_TO_PROCESS images per inner list
|
184 |
batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)]
|
185 |
|
186 |
+
MODEL.eval()
|
187 |
+
layers = LAYERS[layer]
|
|
|
|
|
|
|
|
|
|
|
188 |
results = list()
|
189 |
+
targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None
|
190 |
+
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam:
|
191 |
+
for i, batch in enumerate(tqdm(batched)):
|
192 |
+
images_tensor = torch.stack([transform(image) for image in batch])
|
193 |
+
|
194 |
+
grayscale_cam = cam(input_tensor=images_tensor, targets=targets, aug_smooth=False, eigen_smooth=use_eigen_smooth)
|
195 |
+
for i, image in enumerate(batch):
|
196 |
+
_grayscale_cam = grayscale_cam[i, :]
|
197 |
+
_grayscale_cam = cv2.resize(_grayscale_cam, (width, height), interpolation=cv2.INTER_LINEAR)
|
198 |
+
image = np.float32(image)
|
199 |
+
visualization = None
|
200 |
+
if BWHighlight:
|
201 |
+
image = image * _grayscale_cam[..., np.newaxis]
|
202 |
+
visualization = image.astype(np.uint8)
|
203 |
+
else:
|
204 |
+
image = image / 255
|
205 |
+
visualization = show_cam_on_image(image, _grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap)
|
206 |
+
results.append(visualization)
|
207 |
|
208 |
# save video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS)
|
210 |
video.release()
|
211 |
return 'src/results/gradcam_video.mp4'
|
|
|
239 |
for j in range(IMAGES_PER_ROW):
|
240 |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break
|
241 |
image = images_to_load[i * IMAGES_PER_ROW + j]
|
242 |
+
name = f"{image.split('.')[0]} ({get_translated(image.split('.')[0])})"
|
243 |
+
image = Image.open(os.path.join(full_path, image))
|
244 |
+
# scale so that the longest side is 600px
|
245 |
+
scale = 600 / max(image.size)
|
246 |
+
image = image.resize((int(image.size[0] * scale), int(image.size[1] * scale)))
|
247 |
loaded_images[image_type].append(
|
248 |
gr.Image(
|
249 |
+
value=image,
|
250 |
+
label=name,
|
251 |
type="pil",
|
252 |
interactive=False,
|
253 |
elem_classes=["selectable_images"],
|
|
|
278 |
# INPUT IMAGE
|
279 |
# -------------------------------------------
|
280 |
with gr.Row():
|
281 |
+
with gr.Row(variant="panel", equal_height=True):
|
282 |
+
user_image = gr.Image(
|
283 |
+
type="pil",
|
284 |
+
label="Upload Your Own Image",
|
285 |
+
tool="sketch",
|
286 |
+
interactive=True,
|
287 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
|
289 |
# -------------------------------------------
|
290 |
# TOOLS
|
|
|
302 |
scale=5,
|
303 |
)
|
304 |
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
|
305 |
+
predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True)
|
306 |
|
307 |
# -------------------------------------------
|
308 |
# EXPLAIN
|
|
|
310 |
with gr.Tab("Explain Image"):
|
311 |
with gr.Row():
|
312 |
with gr.Column():
|
313 |
+
_info = "There are different GradCAM methods. You can read more about them here: (https://github.com/jacobgil/pytorch-grad-cam#references)."
|
314 |
cam_method = gr.Radio(
|
315 |
list(CAM_METHODS.keys()),
|
316 |
label="GradCAM Method",
|
317 |
+
info=_info,
|
318 |
value="GradCAM",
|
319 |
interactive=True,
|
320 |
scale=2,
|
321 |
)
|
|
|
|
|
322 |
|
323 |
+
_info = """
|
324 |
+
The alpha value is used to blend the original image with the GradCAM visualization. If you choose a value of 0.5 the original image and the GradCAM visualization will be blended equally.
|
325 |
+
If you choose a value of 0.1 the original image will be barely visible and if you choose a value of 0.9 the GradCAM visualization will be barely visible.
|
326 |
+
"""
|
327 |
alpha = gr.Slider(
|
328 |
minimum=.1,
|
329 |
maximum=.9,
|
|
|
332 |
step=.1,
|
333 |
label="Alpha",
|
334 |
scale=1,
|
335 |
+
info=_info
|
336 |
)
|
|
|
|
|
337 |
|
338 |
+
_info = """
|
339 |
+
The layer is used to choose the layer of the ResNet50 model. The GradCAM visualization will be based on this layer.
|
340 |
+
Best to choose is the last layer (layer4) because it is the layer with the most information before the final prediction. This makes the GradCAM visualization the most meaningful.
|
341 |
+
If all layers are chosen the GradCAM visualization will be averaged over all layers.
|
342 |
+
"""
|
343 |
layer = gr.Radio(
|
344 |
+
LAYERS.keys(),
|
345 |
label="Layer",
|
346 |
value="layer4",
|
347 |
interactive=True,
|
348 |
scale=2,
|
349 |
+
info=_info
|
350 |
)
|
|
|
|
|
351 |
|
352 |
+
_info = """
|
353 |
+
Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class.
|
354 |
+
If you choose a specific class the GradCAM visualization will be based on this class.
|
355 |
+
For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal.
|
356 |
+
"""
|
357 |
animal_to_explain = gr.Dropdown(
|
358 |
choices=["Predicted Class"] + ALL_CLASSES,
|
359 |
label="Animal",
|
360 |
value="Predicted Class",
|
361 |
interactive=True,
|
362 |
scale=2,
|
363 |
+
info=_info
|
364 |
)
|
365 |
+
|
366 |
+
with gr.Row():
|
367 |
+
_info = """
|
368 |
+
Here you can choose the colormap. Instead of a colormap you can also choose "BW Highlight" to just keep the original image and highlight the important parts of the image.
|
369 |
+
If you select "BW Highlight" the colormap will be ignored.
|
370 |
+
"""
|
371 |
+
colormap = gr.Dropdown(
|
372 |
+
choices=list(CV2_COLORMAPS.keys()),
|
373 |
+
label="Colormap",
|
374 |
+
value="Jet",
|
375 |
+
interactive=True,
|
376 |
+
scale=2,
|
377 |
+
info=_info
|
378 |
+
)
|
379 |
+
|
380 |
+
bw_highlight = gr.Checkbox(
|
381 |
+
label="BW Highlight",
|
382 |
+
value=False,
|
383 |
+
interactive=True,
|
384 |
+
scale=1,
|
385 |
+
)
|
386 |
+
bw_highlight.description = "Here you can choose if you want to highlight the important parts of the image in black and white."
|
387 |
+
|
388 |
+
with gr.Row():
|
389 |
+
_info = """
|
390 |
+
The Eigen Smooth is a method to smooth the GradCAM visualization.
|
391 |
+
"""
|
392 |
+
use_eigen_smooth = gr.Checkbox(
|
393 |
+
label="Eigen Smooth",
|
394 |
+
value=False,
|
395 |
+
interactive=True,
|
396 |
+
scale=1,
|
397 |
+
info=_info
|
398 |
+
)
|
399 |
+
_info = """
|
400 |
+
The Aug Smooth is also a method to smooth the GradCAM visualization. But this method needs a lot of performance and is therefore slow.
|
401 |
+
"""
|
402 |
|
403 |
+
use_aug_smooth = gr.Checkbox(
|
404 |
+
label="Aug Smooth",
|
405 |
+
value=False,
|
406 |
+
interactive=True,
|
407 |
+
scale=1,
|
408 |
+
info=_info
|
409 |
+
)
|
410 |
+
|
411 |
+
|
412 |
with gr.Column():
|
413 |
output_cam = gr.Image(
|
414 |
type="pil",
|
415 |
label="GradCAM",
|
416 |
+
info="GradCAM visualization",
|
417 |
+
scale=5,
|
418 |
)
|
419 |
|
420 |
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
|
421 |
+
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True)
|
422 |
|
423 |
# -------------------------------------------
|
424 |
# Video CAM
|
425 |
# -------------------------------------------
|
426 |
with gr.Tab("Explain Video"):
|
427 |
+
build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video)
|
428 |
|
429 |
# -------------------------------------------
|
430 |
# EXAMPLES
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
requirements_old.txt
ADDED
Binary file (4.01 kB). View file
|
|
src/example_videos/jellyfish_-_110877 (360p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:53e45f6bd34aaeefdabfeddc9d70a4d7670a183137f2469db42f4f90e73ea296
|
3 |
-
size 797977
|
|
|
|
|
|
|
|
src/example_videos/monarch_-_327 (360p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:d0c163fd1ca2ec280c13430f8606d14e8c490980d66b63610ccf3e5af581138e
|
3 |
-
size 428449
|
|
|
|
|
|
|
|
src/example_videos/pexels-zlatin-georgiev-5607745 (240p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:02d82cff5c3c4eb51edd6f07e64e33116cd6889ad45dfb9521aa89406022c539
|
3 |
-
size 470993
|
|
|
|
|
|
|
|
src/example_videos/pexels-zlatin-georgiev-7173031 (240p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:6931230df61b43d063905f96e313ea47ac3b3ca16fd5d749bcd167b07d83ee69
|
3 |
-
size 268559
|
|
|
|
|
|
|
|
src/example_videos/pexels_videos_2556839 (240p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:065667e0d791af82df46c2b7667f17d5a0687fce606fc7e5c216a0e9c3045f76
|
3 |
-
size 402447
|
|
|
|
|
|
|
|
src/example_videos/squirrel_on_a_wood (360p).mp4
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:9af343138b0a589ccee2c9469f97ce55cbee720ffec9084429f33ae1e37e4f12
|
3 |
-
size 2006082
|
|
|
|
|
|
|
|
src/examples/AI_Generated/goat (2).png
DELETED
Git LFS Details
|
src/examples/AI_Generated/koala.png
DELETED
Git LFS Details
|
src/examples/AI_Generated/rabbit.png
DELETED
Git LFS Details
|
src/examples/AI_Generated/rhinoceros.png
DELETED
Git LFS Details
|
src/examples/AI_Generated/swan.png
DELETED
Git LFS Details
|
src/examples/AI_Generated/woodpecker.png
DELETED
Git LFS Details
|
src/examples/false_predicted/boar.jpg
DELETED
Git LFS Details
|
src/examples/false_predicted/dolphin.jpg
DELETED
Git LFS Details
|
src/examples/false_predicted/horse.jpg
DELETED
Git LFS Details
|
src/examples/false_predicted/sparrow.jpg
DELETED
Git LFS Details
|
src/examples/others/Tiger-fuers-Wohnzimmer-In-Hybridkatzen-steckt-ein-Stueck-Wildnis-2.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/dragonfly.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/goat.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/panda.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/rat.jpg
DELETED
Git LFS Details
|
src/examples/true_predicted/wombat.jpg
DELETED
Git LFS Details
|
src/gradio_blocks.py
CHANGED
@@ -3,12 +3,12 @@ import os
|
|
3 |
|
4 |
VIDEOS_PER_ROW = 3
|
5 |
VIDEO_EXAMPLES_PATH = "src/example_videos"
|
6 |
-
def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
|
7 |
with gr.Row():
|
8 |
with gr.Column(scale=2):
|
9 |
gr.Markdown("### Video to GradCAM-Video")
|
10 |
gr.Markdown("Here you can upload a video and visualize the GradCAM.")
|
11 |
-
gr.Markdown("Please note that this can take a while. Also currently only a maximum of
|
12 |
gr.Markdown("The more frames and fps the video has, the longer it takes to process and the result will be more choppy.")
|
13 |
video_cam_method = gr.Radio(
|
14 |
["GradCAM", "GradCAM++"],
|
@@ -17,8 +17,6 @@ def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
|
|
17 |
interactive=True,
|
18 |
scale=2,
|
19 |
)
|
20 |
-
video_cam_method.description = "Here you can choose the GradCAM method."
|
21 |
-
video_cam_method.description_place = "left"
|
22 |
|
23 |
video_alpha = gr.Slider(
|
24 |
minimum=.1,
|
@@ -29,35 +27,54 @@ def build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video):
|
|
29 |
label="Alpha",
|
30 |
scale=1,
|
31 |
)
|
32 |
-
video_alpha.description = "Here you can choose the alpha value."
|
33 |
-
video_alpha.description_place = "left"
|
34 |
|
35 |
video_layer = gr.Radio(
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
interactive=True,
|
40 |
scale=2,
|
41 |
)
|
42 |
-
video_layer.description = "Here you can choose the layer to visualize."
|
43 |
-
video_layer.description_place = "left"
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
value="Predicted Class",
|
49 |
interactive=True,
|
50 |
-
scale=
|
51 |
)
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
with gr.Column(scale=1):
|
55 |
with gr.Column():
|
56 |
video_in = gr.Video(autoplay=False, include_audio=False)
|
57 |
video_out = gr.Video(autoplay=False, include_audio=False)
|
58 |
|
59 |
gif_cam_mode_button = gr.Button(value="Show GradCAM-Video", label="GradCAM", scale=1)
|
60 |
-
gif_cam_mode_button.click(fn=gradcam_video, inputs=[video_in, video_alpha, video_cam_method, video_layer, video_animal_to_explain], outputs=[video_out], queue=True)
|
61 |
|
62 |
with gr.Row():
|
63 |
with gr.Column():
|
|
|
3 |
|
4 |
VIDEOS_PER_ROW = 3
|
5 |
VIDEO_EXAMPLES_PATH = "src/example_videos"
|
6 |
+
def build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video):
|
7 |
with gr.Row():
|
8 |
with gr.Column(scale=2):
|
9 |
gr.Markdown("### Video to GradCAM-Video")
|
10 |
gr.Markdown("Here you can upload a video and visualize the GradCAM.")
|
11 |
+
gr.Markdown("Please note that this can take a while. Also currently only a maximum of 60 frames can be processed. The video will be cut to 60 frames if it is longer. Furthermore, the video can only consist of a maximum of 1000.")
|
12 |
gr.Markdown("The more frames and fps the video has, the longer it takes to process and the result will be more choppy.")
|
13 |
video_cam_method = gr.Radio(
|
14 |
["GradCAM", "GradCAM++"],
|
|
|
17 |
interactive=True,
|
18 |
scale=2,
|
19 |
)
|
|
|
|
|
20 |
|
21 |
video_alpha = gr.Slider(
|
22 |
minimum=.1,
|
|
|
27 |
label="Alpha",
|
28 |
scale=1,
|
29 |
)
|
|
|
|
|
30 |
|
31 |
video_layer = gr.Radio(
|
32 |
+
LAYERS.keys(),
|
33 |
+
label="Layer",
|
34 |
+
value="layer4",
|
35 |
+
interactive=True,
|
36 |
+
scale=2,
|
37 |
+
)
|
38 |
+
|
39 |
+
video_animal_to_explain = gr.Dropdown(
|
40 |
+
choices=["Predicted Class"] + ALL_CLASSES,
|
41 |
+
label="Animal",
|
42 |
+
value="Predicted Class",
|
43 |
+
interactive=True,
|
44 |
+
scale=2,
|
45 |
+
)
|
46 |
+
|
47 |
+
with gr.Row():
|
48 |
+
colormap = gr.Dropdown(
|
49 |
+
choices=list(CV2_COLORMAPS.keys()),
|
50 |
+
label="Colormap",
|
51 |
+
value="Jet",
|
52 |
interactive=True,
|
53 |
scale=2,
|
54 |
)
|
|
|
|
|
55 |
|
56 |
+
bw_highlight = gr.Checkbox(
|
57 |
+
label="BW Highlight",
|
58 |
+
value=False,
|
|
|
59 |
interactive=True,
|
60 |
+
scale=1,
|
61 |
)
|
62 |
+
|
63 |
+
with gr.Row():
|
64 |
+
use_eigen_smooth = gr.Checkbox(
|
65 |
+
label="Eigen Smooth",
|
66 |
+
value=False,
|
67 |
+
interactive=True,
|
68 |
+
scale=1,
|
69 |
+
)
|
70 |
+
|
71 |
with gr.Column(scale=1):
|
72 |
with gr.Column():
|
73 |
video_in = gr.Video(autoplay=False, include_audio=False)
|
74 |
video_out = gr.Video(autoplay=False, include_audio=False)
|
75 |
|
76 |
gif_cam_mode_button = gr.Button(value="Show GradCAM-Video", label="GradCAM", scale=1)
|
77 |
+
gif_cam_mode_button.click(fn=gradcam_video, inputs=[video_in, colormap, use_eigen_smooth, bw_highlight, video_alpha, video_cam_method, video_layer, video_animal_to_explain], outputs=[video_out], queue=True)
|
78 |
|
79 |
with gr.Row():
|
80 |
with gr.Column():
|
src/header.md
CHANGED
@@ -2,15 +2,17 @@
|
|
2 |
|
3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
4 |
|
5 |
-
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified.
|
6 |
|
7 |
-
The employed model is
|
8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
9 |
|
10 |
## Usage 🦎
|
11 |
|
12 |
-
**Predict:** In the "Predict" tab, the model can be applied to
|
13 |
|
14 |
-
**Explain:** Under the "Explain" tab,
|
15 |
|
16 |
-
**
|
|
|
|
|
|
2 |
|
3 |
This project was created by [Ilyesse](https://github.com/ilyii) and [Gabriel](https://github.com/Gabriel9753) as part of the Explainable Machine Learning module at the [University of Applied Sciences Karlsruhe](https://www.h-ka.de/).
|
4 |
|
5 |
+
The dataset used in this project is the [Animal Image Dataset](https://www.kaggle.com/datasets/iamsouravbanerjee/animal-image-dataset-90-different-animals) from Kaggle, comprising 90 different animal species that needed to be classified. We also added approx. 1000 AI generated images for all classes to get a more diverse dataset and also improve the performance of the model.
|
6 |
|
7 |
+
The employed model is ResNet50, which was trained on the dataset using transfer learning techniques.
|
8 |
Translation of animal names by [deep-translator](https://pypi.org/project/deep-translator/).
|
9 |
|
10 |
## Usage 🦎
|
11 |
|
12 |
+
**Predict:** In the "Predict" tab, the model can be applied to the uploaded image to obtain a prediction. This is also interessting to get the animal for the following explaination.
|
13 |
|
14 |
+
**Explain Image:** Under the "Explain Image" tab, you can get an explanation of the prediction in the form of a generated heatmap. We are using [this](https://github.com/jacobgil/pytorch-grad-cam) cool implementation of Grad-CAM to generate the heatmaps!
|
15 |
|
16 |
+
**Explain Video**: The same as above, but for short videos. The video is split into frames and the model is applied to each frame. The resulting heatmaps are then combined to a video again.
|
17 |
+
|
18 |
+
**Example Images:** The "Example Images" section allows users to view sample images from the dataset and another sources. Some of the images and videos are from [1](https://www.pexels.com/), [2](https://pixabay.com/) and [3](https://www.bing.com/create).
|
src/results/gradcam_video.mp4
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9617d53ad717194350c99f6b1d2a172f01e712e4109c76b16fe3f70f32c4570
|
3 |
+
size 772080
|
src/{example_videos/butterfly_-_38947 (360p).mp4 → results/infer_image.png}
RENAMED
File without changes
|
src/util.py
CHANGED
@@ -4,7 +4,8 @@ from sklearn.preprocessing import LabelEncoder
|
|
4 |
from tqdm import tqdm
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
-
|
|
|
8 |
|
9 |
class AnimalDataset(Dataset):
|
10 |
def __init__(self, df, transform=None):
|
@@ -41,3 +42,42 @@ transform = transforms.Compose([
|
|
41 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
42 |
])
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from tqdm import tqdm
|
5 |
from PIL import Image
|
6 |
import torch
|
7 |
+
import imagehash
|
8 |
+
ImageCache = None
|
9 |
|
10 |
class AnimalDataset(Dataset):
|
11 |
def __init__(self, df, transform=None):
|
|
|
42 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
43 |
])
|
44 |
|
45 |
+
class CustomImageCache:
|
46 |
+
def __init__(self, cache_size=50, debug=False):
|
47 |
+
self.cache = dict()
|
48 |
+
self.cache_size = 50
|
49 |
+
self.debug = debug
|
50 |
+
self.cache_hits = 0
|
51 |
+
self.cache_misses = 0
|
52 |
+
|
53 |
+
def __getitem__(self, image):
|
54 |
+
if isinstance(image, dict):
|
55 |
+
# Its the image and a mask as pillow both -> Combine them to one image
|
56 |
+
image = Image.blend(image["image"], image["mask"], alpha=0.5)
|
57 |
+
key = imagehash.average_hash(image)
|
58 |
+
|
59 |
+
if key in self.cache:
|
60 |
+
if self.debug: print("Cache hit!")
|
61 |
+
self.cache_hits += 1
|
62 |
+
return self.cache[key]
|
63 |
+
else:
|
64 |
+
if self.debug: print("Cache miss!")
|
65 |
+
self.cache_misses += 1
|
66 |
+
if len(self.cache.keys()) >= self.cache_size:
|
67 |
+
if self.debug: print("Cache full, popping item!")
|
68 |
+
self.cache.popitem()
|
69 |
+
self.cache[key] = image
|
70 |
+
return self.cache[key]
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self.cache.keys())
|
74 |
+
|
75 |
+
def print_info(self):
|
76 |
+
print(f"Cache size: {len(self)}")
|
77 |
+
print(f"Cache hits: {self.cache_hits}")
|
78 |
+
print(f"Cache misses: {self.cache_misses}")
|
79 |
+
|
80 |
+
def imageCacheWrapper(fn):
|
81 |
+
def wrapper(image):
|
82 |
+
return fn(ImageCache[image])
|
83 |
+
return wrapper
|