Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- CUSTOMRESNET.py +88 -12
- incorrect_images.py +65 -0
- visualize.py +70 -0
CUSTOMRESNET.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1 |
-
import torch.nn.functional as F
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
class MyModel(nn.Module):
|
6 |
def __init__(self):
|
@@ -72,19 +78,89 @@ class MyModel(nn.Module):
|
|
72 |
x=x+R2
|
73 |
|
74 |
x = self.maxpool(x)
|
75 |
-
|
76 |
-
|
77 |
-
# squeeze the tensor to size 512x
|
78 |
-
#x = x.squeeze(dim=[2, 3])
|
79 |
x = x.squeeze(dim=2)
|
80 |
x = x.squeeze(dim=2)
|
81 |
-
#x = x.squeeze(dim=2).squeeze(dim=3)
|
82 |
-
#x = x.squeeze(dim=2)
|
83 |
-
|
84 |
-
#x = x.view(512, 10)
|
85 |
-
|
86 |
x = self.fc(x)
|
87 |
-
|
88 |
x = x.view(-1, 10)
|
89 |
-
|
90 |
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from torch import optim
|
5 |
+
from pytorch_lightning import LightningModule
|
6 |
+
from torchmetrics import Accuracy
|
7 |
+
from utils.visualize import find_lr
|
8 |
+
|
9 |
+
|
10 |
|
11 |
class MyModel(nn.Module):
|
12 |
def __init__(self):
|
|
|
78 |
x=x+R2
|
79 |
|
80 |
x = self.maxpool(x)
|
81 |
+
|
|
|
|
|
|
|
82 |
x = x.squeeze(dim=2)
|
83 |
x = x.squeeze(dim=2)
|
|
|
|
|
|
|
|
|
|
|
84 |
x = self.fc(x)
|
|
|
85 |
x = x.view(-1, 10)
|
86 |
+
|
87 |
return x
|
88 |
+
|
89 |
+
|
90 |
+
class Model(LightningModule):
|
91 |
+
def __init__(self, dataset,max_epochs=24):
|
92 |
+
super(Model, self).__init__()
|
93 |
+
|
94 |
+
self.dataset = dataset
|
95 |
+
self.network= MyModel()
|
96 |
+
self.criterion = nn.CrossEntropyLoss()
|
97 |
+
self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
|
98 |
+
self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
|
99 |
+
|
100 |
+
self.max_epochs = max_epochs
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
return self.network(x)
|
104 |
+
|
105 |
+
def common_step(self, batch, mode):
|
106 |
+
x, y = batch
|
107 |
+
logits = self.forward(x)
|
108 |
+
loss = self.criterion(logits, y)
|
109 |
+
|
110 |
+
acc_metric = getattr(self, f'{mode}_accuracy')
|
111 |
+
acc_metric(logits, y)
|
112 |
+
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def training_step(self, batch, batch_idx):
|
116 |
+
loss = self.common_step(batch, 'train')
|
117 |
+
self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True)
|
118 |
+
self.log("train_acc", self.train_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
119 |
+
return loss
|
120 |
+
|
121 |
+
def validation_step(self, batch, batch_idx):
|
122 |
+
loss = self.common_step(batch, 'val')
|
123 |
+
self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
|
124 |
+
self.log("val_acc", self.val_accuracy, on_epoch=True, prog_bar=True, logger=True)
|
125 |
+
return loss
|
126 |
+
|
127 |
+
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
128 |
+
if isinstance(batch, list):
|
129 |
+
x, _ = batch
|
130 |
+
else:
|
131 |
+
x = batch
|
132 |
+
return self.forward(x)
|
133 |
+
|
134 |
+
def configure_optimizers(self):
|
135 |
+
optimizer = optim.Adam(self.parameters(), lr=1e-7, weight_decay=1e-2)
|
136 |
+
best_lr = find_lr(self, self.train_dataloader(), optimizer, self.criterion)
|
137 |
+
scheduler = optim.lr_scheduler.OneCycleLR(
|
138 |
+
optimizer,
|
139 |
+
max_lr=best_lr,
|
140 |
+
steps_per_epoch=len(self.dataset.train_loader),
|
141 |
+
epochs=self.max_epochs,
|
142 |
+
pct_start=5/self.max_epochs,
|
143 |
+
div_factor=100,
|
144 |
+
three_phase=False,
|
145 |
+
final_div_factor=100,
|
146 |
+
anneal_strategy='linear'
|
147 |
+
)
|
148 |
+
return {
|
149 |
+
'optimizer': optimizer,
|
150 |
+
'lr_scheduler': {
|
151 |
+
"scheduler": scheduler,
|
152 |
+
"interval": "step",
|
153 |
+
}
|
154 |
+
}
|
155 |
+
|
156 |
+
def prepare_data(self):
|
157 |
+
self.dataset.download()
|
158 |
+
|
159 |
+
def train_dataloader(self):
|
160 |
+
return self.dataset.train_loader
|
161 |
+
|
162 |
+
def val_dataloader(self):
|
163 |
+
return self.dataset.test_loader
|
164 |
+
|
165 |
+
def predict_dataloader(self):
|
166 |
+
return self.val_dataloader()
|
incorrect_images.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
from pytorch_lightning import Trainer
|
5 |
+
from pytorch_lightning.callbacks import ModelSummary, LearningRateMonitor
|
6 |
+
|
7 |
+
from .visualize import plot_examples, get_cam_visualisation, get_incorrect_preds
|
8 |
+
|
9 |
+
|
10 |
+
class incorrect(object):
|
11 |
+
def __init__(self, model, max_epochs=None, precision="32-true"):
|
12 |
+
self.model = model
|
13 |
+
self.dataset = model.dataset
|
14 |
+
self.incorrect_preds = None
|
15 |
+
self.grad_cam = None
|
16 |
+
self.trainer = Trainer(callbacks=[ModelSummary(max_depth=10), LearningRateMonitor(logging_interval='step')],
|
17 |
+
max_epochs=max_epochs or model.max_epochs, precision=precision)
|
18 |
+
self.incorrect_preds = None
|
19 |
+
self.incorrect_preds_pd = None
|
20 |
+
self.grad_cam = None
|
21 |
+
|
22 |
+
def execute(self):
|
23 |
+
self.trainer.fit(self.model)
|
24 |
+
|
25 |
+
def get_incorrect_preds(self):
|
26 |
+
self.incorrect_preds = defaultdict(list)
|
27 |
+
incorrect_images = list()
|
28 |
+
processed = 0
|
29 |
+
results = self.trainer.predict(self.model, self.model.predict_dataloader())
|
30 |
+
for (data, target), pred in zip(self.model.predict_dataloader(), results):
|
31 |
+
ind, pred_, truth = get_incorrect_preds(pred, target)
|
32 |
+
self.incorrect_preds["indices"] += [x + processed for x in ind]
|
33 |
+
incorrect_images += data[ind]
|
34 |
+
self.incorrect_preds["ground_truths"] += truth
|
35 |
+
self.incorrect_preds["predicted_vals"] += pred_
|
36 |
+
processed += len(data)
|
37 |
+
self.incorrect_preds_pd = pd.DataFrame(self.incorrect_preds)
|
38 |
+
self.incorrect_preds["images"] = incorrect_images
|
39 |
+
|
40 |
+
def show_incorrect(self, cams=False, target_layer=None):
|
41 |
+
if self.incorrect_preds is None:
|
42 |
+
self.get_incorrect_preds()
|
43 |
+
|
44 |
+
images = list()
|
45 |
+
labels = list()
|
46 |
+
|
47 |
+
for i in range(20):
|
48 |
+
image = self.incorrect_preds["images"][i]
|
49 |
+
pred = self.incorrect_preds["predicted_vals"][i]
|
50 |
+
truth = self.incorrect_preds["ground_truths"][i]
|
51 |
+
|
52 |
+
if cams:
|
53 |
+
image = get_cam_visualisation(self.model, self.dataset, image, pred, target_layer)
|
54 |
+
else:
|
55 |
+
image = self.dataset.show_transform(image).cpu()
|
56 |
+
|
57 |
+
if self.dataset.classes is not None:
|
58 |
+
pred = f'{pred}:{self.dataset.classes[pred]}'
|
59 |
+
truth = f'{truth}:{self.dataset.classes[truth]}'
|
60 |
+
label = f'{pred}/{truth}'
|
61 |
+
|
62 |
+
images.append(image)
|
63 |
+
labels.append(label)
|
64 |
+
|
65 |
+
plot_examples(images, labels, figsize=(10, 8))
|
visualize.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchinfo
|
3 |
+
from torch_lr_finder import LRFinder
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from pytorch_grad_cam import GradCAM
|
6 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
7 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
8 |
+
|
9 |
+
SEED = 42
|
10 |
+
DEVICE = None
|
11 |
+
|
12 |
+
|
13 |
+
def get_device():
|
14 |
+
global DEVICE
|
15 |
+
if DEVICE is not None:
|
16 |
+
return DEVICE
|
17 |
+
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
DEVICE = "cuda"
|
20 |
+
elif torch.backends.mps.is_available():
|
21 |
+
DEVICE = "mps"
|
22 |
+
else:
|
23 |
+
DEVICE = "cpu"
|
24 |
+
print("Device Selected:", DEVICE)
|
25 |
+
return DEVICE
|
26 |
+
|
27 |
+
|
28 |
+
def seed(seed=SEED):
|
29 |
+
torch.manual_seed(seed)
|
30 |
+
if get_device() == 'cuda':
|
31 |
+
torch.cuda.manual_seed(seed)
|
32 |
+
|
33 |
+
|
34 |
+
def plot_examples(images, labels, figsize=None, n=20):
|
35 |
+
_ = plt.figure(figsize=figsize)
|
36 |
+
|
37 |
+
for i in range(n):
|
38 |
+
plt.subplot(4, n//4, i + 1)
|
39 |
+
plt.tight_layout()
|
40 |
+
image = images[i]
|
41 |
+
plt.imshow(image, cmap='gray')
|
42 |
+
label = labels[i]
|
43 |
+
plt.title(str(label))
|
44 |
+
plt.xticks([])
|
45 |
+
plt.yticks([])
|
46 |
+
|
47 |
+
|
48 |
+
def find_lr(model, data_loader, optimizer, criterion):
|
49 |
+
lr_finder = LRFinder(model, optimizer, criterion)
|
50 |
+
lr_finder.range_test(data_loader, end_lr=0.1, num_iter=200, step_mode='exp')
|
51 |
+
_, best_lr = lr_finder.plot()
|
52 |
+
lr_finder.reset()
|
53 |
+
return best_lr
|
54 |
+
|
55 |
+
|
56 |
+
def get_incorrect_preds(prediction, labels):
|
57 |
+
prediction = prediction.argmax(dim=1)
|
58 |
+
indices = prediction.ne(labels).nonzero().reshape(-1).tolist()
|
59 |
+
return indices, prediction[indices].tolist(), labels[indices].tolist()
|
60 |
+
|
61 |
+
|
62 |
+
def get_cam_visualisation(model, dataset, input_tensor, label, target_layer, use_cuda=False):
|
63 |
+
grad_cam = GradCAM(model=model, target_layers=[target_layer], use_cuda=use_cuda)
|
64 |
+
targets = [ClassifierOutputTarget(label)]
|
65 |
+
grayscale_cam = grad_cam(input_tensor=input_tensor.unsqueeze(0), targets=targets)
|
66 |
+
grayscale_cam = grayscale_cam[0, :]
|
67 |
+
output = show_cam_on_image(dataset.show_transform(input_tensor).cpu().numpy(), grayscale_cam,use_rgb=True)
|
68 |
+
return output
|
69 |
+
|
70 |
+
|