food-app / hyperpara.py
Jash-RA's picture
Food app for food classification
9e0645e
import torch
import torch.nn as nn
import optuna
from optuna.trial import TrialState
from torch import optim
import engine, data, utils
from train import device, LEARNIGN_RATE, NUM_EPOCH, NUM_CLASSES
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
from torchmetrics.classification import MulticlassAccuracy
import wandb
# WanDB login
API_KEY = "881252af31786a1cf813449b9b4124955f54703e"
wandb.login(key=API_KEY)
print("[LOG]: Login Succesfull.")
def objective(trial, n_trials=100):
"""Objective function to be optimized by Optuna.
Hyperparameters chosen to be optimized: optimizer, learning rate,
dropout values, number of convolutional layers, number of filters of
convolutional layers, number of neurons of fully connected layers.
Inputs:
- trial (optuna.trial._trial.Trial): Optuna trial
Returns:
- accuracy(torch.Tensor): The test accuracy. Parameter to be maximized.
"""
lr = trial.suggest_float("lr", LEARNIGN_RATE*1e-2, LEARNIGN_RATE, log=True) # Learning rates
n_epochs = trial.suggest_int('n_estimators', NUM_EPOCH//2, NUM_EPOCH)
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr) #torch.optim.Adam(model.parameters(), lr = lr)# getattr(optim, optimizer_name)(model.parameters(), lr=lr)
loss_fn = torch.nn.CrossEntropyLoss()
accuracy_fn = MulticlassAccuracy(num_classes = NUM_CLASSES).to(device)
wandb.init(
# set the wandb project where this run will be logged
project="food-app",
# track hyperparameters and run metadata
config={
"optimizer": trial.params["optimizer"],
"architecture": "Efficientnet B0",
"dataset": "Food101",
"epochs": trial.params["n_estimators"],
}
)
# Training of the model
best_loss = 100
patience = 5
early_stop = 0
for epoch in range(n_epochs):
train_model, train_loss, train_acc = engine.train_loop(model = model, dataloader = data.train_dataloader,
loss_fn = loss_fn, optimizer = optimizer,
accuracy_fn = accuracy_fn, device = device)
val_loss, val_acc = engine.validation(model = model, dataloader = data.valid_dataloader, loss_fn = loss_fn,
accuracy_fn = accuracy_fn, log_images=(epoch==(wandb.config.epochs-1)), device = device)
# For pruning (stops trial early if not promising)
trial.report(val_acc, epoch)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
if val_acc < best_loss:
early_stop = 0
best_loss = val_acc
utils.save_model(model = train_model, target_dir = "./save_model", model_name = f"best_model.pth")
else:
early_stop += 1
if early_stop == patience:
break
wandb.log({"Train Loss": train_loss, "Train Accuracy": train_acc, "Validation Loss": val_loss, "Validation Accuracy": val_acc})
wandb.finish()
return val_acc
model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
model.classifier = nn.Sequential(
nn.Dropout(p = 0.2, inplace = True),
nn.Linear(1280, NUM_CLASSES),
# nn.Softmax()
)
# isExist = os.path.exists("save_model/train_model_4e-06.pth")
# print(isExist)
model = utils.load_model(model, "save_model/train_model_4e-06.pth").to(device)
# print(model)
# Create an Optuna study to maximize test accuracy
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=11)
# Find number of pruned and completed trials
pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
# Display the study statistics
print("\nStudy statistics: ")
print(" Number of finished trials: ", len(study.trials))
print(" Number of pruned trials: ", len(pruned_trials))
print(" Number of complete trials: ", len(complete_trials))
trial = study.best_trial
print("Best trial:")
print(" Value: ", trial.value)
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))