raja5259 commited on
Commit
926a474
·
verified ·
1 Parent(s): 3f08bb8

update misclas_helper

Browse files
Files changed (1) hide show
  1. misclas_helper.py +82 -3
misclas_helper.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import math
3
  import numpy as np
4
  import pandas as pd
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
@@ -24,10 +25,9 @@ from PIL import Image
24
  from pytorch_grad_cam import GradCAM
25
  from pytorch_grad_cam.utils.image import show_cam_on_image
26
 
27
- # Denormalize the data using test mean and std deviation
28
  inv_normalize = transforms.Normalize(
29
- mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
30
- std=[1/0.23, 1/0.23, 1/0.23]
31
  )
32
 
33
 
@@ -138,3 +138,82 @@ def display_cifar_misclassified_data(data: list,
138
 
139
  # Plot the misclassified data
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import math
3
  import numpy as np
4
  import pandas as pd
5
+ import seaborn as sn
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
25
  from pytorch_grad_cam import GradCAM
26
  from pytorch_grad_cam.utils.image import show_cam_on_image
27
 
 
28
  inv_normalize = transforms.Normalize(
29
+ mean=[-0.50/0.23, -0.50/0.23, -0.50/0.23],
30
+ std=[1/0.23, 1/0.23, 1/0.23]
31
  )
32
 
33
 
 
138
 
139
  # Plot the misclassified data
140
 
141
+
142
+ def crop_image_pil2(image): #Crop image with 1:1 output aspect ratio
143
+
144
+ image = Image.fromarray(image)
145
+ print("image type = ", type(Image))
146
+ width, height = image.size
147
+ if width == height:
148
+ return image
149
+ offset = int(abs(height-width)/2)
150
+ if width>height:
151
+ image = image.crop([offset,0,width-offset,height])
152
+ else:
153
+ image = image.crop([0,offset,width,height-offset])
154
+ return image
155
+
156
+ def resize_image_pil2(image, new_width, new_height):
157
+ # Convert to PIL image
158
+ img = crop_image_pil2(image)
159
+ img = Image.fromarray(np.array(img))
160
+ # Get original size
161
+ width, height = img.size
162
+
163
+ # Calculate scale
164
+ width_scale = new_width / width # RAJA see if this can be deleted
165
+ height_scale = new_height / height # RAJA see if this can be deleted
166
+ # Resize
167
+ # resized = img.resize((int(width*width_scale), int(height*height_scale)), Image.NEAREST)
168
+ resized = img.resize((32, 32), Image.NEAREST)
169
+ # Crop to exact size
170
+ return resized
171
+
172
+ def classify_images(list_images, model, device):
173
+ """
174
+ Function to run the model on test set and return misclassified images
175
+ :param model: Network Architecture
176
+ :param device: CPU/GPU
177
+ :param test_loader: DataLoader for test set
178
+ """
179
+
180
+ test_transforms = torchvision.transforms.Compose(
181
+ [
182
+ torchvision.transforms.ToTensor(),
183
+ cifar10_normalization(),
184
+ ]
185
+ )
186
+ # Prepare the model for evaluation i.e. drop the dropout layer
187
+ model.eval()
188
+ # List to store misclassified Images
189
+ classified_data = []
190
+
191
+ # Reset the gradients
192
+ with torch.no_grad():
193
+ # Extract images, labels in a batch
194
+ for image in list_images:
195
+ #print("image type = ", type(image))
196
+ orig_image = image
197
+ if(image is None):
198
+ pred = 10 #This entry indicates none in classes, empty string
199
+ else:
200
+ #print("before resize image shape = ", image.shape)
201
+ image = resize_image_pil2(image, 32, 32)
202
+ image = np.asarray(image)
203
+ #print("numpy image dtype = ", image.dtype)
204
+ #print("before test_transforms image shape = ", image.shape)
205
+ image = test_transforms(image)
206
+ #print("after test_transforms image shape = ", image.shape)
207
+
208
+ image = image.unsqueeze(0)
209
+ #print("after squeeze image shape = ", image.shape)
210
+
211
+ # Get the model prediction on the image
212
+ output = model(image)
213
+
214
+ # Convert the output from one-hot encoding to a value
215
+ pred = output.argmax(dim=1, keepdim=True)
216
+
217
+ classified_data.append((orig_image, pred))
218
+
219
+ return classified_data