cycool29 commited on
Commit
5daa5f5
·
1 Parent(s): e6f2a04
Files changed (8) hide show
  1. augment.py +0 -50
  2. configs.py +0 -43
  3. data_loader.py +0 -32
  4. eval.py +0 -86
  5. models.py +0 -36
  6. predict.py +0 -57
  7. train.py +0 -145
  8. tuning.py +0 -186
augment.py DELETED
@@ -1,50 +0,0 @@
1
- import os
2
- import Augmentor
3
- import shutil
4
- from configs import *
5
-
6
- tasks = ["1", "2", "3", "4", "5", "6"]
7
-
8
- for task in tasks:
9
- # Loop through all folders in Task 1 and generate augmented images for each class
10
- for disease in os.listdir("data/train/raw/Task " + task):
11
- if disease != ".DS_Store":
12
- print("Augmenting images in class: ", disease, " in Task ", task)
13
- # Create a temp folder to combine the raw data and the external data
14
- if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
15
- os.makedirs(f"data/temp/Task {task}/{disease}/")
16
- for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
17
- shutil.copy(f"data/train/raw/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
18
- for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
19
- shutil.copy(f"data/train/external/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
20
- p = Augmentor.Pipeline(f"data/temp/Task {task}/{disease}", output_directory=f"{disease}/", save_format="png")
21
- p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
22
- p.flip_left_right(probability=0.8)
23
- p.zoom_random(probability=0.8, percentage_area=0.8)
24
- p.flip_top_bottom(probability=0.8)
25
- p.random_brightness(probability=0.8, min_factor=0.5, max_factor=1.5)
26
- p.random_contrast(probability=0.8, min_factor=0.5, max_factor=1.5)
27
- p.random_color(probability=0.8, min_factor=0.5, max_factor=1.5)
28
- # Generate 100 - total of original images so that the total number of images in each class is 100
29
- p.sample(100 - len(p.augmentor_images))
30
- # Move the folder to data/train/Task 1/augmented
31
- # Create the folder if it does not exist
32
- if not os.path.exists(f"data/train/augmented/Task {task}/"):
33
- os.makedirs(f"data/train/augmented/Task {task}/")
34
- # Move all images in the data/train/Task 1/i folder to data/train/Task 1/augmented/i
35
- os.rename(
36
- f"data/temp/Task {task}/{disease}/{disease}",
37
- f"data/train/augmented/Task {task}/{disease}",
38
- )
39
- # Rename all the augmented images to [01, 02, 03]
40
- number = 0
41
- for file in os.listdir(f"data/train/augmented/Task {task}/{disease}"):
42
- number = int(number) + 1
43
- if len(str(number)) == 1:
44
- number = "0" + str(number)
45
- os.rename(
46
- f"data/train/augmented/Task {task}/{disease}/{file}",
47
- f"data/train/augmented/Task {task}/{disease}/{number}.png",
48
- )
49
-
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
configs.py DELETED
@@ -1,43 +0,0 @@
1
- import os
2
- import torch
3
- from torchvision import transforms
4
- from torch.utils.data import Dataset
5
- from models import *
6
-
7
- # Constants
8
- RANDOM_SEED = 123
9
- BATCH_SIZE = 32
10
- NUM_EPOCHS = 100
11
- LEARNING_RATE = 0.001
12
- STEP_SIZE = 10
13
- GAMMA = 0.5
14
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
- NUM_PRINT = 100
16
- TASK = 1
17
- RAW_DATA_DIR = r"data/train/raw/Task " + str(TASK)
18
- AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
19
- EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
20
- NUM_CLASSES = 7
21
- MODEL_SAVE_PATH = "output/checkpoints/model.pth"
22
- MODEL = mobilenet_v3_small(num_classes=NUM_CLASSES)
23
-
24
- preprocess = transforms.Compose(
25
- [
26
- transforms.Resize((64, 64)), # Resize images to 64x64
27
- transforms.ToTensor(), # Convert to tensor
28
- # Normalize 3 channels
29
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
30
- ]
31
- )
32
-
33
- # Custom dataset class
34
- class CustomDataset(Dataset):
35
- def __init__(self, dataset):
36
- self.data = dataset
37
-
38
- def __len__(self):
39
- return len(self.data)
40
-
41
- def __getitem__(self, idx):
42
- img, label = self.data[idx]
43
- return img, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_loader.py DELETED
@@ -1,32 +0,0 @@
1
- from configs import *
2
- from torchvision.datasets import ImageFolder
3
- from torch.utils.data import random_split, DataLoader, Dataset
4
-
5
-
6
- def load_data(raw_dir, augmented_dir, external_dir, preprocess):
7
- # Load the dataset using ImageFolder
8
- raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
9
- external_dataset = ImageFolder(root=external_dir, transform=preprocess)
10
- augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
11
- dataset = raw_dataset + external_dataset + augmented_dataset
12
-
13
- print("Classes: ", *raw_dataset.classes, sep = ', ')
14
- print("Length of raw dataset: ", len(raw_dataset))
15
- print("Length of external dataset: ", len(external_dataset))
16
- print("Length of augmented dataset: ", len(augmented_dataset))
17
- print("Length of total dataset: ", len(dataset))
18
-
19
- # Split the dataset into train and validation sets
20
- train_size = int(0.8 * len(dataset))
21
- val_size = len(dataset) - train_size
22
- train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
23
-
24
- # Create data loaders for the custom dataset
25
- train_loader = DataLoader(
26
- CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
27
- )
28
- valid_loader = DataLoader(
29
- CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
30
- )
31
-
32
- return train_loader, valid_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval.py DELETED
@@ -1,86 +0,0 @@
1
- import os
2
- import torch
3
- from torchvision.transforms import transforms
4
- from sklearn.metrics import f1_score
5
- import pathlib
6
- from PIL import Image
7
- from torchmetrics import ConfusionMatrix
8
- import matplotlib.pyplot as plt
9
- from configs import *
10
- from data_loader import load_data # Import the load_data function
11
-
12
- image_path = "data/test/Task 1/"
13
-
14
- # Constants
15
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16
-
17
- # Load the model
18
- MODEL = MODEL.to(DEVICE)
19
- MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
20
- MODEL.eval()
21
-
22
- # Get class labels from the dataset
23
- class_labels = os.listdir(image_path)
24
-
25
- # Define transformation for preprocessing
26
- preprocess = transforms.Compose(
27
- [
28
- transforms.Resize((64, 64)), # Resize images to 64x64
29
- transforms.ToTensor(), # Convert to tensor
30
- transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale)
31
- ]
32
- )
33
-
34
- def predict_image(image_path, model, transform):
35
- model.eval()
36
- correct_predictions = 0
37
- total_predictions = len(images)
38
-
39
- # Get a list of image files
40
- images = list(pathlib.Path(image_path).rglob("*.png"))
41
-
42
- true_classes = []
43
- predicted_labels = []
44
-
45
- with torch.no_grad():
46
- for image_file in images:
47
- print('---------------------------')
48
- # Check the true label of the image by checking the sequence of the folder in Task 1
49
- true_class = class_labels.index(image_file.parts[-2])
50
- print("Image path:", image_file)
51
- print("True class:", true_class)
52
- image = Image.open(image_file).convert('RGB')
53
- image = transform(image).unsqueeze(0)
54
- image = image.to(DEVICE)
55
- output = model(image)
56
- predicted_class = torch.argmax(output, dim=1).item()
57
- # Print the predicted class
58
- print("Predicted class:", predicted_class)
59
- # Append true and predicted labels to their respective lists
60
- true_classes.append(true_class)
61
- predicted_labels.append(predicted_class)
62
-
63
- # Check if the prediction is correct
64
- if predicted_class == true_class:
65
- correct_predictions += 1
66
-
67
- # Calculate accuracy and f1 score
68
- accuracy = correct_predictions / total_predictions
69
- print("Accuracy:", accuracy)
70
- f1 = f1_score(true_classes, predicted_labels, average='weighted')
71
- print("Weighted F1 Score:", f1)
72
-
73
- # Convert the lists to tensors
74
- predicted_labels_tensor = torch.tensor(predicted_labels)
75
- true_classes_tensor = torch.tensor(true_classes)
76
-
77
- # Create a confusion matrix
78
- conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
79
- conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
80
-
81
- # Plot the confusion matrix
82
- conf_matrix.plot()
83
- plt.show()
84
-
85
- # Call predict_image function
86
- predict_image(image_path, MODEL, preprocess)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py DELETED
@@ -1,36 +0,0 @@
1
- #######################################################
2
- # This file stores all the models used in the project.#
3
- #######################################################
4
-
5
- # Import all models from torchvision.models
6
- from torchvision.models import resnet50
7
- from torchvision.models import resnet18
8
- from torchvision.models import squeezenet1_0
9
- from torchvision.models import vgg16
10
- from torchvision.models import alexnet
11
- from torchvision.models import densenet121
12
- from torchvision.models import googlenet
13
- from torchvision.models import inception_v3
14
- from torchvision.models import mobilenet_v2
15
- from torchvision.models import mobilenet_v3_small
16
- from torchvision.models import mobilenet_v3_large
17
- from torchvision.models import shufflenet_v2_x0_5
18
- from torchvision.models import vgg11
19
- from torchvision.models import vgg11_bn
20
- from torchvision.models import vgg13
21
- from torchvision.models import vgg13_bn
22
- from torchvision.models import vgg16_bn
23
- from torchvision.models import vgg19_bn
24
- from torchvision.models import vgg19
25
- from torchvision.models import wide_resnet50_2
26
- from torchvision.models import wide_resnet101_2
27
- from torchvision.models import mnasnet0_5
28
- from torchvision.models import mnasnet0_75
29
- from torchvision.models import mnasnet1_0
30
- from torchvision.models import mnasnet1_3
31
- from torchvision.models import resnext50_32x4d
32
- from torchvision.models import resnext101_32x8d
33
- from torchvision.models import shufflenet_v2_x1_0
34
- from torchvision.models import shufflenet_v2_x1_5
35
- from torchvision.models import shufflenet_v2_x2_0
36
- from torchvision.models import squeezenet1_1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
predict.py DELETED
@@ -1,57 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
- from PIL import Image
6
- from models import *
7
- from torchmetrics import ConfusionMatrix
8
- import matplotlib.pyplot as plt
9
- from configs import *
10
-
11
-
12
- # Load your model (change this according to your model definition)
13
- MODEL.load_state_dict(
14
- torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
15
- ) # Load the model on the same device
16
- MODEL.eval()
17
- MODEL = MODEL.to(DEVICE)
18
- MODEL.eval()
19
- torch.set_grad_enabled(False)
20
-
21
-
22
- def predict_image(image_path, model=MODEL, transform=preprocess):
23
- classes = [
24
- 'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease'
25
- ]
26
-
27
- print("---------------------------")
28
- print("Image path:", image_path)
29
- image = Image.open(image_path)
30
- image = transform(image).unsqueeze(0)
31
- image = image.to(DEVICE)
32
- output = model(image)
33
-
34
- # Softmax algorithm
35
- probabilities = torch.softmax(output, dim=1)[0] * 100
36
-
37
- # Sort the classes by probabilities in descending order
38
- sorted_classes = sorted(
39
- zip(classes, probabilities), key=lambda x: x[1], reverse=True
40
- )
41
-
42
- # Report the prediction for each class
43
- print("Probabilities for each class:")
44
- for class_label, class_prob in sorted_classes:
45
- class_prob = class_prob.item().__round__(2)
46
- print(f"{class_label}: {class_prob}%")
47
-
48
- # Get the predicted class
49
- predicted_class = sorted_classes[0][0] # Most probable class
50
- predicted_label = classes.index(predicted_class)
51
-
52
- # Report the prediction
53
- print("Predicted class:", predicted_label)
54
- print("Predicted label:", predicted_class)
55
- print("---------------------------")
56
-
57
- return predicted_label, sorted_classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py DELETED
@@ -1,145 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torchvision.transforms import transforms
6
- from torch.utils.data import DataLoader
7
- from torchvision.utils import make_grid
8
- from scipy.ndimage import gaussian_filter1d
9
- import matplotlib.pyplot as plt
10
- from models import *
11
- from torch.utils.tensorboard import SummaryWriter
12
- from configs import *
13
- import data_loader
14
-
15
- # Set up TensorBoard writer
16
- writer = SummaryWriter(log_dir="output/tensorboard/training")
17
-
18
- # Define a function for plotting and logging metrics
19
- def plot_and_log_metrics(metrics_dict, step, prefix="Train"):
20
- for metric_name, metric_value in metrics_dict.items():
21
- writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step)
22
-
23
- # Data loader
24
- train_loader, valid_loader = data_loader.load_data(
25
- RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
26
- )
27
-
28
- # Initialize model, criterion, optimizer, and scheduler
29
- MODEL = MODEL.to(DEVICE)
30
- criterion = nn.CrossEntropyLoss()
31
- optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
32
- scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
33
-
34
- # Lists to store training and validation loss history
35
- TRAIN_LOSS_HIST = []
36
- VAL_LOSS_HIST = []
37
- AVG_TRAIN_LOSS_HIST = []
38
- AVG_VAL_LOSS_HIST = []
39
- TRAIN_ACC_HIST = []
40
- VAL_ACC_HIST = []
41
-
42
- # Training loop
43
- for epoch in range(NUM_EPOCHS):
44
- MODEL.train() # Set model to training mode
45
- running_loss = 0.0
46
- total_train = 0
47
- correct_train = 0
48
-
49
- for i, (inputs, labels) in enumerate(train_loader, 0):
50
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
51
- optimizer.zero_grad()
52
- outputs = MODEL(inputs)
53
- loss = criterion(outputs, labels)
54
- loss.backward()
55
- optimizer.step()
56
- running_loss += loss.item()
57
-
58
- if (i + 1) % NUM_PRINT == 0:
59
- print(
60
- "[Epoch %d, Batch %d] Loss: %.6f"
61
- % (epoch + 1, i + 1, running_loss / NUM_PRINT)
62
- )
63
- running_loss = 0.0
64
-
65
- _, predicted = torch.max(outputs, 1)
66
- total_train += labels.size(0)
67
- correct_train += (predicted == labels).sum().item()
68
-
69
- avg_train_loss = running_loss / len(train_loader)
70
- TRAIN_LOSS_HIST.append(avg_train_loss)
71
- TRAIN_ACC_HIST.append(correct_train / total_train)
72
-
73
- # Log training metrics
74
- train_metrics = {
75
- "Loss": avg_train_loss,
76
- "Accuracy": correct_train / total_train,
77
- }
78
- plot_and_log_metrics(train_metrics, epoch, prefix="Train")
79
-
80
- # Learning rate scheduling
81
- scheduler.step()
82
-
83
- # Validation loop
84
- MODEL.eval() # Set model to evaluation mode
85
- val_loss = 0.0
86
- correct_val = 0
87
- total_val = 0
88
-
89
- with torch.no_grad():
90
- for inputs, labels in valid_loader:
91
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
92
- outputs = MODEL(inputs)
93
- loss = criterion(outputs, labels)
94
- val_loss += loss.item()
95
- # Calculate accuracy
96
- _, predicted = torch.max(outputs, 1)
97
- total_val += labels.size(0)
98
- correct_val += (predicted == labels).sum().item()
99
-
100
- avg_val_loss = val_loss / len(valid_loader)
101
- VAL_LOSS_HIST.append(avg_val_loss)
102
- VAL_ACC_HIST.append(correct_val / total_val)
103
-
104
- # Log validation metrics
105
- val_metrics = {
106
- "Loss": avg_val_loss,
107
- "Accuracy": correct_val / total_val,
108
- }
109
- plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
110
-
111
- # Add sample images to TensorBoard
112
- sample_images, _ = next(iter(valid_loader))
113
- sample_images = sample_images.to(DEVICE)
114
- grid_image = make_grid(
115
- sample_images, nrow=8, normalize=True
116
- )
117
- writer.add_image("Sample Images", grid_image, global_step=epoch)
118
-
119
- # Save the model
120
- torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
121
- print("Model saved at", MODEL_SAVE_PATH)
122
-
123
- # Plot loss and accuracy curves
124
- plt.figure(figsize=(12, 4))
125
- plt.subplot(1, 2, 1)
126
- plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_LOSS_HIST, label="Train Loss")
127
- plt.plot(range(1, NUM_EPOCHS + 1), VAL_LOSS_HIST, label="Validation Loss")
128
- plt.xlabel("Epochs")
129
- plt.ylabel("Loss")
130
- plt.legend()
131
- plt.title("Loss Curves")
132
-
133
- plt.subplot(1, 2, 2)
134
- plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_ACC_HIST, label="Train Accuracy")
135
- plt.plot(range(1, NUM_EPOCHS + 1), VAL_ACC_HIST, label="Validation Accuracy")
136
- plt.xlabel("Epochs")
137
- plt.ylabel("Accuracy")
138
- plt.legend()
139
- plt.title("Accuracy Curves")
140
-
141
- plt.tight_layout()
142
- plt.savefig("training_curves.png")
143
-
144
- # Close TensorBoard writer
145
- writer.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tuning.py DELETED
@@ -1,186 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from models import * # Import your model here
6
- from torch.utils.tensorboard import SummaryWriter
7
- from torchvision.utils import make_grid
8
- import optuna
9
- from configs import *
10
- import data_loader
11
-
12
- # Data loader
13
- train_loader, valid_loader = data_loader.load_data(
14
- RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
15
- )
16
-
17
- # Initialize model, criterion, optimizer, and scheduler
18
- MODEL = MODEL.to(DEVICE)
19
- criterion = nn.CrossEntropyLoss()
20
- optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
21
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(
22
- optimizer, mode="min", factor=0.1, patience=10, verbose=True
23
- )
24
-
25
- # Lists to store training and validation loss history
26
- TRAIN_LOSS_HIST = []
27
- VAL_LOSS_HIST = []
28
- TRAIN_ACC_HIST = []
29
- VAL_ACC_HIST = []
30
- AVG_TRAIN_LOSS_HIST = []
31
- AVG_VAL_LOSS_HIST = []
32
-
33
- # Create a TensorBoard writer for logging
34
- writer = SummaryWriter(
35
- log_dir="output/tensorboard/tuning",
36
- )
37
-
38
- # Define early stopping parameters
39
- early_stopping_patience = 10 # Number of epochs to wait for improvement
40
- best_val_loss = float('inf')
41
- no_improvement_count = 0
42
-
43
- def train_epoch(epoch):
44
- MODEL.train(True)
45
- running_loss = 0.0
46
- total_train = 0
47
- correct_train = 0
48
-
49
- for i, (inputs, labels) in enumerate(train_loader, 0):
50
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
51
- optimizer.zero_grad()
52
- outputs = MODEL(inputs)
53
- loss = criterion(outputs, labels)
54
- loss.backward()
55
- optimizer.step()
56
- running_loss += loss.item()
57
-
58
- if (i + 1) % NUM_PRINT == 0:
59
- print(
60
- "[Epoch %d, Batch %d] Loss: %.6f"
61
- % (epoch + 1, i + 1, running_loss / NUM_PRINT)
62
- )
63
- running_loss = 0.0
64
-
65
- _, predicted = torch.max(outputs, 1)
66
- total_train += labels.size(0)
67
- correct_train += (predicted == labels).sum().item()
68
-
69
- TRAIN_LOSS_HIST.append(loss.item())
70
- train_accuracy = correct_train / total_train
71
- TRAIN_ACC_HIST.append(train_accuracy)
72
- # Calculate the average training loss for the epoch
73
- avg_train_loss = running_loss / len(train_loader)
74
-
75
- writer.add_scalar("Loss/Train", avg_train_loss, epoch)
76
- writer.add_scalar("Accuracy/Train", train_accuracy, epoch)
77
- AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
78
-
79
- # Print average training loss for the epoch
80
- print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
81
-
82
- # Learning rate scheduling
83
- lr_1 = optimizer.param_groups[0]["lr"]
84
- print("Learning Rate: {:.15f}".format(lr_1))
85
- scheduler.step(avg_train_loss)
86
-
87
- def validate_epoch(epoch):
88
- global best_val_loss, no_improvement_count
89
-
90
- MODEL.eval()
91
- val_loss = 0.0
92
- correct_val = 0
93
- total_val = 0
94
-
95
- with torch.no_grad():
96
- for inputs, labels in valid_loader:
97
- inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
98
- outputs = MODEL(inputs)
99
- loss = criterion(outputs, labels)
100
- val_loss += loss.item()
101
- # Calculate accuracy
102
- _, predicted = torch.max(outputs, 1)
103
- total_val += labels.size(0)
104
- correct_val += (predicted == labels).sum().item()
105
-
106
- VAL_LOSS_HIST.append(loss.item())
107
-
108
- # Calculate the average validation loss for the epoch
109
- avg_val_loss = val_loss / len(valid_loader)
110
- AVG_VAL_LOSS_HIST.append(loss.item())
111
- print("Average Validation Loss: %.6f" % (avg_val_loss))
112
-
113
- # Calculate the accuracy of the validation set
114
- val_accuracy = correct_val / total_val
115
- VAL_ACC_HIST.append(val_accuracy)
116
- print("Validation Accuracy: %.6f" % (val_accuracy))
117
- writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
118
- writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)
119
-
120
- # Add sample images to TensorBoard
121
- sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
122
- sample_images = sample_images.to(DEVICE)
123
- grid_image = make_grid(
124
- sample_images, nrow=8, normalize=True
125
- ) # Create a grid of images
126
- writer.add_image("Sample Images", grid_image, global_step=epoch)
127
-
128
- # Check for early stopping
129
- if avg_val_loss < best_val_loss:
130
- best_val_loss = avg_val_loss
131
- no_improvement_count = 0
132
- else:
133
- no_improvement_count += 1
134
-
135
- if no_improvement_count >= early_stopping_patience:
136
- print(f"Early stopping after {epoch + 1} epochs without improvement.")
137
- return True # Return True to stop training
138
-
139
- def objective(trial):
140
- global best_val_loss, no_improvement_count
141
-
142
- learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1)
143
- batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
144
-
145
- # Modify the model and optimizer using suggested hyperparameters
146
- optimizer = optim.Adam(MODEL.parameters(), lr=learning_rate)
147
-
148
- for epoch in range(10):
149
- train_epoch(epoch)
150
- early_stopping = validate_epoch(epoch)
151
-
152
- # Check for early stopping
153
- if early_stopping:
154
- break
155
-
156
- # Calculate a weighted score based on validation accuracy and loss
157
- validation_score = VAL_ACC_HIST[-1] - AVG_VAL_LOSS_HIST[-1]
158
-
159
- # Return the negative score as Optuna maximizes by default
160
- return -validation_score
161
-
162
- if __name__ == "__main__":
163
- study = optuna.create_study(direction="maximize")
164
- study.optimize(objective, n_trials=100, show_progress_bar=True)
165
-
166
- # Print statistics
167
- print("Number of finished trials: ", len(study.trials))
168
- pruned_trials = [
169
- t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED
170
- ]
171
- print("Number of pruned trials: ", len(pruned_trials))
172
- complete_trials = [
173
- t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE
174
- ]
175
- print("Number of complete trials: ", len(complete_trials))
176
-
177
- # Print best trial
178
- trial = study.best_trial
179
- print("Best trial:")
180
- print(" Value: ", -trial.value) # Negate the value as it was maximized
181
- print(" Params: ")
182
- for key, value in trial.params.items():
183
- print(f" {key}: {value}")
184
-
185
- # Close TensorBoard writer
186
- writer.close()