Spiral-ND / optuna_unused.py
cycool29's picture
Update
73666ad
import os
import optuna
from optuna.trial import TrialState
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
from configs import *
import data_loader
from torch.utils.tensorboard import SummaryWriter
import time
import numpy as np
torch.cuda.empty_cache()
print(f"Using device: {DEVICE}")
EPOCHS = 10
# N_TRIALS = 10
# TIMEOUT = 5000
EARLY_STOPPING_PATIENCE = (
4 # Number of epochs with no improvement to trigger early stopping
)
# Create a TensorBoard writer
writer = SummaryWriter(log_dir="output/tensorboard/tuning")
# Function to create or modify data loaders with the specified batch size
def create_data_loaders(batch_size):
train_loader, valid_loader = data_loader.load_data(
COMBINED_DATA_DIR + "1",
preprocess,
batch_size=batch_size,
)
return train_loader, valid_loader
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1.0 - lam)
cut_w = np.int_(W * cut_rat)
cut_h = np.int_(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
def cutmix_data(input, target, alpha=1.0):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = input.size()[0]
index = torch.randperm(batch_size)
rand_index = torch.randperm(input.size()[0])
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
targets_a = target
targets_b = target[rand_index]
return input, targets_a, targets_b, lam
def cutmix_criterion(criterion, outputs, targets_a, targets_b, lam):
return lam * criterion(outputs, targets_a) + (1 - lam) * criterion(
outputs, targets_b
)
# Objective function for optimization
def objective(trial, model=MODEL):
model = model.to(DEVICE)
batch_size = trial.suggest_categorical("batch_size", [16, 32])
train_loader, valid_loader = create_data_loaders(batch_size)
lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
gamma = trial.suggest_float("gamma", 0.1, 0.9, step=0.1)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
past_trials = 0 # Number of trials already completed
# Print best hyperparameters:
if past_trials > 0:
print("\nBest Hyperparameters:")
print(f"{study.best_trial.params}")
print(f"\n[INFO] Trial: {trial.number}")
print(f"Batch Size: {batch_size}")
print(f"Learning Rate: {lr}")
print(f"Gamma: {gamma}\n")
early_stopping_counter = 0
best_accuracy = 0.0
for epoch in range(EPOCHS):
model.train()
for batch_idx, (data, target) in enumerate(train_loader, 0):
data, target = data.to(DEVICE), target.to(DEVICE)
optimizer.zero_grad()
if model.__class__.__name__ == "GoogLeNet":
output = model(data).logits
else:
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_loader, 0):
data, target = data.to(DEVICE), target.to(DEVICE)
data, targets_a, targets_b, lam = cutmix_data(data, target, alpha=1)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / len(valid_loader.dataset)
if accuracy >= 1.0:
print(f"Desired accuracy of 1.0 achieved. Stopping early.")
return float("inf")
# Log hyperparameters and accuracy to TensorBoard
writer.add_scalar("Accuracy", accuracy, trial.number)
writer.add_hparams(
{"batch_size": batch_size, "lr": lr, "gamma": gamma},
{"accuracy": accuracy},
)
print(f"[EPOCH {epoch + 1}] Accuracy: {accuracy:.4f}")
trial.report(accuracy, epoch)
if accuracy > best_accuracy:
best_accuracy = accuracy
early_stopping_counter = 0
else:
early_stopping_counter += 1
# Early stopping check
if early_stopping_counter >= EARLY_STOPPING_PATIENCE:
print(f"\nEarly stopping at epoch {epoch + 1}")
break
if trial.number > 10 and trial.params["lr"] < 1e-3 and best_accuracy < 0.7:
return float("inf")
past_trials += 1
return best_accuracy
if __name__ == "__main__":
hyperband_pruner = optuna.pruners.HyperbandPruner()
# Record the start time
start_time = time.time()
# storage = optuna.storages.InMemoryStorage()
study = optuna.create_study(
direction="maximize",
pruner=hyperband_pruner,
study_name="hyperparameter_tuning",
storage="sqlite:///" + MODEL.__class__.__name__ + ".sqlite3",
)
study.optimize(objective)
# Record the end time
end_time = time.time()
# Calculate the duration of hyperparameter tuning
tuning_duration = end_time - start_time
print(f"Hyperparameter tuning took {tuning_duration:.2f} seconds.")
best_trial = study.best_trial
print("\nBest Trial:")
print(f" Trial Number: {best_trial.number}")
print(f" Best Accuracy: {best_trial.value:.4f}")
print(" Hyperparameters:")
for key, value in best_trial.params.items():
print(f" {key}: {value}")