Create GetMismatch.py
Browse files- GetMismatch.py +14 -0
GetMismatch.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
incorrect_examples = []
|
2 |
+
incorrect_labels = []
|
3 |
+
incorrect_pred = []
|
4 |
+
model.eval()
|
5 |
+
for data,target in test_loader:
|
6 |
+
|
7 |
+
data , target = data.to(device), target.to(device)
|
8 |
+
output = model(data) # shape = torch.Size([batch_size, 10])
|
9 |
+
pred = output.argmax(dim=1, keepdim=True) #pred will be a 2d tensor of shape [batch_size,1]
|
10 |
+
idxs_mask = ((pred == target.view_as(pred))==False).view(-1)
|
11 |
+
if idxs_mask.numel(): #if index masks is non-empty append the correspoding data value in incorrect examples
|
12 |
+
incorrect_examples.append(data[idxs_mask].squeeze().cpu().numpy())
|
13 |
+
incorrect_labels.append(target[idxs_mask].cpu().numpy()) #the corresponding target to the misclassified image
|
14 |
+
incorrect_pred.append(pred[idxs_mask].squeeze().cpu().numpy()) #the corresponiding predicted class of the misclassified image
|