eraV2S14_raj / app.py
raja5259's picture
update app.py
26ce035 verified
# gradioMisClassGradCAMimageInputter
import os
import torch
import torchvision
from torchvision import datasets, transforms, utils
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torchmetrics.functional import accuracy
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import misclas_helper
import gradcam_helper
import lightningmodel
import trainsave_loadtest
from misclas_helper import display_cifar_misclassified_data
from gradcam_helper import display_gradcam_output
from misclas_helper import get_misclassified_data2
from misclas_helper import classify_images
from lightningmodel import LitResnet
from trainsave_loadtest import ts_lt
import numpy as np
import gradio as gr
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision import datasets, transforms, utils
save1_or_load0 = False
model, trainer = ts_lt(save1_or_load0, Epochs = 26, wt_fname = "weights_92.ckpt") # Train and Save Vs Load and Test
'''
ts_lt(save1_or_load0, # decision maker for training Vs testing
Epochs = 1, # argument for training
wt_fname = "/content/weights.ckpt" # argument for testing
)
'''
targets = None
device = torch.device("cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
device = torch.device("cpu")
# Get the misclassified data from test dataset
misclassified_data = misclas_helper.get_misclassified_data2(model, device, 20)
################################################################################################
fileName = None
inv_normalize = transforms.Normalize(
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
std=[1/0.23, 1/0.23, 1/0.23]
)
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages):
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"):
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages)
return Image.open(fileName)
else:
return None
misClass_demo = gr.Interface(
fn = hello,
inputs=[ gr.Textbox(label="Do you want to show misClassified images ?", placeholder="Yes / yes / YES", lines=1),
gr.Slider(0, 20, step=5, label = "How many images ?")],
outputs=['image'],
title="Misclassified Images",
description="If your answer to the question is yes, then only it works !",
)
############
targets = None
device = torch.device("cpu")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck', 'No Image')
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency):
if(DoYouWantToShowGradCAMMedImages.lower() == "yes"):
if(WhichLayer == -1):
target_layers = [model.model.resNetLayer2Part2[-1]]
elif(WhichLayer == -2):
target_layers = [model.model.resNetLayer2Part1[-1]]
elif(WhichLayer == -3):
target_layers = [model.model.Layer3[-1]]
fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70)
return Image.open(fileName)
gradCAM_demo = gr.Interface(
fn=inference,
#DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency
inputs=[ gr.Textbox(label="Do you want to show gradCammed images ?", placeholder="Yes / yes / YES", lines=1),
gr.Slider(0, 20, step=5, label = "How many images ?"),
gr.Slider(-3, -1, value = -1, step=1, label = "Which layer ?"),
gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")],
outputs=['image'],
title="GradCammed Images",
description="If your answer to the question is yes, then only it works !",
)
############
def ImageInputter(img0, img1, img2, img3, img4, img5, img6, img7, img8, img9):
list_images = [img0, img1, img2, img3, img4, img5, img6, img7, img8, img9]
classified_data = classify_images(list_images, model.model, device)
img_out = []
pred_out = []
for img, pred in classified_data:
img_out.append(img)
pred_out.append(pred)
return classes[pred_out[0]], img_out[0], classes[pred_out[1]], img_out[1], classes[pred_out[2]], img_out[2], classes[pred_out[3]], img_out[3], classes[pred_out[4]], img_out[4], classes[pred_out[5]], img_out[5], classes[pred_out[6]], img_out[6], classes[pred_out[7]], img_out[7], classes[pred_out[8]], img_out[8], classes[pred_out[9]], img_out[9]
imageInputter_demo = gr.Interface(
ImageInputter,
[
"image","image","image","image","image","image","image","image","image","image"
],
[
gr.Textbox("text", label = "pred 0"), gr.Image("image", label = "img 0"),
gr.Textbox("text", label = "pred 1"), gr.Image("image", label = "img 1"),
gr.Textbox("text", label = "pred 2"), gr.Image("image", label = "img 2"),
gr.Textbox("text", label = "pred 3"), gr.Image("image", label = "img 3"),
gr.Textbox("text", label = "pred 4"), gr.Image("image", label = "img 4"),
gr.Textbox("text", label = "pred 5"), gr.Image("image", label = "img 5"),
gr.Textbox("text", label = "pred 6"), gr.Image("image", label = "img 6"),
gr.Textbox("text", label = "pred 7"), gr.Image("image", label = "img 7"),
gr.Textbox("text", label = "pred 8"), gr.Image("image", label = "img 8"),
gr.Textbox("text", label = "pred 9"), gr.Image("image", label = "img 9")
],
examples=[
["bird.jpg", "car.jpg", "cat.jpg", "deer.jpg", "dog.jpg", "frog.jpg", "horse.jpg", "plane.jpg", "ship.jpg"],
[None, None, None, None, "truck.jpg", None, None, None, None],
],
title="Max 10 images input Classifier",
description="Here's a sample image inputter. Allows you to feed in 10 images and display them with classification result. You may drag and drop images from bottom examples to the input feeders. You may copy and paste the images from examples to the input feeders. You may double click on a row of images from examples to get them filled in input feeders.",
)
############
demo = gr.TabbedInterface(
interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo],
tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"]
)
demo.launch(debug=True)