csisc commited on
Commit
d1a9313
·
1 Parent(s): d139de5

Create GetMismatch.py

Browse files
Files changed (1) hide show
  1. GetMismatch.py +14 -0
GetMismatch.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ incorrect_examples = []
2
+ incorrect_labels = []
3
+ incorrect_pred = []
4
+ model.eval()
5
+ for data,target in test_loader:
6
+
7
+ data , target = data.to(device), target.to(device)
8
+ output = model(data) # shape = torch.Size([batch_size, 10])
9
+ pred = output.argmax(dim=1, keepdim=True) #pred will be a 2d tensor of shape [batch_size,1]
10
+ idxs_mask = ((pred == target.view_as(pred))==False).view(-1)
11
+ if idxs_mask.numel(): #if index masks is non-empty append the correspoding data value in incorrect examples
12
+ incorrect_examples.append(data[idxs_mask].squeeze().cpu().numpy())
13
+ incorrect_labels.append(target[idxs_mask].cpu().numpy()) #the corresponding target to the misclassified image
14
+ incorrect_pred.append(pred[idxs_mask].squeeze().cpu().numpy()) #the corresponiding predicted class of the misclassified image