Spaces:
Runtime error
Runtime error
import os | |
import math | |
import numpy as np | |
import pandas as pd | |
import seaborn as sn | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import matplotlib.pyplot as plt | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pl_bolts.datamodules import CIFAR10DataModule | |
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | |
from pytorch_lightning import LightningModule, Trainer, seed_everything | |
from pytorch_lightning.callbacks import LearningRateMonitor | |
from pytorch_lightning.callbacks.progress import TQDMProgressBar | |
from pytorch_lightning.loggers import CSVLogger | |
from torch.optim.lr_scheduler import OneCycleLR | |
from torch.optim.swa_utils import AveragedModel, update_bn | |
from torchmetrics.functional import accuracy | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from torchvision import datasets, transforms, utils | |
from PIL import Image | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.image import show_cam_on_image | |
# Denormalize the data using test mean and std deviation | |
inv_normalize = transforms.Normalize( | |
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23], | |
std=[1/0.23, 1/0.23, 1/0.23] | |
) | |
def get_misclassified_data2(model, device, count): | |
""" | |
Function to run the model on test set and return misclassified images | |
:param model: Network Architecture | |
:param device: CPU/GPU | |
:param test_loader: DataLoader for test set | |
""" | |
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".") | |
BATCH_SIZE = 256 if torch.cuda.is_available() else 64 | |
NUM_WORKERS = int(os.cpu_count() / 2) | |
train_transforms = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.RandomCrop(32, padding=4), | |
torchvision.transforms.RandomHorizontalFlip(), | |
torchvision.transforms.ToTensor(), | |
cifar10_normalization(), | |
] | |
) | |
test_transforms = torchvision.transforms.Compose( | |
[ | |
torchvision.transforms.ToTensor(), | |
cifar10_normalization(), | |
] | |
) | |
cifar10_dm = CIFAR10DataModule( | |
data_dir=PATH_DATASETS, | |
batch_size=BATCH_SIZE, | |
num_workers=NUM_WORKERS, | |
train_transforms=train_transforms, | |
test_transforms=test_transforms, | |
val_transforms=test_transforms, | |
) | |
cifar10_dm.prepare_data() | |
cifar10_dm.setup() | |
test_loader = cifar10_dm.test_dataloader() | |
# Prepare the model for evaluation i.e. drop the dropout layer | |
model.eval() | |
# List to store misclassified Images | |
misclassified_data = [] | |
# Reset the gradients | |
with torch.no_grad(): | |
# Extract images, labels in a batch | |
for data, target in test_loader: | |
# Migrate the data to the device | |
data, target = data.to(device), target.to(device) | |
# Extract single image, label from the batch | |
for image, label in zip(data, target): | |
# Add batch dimension to the image | |
image = image.unsqueeze(0) | |
# Get the model prediction on the image | |
output = model(image) | |
# Convert the output from one-hot encoding to a value | |
pred = output.argmax(dim=1, keepdim=True) | |
# If prediction is incorrect, append the data | |
if pred != label: | |
misclassified_data.append((image, label, pred)) | |
if len(misclassified_data) > count : | |
break | |
return misclassified_data | |
# Yes - This is important predecessor2 for gradioMisClass | |
def display_cifar_misclassified_data(data: list, | |
classes: list[str], | |
inv_normalize: transforms.Normalize, | |
number_of_samples: int = 10): | |
""" | |
Function to plot images with labels | |
:param data: List[Tuple(image, label)] | |
:param classes: Name of classes in the dataset | |
:param inv_normalize: Mean and Standard deviation values of the dataset | |
:param number_of_samples: Number of images to print | |
""" | |
fig = plt.figure(figsize=(10, 10)) | |
img = None | |
x_count = 5 | |
y_count = 1 if number_of_samples <= 5 else math.floor(number_of_samples / x_count) | |
for i in range(number_of_samples): | |
plt.subplot(y_count, x_count, i + 1) | |
img = data[i][0].squeeze().to('cpu') | |
img = inv_normalize(img) | |
plt.imshow(np.transpose(img, (1, 2, 0))) | |
plt.xticks([]) | |
plt.yticks([]) | |
plt.savefig('imshow_output_misclas.png') | |
return 'imshow_output_misclas.png' | |
# Plot the misclassified data | |