import torch import torch.nn as nn import torch.optim as optim from utils.data_loader import get_data_loaders from models.resnet_model import MonkeyResNet import os import matplotlib.pyplot as plt from sklearn.utils.class_weight import compute_class_weight import numpy as np # This class helps stop training early if validation loss stops improving class EarlyStopping: def __init__(self, patience=5): self.patience = patience self.counter = 0 self.best_loss = float('inf') self.early_stop = False def __call__(self, val_loss): if val_loss < self.best_loss: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True # Hyperparameters data_dir = "data" epochs = 50 batch_size = 32 lr = 0.001 patience = 5 # Load training and validation data train_loader, val_loader, class_names = get_data_loaders(data_dir, batch_size) # Use GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Calculate class weights train_labels = [] for _, labels in train_loader: train_labels.extend(labels.numpy()) train_labels = np.array(train_labels) class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(train_labels), y=train_labels ) class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) # Set up model, loss function, optimizer, scheduler model = MonkeyResNet(num_classes=len(class_names)).to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2) early_stopper = EarlyStopping(patience=patience) # Store values for plotting train_losses, val_losses = [], [] train_accuracies, val_accuracies = [], [] # Start training loop for epoch in range(epochs): model.train() train_loss = 0 correct = 0 total = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() train_accuracy = 100 * correct / total train_losses.append(train_loss) train_accuracies.append(train_accuracy) # Validation step model.eval() val_loss = 0 correct_val = 0 total_val = 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) val_loss += loss.item() _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).sum().item() val_accuracy = 100 * correct_val / total_val val_losses.append(val_loss) val_accuracies.append(val_accuracy) scheduler.step(val_loss) early_stopper(val_loss) print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Train Acc: {train_accuracy:.2f}%") if early_stopper.early_stop: print(f"Early stopping triggered at epoch {epoch+1}") break # Save the trained model os.makedirs("models", exist_ok=True) torch.save(model.state_dict(), "models/monkey_resnet.pth") print("Training done. Model saved.") # Save training and validation plots os.makedirs("plots", exist_ok=True) # Loss plot plt.figure(figsize=(10, 5)) plt.plot(train_losses, label="Train Loss") plt.plot(val_losses, label="Val Loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.title("Training and Validation Loss") plt.legend() plt.grid(True) plt.savefig("plots/loss_plot.png") plt.close() # Accuracy plot plt.figure(figsize=(10, 5)) plt.plot(train_accuracies, label="Train Accuracy") plt.plot(val_accuracies, label="Val Accuracy") plt.xlabel("Epoch") plt.ylabel("Accuracy (%)") plt.title("Training and Validation Accuracy") plt.legend() plt.grid(True) plt.savefig("plots/accuracy_plot.png") plt.close()