Spaces:
Runtime error
Runtime error
# gradioMisClassGradCAMimageInputter | |
import os | |
import math | |
import numpy as np | |
import pandas as pd | |
import seaborn as sn | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from IPython.core.display import display | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | |
from pytorch_lightning import LightningModule, Trainer, seed_everything | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
from pytorch_lightning.loggers import CSVLogger | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.optim.swa_utils import AveragedModel, update_bn | |
from torchmetrics.functional import accuracy | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from torchvision import datasets, transforms, utils | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
fileName = None | |
def hello(DoYouWantToShowMisClassifiedImages, HowManyImages): | |
if(DoYouWantToShowMisClassifiedImages.lower() == "yes"): | |
fileName = misclas_helper.display_cifar_misclassified_data(misclassified_data, classes, inv_normalize, number_of_samples=HowManyImages) | |
return Image.open(fileName) | |
else: | |
return None | |
misClass_demo = gr.Interface( | |
fn = hello, | |
inputs=['text', gr.Slider(0, 20, step=5)], | |
outputs=['image'], | |
title="Misclasseified Images", | |
description="If your answer to the question DoYouWantToShowMisClassifiedImages is yes, then only it works.", | |
) | |
############ | |
targets = None | |
device = torch.device("cpu") | |
classes = ('plane', 'car', 'bird', 'cat', 'deer', | |
'dog', 'frog', 'horse', 'ship', 'truck') | |
def inference(DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency): | |
if(DoYouWantToShowGradCAMMedImages.lower() == "yes"): | |
if(WhichLayer == -1): | |
target_layers = [model.model.resNetLayer2Part2[-1]] | |
elif(WhichLayer == -2): | |
target_layers = [model.model.resNetLayer2Part1[-1]] | |
elif(WhichLayer == -3): | |
target_layers = [model.model.Layer3[-1]] | |
fileName = gradcam_helper.display_gradcam_output(misclassified_data, classes, inv_normalize, model.model, target_layers, targets, number_of_samples=HowManyImages, transparency=0.70) | |
return Image.open(fileName) | |
gradCAM_demo = gr.Interface( | |
fn=inference, | |
#DoYouWantToShowGradCAMMedImages, HowManyImages, WhichLayer, transparency | |
inputs=['text', gr.Slider(0, 20, step=5), gr.Slider(-3, -1, value = -1, step=1), gr.Slider(0, 1, value = 0.7, label = "Overall Opacity of the Overlay")], | |
outputs=['image'], | |
title="GradCammd Images", | |
description="If your answer to the question DoYouWantToShowGradCAMMedImages is yes, then only it works.", | |
) | |
############ | |
def ImageInputter(img1, img2, img3, img4, img5, img6, img7, img8, img9, img10): | |
return img1, img2, img3, img4, img5, img6, img7, img8, img9, img10 | |
imageInputter_demo = gr.Interface( | |
ImageInputter, | |
[ | |
"image","image","image","image","image","image","image","image","image","image" | |
], | |
[ | |
"image","image","image","image","image","image","image","image","image","image" | |
], | |
examples=[ | |
["bird.jpg", "car.jpg", "cat.jpg"], | |
["deer.jpg", "dog.jpg", "frog.jpg"], | |
["horse.jpg", "plane.jpg", "ship.jpg"], | |
[None, "truck.jpg", None], | |
], | |
title="Max 10 images input", | |
description="Here's a sample image inputter. Allows you to feed in 10 images and display them. You may drag and drop images from bottom examples to the input feeders", | |
) | |
############ | |
demo = gr.TabbedInterface( | |
interface_list = [misClass_demo, gradCAM_demo, imageInputter_demo], | |
tab_names = ["MisClassified Images", "GradCAMMed Images", "10 images inputter"] | |
) | |
demo.launch(debug=True) |