import copy import os import sys sys.path.append('src') from collections import defaultdict from functools import lru_cache import gradio as gr import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from deep_translator import GoogleTranslator from Nets import CustomResNet18 from PIL import Image from torchcam.methods import GradCAM, GradCAMpp, SmoothGradCAMpp, XGradCAM from torchcam.utils import overlay_mask from torchvision.transforms.functional import to_pil_image from tqdm import tqdm from util import transform IMAGE_PATH = os.path.join(os.getcwd(), 'src/examples') RANDOM_IMAGES_TO_SHOW = 10 IMAGES_PER_ROW = 5 CAM_METHODS = { "GradCAM": GradCAM, "GradCAM++": GradCAMpp, "XGradCAM": XGradCAM, "SmoothGradCAM++": SmoothGradCAMpp, } model = CustomResNet18(90).eval() model.load_state_dict(torch.load('src/results/models/best_model.pth', map_location=torch.device('cpu'))) cam_model = copy.deepcopy(model) data_df = pd.read_csv('src/cache/val_df.csv') def load_random_images(): random_images = list() for i in range(RANDOM_IMAGES_TO_SHOW): idx = np.random.randint(0, len(data_df)) p = os.path.join(IMAGE_PATH, data_df.iloc[idx]['path']) p = p.replace('\\', '/') p = p.replace('//', '/') animal = data_df.iloc[idx]['target'] if os.path.exists(p): random_images.append((animal, Image.open(p))) return random_images def get_class_name(idx): return data_df[data_df['encoded_target'] == idx]['target'].values[0] @lru_cache(maxsize=100) def get_translated(to_translate): return GoogleTranslator(source="en", target="de").translate(to_translate) for idx in range(90): get_translated(get_class_name(idx)) def infer_image(image): image = transform(image) image = image.unsqueeze(0) with torch.no_grad(): output = model(image) distribution = torch.nn.functional.softmax(output, dim=1) ret = defaultdict(float) for idx, prob in enumerate(distribution[0]): animal = f'{get_class_name(idx)} ({get_translated(get_class_name(idx))})' ret[animal] = prob.item() return ret def gradcam(image, alpha, cam_method, layer): if layer == 'layer1': layers = [model.resnet.layer1] elif layer == 'layer2': layers = [model.resnet.layer2] elif layer == 'layer3': layers = [model.resnet.layer3] elif layer == 'layer4': layers = [model.resnet.layer4] else: layers = [model.resnet.layer1, model.resnet.layer2, model.resnet.layer3, model.resnet.layer4] model.eval() img_tensor = transform(image).unsqueeze(0) cam = CAM_METHODS[cam_method](model, target_layer=layers) output = model(img_tensor) activation_map = cam(output.squeeze(0).argmax().item(), output) result = overlay_mask(image, to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=alpha) cam.remove_hooks() # height maximal 300px if result.size[1] > 300: ratio = 300 / result.size[1] result = result.resize((int(result.size[0] * ratio), 300)) return result with gr.Blocks() as demo: with open('src/header.md', 'r') as f: markdown_string = f.read() header = gr.Markdown(markdown_string) with gr.Row(variant="panel", equal_height=True): user_image = gr.Image( type="pil", label="Upload Your Own Image", info="You can also upload your own image for prediction.", scale=1, ) with gr.Tab("Predict"): with gr.Column(): output = gr.Label( num_top_classes=3, label="Output", info="Top three predicted classes and their confidences.", scale=5, ) predict_mode_button = gr.Button(value="Predict Animal", label="Predict", info="Click to make a prediction.", scale=1) predict_mode_button.click(fn=infer_image, inputs=[user_image], outputs=output, queue=True) with gr.Tab("Explain"): with gr.Row(): with gr.Column(): cam_method = gr.Radio( list(CAM_METHODS.keys()), label="GradCAM Method", value="GradCAM", interactive=True, scale=2, ) cam_method.description = "Here you can choose the GradCAM method." cam_method.description_place = "left" alpha = gr.Slider( minimum=.1, maximum=.9, value=0.5, interactive=True, step=.1, label="Alpha", scale=1, ) alpha.description = "Here you can choose the alpha value." alpha.description_place = "left" layer = gr.Radio( ["layer1", "layer2", "layer3", "layer4", "all"], label="Layer", value="layer4", interactive=True, scale=2, ) layer.description = "Here you can choose the layer to visualize." layer.description_place = "left" with gr.Column(): output_cam = gr.Image( type="pil", label="GradCAM", info="GradCAM visualization" ) gradcam_mode_button = gr.Button(value="Show GradCAM", label="GradCAM", info="Click to make a prediction.", scale=1) gradcam_mode_button.click(fn=gradcam, inputs=[user_image, alpha, cam_method, layer], outputs=output_cam, queue=True) with gr.Tab("Example Images"): with gr.Column(): placeholder = gr.Markdown("## Example Images") showed_images = list() loaded_images = load_random_images() amount_rows = max(1, (len(loaded_images) // IMAGES_PER_ROW)) if len(loaded_images) == 0: print(f"Could not find any images in {IMAGE_PATH}") amount_rows = 0 for i in range(amount_rows): with gr.Row(): for j in range(IMAGES_PER_ROW): animal, image = loaded_images[i * IMAGES_PER_ROW + j] showed_images.append(gr.Image( value=image, label=animal, type="pil", interactive=False, )) if __name__ == "__main__": demo.queue() demo.launch()