Spaces:
Sleeping
Sleeping
import copy | |
import os | |
import sys | |
sys.path.append('src') | |
from collections import defaultdict | |
from functools import lru_cache | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
from deep_translator import GoogleTranslator | |
from Nets import CustomResNet18 | |
from PIL import Image | |
from torchcam.methods import GradCAM, GradCAMpp, SmoothGradCAMpp, XGradCAM | |
from torchcam.utils import overlay_mask | |
from torchvision.transforms.functional import to_pil_image | |
from tqdm import tqdm | |
from util import transform | |
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples') | |
RANDOM_IMAGES_TO_SHOW = 10 | |
IMAGES_PER_ROW = 5 | |
CAM_METHODS = { | |
"GradCAM": GradCAM, | |
"GradCAM++": GradCAMpp, | |
"XGradCAM": XGradCAM, | |
"SmoothGradCAM++": SmoothGradCAMpp, | |
} | |
model = CustomResNet18(90).eval() | |
model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu'))) | |
cam_model = copy.deepcopy(model) | |
data_df = pd.read_csv('src/cache/val_df.csv') | |
def load_random_images(): | |
random_images = list() | |
for i in range(RANDOM_IMAGES_TO_SHOW): | |
idx = np.random.randint(0, len(data_df)) | |
p = os.path.join(IMAGE_PATH, data_df.iloc[idx]['path']) | |
p = p.replace('\\', '/') | |
p = p.replace('//', '/') | |
animal = data_df.iloc[idx]['target'] | |
if os.path.exists(p): | |
random_images.append((animal, Image.open(p))) | |
return random_images | |
def get_class_name(idx): | |
return data_df[data_df['encoded_target'] == idx]['target'].values[0] | |
def get_translated(to_translate): | |
return GoogleTranslator(source="en", target="de").translate(to_translate) | |
for idx in range(90): get_translated(get_class_name(idx)) | |
def infer_image(image): | |
image = transform(image) | |
image = image.unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
distribution = torch.nn.functional.softmax(output, dim=1) | |
ret = defaultdict(float) | |
for idx, prob in enumerate(distribution[0]): | |
animal = f'{get_class_name(idx)} ({get_translated(get_class_name(idx))})' | |
ret[animal] = prob.item() | |
return ret | |
def gradcam(image, alpha, cam_method, layer): | |
if layer == 'layer1': layers = [model.resnet.layer1] | |
elif layer == 'layer2': layers = [model.resnet.layer2] | |
elif layer == 'layer3': layers = [model.resnet.layer3] | |
elif layer == 'layer4': layers = [model.resnet.layer4] | |
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4] | |
model.eval() | |
img_tensor = transform(image).unsqueeze(0) | |
cam = CAM_METHODS[cam_method](model, target_layer=layers) | |
output = model(img_tensor) | |
activation_map = cam(output.squeeze(0).argmax().item(), output) | |
result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha) | |
cam.remove_hooks() | |
# height maximal 300px | |
if result.size[1] > 300: | |
ratio = 300 / result.size[1] | |
result = result.resize((int(result.size[0] * ratio), 300)) | |
return result | |
with gr.Blocks() as demo: | |
with open('src/header.md', 'r') as f: | |
markdown_string = f.read() | |
header = gr.Markdown(markdown_string) | |
with gr.Row(variant="panel", equal_height=True): | |
user_image = gr.Image( | |
type="pil", | |
label="Upload Your Own Image", | |
info="You can also upload your own image for prediction.", | |
scale=1, | |
) | |
with gr.Tab("Predict"): | |
with gr.Column(): | |
output = gr.Label( | |
num_top_classes=3, | |
label="Output", | |
info="Top three predicted classes and their confidences.", | |
scale=5, | |
) | |
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1) | |
predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True) | |
with gr.Tab("Explain"): | |
with gr.Row(): | |
with gr.Column(): | |
cam_method = gr.Radio( | |
list(CAM_METHODS.keys()), | |
label="GradCAM Method", | |
value="GradCAM", | |
interactive=True, | |
scale=2, | |
) | |
cam_method.description = "Here you can choose the GradCAM method." | |
cam_method.description_place = "left" | |
alpha = gr.Slider( | |
minimum=.1, | |
maximum=.9, | |
value=0.5, | |
interactive=True, | |
step=.1, | |
label="Alpha", | |
scale=1, | |
) | |
alpha.description = "Here you can choose the alpha value." | |
alpha.description_place = "left" | |
layer = gr.Radio( | |
["layer1", "layer2", "layer3", "layer4", "all"], | |
label="Layer", | |
value="layer4", | |
interactive=True, | |
scale=2, | |
) | |
layer.description = "Here you can choose the layer to visualize." | |
layer.description_place = "left" | |
with gr.Column(): | |
output_cam = gr.Image( | |
type="pil", | |
label="GradCAM", | |
info="GradCAM visualization" | |
) | |
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1) | |
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, alpha, cam_method, layer], outputs=output_cam, queue=True) | |
with gr.Tab("Example Images"): | |
with gr.Column(): | |
placeholder = gr.Markdown("## Example Images") | |
showed_images = list() | |
loaded_images = load_random_images() | |
amount_rows = max(1, (len(loaded_images) // IMAGES_PER_ROW)) | |
if len(loaded_images) == 0: | |
print(f"Could not find any images in {IMAGE_PATH}") | |
amount_rows = 0 | |
for i in range(amount_rows): | |
with gr.Row(): | |
for j in range(IMAGES_PER_ROW): | |
animal, image = loaded_images[i * IMAGES_PER_ROW + j] | |
showed_images.append(gr.Image( | |
value=image, | |
label=animal, | |
type="pil", | |
interactive=False, | |
)) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |