Spaces:
Runtime error
Runtime error
# gradioMisClassGradCAMimageInputter | |
import os | |
import torch | |
import torchvision | |
from torchvision import datasets, transforms, utils | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | |
from pytorch_lightning import LightningModule, Trainer, seed_everything | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
from pytorch_lightning.loggers import CSVLogger | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.optim.swa_utils import AveragedModel, update_bn | |
from torchmetrics.functional import accuracy | |
import pandas as pd | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import misclas_helper | |
import gradcam_helper | |
import lightningmodel | |
import trainsave_loadtest | |
from misclas_helper import display_cifar_misclassified_data | |
from gradcam_helper import display_gradcam_output | |
from misclas_helper import get_misclassified_data2 | |
from misclas_helper import classify_images | |
from lightningmodel import LitResnet | |
from trainsave_loadtest import ts_lt | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
from torchvision import datasets, transforms, utils | |
save1_or_load0 = False | |
model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "weights_92.ckpt") # Train and Save Vs Load and Test | |
''' | |
ts_lt(save1_or_load0, # decision maker for training Vs testing | |
Epochs = 1, # argument for training | |
wt_fname = "/content/weights.ckpt" # argument for testing | |
) | |
''' | |
targets = None | |
device = torch.device("cpu") | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck') | |
device = torch.device("cpu") | |
# Get the misclassified data from test dataset | |
misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20) | |
################################################################################################ | |
fileName = None | |
inv_normalize = transforms.Normalize( | |
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
std=[1/0.23, 1/0.23, 1/0.23] | |
) | |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages): | |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"): | |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages) | |
return Image.open(fileName) | |
else: | |
return None | |
misClass_demo = gr.Interface( | |
fn = hello, | |
inputs=[ gr.Textbox(label="Do you want to show misClassified images ?", placeholder="Yes / yes / YES", lines=1), | |
gr.Slider(0, 20, step=5, label = "How many images ?")], | |
outputs=['image'], | |
title="Misclassified Images", | |
description="If your answer to the question is yes, then only it works !", | |
) | |
############ | |
targets = None | |
device = torch.device("cpu") | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck', 'No Image') | |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency): | |
if(DoYouWantToShowGradCAMMedImages.lower() == "yes"): | |
if(WhichLayer == -1): | |
target_layers = [model.model.resNetLayer2Part2[-1]] | |
elif(WhichLayer == -2): | |
target_layers = [model.model.resNetLayer2Part1[-1]] | |
elif(WhichLayer == -3): | |
target_layers = [model.model.Layer3[-1]] | |
fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70) | |
return Image.open(fileName) | |
gradCAM_demo = gr.Interface( | |
fn=inference, | |
#DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency | |
inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1), | |
gr.Slider(0, 20, step=5, label = "How many images ?"), | |
gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"), | |
gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")], | |
outputs=['image'], | |
title="GradCammed Images", | |
description="If your answer to the question is yes, then only it works !", | |
) | |
############ | |
def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9): | |
list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9] | |
classified_data = classify_images(list_images, model.model, device) | |
img_out = [] | |
pred_out = [] | |
for img, pred in classified_data: | |
img_out.append(img) | |
pred_out.append(pred) | |
return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9] | |
imageInputter_demo = gr.Interface( | |
ImageInputter, | |
[ | |
"image","image","image","image","image","image","image","image","image","image" | |
], | |
[ | |
gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"), | |
gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"), | |
gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"), | |
gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"), | |
gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"), | |
gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"), | |
gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"), | |
gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"), | |
gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"), | |
gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9") | |
], | |
examples=[ | |
["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"], | |
[None, None, None, None, "truck.jpg", None, None, None, None], | |
], | |
title="Max 10 images input Classifier", | |
description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.", | |
) | |
############ | |
demo = gr.TabbedInterface( | |
interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo], | |
tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"] | |
) | |
demo.launch(debug=True) |