File size: 808 Bytes
d1a9313
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
incorrect_examples = []
incorrect_labels = []
incorrect_pred = []
model.eval()
for data,target in test_loader:

  data , target = data.to(device), target.to(device)
  output = model(data) # shape = torch.Size([batch_size, 10])
  pred = output.argmax(dim=1, keepdim=True) #pred will be a 2d tensor of shape [batch_size,1]
  idxs_mask = ((pred == target.view_as(pred))==False).view(-1)
  if idxs_mask.numel(): #if index masks is non-empty append the correspoding data value in incorrect examples
    incorrect_examples.append(data[idxs_mask].squeeze().cpu().numpy())
    incorrect_labels.append(target[idxs_mask].cpu().numpy()) #the corresponding target to the misclassified image
    incorrect_pred.append(pred[idxs_mask].squeeze().cpu().numpy()) #the corresponiding predicted class of the misclassified image