SahithiR commited on
Commit
1682e68
·
1 Parent(s): 42b2b5d

Upload 3 files

Browse files
Files changed (3) hide show
  1. CUSTOMRESNET.py +88 -12
  2. incorrect_images.py +65 -0
  3. visualize.py +70 -0
CUSTOMRESNET.py CHANGED
@@ -1,6 +1,12 @@
1
- import torch.nn.functional as F
2
  import torch
3
  import torch.nn as nn
 
 
 
 
 
 
 
4
 
5
  class MyModel(nn.Module):
6
  def __init__(self):
@@ -72,19 +78,89 @@ class MyModel(nn.Module):
72
  x=x+R2
73
 
74
  x = self.maxpool(x)
75
- #x = x.randn(512, 1)
76
-
77
- # squeeze the tensor to size 512x
78
- #x = x.squeeze(dim=[2, 3])
79
  x = x.squeeze(dim=2)
80
  x = x.squeeze(dim=2)
81
- #x = x.squeeze(dim=2).squeeze(dim=3)
82
- #x = x.squeeze(dim=2)
83
-
84
- #x = x.view(512, 10)
85
-
86
  x = self.fc(x)
87
-
88
  x = x.view(-1, 10)
89
- x = F.log_softmax(x, dim=-1)
90
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch import optim
5
+ from pytorch_lightning import LightningModule
6
+ from torchmetrics import Accuracy
7
+ from utils.visualize import find_lr
8
+
9
+
10
 
11
  class MyModel(nn.Module):
12
  def __init__(self):
 
78
  x=x+R2
79
 
80
  x = self.maxpool(x)
81
+
 
 
 
82
  x = x.squeeze(dim=2)
83
  x = x.squeeze(dim=2)
 
 
 
 
 
84
  x = self.fc(x)
 
85
  x = x.view(-1, 10)
86
+
87
  return x
88
+
89
+
90
+ class Model(LightningModule):
91
+ def __init__(self, dataset,max_epochs=24):
92
+ super(Model, self).__init__()
93
+
94
+ self.dataset = dataset
95
+ self.network= MyModel()
96
+ self.criterion = nn.CrossEntropyLoss()
97
+ self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
98
+ self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
99
+
100
+ self.max_epochs = max_epochs
101
+
102
+ def forward(self, x):
103
+ return self.network(x)
104
+
105
+ def common_step(self, batch, mode):
106
+ x, y = batch
107
+ logits = self.forward(x)
108
+ loss = self.criterion(logits, y)
109
+
110
+ acc_metric = getattr(self, f'{mode}_accuracy')
111
+ acc_metric(logits, y)
112
+
113
+ return loss
114
+
115
+ def training_step(self, batch, batch_idx):
116
+ loss = self.common_step(batch, 'train')
117
+ self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
118
+ self.log("train_acc", self.train_accuracy, on_epoch=True, prog_bar=True, logger=True)
119
+ return loss
120
+
121
+ def validation_step(self, batch, batch_idx):
122
+ loss = self.common_step(batch, 'val')
123
+ self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
124
+ self.log("val_acc", self.val_accuracy, on_epoch=True, prog_bar=True, logger=True)
125
+ return loss
126
+
127
+ def predict_step(self, batch, batch_idx, dataloader_idx=0):
128
+ if isinstance(batch, list):
129
+ x, _ = batch
130
+ else:
131
+ x = batch
132
+ return self.forward(x)
133
+
134
+ def configure_optimizers(self):
135
+ optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
136
+ best_lr = find_lr(self, self.train_dataloader(), optimizer, self.criterion)
137
+ scheduler = optim.lr_scheduler.OneCycleLR(
138
+ optimizer,
139
+ max_lr=best_lr,
140
+ steps_per_epoch=len(self.dataset.train_loader),
141
+ epochs=self.max_epochs,
142
+ pct_start=5/self.max_epochs,
143
+ div_factor=100,
144
+ three_phase=False,
145
+ final_div_factor=100,
146
+ anneal_strategy='linear'
147
+ )
148
+ return {
149
+ 'optimizer': optimizer,
150
+ 'lr_scheduler': {
151
+ "scheduler": scheduler,
152
+ "interval": "step",
153
+ }
154
+ }
155
+
156
+ def prepare_data(self):
157
+ self.dataset.download()
158
+
159
+ def train_dataloader(self):
160
+ return self.dataset.train_loader
161
+
162
+ def val_dataloader(self):
163
+ return self.dataset.test_loader
164
+
165
+ def predict_dataloader(self):
166
+ return self.val_dataloader()
incorrect_images.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from collections import defaultdict
3
+
4
+ from pytorch_lightning import Trainer
5
+ from pytorch_lightning.callbacks import ModelSummary, LearningRateMonitor
6
+
7
+ from .visualize import plot_examples, get_cam_visualisation, get_incorrect_preds
8
+
9
+
10
+ class incorrect(object):
11
+ def __init__(self, model, max_epochs=None, precision="32-true"):
12
+ self.model = model
13
+ self.dataset = model.dataset
14
+ self.incorrect_preds = None
15
+ self.grad_cam = None
16
+ self.trainer = Trainer(callbacks=[ModelSummary(max_depth=10), LearningRateMonitor(logging_interval='step')],
17
+ max_epochs=max_epochs or model.max_epochs, precision=precision)
18
+ self.incorrect_preds = None
19
+ self.incorrect_preds_pd = None
20
+ self.grad_cam = None
21
+
22
+ def execute(self):
23
+ self.trainer.fit(self.model)
24
+
25
+ def get_incorrect_preds(self):
26
+ self.incorrect_preds = defaultdict(list)
27
+ incorrect_images = list()
28
+ processed = 0
29
+ results = self.trainer.predict(self.model, self.model.predict_dataloader())
30
+ for (data, target), pred in zip(self.model.predict_dataloader(), results):
31
+ ind, pred_, truth = get_incorrect_preds(pred, target)
32
+ self.incorrect_preds["indices"] += [x + processed for x in ind]
33
+ incorrect_images += data[ind]
34
+ self.incorrect_preds["ground_truths"] += truth
35
+ self.incorrect_preds["predicted_vals"] += pred_
36
+ processed += len(data)
37
+ self.incorrect_preds_pd = pd.DataFrame(self.incorrect_preds)
38
+ self.incorrect_preds["images"] = incorrect_images
39
+
40
+ def show_incorrect(self, cams=False, target_layer=None):
41
+ if self.incorrect_preds is None:
42
+ self.get_incorrect_preds()
43
+
44
+ images = list()
45
+ labels = list()
46
+
47
+ for i in range(20):
48
+ image = self.incorrect_preds["images"][i]
49
+ pred = self.incorrect_preds["predicted_vals"][i]
50
+ truth = self.incorrect_preds["ground_truths"][i]
51
+
52
+ if cams:
53
+ image = get_cam_visualisation(self.model, self.dataset, image, pred, target_layer)
54
+ else:
55
+ image = self.dataset.show_transform(image).cpu()
56
+
57
+ if self.dataset.classes is not None:
58
+ pred = f'{pred}:{self.dataset.classes[pred]}'
59
+ truth = f'{truth}:{self.dataset.classes[truth]}'
60
+ label = f'{pred}/{truth}'
61
+
62
+ images.append(image)
63
+ labels.append(label)
64
+
65
+ plot_examples(images, labels, figsize=(10, 8))
visualize.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchinfo
3
+ from torch_lr_finder import LRFinder
4
+ from matplotlib import pyplot as plt
5
+ from pytorch_grad_cam import GradCAM
6
+ from pytorch_grad_cam.utils.image import show_cam_on_image
7
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
8
+
9
+ SEED = 42
10
+ DEVICE = None
11
+
12
+
13
+ def get_device():
14
+ global DEVICE
15
+ if DEVICE is not None:
16
+ return DEVICE
17
+
18
+ if torch.cuda.is_available():
19
+ DEVICE = "cuda"
20
+ elif torch.backends.mps.is_available():
21
+ DEVICE = "mps"
22
+ else:
23
+ DEVICE = "cpu"
24
+ print("Device Selected:", DEVICE)
25
+ return DEVICE
26
+
27
+
28
+ def seed(seed=SEED):
29
+ torch.manual_seed(seed)
30
+ if get_device() == 'cuda':
31
+ torch.cuda.manual_seed(seed)
32
+
33
+
34
+ def plot_examples(images, labels, figsize=None, n=20):
35
+ _ = plt.figure(figsize=figsize)
36
+
37
+ for i in range(n):
38
+ plt.subplot(4, n//4, i + 1)
39
+ plt.tight_layout()
40
+ image = images[i]
41
+ plt.imshow(image, cmap='gray')
42
+ label = labels[i]
43
+ plt.title(str(label))
44
+ plt.xticks([])
45
+ plt.yticks([])
46
+
47
+
48
+ def find_lr(model, data_loader, optimizer, criterion):
49
+ lr_finder = LRFinder(model, optimizer, criterion)
50
+ lr_finder.range_test(data_loader, end_lr=0.1, num_iter=200, step_mode='exp')
51
+ _, best_lr = lr_finder.plot()
52
+ lr_finder.reset()
53
+ return best_lr
54
+
55
+
56
+ def get_incorrect_preds(prediction, labels):
57
+ prediction = prediction.argmax(dim=1)
58
+ indices = prediction.ne(labels).nonzero().reshape(-1).tolist()
59
+ return indices, prediction[indices].tolist(), labels[indices].tolist()
60
+
61
+
62
+ def get_cam_visualisation(model, dataset, input_tensor, label, target_layer, use_cuda=False):
63
+ grad_cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=use_cuda)
64
+ targets = [ClassifierOutputTarget(label)]
65
+ grayscale_cam = grad_cam(input_tensor=input_tensor.unsqueeze(0), targets=targets)
66
+ grayscale_cam = grayscale_cam[0, :]
67
+ output = show_cam_on_image(dataset.show_transform(input_tensor).cpu().numpy(), grayscale_cam,use_rgb=True)
68
+ return output
69
+
70
+