pytorch / pages /19_ResNet.py
eaglelandsonce's picture
Update pages/19_ResNet.py
7125d94 verified
raw
history blame
2.93 kB
# Install required packages
# !pip install streamlit torch torchvision matplotlib
# Import Libraries
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision # Add this import
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import time
import copy # Add this import
import matplotlib.pyplot as plt
# Streamlit Interface
st.title("Simple ResNet Fine-Tuning Example")
# User Inputs
st.sidebar.header("Model Parameters")
batch_size = st.sidebar.number_input("Batch Size", value=32)
num_epochs = st.sidebar.number_input("Number of Epochs", value=5)
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
# Data Preparation Section
st.markdown("""
### Data Preparation
We will use a small subset of the CIFAR-10 dataset for quick experimentation. The dataset will be split into training and validation sets, and transformations will be applied to normalize the data.
""")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# Using only 1000 samples for simplicity
subset_indices = list(range(1000))
train_size = int(0.8 * len(subset_indices))
val_size = len(subset_indices) - train_size
train_indices = subset_indices[:train_size]
val_indices = subset_indices[train_size:]
train_dataset = Subset(train_dataset, train_indices)
val_dataset = Subset(val_dataset, val_indices)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
dataloaders = {'train': train_loader, 'val': val_loader}
class_names = datasets.CIFAR10(root='./data', download=False).classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Visualize a few training images
st.markdown("#### Sample Training Images")
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
fig, ax = plt.subplots()
ax.imshow(inp)
if title is not None:
ax.set_title(title)
st.pyplot(fig)
inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# Model Preparation Section
st.markdown("""
### Model Preparation
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our dataset.
""")
# Load Pre-trained ResNet Model
model_ft = models.resnet