cycool29 commited on
Commit
b828b8f
·
1 Parent(s): 905a3f3

Upload 8 files

Browse files
handetect/main.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision.transforms import transforms
6
+ from torch.utils.data import DataLoader, random_split, Dataset
7
+ from torchvision.datasets import ImageFolder
8
+ import matplotlib.pyplot as plt
9
+ from models import *
10
+ from scipy.ndimage import gaussian_filter1d
11
+ import numpy as np
12
+
13
+ # Constants
14
+ RANDOM_SEED = 123
15
+ BATCH_SIZE = 32
16
+ NUM_EPOCHS = 100
17
+ LEARNING_RATE = 0.0001
18
+ STEP_SIZE = 10
19
+ GAMMA = 0.5
20
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
+ NUM_PRINT = 100
22
+ DATA_DIR = r"data/train/Task 1"
23
+ NUM_CLASSES = len(os.listdir(DATA_DIR))
24
+
25
+ # Define transformation for preprocessing
26
+ preprocess = transforms.Compose(
27
+ [
28
+ transforms.Resize((64, 64)), # Resize images to 64x64
29
+ transforms.ToTensor(), # Convert to tensor
30
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
31
+ ]
32
+ )
33
+
34
+ augmentation = transforms.Compose(
35
+ [
36
+ transforms.Resize((64, 64)), # Resize images to 64x64
37
+ transforms.RandomHorizontalFlip(p=0.5), # Random horizontal flip
38
+ transforms.RandomRotation(degrees=45), # Random rotation
39
+ transforms.RandomVerticalFlip(p=0.5), # Random vertical flip
40
+ transforms.RandomGrayscale(p=0.1), # Random grayscale
41
+ transforms.ColorJitter(
42
+ brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5
43
+ ), # Random color jitter
44
+ transforms.ToTensor(), # Convert to tensor
45
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize
46
+ ]
47
+ )
48
+
49
+ # Load the dataset using ImageFolder
50
+ original_dataset = ImageFolder(root=DATA_DIR, transform=preprocess)
51
+ augmented_dataset = ImageFolder(root=DATA_DIR, transform=augmentation)
52
+ dataset = original_dataset + augmented_dataset
53
+
54
+ print("Length of dataset: ", len(dataset))
55
+ print("Classes: ", original_dataset.classes)
56
+
57
+
58
+ # Custom dataset class
59
+ class CustomDataset(Dataset):
60
+ def __init__(self, dataset):
61
+ self.data = dataset
62
+
63
+ def __len__(self):
64
+ return len(self.data)
65
+
66
+ def __getitem__(self, idx):
67
+ img, label = self.data[idx]
68
+ return img, label
69
+
70
+
71
+ # Split the dataset into train and validation sets
72
+ train_size = int(0.8 * len(dataset))
73
+ val_size = len(dataset) - train_size
74
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
75
+
76
+ # Create data loaders for the custom dataset
77
+ train_loader = DataLoader(
78
+ CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
79
+ )
80
+ valid_loader = DataLoader(
81
+ CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
82
+ )
83
+
84
+ # Initialize model, criterion, optimizer, and scheduler
85
+ model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
86
+ model = model.to(DEVICE)
87
+ criterion = nn.CrossEntropyLoss()
88
+ # Adam optimizer
89
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
90
+ # ReduceLROnPlateau scheduler
91
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
92
+ optimizer, mode="min", factor=0.1, patience=10, verbose=True
93
+ )
94
+
95
+ # Lists to store training and validation loss history
96
+ TRAIN_LOSS_HIST = []
97
+ VAL_LOSS_HIST = []
98
+ AVG_TRAIN_LOSS_HIST = []
99
+ AVG_VAL_LOSS_HIST = []
100
+ TRAIN_ACC_HIST = []
101
+ VAL_ACC_HIST = []
102
+
103
+ # Training loop
104
+ for epoch in range(NUM_EPOCHS):
105
+ model.train(True) # Set model to training mode
106
+ running_loss = 0.0
107
+ total_train = 0
108
+ correct_train = 0
109
+
110
+ for i, (inputs, labels) in enumerate(train_loader, 0):
111
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
112
+ optimizer.zero_grad()
113
+ outputs = model(inputs)
114
+ loss = criterion(outputs, labels)
115
+ loss.backward()
116
+ optimizer.step()
117
+ running_loss += loss.item()
118
+
119
+ if (i + 1) % NUM_PRINT == 0:
120
+ print(
121
+ "[Epoch %d, Batch %d] Loss: %.6f"
122
+ % (epoch + 1, i + 1, running_loss / NUM_PRINT)
123
+ )
124
+ running_loss = 0.0
125
+
126
+ _, predicted = torch.max(outputs, 1)
127
+ total_train += labels.size(0)
128
+ correct_train += (predicted == labels).sum().item()
129
+
130
+ TRAIN_ACC_HIST.append(correct_train / total_train)
131
+
132
+ TRAIN_LOSS_HIST.append(loss.item())
133
+
134
+ # Calculate the average training loss for the epoch
135
+ avg_train_loss = running_loss / len(train_loader)
136
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
137
+
138
+ # Print average training loss for the epoch
139
+ print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
140
+
141
+ # Learning rate scheduling
142
+ lr_1 = optimizer.param_groups[0]["lr"]
143
+ print("Learning Rate: {:.15f}".format(lr_1))
144
+ scheduler.step(avg_train_loss)
145
+
146
+ # Validation loop
147
+ model.eval() # Set model to evaluation mode
148
+ val_loss = 0.0
149
+ correct_val = 0
150
+ total_val = 0
151
+
152
+ with torch.no_grad():
153
+ for inputs, labels in valid_loader:
154
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
155
+ outputs = model(inputs)
156
+ loss = criterion(outputs, labels)
157
+ val_loss += loss.item()
158
+ # Calculate accuracy
159
+ _, predicted = torch.max(outputs, 1)
160
+ total_val += labels.size(0)
161
+ correct_val += (predicted == labels).sum().item()
162
+
163
+ VAL_LOSS_HIST.append(loss.item())
164
+
165
+ # Calculate the average validation loss for the epoch
166
+ avg_val_loss = val_loss / len(valid_loader)
167
+ AVG_VAL_LOSS_HIST.append(loss.item())
168
+ print("Average Validation Loss: %.6f" % (avg_val_loss))
169
+
170
+ # Calculate the accuracy of validation set
171
+ val_accuracy = correct_val / total_val
172
+ VAL_ACC_HIST.append(val_accuracy)
173
+ print("Validation Accuracy: %.6f" % (val_accuracy))
174
+
175
+ # End of training loop
176
+
177
+ # Save the model
178
+ model_save_path = "model.pth"
179
+ torch.save(model.state_dict(), model_save_path)
180
+ print("Model saved at", model_save_path)
181
+
182
+ print("Generating loss plot...")
183
+ # Make the plot smoother by interpolating the data
184
+ # https://stackoverflow.com/questions/5283649/plot-smooth-line-with-pyplot
185
+ # train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=10)
186
+ # val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=10)
187
+ # plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label='Train Loss')
188
+ # plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label='Validation Loss')
189
+ avg_train_loss_line = gaussian_filter1d(AVG_TRAIN_LOSS_HIST, sigma=2)
190
+ avg_val_loss_line = gaussian_filter1d(AVG_VAL_LOSS_HIST, sigma=2)
191
+ train_loss_line = gaussian_filter1d(TRAIN_LOSS_HIST, sigma=2)
192
+ val_loss_line = gaussian_filter1d(VAL_LOSS_HIST, sigma=2)
193
+ train_acc_line = gaussian_filter1d(TRAIN_ACC_HIST, sigma=2)
194
+ val_acc_line = gaussian_filter1d(VAL_ACC_HIST, sigma=2)
195
+ plt.plot(range(1, NUM_EPOCHS + 1), train_loss_line, label="Train Loss")
196
+ plt.plot(range(1, NUM_EPOCHS + 1), val_loss_line, label="Validation Loss")
197
+ plt.xlabel("Epochs")
198
+ plt.ylabel("Loss")
199
+ plt.legend()
200
+ plt.title("Train Loss and Validation Loss")
201
+ plt.savefig("loss_plot.png")
202
+ plt.clf()
203
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_train_loss_line, label="Average Train Loss")
204
+ plt.plot(range(1, NUM_EPOCHS + 1), avg_val_loss_line, label="Average Validation Loss")
205
+ plt.xlabel("Epochs")
206
+ plt.ylabel("Loss")
207
+ plt.legend()
208
+ plt.title("Average Train Loss and Average Validation Loss")
209
+ plt.savefig("avg_loss_plot.png")
210
+ plt.clf()
211
+ plt.plot(range(1, NUM_EPOCHS + 1), train_acc_line, label="Train Accuracy")
212
+ plt.plot(range(1, NUM_EPOCHS + 1), val_acc_line, label="Validation Accuracy")
213
+ plt.xlabel("Epochs")
214
+ plt.ylabel("Accuracy")
215
+ plt.legend()
216
+ plt.title("Train Accuracy and Validation Accuracy")
217
+ plt.savefig("accuracy_plot.png")
handetect/models.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #######################################################
2
+ # This file stores all the models used in the project.#
3
+ #######################################################
4
+
5
+ import torch
6
+ from torchvision.models import resnet50
7
+ from torchvision.models import resnet18
8
+
9
+ # resnet50
10
+ class Bottleneck(torch.nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
14
+ super(Bottleneck, self).__init__()
15
+ # hmm,ex 1x1 convolution to reduce channels (intermediate channels)
16
+ self.conv1 = torch.nn.Conv2d(
17
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
18
+ )
19
+ self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
20
+ # 3x3 convolution with specified stride
21
+ self.conv2 = torch.nn.Conv2d(
22
+ out_channels, out_channels, kernel_size=3, stride=stride, padding=1
23
+ )
24
+ self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
25
+ # and then leh,1x1 expand back
26
+ self.conv3 = torch.nn.Conv2d(
27
+ out_channels,
28
+ out_channels * self.expansion,
29
+ kernel_size=1,
30
+ stride=1,
31
+ padding=0,
32
+ )
33
+ self.batch_norm3 = torch.nn.BatchNorm2d(out_channels * self.expansion)
34
+
35
+ self.i_downsample = i_downsample
36
+ self.stride = stride
37
+ self.relu = torch.nn.ReLU()
38
+
39
+ ##forward the input x through the network,haiyaa
40
+ def forward(self, x):
41
+ identity = x.clone()
42
+ x = self.relu(self.batch_norm1(self.conv1(x)))
43
+
44
+ x = self.relu(self.batch_norm2(self.conv2(x)))
45
+
46
+ x = self.conv3(x)
47
+ x = self.batch_norm3(x)
48
+
49
+ # downsample if needed
50
+ if self.i_downsample is not None:
51
+ identity = self.i_downsample(identity)
52
+ # add identity
53
+ x += identity
54
+ x = self.relu(x)
55
+
56
+ return x
57
+
58
+
59
+ # we no use this first,but we can just copy this whole class and apply to resnet16 and etc
60
+ class Block(torch.nn.Module):
61
+ expansion = 1
62
+
63
+ def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
64
+ super(Block, self).__init__()
65
+
66
+ self.conv1 = torch.nn.Conv2d(
67
+ in_channels,
68
+ out_channels,
69
+ kernel_size=3,
70
+ padding=1,
71
+ stride=stride,
72
+ bias=False,
73
+ )
74
+ self.batch_norm1 = torch.nn.BatchNorm2d(out_channels)
75
+ self.conv2 = torch.nn.Conv2d(
76
+ out_channels,
77
+ out_channels,
78
+ kernel_size=3,
79
+ padding=1,
80
+ stride=stride,
81
+ bias=False,
82
+ )
83
+ self.batch_norm2 = torch.nn.BatchNorm2d(out_channels)
84
+
85
+ self.i_downsample = i_downsample
86
+ self.stride = stride
87
+ self.relu = torch.nn.ReLU()
88
+
89
+ def forward(self, x):
90
+ identity = x.clone()
91
+
92
+ x = self.relu(self.batch_norm2(self.conv1(x)))
93
+ x = self.batch_norm2(self.conv2(x))
94
+
95
+ if self.i_downsample is not None:
96
+ identity = self.i_downsample(identity)
97
+ print(x.shape)
98
+ print(identity.shape)
99
+ x += identity
100
+ x = self.relu(x)
101
+ return x
102
+
103
+
104
+ class ResNet(torch.nn.Module):
105
+ def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
106
+ super(ResNet, self).__init__()
107
+ self.in_channels = 64
108
+ # intial conv layaer
109
+ self.conv1 = torch.nn.Conv2d(
110
+ num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
111
+ )
112
+ self.batch_norm1 = torch.nn.BatchNorm2d(64)
113
+ self.relu = torch.nn.ReLU()
114
+ self.max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
115
+ # residual block(layers),each block got three three layer,total 4 blocks
116
+ self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
117
+ self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
118
+ self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
119
+ self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
120
+
121
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
122
+ self.fc = torch.nn.Linear(512 * ResBlock.expansion, num_classes)
123
+
124
+ def forward(self, x):
125
+ x = self.relu(self.batch_norm1(self.conv1(x)))
126
+ x = self.max_pool(x)
127
+
128
+ x = self.layer1(x)
129
+ x = self.layer2(x)
130
+ x = self.layer3(x)
131
+ x = self.layer4(x)
132
+
133
+ x = self.avgpool(x)
134
+ x = x.reshape(x.shape[0], -1)
135
+ x = self.fc(x)
136
+
137
+ return x
138
+
139
+ def _make_layer(self, ResBlock, blocks, planes, stride=1):
140
+ # plane is the number of output channel
141
+ ii_downsample = None
142
+ layers = []
143
+
144
+ if stride != 1 or self.in_channels != planes * ResBlock.expansion:
145
+ ii_downsample = torch.nn.Sequential(
146
+ torch.nn.Conv2d(
147
+ self.in_channels,
148
+ planes * ResBlock.expansion,
149
+ kernel_size=1,
150
+ stride=stride,
151
+ ),
152
+ torch.nn.BatchNorm2d(planes * ResBlock.expansion),
153
+ )
154
+
155
+ layers.append(
156
+ ResBlock(
157
+ self.in_channels, planes, i_downsample=ii_downsample, stride=stride
158
+ )
159
+ )
160
+ self.in_channels = planes * ResBlock.expansion
161
+
162
+ for i in range(blocks - 1):
163
+ layers.append(ResBlock(self.in_channels, planes))
164
+
165
+ return torch.nn.Sequential(*layers)
166
+
167
+
168
+ ##list here leh is the number of residual block in each layer
169
+ def ResNet50(num_classes, channels=3):
170
+ return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)
171
+
172
+
173
+ # VGG16 model
174
+ class VGG16(torch.nn.Module):
175
+ def __init__(self, num_classes):
176
+ super().__init__()
177
+
178
+ self.block_1 = torch.nn.Sequential(
179
+ torch.nn.Conv2d(
180
+ in_channels=3,
181
+ out_channels=64,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=1,
185
+ ),
186
+ torch.nn.ReLU(),
187
+ torch.nn.Conv2d(
188
+ in_channels=64,
189
+ out_channels=64,
190
+ kernel_size=(3, 3),
191
+ stride=(1, 1),
192
+ padding=1,
193
+ ),
194
+ torch.nn.ReLU(),
195
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
196
+ )
197
+
198
+ self.block_2 = torch.nn.Sequential(
199
+ torch.nn.Conv2d(
200
+ in_channels=64,
201
+ out_channels=128,
202
+ kernel_size=(3, 3),
203
+ stride=(1, 1),
204
+ padding=1,
205
+ ),
206
+ torch.nn.ReLU(),
207
+ torch.nn.Conv2d(
208
+ in_channels=128,
209
+ out_channels=128,
210
+ kernel_size=(3, 3),
211
+ stride=(1, 1),
212
+ padding=1,
213
+ ),
214
+ torch.nn.ReLU(),
215
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
216
+ )
217
+
218
+ self.block_3 = torch.nn.Sequential(
219
+ torch.nn.Conv2d(
220
+ in_channels=128,
221
+ out_channels=256,
222
+ kernel_size=(3, 3),
223
+ stride=(1, 1),
224
+ padding=1,
225
+ ),
226
+ torch.nn.ReLU(),
227
+ torch.nn.Conv2d(
228
+ in_channels=256,
229
+ out_channels=256,
230
+ kernel_size=(3, 3),
231
+ stride=(1, 1),
232
+ padding=1,
233
+ ),
234
+ torch.nn.ReLU(),
235
+ torch.nn.Conv2d(
236
+ in_channels=256,
237
+ out_channels=256,
238
+ kernel_size=(3, 3),
239
+ stride=(1, 1),
240
+ padding=1,
241
+ ),
242
+ torch.nn.ReLU(),
243
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
244
+ )
245
+
246
+ self.block_4 = torch.nn.Sequential(
247
+ torch.nn.Conv2d(
248
+ in_channels=256,
249
+ out_channels=512,
250
+ kernel_size=(3, 3),
251
+ stride=(1, 1),
252
+ padding=1,
253
+ ),
254
+ torch.nn.ReLU(),
255
+ torch.nn.Conv2d(
256
+ in_channels=512,
257
+ out_channels=512,
258
+ kernel_size=(3, 3),
259
+ stride=(1, 1),
260
+ padding=1,
261
+ ),
262
+ torch.nn.ReLU(),
263
+ torch.nn.Conv2d(
264
+ in_channels=512,
265
+ out_channels=512,
266
+ kernel_size=(3, 3),
267
+ stride=(1, 1),
268
+ padding=1,
269
+ ),
270
+ torch.nn.ReLU(),
271
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
272
+ )
273
+
274
+ self.block_5 = torch.nn.Sequential(
275
+ torch.nn.Conv2d(
276
+ in_channels=512,
277
+ out_channels=512,
278
+ kernel_size=(3, 3),
279
+ stride=(1, 1),
280
+ padding=1,
281
+ ),
282
+ torch.nn.ReLU(),
283
+ torch.nn.Conv2d(
284
+ in_channels=512,
285
+ out_channels=512,
286
+ kernel_size=(3, 3),
287
+ stride=(1, 1),
288
+ padding=1,
289
+ ),
290
+ torch.nn.ReLU(),
291
+ torch.nn.Conv2d(
292
+ in_channels=512,
293
+ out_channels=512,
294
+ kernel_size=(3, 3),
295
+ stride=(1, 1),
296
+ padding=1,
297
+ ),
298
+ torch.nn.ReLU(),
299
+ torch.nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)),
300
+ )
301
+
302
+ height, width = 3, 3
303
+ self.classifier = torch.nn.Sequential(
304
+ torch.nn.Linear(512 * height * width, 4096),
305
+ torch.nn.ReLU(True),
306
+ torch.nn.Dropout(p=0.5),
307
+ torch.nn.Linear(4096, 4096),
308
+ torch.nn.ReLU(True),
309
+ torch.nn.Dropout(p=0.5),
310
+ torch.nn.Linear(4096, num_classes),
311
+ )
312
+
313
+ for m in self.modules():
314
+ if isinstance(m, torch.torch.nn.Conv2d) or isinstance(
315
+ m, torch.torch.nn.Linear
316
+ ):
317
+ torch.nn.init.kaiming_uniform_(
318
+ m.weight, mode="fan_in", nonlinearity="relu"
319
+ )
320
+ if m.bias is not None:
321
+ m.bias.detach().zero_()
322
+
323
+ self.avgpool = torch.nn.AdaptiveAvgPool2d((height, width))
324
+
325
+ def forward(self, x):
326
+ x = self.block_1(x)
327
+ x = self.block_2(x)
328
+ x = self.block_3(x)
329
+ x = self.block_4(x)
330
+ x = self.block_5(x)
331
+ x = self.avgpool(x)
332
+ x = x.view(x.size(0), -1) # flatten
333
+
334
+ logits = self.classifier(x)
335
+ # probas = F.softmax(logits, dim=1)
336
+
337
+ return logits
338
+
339
+
340
+ # ResNet18 model
handetect/predict.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from handetect.models import *
7
+ from torchmetrics import ConfusionMatrix
8
+ import matplotlib.pyplot as plt
9
+ import pathlib
10
+ import sys
11
+
12
+ # Define the path to your model checkpoint
13
+ model_checkpoint_path = "model.pth"
14
+
15
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
+
17
+ NUM_CLASSES = len(
18
+ os.listdir(r"C:\Users\User\Documents\PISTEK\HANDETECT\data\train\Task 1")
19
+ ) # Update with the correct number of classes
20
+
21
+ # Define transformation for preprocessing the input image
22
+ preprocess = transforms.Compose(
23
+ [
24
+ transforms.Resize((64, 64)), # Resize the image to match training input size
25
+ transforms.Grayscale(num_output_channels=3), # Convert the image to grayscale
26
+ transforms.ToTensor(),
27
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # Normalize the image
28
+ ]
29
+ )
30
+
31
+ # Load your model (change this according to your model definition)
32
+ model = resnet18(pretrained=False, num_classes=NUM_CLASSES)
33
+ model.load_state_dict(
34
+ torch.load(model_checkpoint_path, map_location=DEVICE)
35
+ ) # Load the model on the same device
36
+ model.eval()
37
+ model = model.to(DEVICE)
38
+ model.eval()
39
+ torch.set_grad_enabled(False)
40
+
41
+
42
+ def predict_image(image_path, model=model, transform=preprocess):
43
+ # Define images variable to recursively list all the data file in the image_path
44
+ classes = os.listdir(r"C:\Users\User\Documents\PISTEK\HANDETECT\data\train\Task 1")
45
+
46
+ print("---------------------------")
47
+ print("Image path:", image_path)
48
+ image = Image.open(image_path)
49
+ image = transform(image).unsqueeze(0)
50
+ image = image.to(DEVICE)
51
+ output = model(image)
52
+
53
+ # softmax algorithm
54
+ probabilities = torch.softmax(output, dim=1)[0] * 100
55
+
56
+ # Sort the classes by probabilities in descending order
57
+ sorted_classes = sorted(
58
+ zip(classes, probabilities), key=lambda x: x[1], reverse=True
59
+ )
60
+
61
+ # Report the prediction for each class
62
+ print("Probabilities for each class:")
63
+ for class_label, class_prob in sorted_classes:
64
+ class_prob = class_prob.item().__round__(2)
65
+ print(f"{class_label}: {class_prob}%")
66
+
67
+ # Get the predicted class
68
+ predicted_class = sorted_classes[0][0] # Most probable class
69
+ predicted_label = classes.index(predicted_class)
70
+
71
+ # Report the prediction
72
+ print("Predicted class:", predicted_label)
73
+ print("Predicted label:", predicted_class)
74
+ print("---------------------------")
75
+
76
+ return sorted_classes
77
+
78
+
79
+ # # Call the predict_image function
80
+ # predicted_label, sorted_probabilities = predict_image(image_path, model, preprocess)
81
+
82
+ # # Access probabilities for each class in sorted order
83
+ # for class_label, class_prob in sorted_probabilities:
84
+ # print(f"{class_label}: {class_prob}%")
index.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import handetect.predict as predict
3
+
4
+
5
+ def upload_file(files):
6
+ file_paths = [file.name for file in files]
7
+ return file_paths
8
+
9
+
10
+ def process_file(webcam_filepath, upload_filepath):
11
+ result = []
12
+ if webcam_filepath == None:
13
+ sorted_classes = predict.predict_image(upload_filepath)
14
+ for class_label, class_prob in sorted_classes:
15
+ class_prob = class_prob.item().__round__(2)
16
+ result.append(f"{class_label}: {class_prob.item().__round__(2)}%")
17
+ return sorted_classes
18
+ elif upload_filepath == None:
19
+ sorted_classes = predict.predict_image(webcam_filepath)
20
+ for class_label, class_prob in sorted_classes:
21
+ class_prob = class_prob.item().__round__(2)
22
+ result.append(f"{class_label}: {class_prob.item().__round__(2)}%")
23
+ return sorted_classes
24
+ else:
25
+ sorted_classes = predict.predict_image(upload_filepath)
26
+ for class_label, class_prob in sorted_classes:
27
+ class_prob = class_prob.item().__round__(2)
28
+ result.append(f"{class_label}: {class_prob.item().__round__(2)}%")
29
+ return sorted_classes
30
+
31
+
32
+ demo = gr.Interface(
33
+ fn=process_file,
34
+ title="HANDETECT",
35
+ description="An innovative AI-powered system that facilitates early detection and monitoring of movement disorders through handwriting assessment",
36
+ inputs=[
37
+ gr.inputs.Image(
38
+ source="upload", type="filepath", label="Choose Image"
39
+ ),
40
+ ],
41
+ outputs=[
42
+ gr.outputs.Textbox(label="Prediction 1"),
43
+ gr.outputs.Textbox(label="Prediction 2"),
44
+ gr.outputs.Textbox(label="Prediction 3"),
45
+ ],
46
+ )
47
+
48
+ demo.launch(share=True)
static/icon.png ADDED
static/script.js ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Get references to HTML elements
2
+ const uploadButton = document.getElementById("upload-button");
3
+ const imageUpload = document.getElementById("image-upload");
4
+ const resultsOutput = document.getElementById("results-output");
5
+
6
+ // Function to handle image upload
7
+ function handleImageUpload() {
8
+ const file = imageUpload.files[0];
9
+ if (!file) {
10
+ alert("Please select an image file.");
11
+ return;
12
+ }
13
+
14
+ // Create a FormData object to send the image file to the server
15
+ const formData = new FormData();
16
+ formData.append("file", file);
17
+
18
+ // Send the image to the server using fetch
19
+ fetch("/", {
20
+ method: "POST",
21
+ body: formData,
22
+ })
23
+ // .then((response) => response.json())
24
+ // .then((data) => {
25
+ // // Display the results in the resultsOutput element
26
+ // resultsOutput.innerHTML = "<h2>Results</h2>";
27
+ // for (const key in data) {
28
+ // if (data.hasOwnProperty(key)) {
29
+ // resultsOutput.innerHTML += `<p>${key}: ${data[key]}</p>`;
30
+ // }
31
+ // }
32
+ // })
33
+ .catch((error) => {
34
+ console.error("Error uploading image:", error);
35
+ alert("Error uploading image. Please try again.");
36
+ });
37
+ }
38
+
39
+ // Add a click event listener to the upload button
40
+ uploadButton.addEventListener("click", handleImageUpload);
static/styles.css ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Reset some default styles */
2
+ * {
3
+ margin: 0;
4
+ padding: 0;
5
+ box-sizing: border-box;
6
+ }
7
+
8
+ hr {
9
+ border: none;
10
+ border-top: none;
11
+ margin: 40px 0;
12
+ }
13
+
14
+ /* Set a futuristic background */
15
+ body {
16
+ background-color: #0f0f0f;
17
+ color: #ffffff;
18
+ font-family: Arial, sans-serif;
19
+ }
20
+
21
+ /* Style the header */
22
+ header {
23
+ text-align: center;
24
+ padding: 20px;
25
+ }
26
+
27
+ h1 {
28
+ font-size: 36px;
29
+ margin-bottom: 10px;
30
+ }
31
+
32
+ p {
33
+ font-size: 18px;
34
+ }
35
+
36
+ /* Style the upload section */
37
+ .upload-section {
38
+ text-align: center;
39
+ margin: 40px auto;
40
+ }
41
+
42
+ .upload-box {
43
+ background-color: #272727;
44
+ padding: 10px;
45
+ border-radius: 10px;
46
+ }
47
+
48
+ h2 {
49
+ font-size: 24px;
50
+ margin-bottom: 10px;
51
+ }
52
+
53
+ .upload-area {
54
+ border: 2px dashed #555555;
55
+ /* padding: 0; */
56
+ cursor: pointer;
57
+ }
58
+
59
+ .upload-label {
60
+ display: block;
61
+ text-align: center;
62
+ font-size: 16px;
63
+ color: #555555;
64
+ border: 2px dashed #555555;
65
+ /* padding: 0; */
66
+ cursor: pointer;
67
+ }
68
+
69
+ button {
70
+ background-color: #1e90ff;
71
+ color: #ffffff;
72
+ padding: 10px 20px;
73
+ border: none;
74
+ border-radius: 5px;
75
+ font-size: 18px;
76
+ cursor: pointer;
77
+ transition: background-color 0.3s ease;
78
+ }
79
+
80
+ button:hover {
81
+ background-color: #0077b6;
82
+ }
83
+
84
+ /* Style the results section */
85
+ .results-section {
86
+ margin-top: 40px;
87
+ text-align: center;
88
+ }
89
+
90
+ #results-output {
91
+ font-size: 20px;
92
+ padding: 20px;
93
+ background-color: rgba(255, 255, 255, 0.1);
94
+ border-radius: 10px;
95
+ margin: 0 auto;
96
+ max-width: 600px;
97
+ }
templates/index.html ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>HANDETECT</title>
7
+ <!-- Add your CSS file link here -->
8
+ <link rel="stylesheet" href="{{ url_for('static', filename='styles.css') }}">
9
+ </head>
10
+ <body>
11
+ <header>
12
+ <h1>HANDETECT</h1>
13
+ <p>An innovative AI-powered system that facilitates early detection and monitoring of movement disorders through handwriting assessment</p>
14
+ </header>
15
+ <main>
16
+ <section class="upload-section">
17
+ <div class="upload-box">
18
+ <h2>Upload an Image</h2>
19
+ <input id='image-upload' type=file name=file style='visibility: hidden'>
20
+ <label for="image-upload" class="upload-label">
21
+ <hr>
22
+ <span>Drag & Drop an Image or Click to Upload</span>
23
+ <hr>
24
+ </label>
25
+ <hr>
26
+ <button id="upload-button">Upload</button>
27
+ </div>
28
+ </section>
29
+ <section class="results-section">
30
+ <h2>Results</h2>
31
+ <div id="results-output">
32
+ <p>{{ message }}</p>
33
+ </div>
34
+ </section>
35
+ </main>
36
+ <!-- Add your JavaScript file link here if needed -->
37
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
38
+ </body>
39
+ </html>
40
+