pytorch / pages /19_ResNet.py
eaglelandsonce's picture
Update pages/19_ResNet.py
f6317cd verified
raw
history blame
5.39 kB
# Install required packages
# !pip install streamlit torch torchvision matplotlib
# Import Libraries
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import time
import matplotlib.pyplot as plt
# Streamlit Interface
st.title("Simple ResNet Fine-Tuning Example")
# User Inputs
st.sidebar.header("Model Parameters")
batch_size = st.sidebar.number_input("Batch Size", value=32)
num_epochs = st.sidebar.number_input("Number of Epochs", value=5)
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
# Data Preparation Section
st.markdown("""
### Data Preparation
We will use a small subset of the CIFAR-10 dataset for quick experimentation. The dataset will be split into training and validation sets, and transformations will be applied to normalize the data.
""")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
subset_indices = list(range(1000)) # Use only 1000 samples for simplicity
subset_dataset = Subset(full_dataset, subset_indices)
train_size = int(0.8 * len(subset_dataset))
val_size = len(subset_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
dataloaders = {'train': train_loader, 'val': val_loader}
class_names = full_dataset.classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Model Preparation Section
st.markdown("""
### Model Preparation
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our dataset.
""")
# Load Pre-trained ResNet Model
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to(device)
# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9)
# Training Section
st.markdown("""
### Training
We will train the model using stochastic gradient descent (SGD) with a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
""")
# Train and Evaluate the Model
def train_model(model, criterion, optimizer, num_epochs=5):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(num_epochs):
st.write(f'Epoch {epoch+1}/{num_epochs}')
st.write('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
if phase == 'train':
train_loss_history.append(epoch_loss)
train_acc_history.append(epoch_acc)
else:
val_loss_history.append(epoch_loss)
val_acc_history.append(epoch_acc)
st.write(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model_wts)
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(train_loss_history, label='Training Loss')
ax1.plot(val_loss_history, label='Validation Loss')
ax1.legend(loc='upper right')
ax1.set_title('Training and Validation Loss')
ax2.plot(train_acc_history, label='Training Accuracy')
ax2.plot(val_acc_history, label='Validation Accuracy')
ax2.legend(loc='lower right')
ax2.set_title('Training and Validation Accuracy')
st.pyplot(fig)
return model
if st.button('Train Model'):
model_ft = train_model(model_ft, criterion, optimizer_ft, num_epochs)
# Save the Model
torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
st.write("Model saved as 'fine_tuned_resnet.pth'")