Update
Browse files- augment.py +0 -50
- configs.py +0 -43
- data_loader.py +0 -32
- eval.py +0 -86
- models.py +0 -36
- predict.py +0 -57
- train.py +0 -145
- tuning.py +0 -186
augment.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import Augmentor
|
3 |
-
import shutil
|
4 |
-
from configs import *
|
5 |
-
|
6 |
-
tasks = ["1", "2", "3", "4", "5", "6"]
|
7 |
-
|
8 |
-
for task in tasks:
|
9 |
-
# Loop through all folders in Task 1 and generate augmented images for each class
|
10 |
-
for disease in os.listdir("data/train/raw/Task " + task):
|
11 |
-
if disease != ".DS_Store":
|
12 |
-
print("Augmenting images in class: ", disease, " in Task ", task)
|
13 |
-
# Create a temp folder to combine the raw data and the external data
|
14 |
-
if not os.path.exists(f"data/temp/Task {task}/{disease}/"):
|
15 |
-
os.makedirs(f"data/temp/Task {task}/{disease}/")
|
16 |
-
for file in os.listdir(f"data/train/raw/Task {task}/{disease}"):
|
17 |
-
shutil.copy(f"data/train/raw/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
|
18 |
-
for file in os.listdir(f"data/train/external/Task {task}/{disease}"):
|
19 |
-
shutil.copy(f"data/train/external/Task {task}/{disease}/{file}", f"data/temp/Task {task}/{disease}/{file}")
|
20 |
-
p = Augmentor.Pipeline(f"data/temp/Task {task}/{disease}", output_directory=f"{disease}/", save_format="png")
|
21 |
-
p.rotate(probability=0.8, max_left_rotation=5, max_right_rotation=5)
|
22 |
-
p.flip_left_right(probability=0.8)
|
23 |
-
p.zoom_random(probability=0.8, percentage_area=0.8)
|
24 |
-
p.flip_top_bottom(probability=0.8)
|
25 |
-
p.random_brightness(probability=0.8, min_factor=0.5, max_factor=1.5)
|
26 |
-
p.random_contrast(probability=0.8, min_factor=0.5, max_factor=1.5)
|
27 |
-
p.random_color(probability=0.8, min_factor=0.5, max_factor=1.5)
|
28 |
-
# Generate 100 - total of original images so that the total number of images in each class is 100
|
29 |
-
p.sample(100 - len(p.augmentor_images))
|
30 |
-
# Move the folder to data/train/Task 1/augmented
|
31 |
-
# Create the folder if it does not exist
|
32 |
-
if not os.path.exists(f"data/train/augmented/Task {task}/"):
|
33 |
-
os.makedirs(f"data/train/augmented/Task {task}/")
|
34 |
-
# Move all images in the data/train/Task 1/i folder to data/train/Task 1/augmented/i
|
35 |
-
os.rename(
|
36 |
-
f"data/temp/Task {task}/{disease}/{disease}",
|
37 |
-
f"data/train/augmented/Task {task}/{disease}",
|
38 |
-
)
|
39 |
-
# Rename all the augmented images to [01, 02, 03]
|
40 |
-
number = 0
|
41 |
-
for file in os.listdir(f"data/train/augmented/Task {task}/{disease}"):
|
42 |
-
number = int(number) + 1
|
43 |
-
if len(str(number)) == 1:
|
44 |
-
number = "0" + str(number)
|
45 |
-
os.rename(
|
46 |
-
f"data/train/augmented/Task {task}/{disease}/{file}",
|
47 |
-
f"data/train/augmented/Task {task}/{disease}/{number}.png",
|
48 |
-
)
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs.py
DELETED
@@ -1,43 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from torchvision import transforms
|
4 |
-
from torch.utils.data import Dataset
|
5 |
-
from models import *
|
6 |
-
|
7 |
-
# Constants
|
8 |
-
RANDOM_SEED = 123
|
9 |
-
BATCH_SIZE = 32
|
10 |
-
NUM_EPOCHS = 100
|
11 |
-
LEARNING_RATE = 0.001
|
12 |
-
STEP_SIZE = 10
|
13 |
-
GAMMA = 0.5
|
14 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
15 |
-
NUM_PRINT = 100
|
16 |
-
TASK = 1
|
17 |
-
RAW_DATA_DIR = r"data/train/raw/Task " + str(TASK)
|
18 |
-
AUG_DATA_DIR = r"data/train/augmented/Task " + str(TASK)
|
19 |
-
EXTERNAL_DATA_DIR = r"data/train/external/Task " + str(TASK)
|
20 |
-
NUM_CLASSES = 7
|
21 |
-
MODEL_SAVE_PATH = "output/checkpoints/model.pth"
|
22 |
-
MODEL = mobilenet_v3_small(num_classes=NUM_CLASSES)
|
23 |
-
|
24 |
-
preprocess = transforms.Compose(
|
25 |
-
[
|
26 |
-
transforms.Resize((64, 64)), # Resize images to 64x64
|
27 |
-
transforms.ToTensor(), # Convert to tensor
|
28 |
-
# Normalize 3 channels
|
29 |
-
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
30 |
-
]
|
31 |
-
)
|
32 |
-
|
33 |
-
# Custom dataset class
|
34 |
-
class CustomDataset(Dataset):
|
35 |
-
def __init__(self, dataset):
|
36 |
-
self.data = dataset
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.data)
|
40 |
-
|
41 |
-
def __getitem__(self, idx):
|
42 |
-
img, label = self.data[idx]
|
43 |
-
return img, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_loader.py
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
from configs import *
|
2 |
-
from torchvision.datasets import ImageFolder
|
3 |
-
from torch.utils.data import random_split, DataLoader, Dataset
|
4 |
-
|
5 |
-
|
6 |
-
def load_data(raw_dir, augmented_dir, external_dir, preprocess):
|
7 |
-
# Load the dataset using ImageFolder
|
8 |
-
raw_dataset = ImageFolder(root=raw_dir, transform=preprocess)
|
9 |
-
external_dataset = ImageFolder(root=external_dir, transform=preprocess)
|
10 |
-
augmented_dataset = ImageFolder(root=augmented_dir, transform=preprocess)
|
11 |
-
dataset = raw_dataset + external_dataset + augmented_dataset
|
12 |
-
|
13 |
-
print("Classes: ", *raw_dataset.classes, sep = ', ')
|
14 |
-
print("Length of raw dataset: ", len(raw_dataset))
|
15 |
-
print("Length of external dataset: ", len(external_dataset))
|
16 |
-
print("Length of augmented dataset: ", len(augmented_dataset))
|
17 |
-
print("Length of total dataset: ", len(dataset))
|
18 |
-
|
19 |
-
# Split the dataset into train and validation sets
|
20 |
-
train_size = int(0.8 * len(dataset))
|
21 |
-
val_size = len(dataset) - train_size
|
22 |
-
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
23 |
-
|
24 |
-
# Create data loaders for the custom dataset
|
25 |
-
train_loader = DataLoader(
|
26 |
-
CustomDataset(train_dataset), batch_size=BATCH_SIZE, shuffle=True, num_workers=0
|
27 |
-
)
|
28 |
-
valid_loader = DataLoader(
|
29 |
-
CustomDataset(val_dataset), batch_size=BATCH_SIZE, num_workers=0
|
30 |
-
)
|
31 |
-
|
32 |
-
return train_loader, valid_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval.py
DELETED
@@ -1,86 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
from torchvision.transforms import transforms
|
4 |
-
from sklearn.metrics import f1_score
|
5 |
-
import pathlib
|
6 |
-
from PIL import Image
|
7 |
-
from torchmetrics import ConfusionMatrix
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
from configs import *
|
10 |
-
from data_loader import load_data # Import the load_data function
|
11 |
-
|
12 |
-
image_path = "data/test/Task 1/"
|
13 |
-
|
14 |
-
# Constants
|
15 |
-
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
16 |
-
|
17 |
-
# Load the model
|
18 |
-
MODEL = MODEL.to(DEVICE)
|
19 |
-
MODEL.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
|
20 |
-
MODEL.eval()
|
21 |
-
|
22 |
-
# Get class labels from the dataset
|
23 |
-
class_labels = os.listdir(image_path)
|
24 |
-
|
25 |
-
# Define transformation for preprocessing
|
26 |
-
preprocess = transforms.Compose(
|
27 |
-
[
|
28 |
-
transforms.Resize((64, 64)), # Resize images to 64x64
|
29 |
-
transforms.ToTensor(), # Convert to tensor
|
30 |
-
transforms.Normalize((0.5,), (0.5,)), # Normalize (for grayscale)
|
31 |
-
]
|
32 |
-
)
|
33 |
-
|
34 |
-
def predict_image(image_path, model, transform):
|
35 |
-
model.eval()
|
36 |
-
correct_predictions = 0
|
37 |
-
total_predictions = len(images)
|
38 |
-
|
39 |
-
# Get a list of image files
|
40 |
-
images = list(pathlib.Path(image_path).rglob("*.png"))
|
41 |
-
|
42 |
-
true_classes = []
|
43 |
-
predicted_labels = []
|
44 |
-
|
45 |
-
with torch.no_grad():
|
46 |
-
for image_file in images:
|
47 |
-
print('---------------------------')
|
48 |
-
# Check the true label of the image by checking the sequence of the folder in Task 1
|
49 |
-
true_class = class_labels.index(image_file.parts[-2])
|
50 |
-
print("Image path:", image_file)
|
51 |
-
print("True class:", true_class)
|
52 |
-
image = Image.open(image_file).convert('RGB')
|
53 |
-
image = transform(image).unsqueeze(0)
|
54 |
-
image = image.to(DEVICE)
|
55 |
-
output = model(image)
|
56 |
-
predicted_class = torch.argmax(output, dim=1).item()
|
57 |
-
# Print the predicted class
|
58 |
-
print("Predicted class:", predicted_class)
|
59 |
-
# Append true and predicted labels to their respective lists
|
60 |
-
true_classes.append(true_class)
|
61 |
-
predicted_labels.append(predicted_class)
|
62 |
-
|
63 |
-
# Check if the prediction is correct
|
64 |
-
if predicted_class == true_class:
|
65 |
-
correct_predictions += 1
|
66 |
-
|
67 |
-
# Calculate accuracy and f1 score
|
68 |
-
accuracy = correct_predictions / total_predictions
|
69 |
-
print("Accuracy:", accuracy)
|
70 |
-
f1 = f1_score(true_classes, predicted_labels, average='weighted')
|
71 |
-
print("Weighted F1 Score:", f1)
|
72 |
-
|
73 |
-
# Convert the lists to tensors
|
74 |
-
predicted_labels_tensor = torch.tensor(predicted_labels)
|
75 |
-
true_classes_tensor = torch.tensor(true_classes)
|
76 |
-
|
77 |
-
# Create a confusion matrix
|
78 |
-
conf_matrix = ConfusionMatrix(num_classes=NUM_CLASSES, task='multiclass')
|
79 |
-
conf_matrix.update(predicted_labels_tensor, true_classes_tensor)
|
80 |
-
|
81 |
-
# Plot the confusion matrix
|
82 |
-
conf_matrix.plot()
|
83 |
-
plt.show()
|
84 |
-
|
85 |
-
# Call predict_image function
|
86 |
-
predict_image(image_path, MODEL, preprocess)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
DELETED
@@ -1,36 +0,0 @@
|
|
1 |
-
#######################################################
|
2 |
-
# This file stores all the models used in the project.#
|
3 |
-
#######################################################
|
4 |
-
|
5 |
-
# Import all models from torchvision.models
|
6 |
-
from torchvision.models import resnet50
|
7 |
-
from torchvision.models import resnet18
|
8 |
-
from torchvision.models import squeezenet1_0
|
9 |
-
from torchvision.models import vgg16
|
10 |
-
from torchvision.models import alexnet
|
11 |
-
from torchvision.models import densenet121
|
12 |
-
from torchvision.models import googlenet
|
13 |
-
from torchvision.models import inception_v3
|
14 |
-
from torchvision.models import mobilenet_v2
|
15 |
-
from torchvision.models import mobilenet_v3_small
|
16 |
-
from torchvision.models import mobilenet_v3_large
|
17 |
-
from torchvision.models import shufflenet_v2_x0_5
|
18 |
-
from torchvision.models import vgg11
|
19 |
-
from torchvision.models import vgg11_bn
|
20 |
-
from torchvision.models import vgg13
|
21 |
-
from torchvision.models import vgg13_bn
|
22 |
-
from torchvision.models import vgg16_bn
|
23 |
-
from torchvision.models import vgg19_bn
|
24 |
-
from torchvision.models import vgg19
|
25 |
-
from torchvision.models import wide_resnet50_2
|
26 |
-
from torchvision.models import wide_resnet101_2
|
27 |
-
from torchvision.models import mnasnet0_5
|
28 |
-
from torchvision.models import mnasnet0_75
|
29 |
-
from torchvision.models import mnasnet1_0
|
30 |
-
from torchvision.models import mnasnet1_3
|
31 |
-
from torchvision.models import resnext50_32x4d
|
32 |
-
from torchvision.models import resnext101_32x8d
|
33 |
-
from torchvision.models import shufflenet_v2_x1_0
|
34 |
-
from torchvision.models import shufflenet_v2_x1_5
|
35 |
-
from torchvision.models import shufflenet_v2_x2_0
|
36 |
-
from torchvision.models import squeezenet1_1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predict.py
DELETED
@@ -1,57 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
from torchvision import transforms
|
5 |
-
from PIL import Image
|
6 |
-
from models import *
|
7 |
-
from torchmetrics import ConfusionMatrix
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
from configs import *
|
10 |
-
|
11 |
-
|
12 |
-
# Load your model (change this according to your model definition)
|
13 |
-
MODEL.load_state_dict(
|
14 |
-
torch.load(MODEL_SAVE_PATH, map_location=DEVICE)
|
15 |
-
) # Load the model on the same device
|
16 |
-
MODEL.eval()
|
17 |
-
MODEL = MODEL.to(DEVICE)
|
18 |
-
MODEL.eval()
|
19 |
-
torch.set_grad_enabled(False)
|
20 |
-
|
21 |
-
|
22 |
-
def predict_image(image_path, model=MODEL, transform=preprocess):
|
23 |
-
classes = [
|
24 |
-
'Cerebral Palsy', 'Dystonia', 'Essential Tremor', 'Healthy', 'Huntington Disease', 'Parkinson Disease'
|
25 |
-
]
|
26 |
-
|
27 |
-
print("---------------------------")
|
28 |
-
print("Image path:", image_path)
|
29 |
-
image = Image.open(image_path)
|
30 |
-
image = transform(image).unsqueeze(0)
|
31 |
-
image = image.to(DEVICE)
|
32 |
-
output = model(image)
|
33 |
-
|
34 |
-
# Softmax algorithm
|
35 |
-
probabilities = torch.softmax(output, dim=1)[0] * 100
|
36 |
-
|
37 |
-
# Sort the classes by probabilities in descending order
|
38 |
-
sorted_classes = sorted(
|
39 |
-
zip(classes, probabilities), key=lambda x: x[1], reverse=True
|
40 |
-
)
|
41 |
-
|
42 |
-
# Report the prediction for each class
|
43 |
-
print("Probabilities for each class:")
|
44 |
-
for class_label, class_prob in sorted_classes:
|
45 |
-
class_prob = class_prob.item().__round__(2)
|
46 |
-
print(f"{class_label}: {class_prob}%")
|
47 |
-
|
48 |
-
# Get the predicted class
|
49 |
-
predicted_class = sorted_classes[0][0] # Most probable class
|
50 |
-
predicted_label = classes.index(predicted_class)
|
51 |
-
|
52 |
-
# Report the prediction
|
53 |
-
print("Predicted class:", predicted_label)
|
54 |
-
print("Predicted label:", predicted_class)
|
55 |
-
print("---------------------------")
|
56 |
-
|
57 |
-
return predicted_label, sorted_classes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train.py
DELETED
@@ -1,145 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.optim as optim
|
5 |
-
from torchvision.transforms import transforms
|
6 |
-
from torch.utils.data import DataLoader
|
7 |
-
from torchvision.utils import make_grid
|
8 |
-
from scipy.ndimage import gaussian_filter1d
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
-
from models import *
|
11 |
-
from torch.utils.tensorboard import SummaryWriter
|
12 |
-
from configs import *
|
13 |
-
import data_loader
|
14 |
-
|
15 |
-
# Set up TensorBoard writer
|
16 |
-
writer = SummaryWriter(log_dir="output/tensorboard/training")
|
17 |
-
|
18 |
-
# Define a function for plotting and logging metrics
|
19 |
-
def plot_and_log_metrics(metrics_dict, step, prefix="Train"):
|
20 |
-
for metric_name, metric_value in metrics_dict.items():
|
21 |
-
writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step)
|
22 |
-
|
23 |
-
# Data loader
|
24 |
-
train_loader, valid_loader = data_loader.load_data(
|
25 |
-
RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
|
26 |
-
)
|
27 |
-
|
28 |
-
# Initialize model, criterion, optimizer, and scheduler
|
29 |
-
MODEL = MODEL.to(DEVICE)
|
30 |
-
criterion = nn.CrossEntropyLoss()
|
31 |
-
optimizer = optim.SGD(MODEL.parameters(), lr=LEARNING_RATE)
|
32 |
-
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
|
33 |
-
|
34 |
-
# Lists to store training and validation loss history
|
35 |
-
TRAIN_LOSS_HIST = []
|
36 |
-
VAL_LOSS_HIST = []
|
37 |
-
AVG_TRAIN_LOSS_HIST = []
|
38 |
-
AVG_VAL_LOSS_HIST = []
|
39 |
-
TRAIN_ACC_HIST = []
|
40 |
-
VAL_ACC_HIST = []
|
41 |
-
|
42 |
-
# Training loop
|
43 |
-
for epoch in range(NUM_EPOCHS):
|
44 |
-
MODEL.train() # Set model to training mode
|
45 |
-
running_loss = 0.0
|
46 |
-
total_train = 0
|
47 |
-
correct_train = 0
|
48 |
-
|
49 |
-
for i, (inputs, labels) in enumerate(train_loader, 0):
|
50 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
51 |
-
optimizer.zero_grad()
|
52 |
-
outputs = MODEL(inputs)
|
53 |
-
loss = criterion(outputs, labels)
|
54 |
-
loss.backward()
|
55 |
-
optimizer.step()
|
56 |
-
running_loss += loss.item()
|
57 |
-
|
58 |
-
if (i + 1) % NUM_PRINT == 0:
|
59 |
-
print(
|
60 |
-
"[Epoch %d, Batch %d] Loss: %.6f"
|
61 |
-
% (epoch + 1, i + 1, running_loss / NUM_PRINT)
|
62 |
-
)
|
63 |
-
running_loss = 0.0
|
64 |
-
|
65 |
-
_, predicted = torch.max(outputs, 1)
|
66 |
-
total_train += labels.size(0)
|
67 |
-
correct_train += (predicted == labels).sum().item()
|
68 |
-
|
69 |
-
avg_train_loss = running_loss / len(train_loader)
|
70 |
-
TRAIN_LOSS_HIST.append(avg_train_loss)
|
71 |
-
TRAIN_ACC_HIST.append(correct_train / total_train)
|
72 |
-
|
73 |
-
# Log training metrics
|
74 |
-
train_metrics = {
|
75 |
-
"Loss": avg_train_loss,
|
76 |
-
"Accuracy": correct_train / total_train,
|
77 |
-
}
|
78 |
-
plot_and_log_metrics(train_metrics, epoch, prefix="Train")
|
79 |
-
|
80 |
-
# Learning rate scheduling
|
81 |
-
scheduler.step()
|
82 |
-
|
83 |
-
# Validation loop
|
84 |
-
MODEL.eval() # Set model to evaluation mode
|
85 |
-
val_loss = 0.0
|
86 |
-
correct_val = 0
|
87 |
-
total_val = 0
|
88 |
-
|
89 |
-
with torch.no_grad():
|
90 |
-
for inputs, labels in valid_loader:
|
91 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
92 |
-
outputs = MODEL(inputs)
|
93 |
-
loss = criterion(outputs, labels)
|
94 |
-
val_loss += loss.item()
|
95 |
-
# Calculate accuracy
|
96 |
-
_, predicted = torch.max(outputs, 1)
|
97 |
-
total_val += labels.size(0)
|
98 |
-
correct_val += (predicted == labels).sum().item()
|
99 |
-
|
100 |
-
avg_val_loss = val_loss / len(valid_loader)
|
101 |
-
VAL_LOSS_HIST.append(avg_val_loss)
|
102 |
-
VAL_ACC_HIST.append(correct_val / total_val)
|
103 |
-
|
104 |
-
# Log validation metrics
|
105 |
-
val_metrics = {
|
106 |
-
"Loss": avg_val_loss,
|
107 |
-
"Accuracy": correct_val / total_val,
|
108 |
-
}
|
109 |
-
plot_and_log_metrics(val_metrics, epoch, prefix="Validation")
|
110 |
-
|
111 |
-
# Add sample images to TensorBoard
|
112 |
-
sample_images, _ = next(iter(valid_loader))
|
113 |
-
sample_images = sample_images.to(DEVICE)
|
114 |
-
grid_image = make_grid(
|
115 |
-
sample_images, nrow=8, normalize=True
|
116 |
-
)
|
117 |
-
writer.add_image("Sample Images", grid_image, global_step=epoch)
|
118 |
-
|
119 |
-
# Save the model
|
120 |
-
torch.save(MODEL.state_dict(), MODEL_SAVE_PATH)
|
121 |
-
print("Model saved at", MODEL_SAVE_PATH)
|
122 |
-
|
123 |
-
# Plot loss and accuracy curves
|
124 |
-
plt.figure(figsize=(12, 4))
|
125 |
-
plt.subplot(1, 2, 1)
|
126 |
-
plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_LOSS_HIST, label="Train Loss")
|
127 |
-
plt.plot(range(1, NUM_EPOCHS + 1), VAL_LOSS_HIST, label="Validation Loss")
|
128 |
-
plt.xlabel("Epochs")
|
129 |
-
plt.ylabel("Loss")
|
130 |
-
plt.legend()
|
131 |
-
plt.title("Loss Curves")
|
132 |
-
|
133 |
-
plt.subplot(1, 2, 2)
|
134 |
-
plt.plot(range(1, NUM_EPOCHS + 1), TRAIN_ACC_HIST, label="Train Accuracy")
|
135 |
-
plt.plot(range(1, NUM_EPOCHS + 1), VAL_ACC_HIST, label="Validation Accuracy")
|
136 |
-
plt.xlabel("Epochs")
|
137 |
-
plt.ylabel("Accuracy")
|
138 |
-
plt.legend()
|
139 |
-
plt.title("Accuracy Curves")
|
140 |
-
|
141 |
-
plt.tight_layout()
|
142 |
-
plt.savefig("training_curves.png")
|
143 |
-
|
144 |
-
# Close TensorBoard writer
|
145 |
-
writer.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tuning.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.optim as optim
|
5 |
-
from models import * # Import your model here
|
6 |
-
from torch.utils.tensorboard import SummaryWriter
|
7 |
-
from torchvision.utils import make_grid
|
8 |
-
import optuna
|
9 |
-
from configs import *
|
10 |
-
import data_loader
|
11 |
-
|
12 |
-
# Data loader
|
13 |
-
train_loader, valid_loader = data_loader.load_data(
|
14 |
-
RAW_DATA_DIR, AUG_DATA_DIR, EXTERNAL_DATA_DIR, preprocess
|
15 |
-
)
|
16 |
-
|
17 |
-
# Initialize model, criterion, optimizer, and scheduler
|
18 |
-
MODEL = MODEL.to(DEVICE)
|
19 |
-
criterion = nn.CrossEntropyLoss()
|
20 |
-
optimizer = optim.Adam(MODEL.parameters(), lr=LEARNING_RATE)
|
21 |
-
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
22 |
-
optimizer, mode="min", factor=0.1, patience=10, verbose=True
|
23 |
-
)
|
24 |
-
|
25 |
-
# Lists to store training and validation loss history
|
26 |
-
TRAIN_LOSS_HIST = []
|
27 |
-
VAL_LOSS_HIST = []
|
28 |
-
TRAIN_ACC_HIST = []
|
29 |
-
VAL_ACC_HIST = []
|
30 |
-
AVG_TRAIN_LOSS_HIST = []
|
31 |
-
AVG_VAL_LOSS_HIST = []
|
32 |
-
|
33 |
-
# Create a TensorBoard writer for logging
|
34 |
-
writer = SummaryWriter(
|
35 |
-
log_dir="output/tensorboard/tuning",
|
36 |
-
)
|
37 |
-
|
38 |
-
# Define early stopping parameters
|
39 |
-
early_stopping_patience = 10 # Number of epochs to wait for improvement
|
40 |
-
best_val_loss = float('inf')
|
41 |
-
no_improvement_count = 0
|
42 |
-
|
43 |
-
def train_epoch(epoch):
|
44 |
-
MODEL.train(True)
|
45 |
-
running_loss = 0.0
|
46 |
-
total_train = 0
|
47 |
-
correct_train = 0
|
48 |
-
|
49 |
-
for i, (inputs, labels) in enumerate(train_loader, 0):
|
50 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
51 |
-
optimizer.zero_grad()
|
52 |
-
outputs = MODEL(inputs)
|
53 |
-
loss = criterion(outputs, labels)
|
54 |
-
loss.backward()
|
55 |
-
optimizer.step()
|
56 |
-
running_loss += loss.item()
|
57 |
-
|
58 |
-
if (i + 1) % NUM_PRINT == 0:
|
59 |
-
print(
|
60 |
-
"[Epoch %d, Batch %d] Loss: %.6f"
|
61 |
-
% (epoch + 1, i + 1, running_loss / NUM_PRINT)
|
62 |
-
)
|
63 |
-
running_loss = 0.0
|
64 |
-
|
65 |
-
_, predicted = torch.max(outputs, 1)
|
66 |
-
total_train += labels.size(0)
|
67 |
-
correct_train += (predicted == labels).sum().item()
|
68 |
-
|
69 |
-
TRAIN_LOSS_HIST.append(loss.item())
|
70 |
-
train_accuracy = correct_train / total_train
|
71 |
-
TRAIN_ACC_HIST.append(train_accuracy)
|
72 |
-
# Calculate the average training loss for the epoch
|
73 |
-
avg_train_loss = running_loss / len(train_loader)
|
74 |
-
|
75 |
-
writer.add_scalar("Loss/Train", avg_train_loss, epoch)
|
76 |
-
writer.add_scalar("Accuracy/Train", train_accuracy, epoch)
|
77 |
-
AVG_TRAIN_LOSS_HIST.append(avg_train_loss)
|
78 |
-
|
79 |
-
# Print average training loss for the epoch
|
80 |
-
print("[Epoch %d] Average Training Loss: %.6f" % (epoch + 1, avg_train_loss))
|
81 |
-
|
82 |
-
# Learning rate scheduling
|
83 |
-
lr_1 = optimizer.param_groups[0]["lr"]
|
84 |
-
print("Learning Rate: {:.15f}".format(lr_1))
|
85 |
-
scheduler.step(avg_train_loss)
|
86 |
-
|
87 |
-
def validate_epoch(epoch):
|
88 |
-
global best_val_loss, no_improvement_count
|
89 |
-
|
90 |
-
MODEL.eval()
|
91 |
-
val_loss = 0.0
|
92 |
-
correct_val = 0
|
93 |
-
total_val = 0
|
94 |
-
|
95 |
-
with torch.no_grad():
|
96 |
-
for inputs, labels in valid_loader:
|
97 |
-
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
98 |
-
outputs = MODEL(inputs)
|
99 |
-
loss = criterion(outputs, labels)
|
100 |
-
val_loss += loss.item()
|
101 |
-
# Calculate accuracy
|
102 |
-
_, predicted = torch.max(outputs, 1)
|
103 |
-
total_val += labels.size(0)
|
104 |
-
correct_val += (predicted == labels).sum().item()
|
105 |
-
|
106 |
-
VAL_LOSS_HIST.append(loss.item())
|
107 |
-
|
108 |
-
# Calculate the average validation loss for the epoch
|
109 |
-
avg_val_loss = val_loss / len(valid_loader)
|
110 |
-
AVG_VAL_LOSS_HIST.append(loss.item())
|
111 |
-
print("Average Validation Loss: %.6f" % (avg_val_loss))
|
112 |
-
|
113 |
-
# Calculate the accuracy of the validation set
|
114 |
-
val_accuracy = correct_val / total_val
|
115 |
-
VAL_ACC_HIST.append(val_accuracy)
|
116 |
-
print("Validation Accuracy: %.6f" % (val_accuracy))
|
117 |
-
writer.add_scalar("Loss/Validation", avg_val_loss, epoch)
|
118 |
-
writer.add_scalar("Accuracy/Validation", val_accuracy, epoch)
|
119 |
-
|
120 |
-
# Add sample images to TensorBoard
|
121 |
-
sample_images, _ = next(iter(valid_loader)) # Get a batch of sample images
|
122 |
-
sample_images = sample_images.to(DEVICE)
|
123 |
-
grid_image = make_grid(
|
124 |
-
sample_images, nrow=8, normalize=True
|
125 |
-
) # Create a grid of images
|
126 |
-
writer.add_image("Sample Images", grid_image, global_step=epoch)
|
127 |
-
|
128 |
-
# Check for early stopping
|
129 |
-
if avg_val_loss < best_val_loss:
|
130 |
-
best_val_loss = avg_val_loss
|
131 |
-
no_improvement_count = 0
|
132 |
-
else:
|
133 |
-
no_improvement_count += 1
|
134 |
-
|
135 |
-
if no_improvement_count >= early_stopping_patience:
|
136 |
-
print(f"Early stopping after {epoch + 1} epochs without improvement.")
|
137 |
-
return True # Return True to stop training
|
138 |
-
|
139 |
-
def objective(trial):
|
140 |
-
global best_val_loss, no_improvement_count
|
141 |
-
|
142 |
-
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1)
|
143 |
-
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
144 |
-
|
145 |
-
# Modify the model and optimizer using suggested hyperparameters
|
146 |
-
optimizer = optim.Adam(MODEL.parameters(), lr=learning_rate)
|
147 |
-
|
148 |
-
for epoch in range(10):
|
149 |
-
train_epoch(epoch)
|
150 |
-
early_stopping = validate_epoch(epoch)
|
151 |
-
|
152 |
-
# Check for early stopping
|
153 |
-
if early_stopping:
|
154 |
-
break
|
155 |
-
|
156 |
-
# Calculate a weighted score based on validation accuracy and loss
|
157 |
-
validation_score = VAL_ACC_HIST[-1] - AVG_VAL_LOSS_HIST[-1]
|
158 |
-
|
159 |
-
# Return the negative score as Optuna maximizes by default
|
160 |
-
return -validation_score
|
161 |
-
|
162 |
-
if __name__ == "__main__":
|
163 |
-
study = optuna.create_study(direction="maximize")
|
164 |
-
study.optimize(objective, n_trials=100, show_progress_bar=True)
|
165 |
-
|
166 |
-
# Print statistics
|
167 |
-
print("Number of finished trials: ", len(study.trials))
|
168 |
-
pruned_trials = [
|
169 |
-
t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED
|
170 |
-
]
|
171 |
-
print("Number of pruned trials: ", len(pruned_trials))
|
172 |
-
complete_trials = [
|
173 |
-
t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE
|
174 |
-
]
|
175 |
-
print("Number of complete trials: ", len(complete_trials))
|
176 |
-
|
177 |
-
# Print best trial
|
178 |
-
trial = study.best_trial
|
179 |
-
print("Best trial:")
|
180 |
-
print(" Value: ", -trial.value) # Negate the value as it was maximized
|
181 |
-
print(" Params: ")
|
182 |
-
for key, value in trial.params.items():
|
183 |
-
print(f" {key}: {value}")
|
184 |
-
|
185 |
-
# Close TensorBoard writer
|
186 |
-
writer.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|