GabrielML's picture
remove debug
8717d3d
raw
history blame
6.79 kB
import copy
import os
import sys
sys.path.append('src')
from collections import defaultdict
from functools import lru_cache
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from deep_translator import GoogleTranslator
from Nets import CustomResNet18
from PIL import Image
from torchcam.methods import GradCAM, GradCAMpp, SmoothGradCAMpp, XGradCAM
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image
from tqdm import tqdm
from util import transform
IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples')
RANDOM_IMAGES_TO_SHOW = 10
IMAGES_PER_ROW = 5
CAM_METHODS = {
"GradCAM": GradCAM,
"GradCAM++": GradCAMpp,
"XGradCAM": XGradCAM,
"SmoothGradCAM++": SmoothGradCAMpp,
}
model = CustomResNet18(90).eval()
model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu')))
cam_model = copy.deepcopy(model)
data_df = pd.read_csv('src/cache/val_df.csv')
def load_random_images():
random_images = list()
for i in range(RANDOM_IMAGES_TO_SHOW):
idx = np.random.randint(0, len(data_df))
p = os.path.join(IMAGE_PATH, data_df.iloc[idx]['path'])
p = p.replace('\\', '/')
p = p.replace('//', '/')
animal = data_df.iloc[idx]['target']
if os.path.exists(p):
random_images.append((animal, Image.open(p)))
return random_images
def get_class_name(idx):
return data_df[data_df['encoded_target'] == idx]['target'].values[0]
@lru_cache(maxsize=100)
def get_translated(to_translate):
return GoogleTranslator(source="en", target="de").translate(to_translate)
for idx in range(90): get_translated(get_class_name(idx))
def infer_image(image):
image = transform(image)
image = image.unsqueeze(0)
with torch.no_grad():
output = model(image)
distribution = torch.nn.functional.softmax(output, dim=1)
ret = defaultdict(float)
for idx, prob in enumerate(distribution[0]):
animal = f'{get_class_name(idx)} ({get_translated(get_class_name(idx))})'
ret[animal] = prob.item()
return ret
def gradcam(image, alpha, cam_method, layer):
if layer == 'layer1': layers = [model.resnet.layer1]
elif layer == 'layer2': layers = [model.resnet.layer2]
elif layer == 'layer3': layers = [model.resnet.layer3]
elif layer == 'layer4': layers = [model.resnet.layer4]
else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4]
model.eval()
img_tensor = transform(image).unsqueeze(0)
cam = CAM_METHODS[cam_method](model, target_layer=layers)
output = model(img_tensor)
activation_map = cam(output.squeeze(0).argmax().item(), output)
result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha)
cam.remove_hooks()
# height maximal 300px
if result.size[1] > 300:
ratio = 300 / result.size[1]
result = result.resize((int(result.size[0] * ratio), 300))
return result
with gr.Blocks() as demo:
with open('src/header.md', 'r') as f:
markdown_string = f.read()
header = gr.Markdown(markdown_string)
with gr.Row(variant="panel", equal_height=True):
user_image = gr.Image(
type="pil",
label="Upload Your Own Image",
info="You can also upload your own image for prediction.",
scale=1,
)
with gr.Tab("Predict"):
with gr.Column():
output = gr.Label(
num_top_classes=3,
label="Output",
info="Top three predicted classes and their confidences.",
scale=5,
)
predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1)
predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True)
with gr.Tab("Explain"):
with gr.Row():
with gr.Column():
cam_method = gr.Radio(
list(CAM_METHODS.keys()),
label="GradCAM Method",
value="GradCAM",
interactive=True,
scale=2,
)
cam_method.description = "Here you can choose the GradCAM method."
cam_method.description_place = "left"
alpha = gr.Slider(
minimum=.1,
maximum=.9,
value=0.5,
interactive=True,
step=.1,
label="Alpha",
scale=1,
)
alpha.description = "Here you can choose the alpha value."
alpha.description_place = "left"
layer = gr.Radio(
["layer1", "layer2", "layer3", "layer4", "all"],
label="Layer",
value="layer4",
interactive=True,
scale=2,
)
layer.description = "Here you can choose the layer to visualize."
layer.description_place = "left"
with gr.Column():
output_cam = gr.Image(
type="pil",
label="GradCAM",
info="GradCAM visualization"
)
gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1)
gradcam_mode_button.click(fn=gradcam, inputs=[user_image, alpha, cam_method, layer], outputs=output_cam, queue=True)
with gr.Tab("Example Images"):
with gr.Column():
placeholder = gr.Markdown("## Example Images")
showed_images = list()
loaded_images = load_random_images()
amount_rows = max(1, (len(loaded_images) // IMAGES_PER_ROW))
if len(loaded_images) == 0:
print(f"Could not find any images in {IMAGE_PATH}")
amount_rows = 0
for i in range(amount_rows):
with gr.Row():
for j in range(IMAGES_PER_ROW):
animal, image = loaded_images[i * IMAGES_PER_ROW + j]
showed_images.append(gr.Image(
value=image,
label=animal,
type="pil",
interactive=False,
))
if __name__ == "__main__":
demo.queue()
demo.launch()