Spaces:
Sleeping
Sleeping
import copy | |
import os | |
import sys | |
sys.path.append('src') | |
from collections import defaultdict | |
from functools import lru_cache | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
from deep_translator import GoogleTranslator | |
from Nets import CustomResNet18 | |
from PIL import Image | |
from torchcam.methods import GradCAM, GradCAMpp, SmoothGradCAMpp, XGradCAM | |
from torchcam.utils import overlay_mask | |
from torchvision.transforms.functional import to_pil_image | |
from tqdm import tqdm | |
from util import transform | |
from gradio_blocks import build_video_to_camvideo | |
import cv2 | |
import ffmpeg | |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples') | |
IMAGES_PER_ROW = 7 | |
MAXIMAL_FRAMES = 1000 | |
BATCHES_TO_PROCESS = 10 | |
OUTPUT_FPS = 15 | |
MAX_OUT_FRAMES = 60 | |
CAM_METHODS = { | |
"GradCAM": GradCAM, | |
"GradCAM++": GradCAMpp, | |
"XGradCAM": XGradCAM, | |
"SmoothGradCAM++": SmoothGradCAMpp, | |
} | |
model = CustomResNet18(90).eval() | |
model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu'))) | |
cam_model = copy.deepcopy(model) | |
data_df = pd.read_csv('src/cache/val_df.csv') | |
C_NUM_TO_NAME = data_df[['encoded_target', 'target']].drop_duplicates().sort_values('encoded_target').set_index('encoded_target')['target'].to_dict() | |
C_NAME_TO_NUM = {v: k for k, v in C_NUM_TO_NAME.items()} | |
ALL_CLASSES = sorted(list(C_NUM_TO_NAME.values()), key=lambda x: x.lower()) | |
def get_class_name(idx): | |
return C_NUM_TO_NAME[idx] | |
def get_class_idx(name): | |
return C_NAME_TO_NUM[name] | |
def get_translated(to_translate): | |
return "ssss" | |
# return GoogleTranslator(source="en", target="de").translate(to_translate) | |
# for idx in range(90): get_translated(get_class_name(idx)) | |
def infer_image(image, image_sketch): | |
image = image if image is not None else image_sketch | |
image = transform(image) | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
distribution = torch.nn.functional.softmax(output, dim=1) | |
ret = defaultdict(float) | |
for idx, prob in enumerate(distribution[0]): | |
animal = f'{get_class_name(idx)} ({get_translated(get_class_name(idx))})' | |
ret[animal] = prob.item() | |
return ret | |
def gradcam(image, image_sketch=None, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"): | |
image = image if image is not None else image_sketch | |
if layer == 'layer1': layers = [model.resnet.layer1] | |
elif layer == 'layer2': layers = [model.resnet.layer2] | |
elif layer == 'layer3': layers = [model.resnet.layer3] | |
elif layer == 'layer4': layers = [model.resnet.layer4] | |
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4] | |
model.eval() | |
img_tensor = transform(image).unsqueeze(0) | |
cam = CAM_METHODS[cam_method](model, target_layer=layers) | |
output = model(img_tensor) | |
class_to_explain = output.squeeze(0).argmax().item() if specific_class == "Predicted Class" else get_class_idx(specific_class) | |
activation_map = cam(class_to_explain, output) | |
result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha) | |
cam.remove_hooks() | |
# # height maximal 300px | |
# if result.size[1] > 300: | |
# ratio = 300 / result.size[1] | |
# result = result.resize((int(result.size[0] * ratio), 300)) | |
return result | |
def gradcam_video(video, alpha=0.5, cam_method=GradCAM, layer=None, specific_class="Predicted Class"): | |
global OUTPUT_FPS, MAXIMAL_FRAMES, BATCHES_TO_PROCESS, MAX_OUT_FRAMES | |
video = cv2.VideoCapture(video) | |
fps = int(video.get(cv2.CAP_PROP_FPS)) | |
if OUTPUT_FPS == -1: OUTPUT_FPS = fps | |
width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
if width > 3000 or height > 3000: | |
raise gr.Error("The video is too big. The maximal size is 3000x3000.") | |
print(f'FPS: {fps}, Width: {width}, Height: {height}') | |
frames = list() | |
success, image = video.read() | |
while success: | |
frames.append(image) | |
success, image = video.read() | |
print(f'Frames: {len(frames)}') | |
if len(frames) == 0: | |
raise gr.Error("The video is empty.") | |
if len(frames) >= MAXIMAL_FRAMES: | |
raise gr.Error(f"The video is too long. The maximal length is {MAXIMAL_FRAMES} frames.") | |
if len(frames) > MAX_OUT_FRAMES: | |
frames = frames[::len(frames) // MAX_OUT_FRAMES] | |
print(f'Frames to process: {len(frames)}') | |
processed = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames] | |
# generate lists in lists for the images for batch processing. 10 images per inner list.. | |
batched = [processed[i:i + BATCHES_TO_PROCESS] for i in range(0, len(processed), BATCHES_TO_PROCESS)] | |
model.eval() | |
if layer == 'layer1': layers = [model.resnet.layer1] | |
elif layer == 'layer2': layers = [model.resnet.layer2] | |
elif layer == 'layer3': layers = [model.resnet.layer3] | |
elif layer == 'layer4': layers = [model.resnet.layer4] | |
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4] | |
cam = CAM_METHODS[cam_method](model, target_layer=layers) | |
results = list() | |
for i, batch in enumerate(tqdm(batched)): | |
images_tensor = torch.stack([transform(image) for image in batch]) | |
outputs = model(images_tensor) | |
out_classes = [output.argmax().item() for output in outputs] | |
classes_to_explain = out_classes if specific_class == "Predicted Class" else [get_class_idx(specific_class)] * len(out_classes) | |
activation_maps = cam(classes_to_explain, outputs) | |
for j, activation_map in enumerate(activation_maps[0]): | |
result = overlay_mask(batch[j], to_pil_image(activation_map, mode='F'), alpha=alpha) | |
results.append(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR)) | |
cam.remove_hooks() | |
# save video | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
size = (results[0].shape[1], results[0].shape[0]) | |
video = cv2.VideoWriter('src/results/gradcam_video.mp4', fourcc, OUTPUT_FPS, size) | |
for frame in results: | |
video.write(frame) | |
video.release() | |
return 'src/results/gradcam_video.mp4' | |
def load_examples(): | |
folder_name_to_header = { | |
"AI_Generated": "AI Generated Images", | |
"true_predicted": "True Predicted Images (Validation Set)", | |
"false_predicted": "False Predicted Images (Validation Set)", | |
"others": "Other interesting images from the internet" | |
} | |
images_description = { | |
"AI_Generated": "These images are generated by Dalle3 and Stable Diffusion. All of them are not real images and because of that it is interesting to see how the model predicts them.", | |
"true_predicted": "These images are from the validation set and the model predicted them correctly.", | |
"false_predicted": "These images are from the validation set and the model predicted them incorrectly. Maybe you can see why the model predicted them incorrectly using the GradCAM visualization. :)", | |
"others": "These images are from the internet and are not part of the validation set. They are interesting because most of them show different animals." | |
} | |
loaded_images = defaultdict(list) | |
for image_type in ["AI_Generated", "true_predicted", "false_predicted", "others"]: | |
# for image_type in os.listdir(IMAGE_PATH): | |
full_path = os.path.join(IMAGE_PATH, image_type).replace('\\', '/').replace('//', '/') | |
gr.Markdown(f'## {folder_name_to_header[image_type]}') | |
gr.Markdown(images_description[image_type]) | |
images_to_load = os.listdir(full_path) | |
rows = (len(images_to_load) // IMAGES_PER_ROW) + 1 | |
for i in range(rows): | |
with gr.Row(elem_classes=["row-example-images"], equal_height=False): | |
for j in range(IMAGES_PER_ROW): | |
if i * IMAGES_PER_ROW + j >= len(images_to_load): break | |
image = images_to_load[i * IMAGES_PER_ROW + j] | |
loaded_images[image_type].append( | |
gr.Image( | |
value=os.path.join(full_path, image), | |
label=f"image ({get_translated(image.split('.')[0])})", | |
type="pil", | |
interactive=False, | |
elem_classes=["selectable_images"], | |
) | |
) | |
return loaded_images | |
css = """ | |
#logo {text-align: right;} | |
p {text-align: justify; text-justify: inter-word; font-size: 1.1em; line-height: 1.2em;} | |
.svelte-1btp92j.selectable {cursor: pointer !important; } | |
""" | |
with gr.Blocks(theme='freddyaboulton/dracula_revamped', css=css) as demo: | |
# ------------------------------------------- | |
# HEADER WITH LOGO | |
# ------------------------------------------- | |
with gr.Row(): | |
with open('src/header.md', 'r', encoding='utf-8') as f: | |
markdown_string = f.read() | |
with gr.Column(scale=10): | |
header = gr.Markdown(markdown_string) | |
with gr.Column(scale=1): | |
pil_logo = Image.open('animals.png') | |
logo = gr.Image(value=pil_logo, scale=2, interactive=False, show_download_button=False, show_label=False, container=False, elem_id="logo") | |
# ------------------------------------------- | |
# INPUT IMAGE | |
# ------------------------------------------- | |
with gr.Row(): | |
with gr.Tab("Upload Image"): | |
with gr.Row(variant="panel", equal_height=True): | |
user_image = gr.Image( | |
type="pil", | |
label="Upload Your Own Image", | |
info="You can also upload your own image for prediction.", | |
) | |
with gr.Tab("Draw Image"): | |
with gr.Row(variant="panel", equal_height=True): | |
user_image_sketched = gr.Image( | |
type="pil", | |
source="canvas", | |
tool="color-sketch", | |
label="Draw Your Own Image", | |
info="You can also draw your own image for prediction.", | |
) | |
# ------------------------------------------- | |
# TOOLS | |
# ------------------------------------------- | |
with gr.Row(): | |
# ------------------------------------------- | |
# PREDICT | |
# ------------------------------------------- | |
with gr.Tab("Predict"): | |
with gr.Column(): | |
output = gr.Label( | |
num_top_classes=5, | |
label="Output", | |
info="Top three predicted classes and their confidences.", | |
scale=5, | |
) | |
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1) | |
predict_mode_button.click(fn=infer_image, inputs=[user_image, user_image_sketched], outputs=output, queue=True) | |
# ------------------------------------------- | |
# EXPLAIN | |
# ------------------------------------------- | |
with gr.Tab("Explain"): | |
with gr.Row(): | |
with gr.Column(): | |
cam_method = gr.Radio( | |
list(CAM_METHODS.keys()), | |
label="GradCAM Method", | |
value="GradCAM", | |
interactive=True, | |
scale=2, | |
) | |
cam_method.description = "Here you can choose the GradCAM method." | |
cam_method.description_place = "left" | |
alpha = gr.Slider( | |
minimum=.1, | |
maximum=.9, | |
value=0.5, | |
interactive=True, | |
step=.1, | |
label="Alpha", | |
scale=1, | |
) | |
alpha.description = "Here you can choose the alpha value." | |
alpha.description_place = "left" | |
layer = gr.Radio( | |
["layer1", "layer2", "layer3", "layer4", "all"], | |
label="Layer", | |
value="layer4", | |
interactive=True, | |
scale=2, | |
) | |
layer.description = "Here you can choose the layer to visualize." | |
layer.description_place = "left" | |
animal_to_explain = gr.Dropdown( | |
choices=["Predicted Class"] + ALL_CLASSES, | |
label="Animal", | |
value="Predicted Class", | |
interactive=True, | |
scale=2, | |
) | |
animal_to_explain.description = "Here you can choose the animal to explain. If you choose 'Predicted Class' the method will explain the predicted class." | |
animal_to_explain.description_place = "center" | |
with gr.Column(): | |
output_cam = gr.Image( | |
type="pil", | |
label="GradCAM", | |
info="GradCAM visualization" | |
) | |
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1) | |
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, user_image_sketched, alpha, cam_method, layer, animal_to_explain], outputs=output_cam, queue=True) | |
# ------------------------------------------- | |
# GIF CAM | |
# ------------------------------------------- | |
with gr.Tab("Gif Cam"): | |
build_video_to_camvideo(CAM_METHODS, ALL_CLASSES, gradcam_video) | |
# ------------------------------------------- | |
# EXAMPLES | |
# ------------------------------------------- | |
with gr.Tab("Example Images"): | |
placeholder = gr.Markdown("## Example Images") | |
loaded_images = load_examples() | |
for k in loaded_images.keys(): | |
for image in loaded_images[k]: | |
image.select(fn=lambda x: x, inputs=[image], outputs=[user_image]) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |