Spaces:
Runtime error
Runtime error
update misclas_helper
Browse files- misclas_helper.py +82 -3
misclas_helper.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2 |
import math
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
|
|
5 |
import torch
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
@@ -24,10 +25,9 @@ from PIL import Image
|
|
24 |
from pytorch_grad_cam import GradCAM
|
25 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
26 |
|
27 |
-
# Denormalize the data using test mean and std deviation
|
28 |
inv_normalize = transforms.Normalize(
|
29 |
-
|
30 |
-
|
31 |
)
|
32 |
|
33 |
|
@@ -138,3 +138,82 @@ def display_cifar_misclassified_data(data: list,
|
|
138 |
|
139 |
# Plot the misclassified data
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import math
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
+
import seaborn as sn
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
|
|
25 |
from pytorch_grad_cam import GradCAM
|
26 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
27 |
|
|
|
28 |
inv_normalize = transforms.Normalize(
|
29 |
+
mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
|
30 |
+
std=[1/0.23, 1/0.23, 1/0.23]
|
31 |
)
|
32 |
|
33 |
|
|
|
138 |
|
139 |
# Plot the misclassified data
|
140 |
|
141 |
+
|
142 |
+
def crop_image_pil2(image): #Crop image with 1:1 output aspect ratio
|
143 |
+
|
144 |
+
image = Image.fromarray(image)
|
145 |
+
print("image type = ", type(Image))
|
146 |
+
width, height = image.size
|
147 |
+
if width == height:
|
148 |
+
return image
|
149 |
+
offset = int(abs(height-width)/2)
|
150 |
+
if width>height:
|
151 |
+
image = image.crop([offset,0,width-offset,height])
|
152 |
+
else:
|
153 |
+
image = image.crop([0,offset,width,height-offset])
|
154 |
+
return image
|
155 |
+
|
156 |
+
def resize_image_pil2(image, new_width, new_height):
|
157 |
+
# Convert to PIL image
|
158 |
+
img = crop_image_pil2(image)
|
159 |
+
img = Image.fromarray(np.array(img))
|
160 |
+
# Get original size
|
161 |
+
width, height = img.size
|
162 |
+
|
163 |
+
# Calculate scale
|
164 |
+
width_scale = new_width / width # RAJA see if this can be deleted
|
165 |
+
height_scale = new_height / height # RAJA see if this can be deleted
|
166 |
+
# Resize
|
167 |
+
# resized = img.resize((int(width*width_scale), int(height*height_scale)), Image.NEAREST)
|
168 |
+
resized = img.resize((32, 32), Image.NEAREST)
|
169 |
+
# Crop to exact size
|
170 |
+
return resized
|
171 |
+
|
172 |
+
def classify_images(list_images, model, device):
|
173 |
+
"""
|
174 |
+
Function to run the model on test set and return misclassified images
|
175 |
+
:param model: Network Architecture
|
176 |
+
:param device: CPU/GPU
|
177 |
+
:param test_loader: DataLoader for test set
|
178 |
+
"""
|
179 |
+
|
180 |
+
test_transforms = torchvision.transforms.Compose(
|
181 |
+
[
|
182 |
+
torchvision.transforms.ToTensor(),
|
183 |
+
cifar10_normalization(),
|
184 |
+
]
|
185 |
+
)
|
186 |
+
# Prepare the model for evaluation i.e. drop the dropout layer
|
187 |
+
model.eval()
|
188 |
+
# List to store misclassified Images
|
189 |
+
classified_data = []
|
190 |
+
|
191 |
+
# Reset the gradients
|
192 |
+
with torch.no_grad():
|
193 |
+
# Extract images, labels in a batch
|
194 |
+
for image in list_images:
|
195 |
+
#print("image type = ", type(image))
|
196 |
+
orig_image = image
|
197 |
+
if(image is None):
|
198 |
+
pred = 10 #This entry indicates none in classes, empty string
|
199 |
+
else:
|
200 |
+
#print("before resize image shape = ", image.shape)
|
201 |
+
image = resize_image_pil2(image, 32, 32)
|
202 |
+
image = np.asarray(image)
|
203 |
+
#print("numpy image dtype = ", image.dtype)
|
204 |
+
#print("before test_transforms image shape = ", image.shape)
|
205 |
+
image = test_transforms(image)
|
206 |
+
#print("after test_transforms image shape = ", image.shape)
|
207 |
+
|
208 |
+
image = image.unsqueeze(0)
|
209 |
+
#print("after squeeze image shape = ", image.shape)
|
210 |
+
|
211 |
+
# Get the model prediction on the image
|
212 |
+
output = model(image)
|
213 |
+
|
214 |
+
# Convert the output from one-hot encoding to a value
|
215 |
+
pred = output.argmax(dim=1, keepdim=True)
|
216 |
+
|
217 |
+
classified_data.append((orig_image, pred))
|
218 |
+
|
219 |
+
return classified_data
|