|
import torch |
|
import torch.nn as nn |
|
import optuna |
|
from optuna.trial import TrialState |
|
from torch import optim |
|
import engine, data, utils |
|
from train import device, LEARNIGN_RATE, NUM_EPOCH, NUM_CLASSES |
|
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights |
|
from torchmetrics.classification import MulticlassAccuracy |
|
|
|
import wandb |
|
|
|
|
|
|
|
API_KEY = "881252af31786a1cf813449b9b4124955f54703e" |
|
|
|
wandb.login(key=API_KEY) |
|
print("[LOG]: Login Succesfull.") |
|
|
|
|
|
|
|
|
|
def objective(trial, n_trials=100): |
|
"""Objective function to be optimized by Optuna. |
|
Hyperparameters chosen to be optimized: optimizer, learning rate, |
|
dropout values, number of convolutional layers, number of filters of |
|
convolutional layers, number of neurons of fully connected layers. |
|
Inputs: |
|
- trial (optuna.trial._trial.Trial): Optuna trial |
|
Returns: |
|
- accuracy(torch.Tensor): The test accuracy. Parameter to be maximized. |
|
""" |
|
|
|
lr = trial.suggest_float("lr", LEARNIGN_RATE*1e-2, LEARNIGN_RATE, log=True) |
|
n_epochs = trial.suggest_int('n_estimators', NUM_EPOCH//2, NUM_EPOCH) |
|
|
|
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"]) |
|
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) |
|
|
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
accuracy_fn = MulticlassAccuracy(num_classes = NUM_CLASSES).to(device) |
|
|
|
wandb.init( |
|
|
|
project="food-app", |
|
|
|
|
|
config={ |
|
"optimizer": trial.params["optimizer"], |
|
"architecture": "Efficientnet B0", |
|
"dataset": "Food101", |
|
"epochs": trial.params["n_estimators"], |
|
} |
|
) |
|
|
|
|
|
best_loss = 100 |
|
patience = 5 |
|
early_stop = 0 |
|
for epoch in range(n_epochs): |
|
|
|
train_model, train_loss, train_acc = engine.train_loop(model = model, dataloader = data.train_dataloader, |
|
loss_fn = loss_fn, optimizer = optimizer, |
|
accuracy_fn = accuracy_fn, device = device) |
|
|
|
val_loss, val_acc = engine.validation(model = model, dataloader = data.valid_dataloader, loss_fn = loss_fn, |
|
accuracy_fn = accuracy_fn, log_images=(epoch==(wandb.config.epochs-1)), device = device) |
|
|
|
|
|
|
|
trial.report(val_acc, epoch) |
|
|
|
if trial.should_prune(): |
|
raise optuna.exceptions.TrialPruned() |
|
|
|
if val_acc < best_loss: |
|
early_stop = 0 |
|
best_loss = val_acc |
|
utils.save_model(model = train_model, target_dir = "./save_model", model_name = f"best_model.pth") |
|
|
|
else: |
|
early_stop += 1 |
|
|
|
if early_stop == patience: |
|
break |
|
|
|
wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Validation Loss": val_loss, "Validation Accuracy": val_acc}) |
|
|
|
wandb.finish() |
|
|
|
return val_acc |
|
|
|
|
|
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) |
|
model.classifier = nn.Sequential( |
|
nn.Dropout(p = 0.2, inplace = True), |
|
nn.Linear(1280, NUM_CLASSES), |
|
|
|
) |
|
|
|
|
|
|
|
model = utils.load_model(model, "save_model/train_model_4e-06.pth").to(device) |
|
|
|
|
|
|
|
study = optuna.create_study(direction="maximize") |
|
study.optimize(objective, n_trials=11) |
|
|
|
|
|
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED]) |
|
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE]) |
|
|
|
|
|
print("\nStudy statistics: ") |
|
print(" Number of finished trials: ", len(study.trials)) |
|
print(" Number of pruned trials: ", len(pruned_trials)) |
|
print(" Number of complete trials: ", len(complete_trials)) |
|
|
|
trial = study.best_trial |
|
print("Best trial:") |
|
print(" Value: ", trial.value) |
|
print(" Params: ") |
|
for key, value in trial.params.items(): |
|
print(" {}: {}".format(key, value)) |