File size: 4,715 Bytes
97daae4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import torch.nn as nn
import optuna
from optuna.trial import TrialState
from torch import optim
import engine, data, utils
from train import device, LEARNIGN_RATE, NUM_EPOCH, NUM_CLASSES
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchmetrics.classification import MulticlassAccuracy
import wandb
# WanDB login
API_KEY = "881252af31786a1cf813449b9b4124955f54703e"
wandb.login(key=API_KEY)
print("[LOG]: Login Succesfull.")
def objective(trial, n_trials=100):
"""Objective function to be optimized by Optuna.
Hyperparameters chosen to be optimized: optimizer, learning rate,
dropout values, number of convolutional layers, number of filters of
convolutional layers, number of neurons of fully connected layers.
Inputs:
- trial (optuna.trial._trial.Trial): Optuna trial
Returns:
- accuracy(torch.Tensor): The test accuracy. Parameter to be maximized.
"""
lr = trial.suggest_float("lr", LEARNIGN_RATE*1e-2, LEARNIGN_RATE, log=True) # Learning rates
n_epochs = trial.suggest_int('n_estimators', NUM_EPOCH//2, NUM_EPOCH)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) #torch.optim.Adam(model.parameters(), lr = lr)# getattr(optim, optimizer_name)(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()
accuracy_fn = MulticlassAccuracy(num_classes = NUM_CLASSES).to(device)
wandb.init(
# set the wandb project where this run will be logged
project="food-app",
# track hyperparameters and run metadata
config={
"optimizer": trial.params["optimizer"],
"architecture": "Efficientnet B0",
"dataset": "Food101",
"epochs": trial.params["n_estimators"],
}
)
# Training of the model
best_loss = 100
patience = 5
early_stop = 0
for epoch in range(n_epochs):
train_model, train_loss, train_acc = engine.train_loop(model = model, dataloader = data.train_dataloader,
loss_fn = loss_fn, optimizer = optimizer,
accuracy_fn = accuracy_fn, device = device)
val_loss, val_acc = engine.validation(model = model, dataloader = data.valid_dataloader, loss_fn = loss_fn,
accuracy_fn = accuracy_fn, log_images=(epoch==(wandb.config.epochs-1)), device = device)
# For pruning (stops trial early if not promising)
trial.report(val_acc, epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
if val_acc < best_loss:
early_stop = 0
best_loss = val_acc
utils.save_model(model = train_model, target_dir = "./save_model", model_name = f"best_model.pth")
else:
early_stop += 1
if early_stop == patience:
break
wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Validation Loss": val_loss, "Validation Accuracy": val_acc})
wandb.finish()
return val_acc
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
model.classifier = nn.Sequential(
nn.Dropout(p = 0.2, inplace = True),
nn.Linear(1280, NUM_CLASSES),
# nn.Softmax()
)
# isExist = os.path.exists("save_model/train_model_4e-06.pth")
# print(isExist)
model = utils.load_model(model, "save_model/train_model_4e-06.pth").to(device)
# print(model)
# Create an Optuna study to maximize test accuracy
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=11)
# Find number of pruned and completed trials
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
# Display the study statistics
print("\nStudy statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
trial = study.best_trial
print("Best trial:")
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value)) |