Spaces:
Running
Running
# Install required packages | |
# !pip install streamlit torch torchvision matplotlib | |
# Import Libraries | |
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torchvision # Add this import | |
from torchvision import datasets, models, transforms | |
from torch.utils.data import DataLoader, Subset | |
import numpy as np | |
import time | |
import copy # Add this import | |
import matplotlib.pyplot as plt | |
# Streamlit Interface | |
st.title("Simple ResNet Fine-Tuning Example") | |
# User Inputs | |
st.sidebar.header("Model Parameters") | |
batch_size = st.sidebar.number_input("Batch Size", value=32) | |
num_epochs = st.sidebar.number_input("Number of Epochs", value=5) | |
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001) | |
# Data Preparation Section | |
st.markdown(""" | |
### Data Preparation | |
We will use a small subset of the CIFAR-10 dataset for quick experimentation. The dataset will be split into training and validation sets, and transformations will be applied to normalize the data. | |
""") | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
]) | |
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) | |
val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform) | |
# Using only 1000 samples for simplicity | |
subset_indices = list(range(1000)) | |
train_size = int(0.8 * len(subset_indices)) | |
val_size = len(subset_indices) - train_size | |
train_indices = subset_indices[:train_size] | |
val_indices = subset_indices[train_size:] | |
train_dataset = Subset(train_dataset, train_indices) | |
val_dataset = Subset(val_dataset, val_indices) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) | |
dataloaders = {'train': train_loader, 'val': val_loader} | |
class_names = datasets.CIFAR10(root='./data', download=False).classes | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Visualize a few training images | |
st.markdown("#### Sample Training Images") | |
def imshow(inp, title=None): | |
inp = inp.numpy().transpose((1, 2, 0)) | |
mean = np.array([0.485, 0.456, 0.406]) | |
std = np.array([0.229, 0.224, 0.225]) | |
inp = std * inp + mean | |
inp = np.clip(inp, 0, 1) | |
fig, ax = plt.subplots() | |
ax.imshow(inp) | |
if title is not None: | |
ax.set_title(title) | |
st.pyplot(fig) | |
inputs, classes = next(iter(dataloaders['train'])) | |
out = torchvision.utils.make_grid(inputs) | |
imshow(out, title=[class_names[x] for x in classes]) | |
# Model Preparation Section | |
st.markdown(""" | |
### Model Preparation | |
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our dataset. | |
""") | |
# Load Pre-trained ResNet Model | |
model_ft = models.resnet | |