cycool29 commited on
Commit
59908f1
·
1 Parent(s): 672baaa
Files changed (6) hide show
  1. augment.py +28 -25
  2. configs.py +10 -10
  3. eval.py +6 -6
  4. models.py +6 -0
  5. train.py +141 -119
  6. tuning.py +44 -37
augment.py CHANGED
@@ -2,30 +2,33 @@ import os
2
  import Augmentor
3
  import shutil
4
  from configs import *
 
5
 
6
  tasks = ["1", "2", "3", "4", "5", "6"]
7
 
8
- for task in tasks:
9
  # Loop through all folders in Task 1 and generate augmented images for each class
10
- for disease in os.listdir("data/train/raw/Task " + task):
11
- if disease != ".DS_Store":
12
- print("Augmenting images in class: ", disease, " in Task ", task)
13
  # Create a temp folder to combine the raw data and the external data
14
- if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
15
- os.makedirs(f"data/temp/Task {task}/{disease}/")
16
- for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
17
- shutil.copy(
18
- f"data/train/raw/Task {task}/{disease}/{file}",
19
- f"data/temp/Task {task}/{disease}/{file}",
20
- )
21
- for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
22
- shutil.copy(
23
- f"data/train/external/Task {task}/{disease}/{file}",
24
- f"data/temp/Task {task}/{disease}/{file}",
25
- )
 
 
26
  p = Augmentor.Pipeline(
27
- f"data/temp/Task {task}/{disease}",
28
- output_directory=f"{disease}/",
29
  save_format="png",
30
  )
31
  p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
@@ -39,20 +42,20 @@ for task in tasks:
39
  p.sample(100 - len(p.augmentor_images))
40
  # Move the folder to data/train/Task 1/augmented
41
  # Create the folder if it does not exist
42
- if not os.path.exists(f"data/train/augmented/Task {task}/"):
43
- os.makedirs(f"data/train/augmented/Task {task}/")
44
  # Move all images in the data/train/Task 1/i folder to data/train/Task 1/augmented/i
45
  os.rename(
46
- f"data/temp/Task {task}/{disease}/{disease}",
47
- f"data/train/augmented/Task {task}/{disease}",
48
  )
49
  # Rename all the augmented images to [01, 02, 03]
50
  number = 0
51
- for file in os.listdir(f"data/train/augmented/Task {task}/{disease}"):
52
  number = int(number) + 1
53
  if len(str(number)) == 1:
54
  number = "0" + str(number)
55
  os.rename(
56
- f"data/train/augmented/Task {task}/{disease}/{file}",
57
- f"data/train/augmented/Task {task}/{disease}/{number}.png",
58
  )
 
2
  import Augmentor
3
  import shutil
4
  from configs import *
5
+ import uuid
6
 
7
  tasks = ["1", "2", "3", "4", "5", "6"]
8
 
9
+ for task in ["1"]:
10
  # Loop through all folders in Task 1 and generate augmented images for each class
11
+ for class_label in ['Alzheimer Disease', 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']:
12
+ if class_label != ".DS_Store":
13
+ print("Augmenting images in class: ", class_label, " in Task ", task)
14
  # Create a temp folder to combine the raw data and the external data
15
+ if not os.path.exists(f"{TEMP_DATA_DIR}Task {task}/{class_label}/"):
16
+ os.makedirs(f"{TEMP_DATA_DIR}Task {task}/{class_label}/")
17
+ if os.path.exists(f"{RAW_DATA_DIR}Task {task}/{class_label}"):
18
+ for file in os.listdir(f"{RAW_DATA_DIR}Task {task}/{class_label}"):
19
+ shutil.copy(
20
+ f"{RAW_DATA_DIR}Task {task}/{class_label}/{file}",
21
+ f"{TEMP_DATA_DIR}Task {task}/{class_label}/{str(uuid.uuid4())}.png",
22
+ )
23
+ if os.path.exists(f"{EXTERNAL_DATA_DIR}Task {task}/{class_label}"):
24
+ for file in os.listdir(f"{EXTERNAL_DATA_DIR}Task {task}/{class_label}"):
25
+ shutil.copy(
26
+ f"{EXTERNAL_DATA_DIR}Task {task}/{class_label}/{file}",
27
+ f"{TEMP_DATA_DIR}Task {task}/{class_label}/{str(uuid.uuid4())}.png",
28
+ )
29
  p = Augmentor.Pipeline(
30
+ f"{TEMP_DATA_DIR}Task {task}/{class_label}",
31
+ output_directory=f"{class_label}/",
32
  save_format="png",
33
  )
34
  p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
 
42
  p.sample(100 - len(p.augmentor_images))
43
  # Move the folder to data/train/Task 1/augmented
44
  # Create the folder if it does not exist
45
+ if not os.path.exists(f"{AUG_DATA_DIR}Task {task}/"):
46
+ os.makedirs(f"{AUG_DATA_DIR}Task {task}/")
47
  # Move all images in the data/train/Task 1/i folder to data/train/Task 1/augmented/i
48
  os.rename(
49
+ f"{TEMP_DATA_DIR}Task {task}/{class_label}/{class_label}",
50
+ f"{AUG_DATA_DIR}Task {task}/{class_label}",
51
  )
52
  # Rename all the augmented images to [01, 02, 03]
53
  number = 0
54
+ for file in os.listdir(f"{AUG_DATA_DIR}Task {task}/{class_label}"):
55
  number = int(number) + 1
56
  if len(str(number)) == 1:
57
  number = "0" + str(number)
58
  os.rename(
59
+ f"{AUG_DATA_DIR}Task {task}/{class_label}/{file}",
60
+ f"{AUG_DATA_DIR}Task {task}/{class_label}/{number}.png",
61
  )
configs.py CHANGED
@@ -6,23 +6,23 @@ from models import *
6
 
7
  # Constants
8
  RANDOM_SEED = 123
9
- BATCH_SIZE = 64
10
  NUM_EPOCHS = 100
11
- LEARNING_RATE = 1.6317268278715415e-05
12
- OPTIMIZER_NAME = "Adam"
13
  STEP_SIZE = 10
14
- GAMMA = 0.5
15
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
  NUM_PRINT = 100
17
  TASK = 1
18
- RAW_DATA_DIR = r"data/train/raw/Task " + str(TASK)
19
- AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
20
- EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
 
21
  NUM_CLASSES = 7
22
- # Define classes as listdir of augmented data
23
- CLASSES = os.listdir("data/train/augmented/Task 1/")
24
  MODEL_SAVE_PATH = "output/checkpoints/model.pth"
25
- MODEL = googlenet(num_classes=NUM_CLASSES)
26
 
27
  print(CLASSES)
28
 
 
6
 
7
  # Constants
8
  RANDOM_SEED = 123
9
+ BATCH_SIZE = 16
10
  NUM_EPOCHS = 100
11
+ LEARNING_RATE = 5.847227637580824e-05
 
12
  STEP_SIZE = 10
13
+ GAMMA = 1.0
14
  DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
  NUM_PRINT = 100
16
  TASK = 1
17
+ RAW_DATA_DIR = r"data/train/raw/Task "
18
+ AUG_DATA_DIR = r"data/train/augmented/Task "
19
+ EXTERNAL_DATA_DIR = r"data/train/external/Task "
20
+ TEMP_DATA_DIR = "data/temp/"
21
  NUM_CLASSES = 7
22
+ EARLY_STOPPING_PATIENCE = 20
23
+ CLASSES = ['Alzheimer Disease', 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease']
24
  MODEL_SAVE_PATH = "output/checkpoints/model.pth"
25
+ MODEL = efficientnet_b1(num_classes=NUM_CLASSES)
26
 
27
  print(CLASSES)
28
 
eval.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
  import torch
3
  from torchvision.transforms import transforms
4
- from sklearn.metrics import f1_score
5
  import pathlib
6
  from PIL import Image
7
- from torchmetrics import ConfusionMatrix
8
  import matplotlib.pyplot as plt
9
  from configs import *
10
  from data_loader import load_data # Import the load_data function
@@ -19,7 +18,6 @@ MODEL = MODEL.to(DEVICE)
19
  MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
20
  MODEL.eval()
21
 
22
-
23
  def predict_image(image_path, model, transform):
24
  model.eval()
25
  correct_predictions = 0
@@ -32,6 +30,9 @@ def predict_image(image_path, model, transform):
32
  true_classes = []
33
  predicted_labels = []
34
 
 
 
 
35
  with torch.no_grad():
36
  for image_file in images:
37
  print("---------------------------")
@@ -57,7 +58,7 @@ def predict_image(image_path, model, transform):
57
  # Calculate accuracy and f1 score
58
  accuracy = correct_predictions / total_predictions
59
  print("Accuracy:", accuracy)
60
- f1 = f1_score(true_classes, predicted_labels, average="weighted")
61
  print("Weighted F1 Score:", f1)
62
 
63
  # Convert the lists to tensors
@@ -66,13 +67,12 @@ def predict_image(image_path, model, transform):
66
 
67
  # Create a confusion matrix
68
  conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task="multiclass")
69
- conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
70
 
71
  # Plot the confusion matrix
72
  conf_matrix.compute()
73
  conf_matrix.plot()
74
  plt.show()
75
 
76
-
77
  # Call predict_image function
78
  predict_image(image_path, MODEL, preprocess)
 
1
  import os
2
  import torch
3
  from torchvision.transforms import transforms
 
4
  import pathlib
5
  from PIL import Image
6
+ from torchmetrics import ConfusionMatrix, Accuracy, F1Score
7
  import matplotlib.pyplot as plt
8
  from configs import *
9
  from data_loader import load_data # Import the load_data function
 
18
  MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
19
  MODEL.eval()
20
 
 
21
  def predict_image(image_path, model, transform):
22
  model.eval()
23
  correct_predictions = 0
 
30
  true_classes = []
31
  predicted_labels = []
32
 
33
+ accuracy_metric = Accuracy(num_classes=NUM_CLASSES, task="multiclass")
34
+ f1_metric = F1Score(num_classes=NUM_CLASSES, task="multiclass")
35
+
36
  with torch.no_grad():
37
  for image_file in images:
38
  print("---------------------------")
 
58
  # Calculate accuracy and f1 score
59
  accuracy = correct_predictions / total_predictions
60
  print("Accuracy:", accuracy)
61
+ f1 = f1_metric(torch.tensor(predicted_labels), torch.tensor(true_classes)).item()
62
  print("Weighted F1 Score:", f1)
63
 
64
  # Convert the lists to tensors
 
67
 
68
  # Create a confusion matrix
69
  conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task="multiclass")
70
+ conf_matrix(predicted_labels_tensor, true_classes_tensor)
71
 
72
  # Plot the confusion matrix
73
  conf_matrix.compute()
74
  conf_matrix.plot()
75
  plt.show()
76
 
 
77
  # Call predict_image function
78
  predict_image(image_path, MODEL, preprocess)
models.py CHANGED
@@ -34,3 +34,9 @@ from torchvision.models import shufflenet_v2_x1_0
34
  from torchvision.models import shufflenet_v2_x1_5
35
  from torchvision.models import shufflenet_v2_x2_0
36
  from torchvision.models import squeezenet1_1
 
 
 
 
 
 
 
34
  from torchvision.models import shufflenet_v2_x1_5
35
  from torchvision.models import shufflenet_v2_x2_0
36
  from torchvision.models import squeezenet1_1
37
+ from torchvision.models import efficientnet_v2_s
38
+ from torchvision.models import efficientnet_v2_m
39
+ from torchvision.models import efficientnet_v2_l
40
+ from torchvision.models import efficientnet_b0
41
+ from torchvision.models import efficientnet_b1
42
+
train.py CHANGED
@@ -8,50 +8,32 @@ from torch.utils.tensorboard import SummaryWriter
8
  from configs import *
9
  import data_loader
10
 
11
- # Set up TensorBoard writer
12
- writer = SummaryWriter(log_dir="output/tensorboard/training")
13
 
14
- # Define a function for plotting and logging metrics
15
- def plot_and_log_metrics(metrics_dict, step, prefix="Train"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  for metric_name, metric_value in metrics_dict.items():
17
  writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step)
18
 
19
- # Data loader
20
- train_loader, valid_loader = data_loader.load_data(
21
- RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
22
- )
23
-
24
- # Initialize model, criterion, optimizer, and scheduler
25
- MODEL = MODEL.to(DEVICE)
26
- criterion = nn.CrossEntropyLoss()
27
- if OPTIMIZER_NAME == "LBFGS":
28
- optimizer = optim.LBFGS(MODEL.parameters(), lr=LEARNING_RATE)
29
- elif OPTIMIZER_NAME == "Adam":
30
- optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
31
- elif OPTIMIZER_NAME == "SGD":
32
- optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
33
-
34
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
35
-
36
- # Define early stopping parameters
37
- early_stopping_patience = 20 # Number of epochs with no improvement to wait before stopping
38
- best_val_loss = float("inf")
39
- best_val_accuracy = 0.0
40
- no_improvement_count = 0
41
-
42
- # Lists to store training and validation loss history
43
- TRAIN_LOSS_HIST = []
44
- VAL_LOSS_HIST = []
45
- AVG_TRAIN_LOSS_HIST = []
46
- AVG_VAL_LOSS_HIST = []
47
- TRAIN_ACC_HIST = []
48
- VAL_ACC_HIST = []
49
-
50
- # Training loop
51
- for epoch in range(NUM_EPOCHS):
52
- print(f"[Epoch: {epoch + 1}]")
53
- print("Learning rate:", scheduler.get_last_lr()[0])
54
- MODEL.train() # Set model to training mode
55
  running_loss = 0.0
56
  total_train = 0
57
  correct_train = 0
@@ -59,16 +41,13 @@ for epoch in range(NUM_EPOCHS):
59
  for i, (inputs, labels) in enumerate(train_loader, 0):
60
  inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
61
  optimizer.zero_grad()
62
- if MODEL.__class__.__name__ == "GoogLeNet": # the shit GoogLeNet has a different output
63
- outputs = MODEL(inputs).logits
64
  else:
65
- outputs = MODEL(inputs)
66
  loss = criterion(outputs, labels)
67
  loss.backward()
68
- if OPTIMIZER_NAME == "LBFGS":
69
- optimizer.step(closure=lambda: loss)
70
- else:
71
- optimizer.step()
72
  running_loss += loss.item()
73
 
74
  if (i + 1) % NUM_PRINT == 0:
@@ -83,21 +62,11 @@ for epoch in range(NUM_EPOCHS):
83
  correct_train += (predicted == labels).sum().item()
84
 
85
  avg_train_loss = running_loss / len(train_loader)
86
- AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
87
- TRAIN_ACC_HIST.append(correct_train / total_train)
88
-
89
- # Log training metrics
90
- train_metrics = {
91
- "Loss": avg_train_loss,
92
- "Accuracy": correct_train / total_train,
93
- }
94
- plot_and_log_metrics(train_metrics, epoch, prefix="Train")
95
-
96
- # Learning rate scheduling
97
- scheduler.step()
98
-
99
- # Validation loop
100
- MODEL.eval() # Set model to evaluation mode
101
  val_loss = 0.0
102
  correct_val = 0
103
  total_val = 0
@@ -105,67 +74,120 @@ for epoch in range(NUM_EPOCHS):
105
  with torch.no_grad():
106
  for inputs, labels in valid_loader:
107
  inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
108
- outputs = MODEL(inputs)
109
  loss = criterion(outputs, labels)
110
  val_loss += loss.item()
111
- # Calculate accuracy
112
  _, predicted = torch.max(outputs, 1)
113
  total_val += labels.size(0)
114
  correct_val += (predicted == labels).sum().item()
115
 
116
  avg_val_loss = val_loss / len(valid_loader)
117
- AVG_VAL_LOSS_HIST.append(avg_val_loss)
118
- VAL_ACC_HIST.append(correct_val / total_val)
119
-
120
- # Log validation metrics
121
- val_metrics = {
122
- "Loss": avg_val_loss,
123
- "Accuracy": correct_val / total_val,
124
- }
125
- plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
126
-
127
- # Print average training and validation metrics
128
- print(f"Average Training Loss: {avg_train_loss:.6f}")
129
- print(f"Average Validation Loss: {avg_val_loss:.6f}")
130
- print(f"Training Accuracy: {correct_train / total_train:.6f}")
131
- print(f"Validation Accuracy: {correct_val / total_val:.6f}")
132
-
133
- # Check for early stopping based on validation accuracy
134
- if correct_val / total_val > best_val_accuracy:
135
- best_val_accuracy = correct_val / total_val
136
- no_improvement_count = 0
137
- else:
138
- no_improvement_count += 1
139
-
140
- # Early stopping condition
141
- if no_improvement_count >= early_stopping_patience:
142
- print("Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format(early_stopping_patience))
143
- break # Stop training
144
-
145
- # Save the model
146
- torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
147
- print("Model saved at", MODEL_SAVE_PATH)
148
-
149
- # Plot loss and accuracy curves
150
- plt.figure(figsize=(12, 4))
151
- plt.subplot(1, 2, 1)
152
- plt.plot(range(1, len(AVG_TRAIN_LOSS_HIST) + 1), AVG_TRAIN_LOSS_HIST, label="Average Train Loss")
153
- plt.plot(range(1, len(AVG_VAL_LOSS_HIST) + 1), AVG_VAL_LOSS_HIST, label="Average Validation Loss")
154
- plt.xlabel("Epochs")
155
- plt.ylabel("Loss")
156
- plt.legend()
157
- plt.title("Loss Curves")
158
-
159
- plt.subplot(1, 2, 2)
160
- plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy")
161
- plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy")
162
- plt.xlabel("Epochs")
163
- plt.ylabel("Accuracy")
164
- plt.legend()
165
- plt.title("Accuracy Curves")
166
-
167
- plt.tight_layout()
168
- plt.savefig("training_curves.png")
169
-
170
- # Close TensorBoard writer
171
- writer.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from configs import *
9
  import data_loader
10
 
 
 
11
 
12
+ def setup_tensorboard():
13
+ return SummaryWriter(log_dir="output/tensorboard/training")
14
+
15
+
16
+ def load_and_preprocess_data():
17
+ return data_loader.load_data(
18
+ RAW_DATA_DIR + str(TASK), AUG_DATA_DIR + str(TASK), EXTERNAL_DATA_DIR + str(TASK), preprocess
19
+ )
20
+
21
+
22
+ def initialize_model_optimizer_scheduler():
23
+ model = MODEL.to(DEVICE)
24
+ criterion = nn.CrossEntropyLoss()
25
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
26
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
27
+ return model, criterion, optimizer, scheduler
28
+
29
+
30
+ def plot_and_log_metrics(metrics_dict, step, writer, prefix="Train"):
31
  for metric_name, metric_value in metrics_dict.items():
32
  writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step)
33
 
34
+
35
+ def train_one_epoch(model, criterion, optimizer, train_loader, epoch):
36
+ model.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  running_loss = 0.0
38
  total_train = 0
39
  correct_train = 0
 
41
  for i, (inputs, labels) in enumerate(train_loader, 0):
42
  inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
43
  optimizer.zero_grad()
44
+ if model.__class__.__name__ == "GoogLeNet":
45
+ outputs = model(inputs).logits
46
  else:
47
+ outputs = model(inputs)
48
  loss = criterion(outputs, labels)
49
  loss.backward()
50
+ optimizer.step()
 
 
 
51
  running_loss += loss.item()
52
 
53
  if (i + 1) % NUM_PRINT == 0:
 
62
  correct_train += (predicted == labels).sum().item()
63
 
64
  avg_train_loss = running_loss / len(train_loader)
65
+ return avg_train_loss, correct_train / total_train
66
+
67
+
68
+ def validate_model(model, criterion, valid_loader):
69
+ model.eval()
 
 
 
 
 
 
 
 
 
 
70
  val_loss = 0.0
71
  correct_val = 0
72
  total_val = 0
 
74
  with torch.no_grad():
75
  for inputs, labels in valid_loader:
76
  inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
77
+ outputs = model(inputs)
78
  loss = criterion(outputs, labels)
79
  val_loss += loss.item()
 
80
  _, predicted = torch.max(outputs, 1)
81
  total_val += labels.size(0)
82
  correct_val += (predicted == labels).sum().item()
83
 
84
  avg_val_loss = val_loss / len(valid_loader)
85
+ return avg_val_loss, correct_val / total_val
86
+
87
+
88
+ def main_training_loop():
89
+ writer = setup_tensorboard()
90
+ train_loader, valid_loader = load_and_preprocess_data()
91
+ model, criterion, optimizer, scheduler = initialize_model_optimizer_scheduler()
92
+
93
+ best_val_loss = float("inf")
94
+ best_val_accuracy = 0.0
95
+ no_improvement_count = 0
96
+
97
+ AVG_TRAIN_LOSS_HIST = []
98
+ AVG_VAL_LOSS_HIST = []
99
+ TRAIN_ACC_HIST = []
100
+ VAL_ACC_HIST = []
101
+
102
+ for epoch in range(NUM_EPOCHS):
103
+ print(f"[Epoch: {epoch + 1}]")
104
+ print("Learning rate:", scheduler.get_last_lr()[0])
105
+
106
+ avg_train_loss, train_accuracy = train_one_epoch(
107
+ model, criterion, optimizer, train_loader, epoch
108
+ )
109
+ AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
110
+ TRAIN_ACC_HIST.append(train_accuracy)
111
+
112
+ # Log training metrics
113
+ train_metrics = {
114
+ "Loss": avg_train_loss,
115
+ "Accuracy": train_accuracy,
116
+ }
117
+ plot_and_log_metrics(train_metrics, epoch, writer=writer, prefix="Train")
118
+
119
+ # Learning rate scheduling
120
+ scheduler.step()
121
+
122
+ avg_val_loss, val_accuracy = validate_model(model, criterion, valid_loader)
123
+ AVG_VAL_LOSS_HIST.append(avg_val_loss)
124
+ VAL_ACC_HIST.append(val_accuracy)
125
+
126
+ # Log validation metrics
127
+ val_metrics = {
128
+ "Loss": avg_val_loss,
129
+ "Accuracy": val_accuracy,
130
+ }
131
+ plot_and_log_metrics(train_metrics, epoch, writer=writer, prefix="Train")
132
+
133
+ # Print average training and validation metrics
134
+ print(f"Average Training Loss: {avg_train_loss:.6f}")
135
+ print(f"Average Validation Loss: {avg_val_loss:.6f}")
136
+ print(f"Training Accuracy: {train_accuracy:.6f}")
137
+ print(f"Validation Accuracy: {val_accuracy:.6f}")
138
+
139
+ # Check for early stopping based on validation accuracy
140
+ if val_accuracy > best_val_accuracy:
141
+ best_val_accuracy = val_accuracy
142
+ no_improvement_count = 0
143
+ else:
144
+ no_improvement_count += 1
145
+
146
+ # Early stopping condition
147
+ if no_improvement_count >= EARLY_STOPPING_PATIENCE:
148
+ print(
149
+ "Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format(
150
+ EARLY_STOPPING_PATIENCE
151
+ )
152
+ )
153
+ break
154
+
155
+ # Save the model
156
+ torch.save(model.state_dict(), MODEL_SAVE_PATH)
157
+ print("Model saved at", MODEL_SAVE_PATH)
158
+
159
+ # Plot loss and accuracy curves
160
+ plt.figure(figsize=(12, 4))
161
+ plt.subplot(1, 2, 1)
162
+ plt.plot(
163
+ range(1, len(AVG_TRAIN_LOSS_HIST) + 1),
164
+ AVG_TRAIN_LOSS_HIST,
165
+ label="Average Train Loss",
166
+ )
167
+ plt.plot(
168
+ range(1, len(AVG_VAL_LOSS_HIST) + 1),
169
+ AVG_VAL_LOSS_HIST,
170
+ label="Average Validation Loss",
171
+ )
172
+ plt.xlabel("Epochs")
173
+ plt.ylabel("Loss")
174
+ plt.legend()
175
+ plt.title("Loss Curves")
176
+
177
+ plt.subplot(1, 2, 2)
178
+ plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy")
179
+ plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy")
180
+ plt.xlabel("Epochs")
181
+ plt.ylabel("Accuracy")
182
+ plt.legend()
183
+ plt.title("Accuracy Curves")
184
+
185
+ plt.tight_layout()
186
+ plt.savefig("training_curves.png")
187
+
188
+ # Close TensorBoard writer
189
+ writer.close()
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main_training_loop()
tuning.py CHANGED
@@ -9,21 +9,27 @@ from configs import *
9
  import data_loader
10
  from torch.utils.tensorboard import SummaryWriter
11
 
12
- optuna.logging.set_verbosity(optuna.logging.DEBUG)
13
-
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  EPOCHS = 10
 
 
16
 
17
  # Create a TensorBoard writer
18
- writer = SummaryWriter(log_dir="output/tensorboard/tuning/", )
 
19
 
20
  def create_data_loaders(batch_size):
21
  # Create or modify data loaders with the specified batch size
22
  train_loader, valid_loader = data_loader.load_data(
23
- RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess, batch_size=batch_size
 
 
 
 
24
  )
25
  return train_loader, valid_loader
26
 
 
27
  def objective(trial, model=MODEL):
28
  # Generate the model.
29
  model = model.to(DEVICE)
@@ -35,11 +41,16 @@ def objective(trial, model=MODEL):
35
  train_loader, valid_loader = create_data_loaders(batch_size)
36
 
37
  # Generate the optimizer.
38
- optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "SGD"])
39
- lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
40
- optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
41
  criterion = nn.CrossEntropyLoss()
42
 
 
 
 
 
 
 
43
  # Training of the model.
44
  for epoch in range(EPOCHS):
45
  print(f"[Epoch: {epoch} | Trial: {trial.number}]")
@@ -47,16 +58,18 @@ def objective(trial, model=MODEL):
47
  for batch_idx, (data, target) in enumerate(train_loader, 0):
48
  data, target = data.to(DEVICE), target.to(DEVICE)
49
  optimizer.zero_grad()
50
- if model.__class__.__name__ == "GoogLeNet": # the shit GoogLeNet has a different output
 
 
51
  output = model(data).logits
52
  else:
53
  output = model(data)
54
  loss = criterion(output, target)
55
  loss.backward()
56
- if optimizer_name == "LBFGS":
57
- optimizer.step(closure=lambda: loss)
58
- else:
59
- optimizer.step()
60
 
61
  # Validation of the model.
62
  model.eval()
@@ -74,14 +87,8 @@ def objective(trial, model=MODEL):
74
  # Log hyperparameters and accuracy to TensorBoard
75
  writer.add_scalar("Accuracy", accuracy, trial.number)
76
  writer.add_hparams(
77
- {
78
- "batch_size": batch_size,
79
- "optimizer": optimizer_name,
80
- "lr": lr
81
- },
82
- {
83
- "accuracy": accuracy
84
- }
85
  )
86
 
87
  # Print hyperparameters and accuracy
@@ -93,29 +100,29 @@ def objective(trial, model=MODEL):
93
  if trial.should_prune():
94
  raise optuna.exceptions.TrialPruned()
95
 
 
 
 
96
  return accuracy
97
 
 
98
  if __name__ == "__main__":
99
  pruner = optuna.pruners.HyperbandPruner()
100
- study = optuna.create_study(direction="maximize", pruner=pruner, study_name="handetect")
101
- study.optimize(objective, n_trials=100, timeout=1000)
102
-
103
- pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
104
- complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
105
 
106
- print("Study statistics: ")
107
- print(" Number of finished trials: ", len(study.trials))
108
- print(" Number of pruned trials: ", len(pruned_trials))
109
- print(" Number of complete trials: ", len(complete_trials))
110
 
 
 
111
  print("Best trial:")
112
- trial = study.best_trial
113
-
114
- print(" Value: ", trial.value)
115
-
116
  print(" Params: ")
117
- for key, value in trial.params.items():
118
  print(" {}: {}".format(key, value))
119
-
120
- # Close TensorBoard writer
121
- writer.close()
 
9
  import data_loader
10
  from torch.utils.tensorboard import SummaryWriter
11
 
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  EPOCHS = 10
14
+ N_TRIALS = 50
15
+ TIMEOUT = 3600 # 1 hour
16
 
17
  # Create a TensorBoard writer
18
+ writer = SummaryWriter(log_dir="output/tensorboard/tuning")
19
+
20
 
21
  def create_data_loaders(batch_size):
22
  # Create or modify data loaders with the specified batch size
23
  train_loader, valid_loader = data_loader.load_data(
24
+ RAW_DATA_DIR + str(TASK),
25
+ AUG_DATA_DIR + str(TASK),
26
+ EXTERNAL_DATA_DIR + str(TASK),
27
+ preprocess,
28
+ batch_size=batch_size,
29
  )
30
  return train_loader, valid_loader
31
 
32
+
33
  def objective(trial, model=MODEL):
34
  # Generate the model.
35
  model = model.to(DEVICE)
 
41
  train_loader, valid_loader = create_data_loaders(batch_size)
42
 
43
  # Generate the optimizer.
44
+ lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
45
+ optimizer = optim.Adam(model.parameters(), lr=lr)
 
46
  criterion = nn.CrossEntropyLoss()
47
 
48
+ # Suggest the gamma parameter for the learning rate scheduler.
49
+ gamma = trial.suggest_float("gamma", 0.1, 1.0, step=0.1)
50
+
51
+ # Create a learning rate scheduler with the suggested gamma.
52
+ scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
53
+
54
  # Training of the model.
55
  for epoch in range(EPOCHS):
56
  print(f"[Epoch: {epoch} | Trial: {trial.number}]")
 
58
  for batch_idx, (data, target) in enumerate(train_loader, 0):
59
  data, target = data.to(DEVICE), target.to(DEVICE)
60
  optimizer.zero_grad()
61
+ if (
62
+ model.__class__.__name__ == "GoogLeNet"
63
+ ): # the shit GoogLeNet has a different output
64
  output = model(data).logits
65
  else:
66
  output = model(data)
67
  loss = criterion(output, target)
68
  loss.backward()
69
+ optimizer.step()
70
+
71
+ # Update the learning rate using the scheduler.
72
+ scheduler.step()
73
 
74
  # Validation of the model.
75
  model.eval()
 
87
  # Log hyperparameters and accuracy to TensorBoard
88
  writer.add_scalar("Accuracy", accuracy, trial.number)
89
  writer.add_hparams(
90
+ {"batch_size": batch_size, "lr": lr, "gamma": gamma},
91
+ {"accuracy": accuracy},
 
 
 
 
 
 
92
  )
93
 
94
  # Print hyperparameters and accuracy
 
100
  if trial.should_prune():
101
  raise optuna.exceptions.TrialPruned()
102
 
103
+ if trial.number > 10 and trial.params["lr"] < 1e-3 and accuracy < 0.7:
104
+ return float("inf") # Prune the trial
105
+
106
  return accuracy
107
 
108
+
109
  if __name__ == "__main__":
110
  pruner = optuna.pruners.HyperbandPruner()
111
+ study = optuna.create_study(
112
+ direction="maximize", # Adjust the direction as per your optimization goal
113
+ pruner=pruner,
114
+ study_name="hyperparameter_tuning",
115
+ )
116
 
117
+ # Optimize the hyperparameters
118
+ study.optimize(
119
+ objective, n_trials=100, timeout=3600
120
+ ) # Adjust the number of trials and timeout as needed
121
 
122
+ # Print the best trial
123
+ best_trial = study.best_trial
124
  print("Best trial:")
125
+ print(" Value: ", best_trial.value)
 
 
 
126
  print(" Params: ")
127
+ for key, value in best_trial.params.items():
128
  print(" {}: {}".format(key, value))