import os import torch import torch.nn as nn import torch.optim as optim import matplotlib.pyplot as plt from matplotlib import rcParams from models import * from torch.utils.tensorboard import SummaryWriter from configs import * import data_loader import torch.nn.functional as F import csv import numpy as np from torchcontrib.optim import SWA rcParams["font.family"] = "Times New Roman" SWA_START = 5 # Starting epoch for SWA SWA_FREQ = 5 # Frequency of updating SWA weights def rand_bbox(size, lam): W = size[2] H = size[3] cut_rat = np.sqrt(1.0 - lam) cut_w = np.int_(W * cut_rat) cut_h = np.int_(H * cut_rat) # uniform cx = np.random.randint(W) cy = np.random.randint(H) bbx1 = np.clip(cx - cut_w // 2, 0, W) bby1 = np.clip(cy - cut_h // 2, 0, H) bbx2 = np.clip(cx + cut_w // 2, 0, W) bby2 = np.clip(cy + cut_h // 2, 0, H) return bbx1, bby1, bbx2, bby2 def cutmix_data(input, target, alpha=1.0): if alpha > 0: lam = np.random.beta(alpha, alpha) else: lam = 1 batch_size = input.size()[0] index = torch.randperm(batch_size) rand_index = torch.randperm(input.size()[0]) bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2] lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) targets_a = target targets_b = target[rand_index] return input, targets_a, targets_b, lam def cutmix_criterion(criterion, outputs, targets_a, targets_b, lam): return lam * criterion(outputs, targets_a) + (1 - lam) * criterion( outputs, targets_b ) def setup_tensorboard(): return SummaryWriter(log_dir="output/tensorboard/training") def load_and_preprocess_data(): return data_loader.load_data( COMBINED_DATA_DIR + "1", preprocess, ) def initialize_model_optimizer_scheduler(): model = MODEL.to(DEVICE) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS) return model, criterion, optimizer, scheduler def plot_and_log_metrics(metrics_dict, step, writer, prefix="Train"): for metric_name, metric_value in metrics_dict.items(): writer.add_scalar(f"{prefix}/{metric_name}", metric_value, step) def train_one_epoch(model, criterion, optimizer, train_loader, epoch, alpha): model.train() running_loss = 0.0 total_train = 0 correct_train = 0 for i, (inputs, labels) in enumerate(train_loader, 0): inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) optimizer.zero_grad() # Apply CutMix inputs, targets_a, targets_b, lam = cutmix_data(inputs, labels, alpha=1) outputs = model(inputs) # Calculate CutMix loss loss = cutmix_criterion(criterion, outputs, targets_a, targets_b, lam) loss.backward() optimizer.step() running_loss += loss.item() if (i + 1) % NUM_PRINT == 0: print( f"[Epoch {epoch + 1}, Batch {i + 1}/{len(train_loader)}] " f"Loss: {running_loss / NUM_PRINT:.6f}" ) running_loss = 0.0 _, predicted = torch.max(outputs, 1) total_train += labels.size(0) correct_train += (predicted == labels).sum().item() avg_train_loss = running_loss / len(train_loader) return avg_train_loss, correct_train / total_train def validate_model(model, criterion, valid_loader): model.eval() val_loss = 0.0 correct_val = 0 total_val = 0 with torch.no_grad(): for inputs, labels in valid_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) outputs = model(inputs) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs, 1) total_val += labels.size(0) correct_val += (predicted == labels).sum().item() avg_val_loss = val_loss / len(valid_loader) return avg_val_loss, correct_val / total_val def main_training_loop(): writer = setup_tensorboard() train_loader, valid_loader = load_and_preprocess_data() model, criterion, optimizer, scheduler = initialize_model_optimizer_scheduler() best_val_loss = float("inf") best_val_accuracy = 0.0 no_improvement_count = 0 epoch_metrics = [] AVG_TRAIN_LOSS_HIST = [] AVG_VAL_LOSS_HIST = [] TRAIN_ACC_HIST = [] VAL_ACC_HIST = [] # Initialize SWA optimizer swa_optimizer = SWA(optimizer, swa_start=SWA_START, swa_freq=SWA_FREQ) for epoch in range(NUM_EPOCHS): print(f"\n[Epoch: {epoch + 1}/{NUM_EPOCHS}]") print("Learning rate:", scheduler.get_last_lr()[0]) avg_train_loss, train_accuracy = train_one_epoch( model, criterion, optimizer, train_loader, epoch, CUTMIX_ALPHA ) AVG_TRAIN_LOSS_HIST.append(avg_train_loss) TRAIN_ACC_HIST.append(train_accuracy) # Log training metrics train_metrics = { "Loss": avg_train_loss, "Accuracy": train_accuracy, } plot_and_log_metrics(train_metrics, epoch, writer=writer, prefix="Train") epoch_metrics.append( { "Epoch": epoch + 1, "Train Loss": avg_train_loss, "Train Accuracy": train_accuracy, "Validation Loss": avg_val_loss, "Validation Accuracy": val_accuracy, "Learning Rate": scheduler.get_last_lr()[0], } ) # Learning rate scheduling if epoch < WARMUP_EPOCHS: # Linear warm-up phase lr = LEARNING_RATE * (epoch + 1) / WARMUP_EPOCHS for param_group in optimizer.param_groups: param_group["lr"] = lr else: # Cosine annealing scheduler after warm-up scheduler.step() avg_val_loss, val_accuracy = validate_model(model, criterion, valid_loader) AVG_VAL_LOSS_HIST.append(avg_val_loss) VAL_ACC_HIST.append(val_accuracy) # Log validation metrics val_metrics = { "Loss": avg_val_loss, "Accuracy": val_accuracy, } plot_and_log_metrics(val_metrics, epoch, writer=writer, prefix="Validation") # Print average training and validation metrics print(f"Average Training Loss: {avg_train_loss:.6f}") print(f"Average Validation Loss: {avg_val_loss:.6f}") print(f"Training Accuracy: {train_accuracy:.6f}") print(f"Validation Accuracy: {val_accuracy:.6f}") # Check for early stopping based on validation accuracy if val_accuracy > best_val_accuracy: best_val_accuracy = val_accuracy no_improvement_count = 0 else: no_improvement_count += 1 # Early stopping condition if no_improvement_count >= EARLY_STOPPING_PATIENCE: print( "Early stopping: Validation accuracy did not improve for {} consecutive epochs.".format( EARLY_STOPPING_PATIENCE ) ) break # Update SWA weights if epoch >= SWA_START and epoch % SWA_FREQ == 0: swa_optimizer.update_swa() # Apply SWA to the final model weights swa_optimizer.swap_swa_sgd() csv_filename = "training_metrics.csv" with open(csv_filename, mode="w", newline="") as csv_file: fieldnames = [ "Epoch", "Train Loss", "Train Accuracy", "Validation Loss", "Validation Accuracy", "Learning Rate", ] writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer.writeheader() for metric in epoch_metrics: writer.writerow(metric) print(f"Metrics saved to {csv_filename}") # Ensure the parent directory exists os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True) torch.save(model.state_dict(), MODEL_SAVE_PATH) print("\nModel saved at", MODEL_SAVE_PATH) # Plot loss and accuracy curves plt.figure(figsize=(12, 4)) plt.subplot(1, 2, 1) plt.plot( range(1, len(AVG_TRAIN_LOSS_HIST) + 1), AVG_TRAIN_LOSS_HIST, label="Average Train Loss", ) plt.plot( range(1, len(AVG_VAL_LOSS_HIST) + 1), AVG_VAL_LOSS_HIST, label="Average Validation Loss", ) plt.xlabel("Epochs") plt.ylabel("Loss") plt.legend() plt.title("Loss Curves") plt.subplot(1, 2, 2) plt.plot(range(1, len(TRAIN_ACC_HIST) + 1), TRAIN_ACC_HIST, label="Train Accuracy") plt.plot(range(1, len(VAL_ACC_HIST) + 1), VAL_ACC_HIST, label="Validation Accuracy") plt.xlabel("Epochs") plt.ylabel("Accuracy") plt.legend() plt.title("Accuracy Curves") plt.tight_layout() plt.savefig("training_curves.png") # Close TensorBoard writer writer.close() if __name__ == "__main__": main_training_loop()