Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from utils.data_loader import get_data_loaders | |
from models.resnet_model import MonkeyResNet # uses your custom ResNet model | |
import os | |
import matplotlib.pyplot as plt | |
from sklearn.utils.class_weight import compute_class_weight | |
import numpy as np | |
# This class helps stop training early if validation loss stops improving | |
class EarlyStopping: | |
def __init__(self, patience=5): | |
self.patience = patience | |
self.counter = 0 | |
self.best_loss = float('inf') | |
self.early_stop = False | |
def __call__(self, val_loss): | |
if val_loss < self.best_loss: | |
self.best_loss = val_loss | |
self.counter = 0 | |
else: | |
self.counter += 1 | |
if self.counter >= self.patience: | |
self.early_stop = True | |
# Hyperparameters | |
data_dir = "data" | |
epochs = 50 | |
batch_size = 32 | |
lr = 0.001 | |
patience = 5 | |
# Load training and validation data | |
train_loader, val_loader, class_names = get_data_loaders(data_dir, batch_size) | |
# Use GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Calculate class weights | |
train_labels = [] | |
for _, labels in train_loader: | |
train_labels.extend(labels.numpy()) | |
train_labels = np.array(train_labels) | |
class_weights = compute_class_weight( | |
class_weight='balanced', | |
classes=np.unique(train_labels), | |
y=train_labels | |
) | |
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) | |
# Set up model, loss function, optimizer, scheduler | |
model = MonkeyResNet(num_classes=len(class_names)).to(device) | |
criterion = nn.CrossEntropyLoss(weight=class_weights) | |
optimizer = optim.Adam(model.parameters(), lr=lr) | |
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2) | |
early_stopper = EarlyStopping(patience=patience) | |
# Store values for plotting | |
train_losses, val_losses = [], [] | |
train_accuracies, val_accuracies = [], [] | |
# Start training loop | |
for epoch in range(epochs): | |
model.train() | |
train_loss = 0 | |
correct = 0 | |
total = 0 | |
for images, labels in train_loader: | |
images, labels = images.to(device), labels.to(device) | |
optimizer.zero_grad() | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
train_loss += loss.item() | |
_, predicted = torch.max(outputs.data, 1) | |
total += labels.size(0) | |
correct += (predicted == labels).sum().item() | |
train_accuracy = 100 * correct / total | |
train_losses.append(train_loss) | |
train_accuracies.append(train_accuracy) | |
# Validation step | |
model.eval() | |
val_loss = 0 | |
correct_val = 0 | |
total_val = 0 | |
with torch.no_grad(): | |
for images, labels in val_loader: | |
images, labels = images.to(device), labels.to(device) | |
outputs = model(images) | |
loss = criterion(outputs, labels) | |
val_loss += loss.item() | |
_, predicted = torch.max(outputs.data, 1) | |
total_val += labels.size(0) | |
correct_val += (predicted == labels).sum().item() | |
val_accuracy = 100 * correct_val / total_val | |
val_losses.append(val_loss) | |
val_accuracies.append(val_accuracy) | |
scheduler.step(val_loss) | |
early_stopper(val_loss) | |
print(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f} - Train Acc: {train_accuracy:.2f}%") | |
if early_stopper.early_stop: | |
print(f"Early stopping triggered at epoch {epoch+1}") | |
break | |
# Save the trained model | |
os.makedirs("models", exist_ok=True) | |
torch.save(model.state_dict(), "models/monkey_resnet.pth") | |
print("Training done. Model saved.") | |
# Save training and validation plots | |
os.makedirs("plots", exist_ok=True) | |
# Loss plot | |
plt.figure(figsize=(10, 5)) | |
plt.plot(train_losses, label="Train Loss") | |
plt.plot(val_losses, label="Val Loss") | |
plt.xlabel("Epoch") | |
plt.ylabel("Loss") | |
plt.title("Training and Validation Loss") | |
plt.legend() | |
plt.grid(True) | |
plt.savefig("plots/loss_plot.png") | |
plt.close() | |
# Accuracy plot | |
plt.figure(figsize=(10, 5)) | |
plt.plot(train_accuracies, label="Train Accuracy") | |
plt.plot(val_accuracies, label="Val Accuracy") | |
plt.xlabel("Epoch") | |
plt.ylabel("Accuracy (%)") | |
plt.title("Training and Validation Accuracy") | |
plt.legend() | |
plt.grid(True) | |
plt.savefig("plots/accuracy_plot.png") | |
plt.close() |