Update
Browse files- configs.py +29 -6
- convert.py +7 -0
- requirements.txt +0 -0
- train.py +5 -1
- tuning.py +1 -1
configs.py
CHANGED
@@ -3,14 +3,16 @@ import torch
|
|
3 |
from torchvision import transforms
|
4 |
from torch.utils.data import Dataset
|
5 |
from models import *
|
6 |
-
|
|
|
|
|
7 |
# Constants
|
8 |
RANDOM_SEED = 123
|
9 |
BATCH_SIZE = 32
|
10 |
-
NUM_EPOCHS =
|
11 |
LEARNING_RATE = 0.00017588413773574044
|
12 |
STEP_SIZE = 10
|
13 |
-
GAMMA = 0.
|
14 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
15 |
NUM_PRINT = 100
|
16 |
TASK = 1
|
@@ -22,10 +24,31 @@ NUM_CLASSES = 7
|
|
22 |
EARLY_STOPPING_PATIENCE = 20
|
23 |
CLASSES = ['Alzheimer Disease', 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
|
24 |
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
25 |
-
MODEL = squeezenet1_0(num_classes=NUM_CLASSES)
|
26 |
|
27 |
-
print(CLASSES)
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
preprocess = transforms.Compose(
|
31 |
[
|
@@ -47,4 +70,4 @@ class CustomDataset(Dataset):
|
|
47 |
|
48 |
def __getitem__(self, idx):
|
49 |
img, label = self.data[idx]
|
50 |
-
return img, label
|
|
|
3 |
from torchvision import transforms
|
4 |
from torch.utils.data import Dataset
|
5 |
from models import *
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.models import squeezenet1_0, SqueezeNet1_0_Weights
|
8 |
+
from torchvision.models import squeezenet1_0
|
9 |
# Constants
|
10 |
RANDOM_SEED = 123
|
11 |
BATCH_SIZE = 32
|
12 |
+
NUM_EPOCHS = 40
|
13 |
LEARNING_RATE = 0.00017588413773574044
|
14 |
STEP_SIZE = 10
|
15 |
+
GAMMA = 0.9
|
16 |
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
17 |
NUM_PRINT = 100
|
18 |
TASK = 1
|
|
|
24 |
EARLY_STOPPING_PATIENCE = 20
|
25 |
CLASSES = ['Alzheimer Disease', 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
|
26 |
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
|
|
27 |
|
|
|
28 |
|
29 |
+
|
30 |
+
class SqueezeNet1_0WithDropout(nn.Module):
|
31 |
+
def __init__(self, num_classes=1000):
|
32 |
+
super(SqueezeNet1_0WithDropout, self).__init__()
|
33 |
+
squeezenet = squeezenet1_0(weights=SqueezeNet1_0_Weights)
|
34 |
+
self.features = squeezenet.features
|
35 |
+
self.classifier = nn.Sequential(
|
36 |
+
nn.Conv2d(512, num_classes, kernel_size=1),
|
37 |
+
nn.BatchNorm2d(num_classes), # add batch normalization
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.features(x)
|
44 |
+
x = self.classifier(x)
|
45 |
+
x = torch.flatten(x, 1)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
MODEL = SqueezeNet1_0WithDropout(num_classes=7)
|
51 |
+
print(CLASSES)
|
52 |
|
53 |
preprocess = transforms.Compose(
|
54 |
[
|
|
|
70 |
|
71 |
def __getitem__(self, idx):
|
72 |
img, label = self.data[idx]
|
73 |
+
return img, label
|
convert.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import onnx2tf
|
3 |
+
from configs import *
|
4 |
+
|
5 |
+
torch.onnx.export(model=MODEL, args=torch.randn(1, 3, 64, 64), f='output/checkpoints/model.onnx', verbose=True, input_names=['input'], output_names=['output'])
|
6 |
+
|
7 |
+
onnx2tf.convert(input_onnx_file_path='output/checkpoints/model.onnx', output_folder_path='output/checkpoints/converted/')
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
train.py
CHANGED
@@ -153,6 +153,10 @@ def main_training_loop():
|
|
153 |
break
|
154 |
|
155 |
# Save the model
|
|
|
|
|
|
|
|
|
156 |
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|
157 |
print("Model saved at", MODEL_SAVE_PATH)
|
158 |
|
@@ -190,4 +194,4 @@ def main_training_loop():
|
|
190 |
|
191 |
|
192 |
if __name__ == "__main__":
|
193 |
-
main_training_loop()
|
|
|
153 |
break
|
154 |
|
155 |
# Save the model
|
156 |
+
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
157 |
+
|
158 |
+
# Ensure the parent directory exists
|
159 |
+
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
|
160 |
torch.save(model.state_dict(), MODEL_SAVE_PATH)
|
161 |
print("Model saved at", MODEL_SAVE_PATH)
|
162 |
|
|
|
194 |
|
195 |
|
196 |
if __name__ == "__main__":
|
197 |
+
main_training_loop()
|
tuning.py
CHANGED
@@ -12,7 +12,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|
12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
EPOCHS = 10
|
14 |
N_TRIALS = 50
|
15 |
-
TIMEOUT =
|
16 |
|
17 |
# Create a TensorBoard writer
|
18 |
writer = SummaryWriter(log_dir="output/tensorboard/tuning")
|
|
|
12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
EPOCHS = 10
|
14 |
N_TRIALS = 50
|
15 |
+
TIMEOUT = 900
|
16 |
|
17 |
# Create a TensorBoard writer
|
18 |
writer = SummaryWriter(log_dir="output/tensorboard/tuning")
|