Spaces:
Sleeping
Sleeping
from concurrent.futures import ThreadPoolExecutor | |
import copy | |
import os | |
import sys | |
sys.path.append('src') | |
import shutil | |
from collections import defaultdict | |
from functools import lru_cache | |
import cv2 | |
import gradio as gr | |
import mediapy | |
import numpy as np | |
import pandas as pd | |
import torch | |
from deep_translator import GoogleTranslator | |
from gradio_blocks import build_video_to_camvideo | |
from Nets import CustomResNet18 | |
from PIL import Image, ImageDraw, ImageFont | |
from pytorch_grad_cam import GradCAM, HiResCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from tqdm import tqdm | |
from util import transform | |
font = ImageFont.truetype("src/Roboto-Regular.ttf", 16) | |
ffmpeg_path = shutil.which('ffmpeg') | |
mediapy.set_ffmpeg(ffmpeg_path) | |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples') | |
IMAGES_PER_ROW = 5 | |
MAXIMAL_FRAMES = 700 | |
BATCHES_TO_PROCESS = 20 | |
OUTPUT_FPS = 10 | |
MAX_OUT_FRAMES = 70 | |
MODEL = CustomResNet18(111).eval() | |
MODEL.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu'))) | |
LANGUAGES_TO_SELECT = { | |
"None": None, | |
"German": "de", | |
"French": "fr", | |
"Spanish": "es", | |
"Italian": "it", | |
"Finnish": "fi", | |
"Ukrainian": "uk", | |
"Japanese": "ja", | |
"Hebrew": "iw" | |
} | |
CAM_METHODS = { | |
"GradCAM": GradCAM, | |
"GradCAM++": GradCAMPlusPlus, | |
"XGradCAM": XGradCAM, | |
"HiResCAM": HiResCAM, | |
"EigenCAM": EigenCAM | |
} | |
LAYERS = { | |
'layer1': MODEL.resnet.layer1, | |
'layer2': MODEL.resnet.layer2, | |
'layer3': MODEL.resnet.layer3, | |
'layer4': MODEL.resnet.layer4, | |
'all': [MODEL.resnet.layer1, MODEL.resnet.layer2, MODEL.resnet.layer3, MODEL.resnet.layer4], | |
'layer3+4': [MODEL.resnet.layer3, MODEL.resnet.layer4] | |
} | |
CV2_COLORMAPS = { | |
"Autumn": cv2.COLORMAP_AUTUMN, | |
"Bone": cv2.COLORMAP_BONE, | |
"Jet": cv2.COLORMAP_JET, | |
"Winter": cv2.COLORMAP_WINTER, | |
"Rainbow": cv2.COLORMAP_RAINBOW, | |
"Ocean": cv2.COLORMAP_OCEAN, | |
"Summer": cv2.COLORMAP_SUMMER, | |
"Pink": cv2.COLORMAP_PINK, | |
"Hot": cv2.COLORMAP_HOT, | |
"Magma": cv2.COLORMAP_MAGMA, | |
"Inferno": cv2.COLORMAP_INFERNO, | |
"Plasma": cv2.COLORMAP_PLASMA, | |
"Twilight": cv2.COLORMAP_TWILIGHT, | |
} | |
# cam_model = copy.deepcopy(model) | |
data_df = pd.read_csv('src/cache/val_df.csv') | |
C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict() | |
C_NAME_TO_NUM = {v: k for k, v in C_NUM_TO_NAME.items()} | |
ALL_CLASSES = sorted(list(C_NUM_TO_NAME.values()), key=lambda x: x.lower()) | |
def get_class_name(idx): | |
return C_NUM_TO_NAME[idx] | |
def get_class_idx(name): | |
return C_NAME_TO_NUM[name] | |
def get_translated(to_translate, target_language="German"): | |
target_language = LANGUAGES_TO_SELECT[target_language] if target_language in LANGUAGES_TO_SELECT else target_language | |
if target_language == "en": return to_translate | |
if target_language not in LANGUAGES_TO_SELECT.values(): raise gr.Error(f'Language {target_language} not found.') | |
return GoogleTranslator(source="en", target=target_language).translate(to_translate) | |
# for idx in range(111): get_translated(get_class_name(idx)) | |
with ThreadPoolExecutor(max_workers=30) as executor: | |
# give the executor the list of images and args (in this case, the target language) | |
# and let the executor map the function to the list of images | |
for language in tqdm(LANGUAGES_TO_SELECT.keys(), desc='Preloading translations'): | |
executor.map(get_translated, ALL_CLASSES, [language] * len(ALL_CLASSES)) | |
def infer_image(image, target_language): | |
if image is None: raise gr.Error("Please upload an image.") | |
image.save('src/results/infer_image.png') | |
image = transform(image) | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
output = MODEL(image) | |
distribution = torch.nn.functional.softmax(output, dim=1) | |
ret = defaultdict(float) | |
for idx, prob in enumerate(distribution[0]): | |
animal = f'{get_class_name(idx)}' | |
if target_language is not None and target_language != "None": | |
animal += f' ({get_translated(get_class_name(idx), target_language)})' | |
ret[animal] = prob.item() | |
return ret | |
def gradcam(image, colormap="Jet", use_eigen_smooth=False, use_aug_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class", label_image=True, target_lang="German"): | |
if image is None: | |
raise gr.Error("Please upload an image.") | |
if isinstance(image, dict): | |
# Its the image and a mask as pillow both -> Combine them to one image | |
image = Image.blend(image["image"], image["mask"], alpha=0.5) | |
if colormap not in CV2_COLORMAPS.keys(): | |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.") | |
else: | |
colormap = CV2_COLORMAPS[colormap] | |
image_width, image_height = image.size | |
if image_width > 6000 or image_height > 6000: | |
raise gr.Error("The image is too big. The maximal size is 6000x6000.") | |
MODEL.eval() | |
layers = LAYERS[layer] | |
image_tensor = transform(image).unsqueeze(0) | |
targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None | |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam: | |
grayscale_cam = cam(input_tensor=image_tensor, targets=targets, aug_smooth=use_aug_smooth, eigen_smooth=use_eigen_smooth) | |
if label_image: | |
predicted_animal = get_class_name(np.argmax(cam.outputs.cpu().data.numpy(), axis=-1)[0]) | |
grayscale_cam = grayscale_cam[0, :] | |
grayscale_cam = cv2.resize(grayscale_cam, (image_width, image_height), interpolation=cv2.INTER_CUBIC) | |
image = np.float32(image) | |
visualization = None | |
if BWHighlight: | |
image = image * grayscale_cam[..., np.newaxis] | |
visualization = image.astype(np.uint8) | |
else: | |
image = image / 255 | |
visualization = show_cam_on_image(image, grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap) | |
if label_image: | |
# add alpha channel to visualization | |
visualization = np.concatenate([visualization, np.ones((image_height, image_width, 1), dtype=np.uint8) * 255], axis=-1) | |
plt_image = Image.fromarray(visualization, mode="RGBA") | |
draw = ImageDraw.Draw(plt_image) | |
draw.rectangle((5, 5, 150, 30), fill=(10, 10, 10, 100)) | |
animal = predicted_animal.capitalize() | |
if target_lang is not None and target_lang != "None": | |
animal += f' ({get_translated(animal, target_lang)})' | |
draw.text((10, 7), animal, font=font, fill=(255, 125, 0, 255)) | |
visualization = np.array(plt_image) | |
out_image = Image.fromarray(visualization) | |
return out_image | |
def gradcam_video(video, colormap="Jet", use_eigen_smooth=False, BWHighlight=False, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"): | |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES | |
if video is None: raise gr.Error("Please upload a video.") | |
if colormap not in CV2_COLORMAPS.keys(): | |
raise gr.Error(f"Colormap {colormap} not found in {list(CV2_COLORMAPS.keys())}.") | |
else: | |
colormap = CV2_COLORMAPS[colormap] | |
video = cv2.VideoCapture(video) | |
fps = int(video.get(cv2.CAP_PROP_FPS)) | |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
if width > 2000 or height > 2000: | |
raise gr.Error("The video is too big. The maximal size is 2000x2000.") | |
print(f'FPS: {fps}, Width: {width}, Height: {height}') | |
frames = list() | |
success, image = video.read() | |
while success: | |
frames.append(image) | |
success, image = video.read() | |
print(f'Frames: {len(frames)}') | |
if len(frames) == 0: | |
raise gr.Error("The video is empty.") | |
if len(frames) >= MAXIMAL_FRAMES: | |
raise gr.Error(f"The video is too long. The maximal length is {MAXIMAL_FRAMES} frames.") | |
if len(frames) > MAX_OUT_FRAMES: | |
frames = frames[::len(frames) // MAX_OUT_FRAMES] | |
print(f'Frames to process: {len(frames)}') | |
processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames] | |
# generate lists in lists for the images for batch processing. BATCHES_TO_PROCESS images per inner list | |
batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)] | |
MODEL.eval() | |
layers = LAYERS[layer] | |
results = list() | |
targets = [ClassifierOutputTarget(get_class_idx(specific_class))] if specific_class != "Predicted Class" else None | |
with CAM_METHODS[cam_method](model=MODEL, target_layers=layers) as cam: | |
for i, batch in enumerate(tqdm(batched)): | |
images_tensor = torch.stack([transform(image) for image in batch]) | |
grayscale_cam = cam(input_tensor=images_tensor, targets=targets, aug_smooth=False, eigen_smooth=use_eigen_smooth) | |
for i, image in enumerate(batch): | |
_grayscale_cam = grayscale_cam[i, :] | |
_grayscale_cam = cv2.resize(_grayscale_cam, (width, height), interpolation=cv2.INTER_LINEAR) | |
image = np.float32(image) | |
visualization = None | |
if BWHighlight: | |
image = image * _grayscale_cam[..., np.newaxis] | |
visualization = image.astype(np.uint8) | |
else: | |
image = image / 255 | |
visualization = show_cam_on_image(image, _grayscale_cam, use_rgb=True, image_weight=alpha, colormap=colormap) | |
results.append(visualization) | |
# save video | |
mediapy.write_video('src/results/gradcam_video.mp4', results, fps=OUTPUT_FPS) | |
video.release() | |
return 'src/results/gradcam_video.mp4' | |
def load_examples(): | |
folder_name_to_header = { | |
"AI_Generated": "AI Generated Images", | |
"true": "True Predicted Images (Validation Set)", | |
"false": "False Predicted Images (Validation Set)", | |
"others": "Other interesting images from the internet" | |
} | |
images_description = { | |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.", | |
"true": "These images are from the validation set and the model predicted them correctly.", | |
"false": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)", | |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals." | |
} | |
loaded_images = defaultdict(list) | |
for image_type in ["AI_Generated", "true", "false", "others"]: | |
# for image_type in os.listdir(IMAGE_PATH): | |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/') | |
gr.Markdown(f'## {folder_name_to_header[image_type]}') | |
gr.Markdown(images_description[image_type]) | |
images_to_load = os.listdir(full_path) | |
rows = (len(images_to_load) // IMAGES_PER_ROW) + 1 | |
for i in range(rows): | |
with gr.Row(elem_classes=["row-example-images"], equal_height=False): | |
for j in range(IMAGES_PER_ROW): | |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break | |
image = images_to_load[i * IMAGES_PER_ROW + j] | |
name = f"{image.split('.')[0]}" | |
image = Image.open(os.path.join(full_path, image)) | |
# scale so that the longest side is 600px | |
scale = 600 / max(image.size) | |
image = image.resize((int(image.size[0] * scale), int(image.size[1] * scale))) | |
loaded_images[image_type].append( | |
gr.Image( | |
value=image, | |
label=name, | |
type="pil", | |
interactive=False, | |
elem_classes=["selectable_images"], | |
) | |
) | |
return loaded_images | |
css = """ | |
#logo {text-align: right;} | |
p {text-align: justify; text-justify: inter-word; font-size: 1.1em; line-height: 1.2em;} | |
.svelte-1btp92j.selectable {cursor: pointer !important; } | |
""" | |
with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo: | |
# ------------------------------------------- | |
# HEADER WITH LOGO | |
# ------------------------------------------- | |
with gr.Row(): | |
with open('src/header.md', 'r', encoding='utf-8') as f: | |
markdown_string = f.read() | |
with gr.Column(scale=10): | |
header = gr.Markdown(markdown_string) | |
with gr.Column(scale=1): | |
pil_logo = Image.open('animals.png') | |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo") | |
animal_translation_target_language = gr.Dropdown( | |
choices=LANGUAGES_TO_SELECT.keys(), | |
label="Translation language for animals", | |
value="German", | |
interactive=True, | |
scale=2, | |
) | |
# ------------------------------------------- | |
# INPUT IMAGE | |
# ------------------------------------------- | |
with gr.Row(): | |
with gr.Row(variant="panel", equal_height=True): | |
user_image = gr.Image( | |
type="pil", | |
label="Upload Your Own Image", | |
interactive=True, | |
) | |
# ------------------------------------------- | |
# TOOLS | |
# ------------------------------------------- | |
with gr.Row(): | |
# ------------------------------------------- | |
# PREDICT | |
# ------------------------------------------- | |
with gr.Tab("Predict"): | |
with gr.Column(): | |
output = gr.Label( | |
num_top_classes=5, | |
label="Output", | |
info="Top three predicted classes and their confidences.", | |
scale=5, | |
) | |
with gr.Row(): | |
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=6) | |
predict_mode_button.click(fn=infer_image, inputs=[user_image, animal_translation_target_language], outputs=output, queue=True) | |
# ------------------------------------------- | |
# EXPLAIN | |
# ------------------------------------------- | |
with gr.Tab("Explain Image"): | |
with gr.Row(): | |
with gr.Column(): | |
_info = "There are different GradCAM methods. You can read more about them here: (https://github.com/jacobgil/pytorch-grad-cam#references)." | |
cam_method = gr.Radio( | |
list(CAM_METHODS.keys()), | |
label="GradCAM Method", | |
info=_info, | |
value="GradCAM", | |
interactive=True, | |
scale=2, | |
) | |
_info = """ | |
The alpha value is used to blend the original image with the GradCAM visualization. If you choose a value of 0.5 the original image and the GradCAM visualization will be blended equally. | |
If you choose a value of 0.1 the original image will be barely visible and if you choose a value of 0.9 the GradCAM visualization will be barely visible. | |
""" | |
alpha = gr.Slider( | |
minimum=.1, | |
maximum=.9, | |
value=0.5, | |
interactive=True, | |
step=.1, | |
label="Alpha", | |
scale=1, | |
info=_info | |
) | |
_info = """ | |
The layer is used to choose the layer of the ResNet50 model. The GradCAM visualization will be based on this layer. | |
Best to choose is the last layer (layer4) because it is the layer with the most information before the final prediction. This makes the GradCAM visualization the most meaningful. | |
If all layers are chosen the GradCAM visualization will be averaged over all layers. | |
""" | |
layer = gr.Radio( | |
LAYERS.keys(), | |
label="Layer", | |
value="layer4", | |
interactive=True, | |
scale=2, | |
info=_info | |
) | |
with gr.Row(): | |
_info = """ | |
Here you can choose the animal to "explain". If you choose "Predicted Class" the GradCAM visualization will be based on the predicted class. | |
If you choose a specific class the GradCAM visualization will be based on this class. | |
For example if you have an image with a dog and a cat, you can select either Cat or Dog and see if the model can focus on the correct animal. | |
""" | |
animal_to_explain = gr.Dropdown( | |
choices=["Predicted Class"] + ALL_CLASSES, | |
label="Animal", | |
value="Predicted Class", | |
interactive=True, | |
scale=4, | |
info=_info | |
) | |
show_predicted_class = gr.Checkbox( | |
label="Show Predicted Class", | |
value=True, | |
interactive=True, | |
scale=1, | |
) | |
with gr.Row(): | |
_info = """ | |
Here you can choose the colormap. Instead of a colormap you can also choose "BW Highlight" to just keep the original image and highlight the important parts of the image. | |
If you select "BW Highlight" the colormap will be ignored. | |
""" | |
colormap = gr.Dropdown( | |
choices=list(CV2_COLORMAPS.keys()), | |
label="Colormap", | |
value="Inferno", | |
interactive=True, | |
scale=2, | |
info=_info | |
) | |
bw_highlight = gr.Checkbox( | |
label="BW Highlight", | |
value=False, | |
interactive=True, | |
scale=1, | |
) | |
bw_highlight.description = "Here you can choose if you want to highlight the important parts of the image in black and white." | |
with gr.Row(): | |
_info = """ | |
The Eigen Smooth is a method to smooth the GradCAM visualization. | |
""" | |
use_eigen_smooth = gr.Checkbox( | |
label="Eigen Smooth", | |
value=False, | |
interactive=True, | |
scale=1, | |
info=_info | |
) | |
_info = """ | |
The Aug Smooth is also a method to smooth the GradCAM visualization. But this method needs a lot of performance and is therefore slow. | |
""" | |
use_aug_smooth = gr.Checkbox( | |
label="Aug Smooth", | |
value=False, | |
interactive=True, | |
scale=1, | |
info=_info | |
) | |
with gr.Column(): | |
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1) | |
output_cam = gr.Image( | |
type="pil", | |
label="GradCAM", | |
info="GradCAM visualization", | |
show_label=False, | |
scale=7, | |
) | |
_inputs = [user_image, colormap, use_eigen_smooth, use_aug_smooth, bw_highlight, alpha, cam_method, layer, animal_to_explain, show_predicted_class, animal_translation_target_language] | |
gradcam_mode_button.click(fn=gradcam, inputs=_inputs, outputs=output_cam, queue=True) | |
# ------------------------------------------- | |
# Video CAM | |
# ------------------------------------------- | |
with gr.Tab("Explain Video"): | |
build_video_to_camvideo(CAM_METHODS, CV2_COLORMAPS, LAYERS, ALL_CLASSES, gradcam_video) | |
# ------------------------------------------- | |
# EXAMPLES | |
# ------------------------------------------- | |
with gr.Tab("Example Images"): | |
placeholder = gr.Markdown("## Example Images") | |
loaded_images = load_examples() | |
for k in loaded_images.keys(): | |
for image in loaded_images[k]: | |
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image], queue=True, scroll_to_output=True) | |
if __name__ == "__main__": | |
demo.queue() | |
print("Starting Gradio server...") | |
demo.launch(show_tips=True) |