# gradioMisClassGradCAMimageInputter import os import torch import torchvision from torchvision import datasets, transforms, utils from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization from pytorch_lightning import LightningModule, Trainer, seed_everything from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.loggers import CSVLogger from torch.optim.lr_scheduler import OneCycleLR from torch.optim.swa_utils import AveragedModel, update_bn from torchmetrics.functional import accuracy import pandas as pd import torch.nn as nn import torch.nn.functional as F import misclas_helper import gradcam_helper import lightningmodel import trainsave_loadtest from misclas_helper import display_cifar_misclassified_data from gradcam_helper import display_gradcam_output from misclas_helper import get_misclassified_data2 from misclas_helper import classify_images from lightningmodel import LitResnet from trainsave_loadtest import ts_lt import numpy as np import gradio as gr from PIL import Image from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from torchvision import datasets, transforms, utils save1_or_load0 = False model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "weights_92.ckpt") # Train and Save Vs Load and Test ''' ts_lt(save1_or_load0, # decision maker for training Vs testing Epochs = 1, # argument for training wt_fname = "/content/weights.ckpt" # argument for testing ) ''' targets = None device = torch.device("cpu") classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') device = torch.device("cpu") # Get the misclassified data from test dataset misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20) ################################################################################################ fileName = None inv_normalize = transforms.Normalize( mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], std=[1/0.23, 1/0.23, 1/0.23] ) def hello(DoYouWantToShowMisClassifiedImages, HowManyImages): if(DoYouWantToShowMisClassifiedImages.lower() == "yes"): fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages) return Image.open(fileName) else: return None misClass_demo = gr.Interface( fn = hello, inputs=[ gr.Textbox(label="Do you want to show misClassified images ?", placeholder="Yes / yes / YES", lines=1), gr.Slider(0, 20, step=5, label = "How many images ?")], outputs=['image'], title="Misclassified Images", description="If your answer to the question is yes, then only it works !", ) ############ targets = None device = torch.device("cpu") classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck', 'No Image') def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency): if(DoYouWantToShowGradCAMMedImages.lower() == "yes"): if(WhichLayer == -1): target_layers = [model.model.resNetLayer2Part2[-1]] elif(WhichLayer == -2): target_layers = [model.model.resNetLayer2Part1[-1]] elif(WhichLayer == -3): target_layers = [model.model.Layer3[-1]] fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70) return Image.open(fileName) gradCAM_demo = gr.Interface( fn=inference, #DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1), gr.Slider(0, 20, step=5, label = "How many images ?"), gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"), gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")], outputs=['image'], title="GradCammed Images", description="If your answer to the question is yes, then only it works !", ) ############ def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9): list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9] classified_data = classify_images(list_images, model.model, device) img_out = [] pred_out = [] for img, pred in classified_data: img_out.append(img) pred_out.append(pred) return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9] imageInputter_demo = gr.Interface( ImageInputter, [ "image","image","image","image","image","image","image","image","image","image" ], [ gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"), gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"), gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"), gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"), gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"), gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"), gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"), gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"), gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"), gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9") ], examples=[ ["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"], [None, None, None, None, "truck.jpg", None, None, None, None], ], title="Max 10 images input Classifier", description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.", ) ############ demo = gr.TabbedInterface( interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo], tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"] ) demo.launch(debug=True)