Spaces:
Running
Running
File size: 6,026 Bytes
3276ada f6317cd 03020ab 3276ada 03020ab 7125d94 03020ab f6317cd 03020ab 7125d94 03020ab 3276ada f6317cd 3276ada f6317cd 3276ada f6317cd 3276ada 03020ab 083d57e f6317cd 083d57e 63668c4 083d57e 63668c4 3276ada 03020ab 7125d94 3276ada 63668c4 3276ada 63668c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# Install required packages
# !pip install streamlit torch torchvision matplotlib
# Import Libraries
import streamlit as st
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision # Add this import
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import time
import copy # Add this import
import matplotlib.pyplot as plt
# Streamlit Interface
st.title("Simple ResNet Fine-Tuning Example")
# User Inputs
st.sidebar.header("Model Parameters")
batch_size = st.sidebar.number_input("Batch Size", value=32)
num_epochs = st.sidebar.number_input("Number of Epochs", value=5)
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
# Data Preparation Section
st.markdown("""
### Data Preparation
We will use a small subset of the CIFAR-10 dataset for quick experimentation. The dataset will be split into training and validation sets, and transformations will be applied to normalize the data.
""")
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
subset_indices = list(range(1000)) # Use only 1000 samples for simplicity
subset_dataset = Subset(full_dataset, subset_indices)
train_size = int(0.8 * len(subset_dataset))
val_size = len(subset_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
dataloaders = {'train': train_loader, 'val': val_loader}
class_names = full_dataset.classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Visualize a few training images
st.markdown("#### Sample Training Images")
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
fig, ax = plt.subplots()
ax.imshow(inp)
if title is not None:
ax.set_title(title)
st.pyplot(fig)
inputs, classes = next(iter(dataloaders['train']))
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[class_names[x] for x in classes])
# Model Preparation Section
st.markdown("""
### Model Preparation
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our custom dataset.
""")
# Load Pre-trained ResNet Model
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
model_ft = model_ft.to(device)
# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9)
# Training Section
st.markdown("""
### Training
We will train the model using stochastic gradient descent (SGD) with a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
""")
# Train and Evaluate the Model
def train_model(model, criterion, optimizer, num_epochs=5):
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []
for epoch in range(num_epochs):
st.write(f'Epoch {epoch+1}/{num_epochs}')
st.write('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
if phase == 'train':
train_loss_history.append(epoch_loss)
train_acc_history.append(epoch_acc)
else:
val_loss_history.append(epoch_loss)
val_acc_history.append(epoch_acc)
st.write(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
model.load_state_dict(best_model_wts)
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.plot(train_loss_history, label='Training Loss')
ax1.plot(val_loss_history, label='Validation Loss')
ax1.legend(loc='upper right')
ax1.set_title('Training and Validation Loss')
ax2.plot(train_acc_history, label='Training Accuracy')
ax2.plot(val_acc_history, label='Validation Accuracy')
ax2.legend(loc='lower right')
ax2.set_title('Training and Validation Accuracy')
st.pyplot(fig)
return model
if st.button('Train Model'):
model_ft = train_model(model_ft, criterion, optimizer_ft, num_epochs)
# Save the Model
torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
st.write("Model saved as 'fine_tuned_resnet.pth'")
|