mainakhf commited on
Commit
1493156
·
1 Parent(s): 5a4faf3

Upload 2 files

Browse files
utils/image_classification.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ import torch.optim as optim
5
+ from torch.optim import lr_scheduler
6
+ import torch.backends.cudnn as cudnn
7
+ import numpy as np
8
+ import torchvision
9
+ from torchvision import datasets, models, transforms
10
+ import matplotlib.pyplot as plt
11
+ import time
12
+ import os
13
+ from PIL import Image
14
+ from tempfile import TemporaryDirectory
15
+ import streamlit as st
16
+
17
+ cudnn.benchmark = True
18
+ plt.ion() # interactive mode
19
+
20
+ class classifier():
21
+ def __init__(self):
22
+ self.data_transforms = None
23
+ self.data_dir = None
24
+ self.image_datasets = None
25
+ self.dataloaders = None
26
+ self.dataset_sizes = None
27
+ self.class_names = None
28
+ self.device = None
29
+ self.num_classes = None
30
+ def data_loader(self,path,batch_size=4):
31
+ # Data augmentation and normalization for training
32
+ # Just normalization for validation
33
+ self.data_transforms = {
34
+ 'train': transforms.Compose([
35
+ transforms.RandomResizedCrop(224),
36
+ transforms.RandomHorizontalFlip(),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
39
+ ]),
40
+ 'val': transforms.Compose([
41
+ transforms.Resize(256),
42
+ transforms.CenterCrop(224),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ ]),
46
+ 'test': transforms.Compose([
47
+ transforms.Resize(256),
48
+ transforms.CenterCrop(224),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
51
+ ])
52
+ }
53
+
54
+ self.data_dir = path
55
+ self.image_datasets = {x: datasets.ImageFolder(os.path.join(self.data_dir, x),
56
+ self.data_transforms[x])
57
+ for x in ['train', 'val','test']}
58
+ self.dataloaders = {x: torch.utils.data.DataLoader(self.image_datasets[x], batch_size=batch_size,
59
+ shuffle=True, num_workers=4)
60
+ for x in ['train', 'val','test']}
61
+ self.dataset_sizes = {x: len(self.image_datasets[x]) for x in ['train', 'val','test']}
62
+ self.class_names = self.image_datasets['train'].classes
63
+ self.num_classes = len(self.class_names)
64
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
65
+
66
+ def train(self,model, criterion, optimizer, scheduler, num_epochs=25):
67
+ since = time.time()
68
+
69
+ # Create a temporary directory to save training checkpoints
70
+ with TemporaryDirectory() as tempdir:
71
+ best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
72
+
73
+ torch.save(model.state_dict(), best_model_params_path)
74
+ best_acc = 0.0
75
+
76
+ for epoch in range(num_epochs):
77
+ print(f'Epoch {epoch+1}/{num_epochs}')
78
+ print('-' * 10)
79
+ st.sidebar.subheader(f':blue[Epoch {epoch+1}/{num_epochs}]', divider='blue')
80
+ # st.sidebar.code('-' * 10)
81
+ # Each epoch has a training and validation phase
82
+ for phase in ['train', 'val']:
83
+ if phase == 'train':
84
+ model.train() # Set model to training mode
85
+ else:
86
+ model.eval() # Set model to evaluate mode
87
+
88
+ running_loss = 0.0
89
+ running_corrects = 0
90
+
91
+ # Iterate over data.
92
+ for inputs, labels in self.dataloaders[phase]:
93
+ inputs = inputs.to(self.device)
94
+ labels = labels.to(self.device)
95
+
96
+ # zero the parameter gradients
97
+ optimizer.zero_grad()
98
+
99
+ # forward
100
+ # track history if only in train
101
+ with torch.set_grad_enabled(phase == 'train'):
102
+ outputs = model(inputs)
103
+ _, preds = torch.max(outputs, 1)
104
+ loss = criterion(outputs, labels)
105
+
106
+ # backward + optimize only if in training phase
107
+ if phase == 'train':
108
+ loss.backward()
109
+ optimizer.step()
110
+
111
+ # statistics
112
+ running_loss += loss.item() * inputs.size(0)
113
+ running_corrects += torch.sum(preds == labels.data)
114
+ if phase == 'train':
115
+ scheduler.step()
116
+
117
+ epoch_loss = running_loss / self.dataset_sizes[phase]
118
+ epoch_acc = running_corrects.double() / self.dataset_sizes[phase]
119
+
120
+ print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
121
+ st.sidebar.caption(f':blue[{phase[0].upper() + phase[1:]} Loss:] {epoch_loss:.4f} :blue[ Accuracy:] {epoch_acc:.4f}')
122
+ # deep copy the model
123
+ if phase == 'val' and epoch_acc > best_acc:
124
+ best_acc = epoch_acc
125
+ torch.save(model.state_dict(), best_model_params_path)
126
+
127
+ print()
128
+
129
+ time_elapsed = time.time() - since
130
+ print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
131
+ print(f'Best val Accuracy: {best_acc:4f}')
132
+ st.sidebar.caption(f':green[Training complete in] {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
133
+ st.sidebar.subheader(f':blue[Best val Accuracy:] {best_acc:4f}')
134
+ # load best model weights
135
+ model.load_state_dict(torch.load(best_model_params_path))
136
+ return model
137
+
138
+ def train_model(self,model_name,epochs):
139
+ num_classes = self.num_classes
140
+ if model_name == 'EfficientNet_B0':
141
+ model = models.efficientnet_b0(pretrained=True)
142
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
143
+ # model.classifier[1].out_features = num_classes
144
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
145
+
146
+ elif model_name == 'EfficientNet_B1':
147
+ model = models.efficientnet_b1(pretrained=True)
148
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
149
+ # model.classifier[1].out_features = num_classes
150
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
151
+ elif model_name == 'MnasNet0_5':
152
+ model = models.mnasnet0_5(pretrained=True)
153
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
154
+ # model.classifier[1].out_features = num_classes
155
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
156
+
157
+ elif model_name == 'MnasNet0_75':
158
+ model = models.mnasnet0_75(pretrained=True)
159
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
160
+ # model.classifier[1].out_features = num_classes
161
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
162
+
163
+
164
+ elif model_name == 'MnasNet1_0':
165
+ model = models.mnasnet1_0(pretrained=True)
166
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
167
+ # model.classifier[1].out_features = num_classes
168
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
169
+
170
+
171
+ elif model_name == 'MobileNet_v2':
172
+ model = models.mobilenet_v2(pretrained=True)
173
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
174
+ # model.classifier[1].out_features = num_classes
175
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
176
+
177
+
178
+ elif model_name == 'MobileNet_v3_small':
179
+ model = models.mobilenet_v3_small(pretrained=True)
180
+ model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
181
+ # model.classifier[3].out_features = num_classes
182
+ optimizer = torch.optim.SGD(model.classifier[3].parameters(), lr=0.001, momentum=0.9)
183
+
184
+
185
+ elif model_name == 'MobileNet_v3_large':
186
+ model = models.mobilenet_v3_large(pretrained=True)
187
+ model.classifier[3] = nn.Linear(model.classifier[3].in_features, num_classes)
188
+ # model.classifier[3].out_features = num_classes
189
+ optimizer = torch.optim.SGD(model.classifier[3].parameters(), lr=0.001, momentum=0.9)
190
+
191
+
192
+ elif model_name == 'RegNet_y_400mf':
193
+ model = models.regnet_y_400mf(pretrained=True)
194
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
195
+ # model.fc.out_features = num_classes
196
+ optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
197
+
198
+
199
+ elif model_name == 'ShuffleNet_v2_x0_5':
200
+ model = models.shufflenet_v2_x0_5(pretrained=True)
201
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
202
+ # model.fc.out_features = num_classes
203
+ optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
204
+
205
+
206
+ elif model_name == 'ShuffleNet_v2_x1_0':
207
+ model = models.shufflenet_v2_x1_0(pretrained=True)
208
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
209
+ # model.fc.out_features = num_classes
210
+ optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
211
+
212
+
213
+ elif model_name == 'ShuffleNet_v2_x1_5':
214
+ model = models.shufflenet_v2_x1_5(pretrained=True)
215
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
216
+ # model.fc.out_features = num_classes
217
+ optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
218
+
219
+
220
+ elif model_name == 'SqueezeNet 1_0':
221
+ model = models.squeezenet1_0(pretrained=True)
222
+ model.classifier[1] = nn.Conv2d(model.classifier[1].in_channels, num_classes,model.classifier[1].kernel_size, model.classifier[1].stride)
223
+ # model.classifier[1].out_channels = num_classes
224
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
225
+
226
+
227
+ elif model_name == 'SqueezeNet 1_1':
228
+ model = models.squeezenet1_1(pretrained=True)
229
+ model.classifier[1] = nn.Conv2d(model.classifier[1].in_channels, num_classes,model.classifier[1].kernel_size, model.classifier[1].stride)
230
+ # model.classifier[1].out_channels = num_classes
231
+ optimizer = torch.optim.SGD(model.classifier[1].parameters(), lr=0.001, momentum=0.9)
232
+
233
+ exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
234
+ criterion = nn.CrossEntropyLoss()
235
+ model_ft = self.train(model, criterion, optimizer, exp_lr_scheduler,
236
+ num_epochs=epochs)
237
+ torch.save(model.state_dict(), 'model.pt')
238
+ return model_ft
239
+
240
+ def imshow(self,inp, title=None):
241
+ """Display image for Tensor."""
242
+ inp = inp.numpy().transpose((1, 2, 0))
243
+ mean = np.array([0.485, 0.456, 0.406])
244
+ std = np.array([0.229, 0.224, 0.225])
245
+ inp = std * inp + mean
246
+ inp = np.clip(inp, 0, 1)
247
+ plt.imshow(inp)
248
+ if title is not None:
249
+ plt.title(title)
250
+ plt.pause(0.001)
251
+
252
+ def visualize_model(self,model, num_images=6):
253
+ was_training = model.training
254
+ model.eval()
255
+ images_so_far = 0
256
+ fig = plt.figure()
257
+
258
+ with torch.no_grad():
259
+ for i, (inputs, labels) in enumerate(self.dataloaders['val']):
260
+ inputs = inputs.to(self.device)
261
+ labels = labels.to(self.device)
262
+
263
+ outputs = model(inputs)
264
+ _, preds = torch.max(outputs, 1)
265
+
266
+ for j in range(inputs.size()[0]):
267
+ images_so_far += 1
268
+ ax = plt.subplot(num_images//2, 2, images_so_far)
269
+ ax.axis('off')
270
+ ax.set_title(f'predicted: {self.class_names[preds[j]]}')
271
+ self.imshow(inputs.cpu().data[j])
272
+
273
+ if images_so_far == num_images:
274
+ model.train(mode=was_training)
275
+ return
276
+ model.train(mode=was_training)
277
+
278
+ def pytorch_predict(self,model):
279
+ '''
280
+ Make prediction from a pytorch model
281
+ '''
282
+ # set model to evaluate model
283
+
284
+ model.eval()
285
+
286
+ y_true = torch.tensor([], dtype=torch.long, device=self.device)
287
+ all_outputs = torch.tensor([], device=self.device)
288
+
289
+ # deactivate autograd engine and reduce memory usage and speed up computations
290
+ with torch.no_grad():
291
+ for data in self.dataloaders['test']:
292
+ inputs = [i.to(self.device) for i in data[:-1]]
293
+ labels = data[-1].to(self.device)
294
+
295
+ outputs = model(*inputs)
296
+ y_true = torch.cat((y_true, labels), 0)
297
+ all_outputs = torch.cat((all_outputs, outputs), 0)
298
+
299
+ y_true = y_true.cpu().numpy()
300
+ _, y_pred = torch.max(all_outputs, 1)
301
+ y_pred = y_pred.cpu().numpy()
302
+ y_pred_prob = F.softmax(all_outputs, dim=1).cpu().numpy()
303
+
304
+ return y_true, y_pred, y_pred_prob
305
+
306
+
307
+
utils/object_detection.py ADDED
File without changes