Spaces:
Running
Running
Update pages/19_ResNet.py
Browse files- pages/19_ResNet.py +29 -80
pages/19_ResNet.py
CHANGED
@@ -1,85 +1,57 @@
|
|
1 |
# Install required packages
|
2 |
-
# !pip install streamlit torch torchvision matplotlib
|
3 |
|
4 |
# Import Libraries
|
5 |
import streamlit as st
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.optim as optim
|
9 |
-
import torchvision # Add this import
|
10 |
from torchvision import datasets, models, transforms
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
import numpy as np
|
13 |
import time
|
14 |
-
import os
|
15 |
-
import copy
|
16 |
import matplotlib.pyplot as plt
|
17 |
-
from transformers import Trainer, TrainingArguments
|
18 |
-
from datasets import load_dataset
|
19 |
|
20 |
# Streamlit Interface
|
21 |
-
st.title("Fine-Tuning
|
22 |
-
|
23 |
-
# Introduction Section
|
24 |
-
st.markdown("""
|
25 |
-
### Introduction
|
26 |
-
In this exercise, we will fine-tune a pre-trained ResNet model on a custom image classification task using PyTorch. The ResNet (Residual Network) architecture helps in training very deep neural networks by using skip connections to mitigate the vanishing gradient problem.
|
27 |
-
""")
|
28 |
|
29 |
# User Inputs
|
30 |
st.sidebar.header("Model Parameters")
|
31 |
-
input_size = st.sidebar.number_input("Input Size", value=224)
|
32 |
batch_size = st.sidebar.number_input("Batch Size", value=32)
|
33 |
-
num_epochs = st.sidebar.number_input("Number of Epochs", value=
|
34 |
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
|
35 |
-
momentum = st.sidebar.number_input("Momentum", value=0.9)
|
36 |
|
37 |
# Data Preparation Section
|
38 |
st.markdown("""
|
39 |
### Data Preparation
|
40 |
-
We will use the CIFAR-10 dataset
|
41 |
""")
|
42 |
|
43 |
transform = transforms.Compose([
|
44 |
-
transforms.Resize(
|
45 |
transforms.ToTensor(),
|
46 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
47 |
])
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
|
|
51 |
|
52 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
53 |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
54 |
|
55 |
dataloaders = {'train': train_loader, 'val': val_loader}
|
56 |
-
|
57 |
-
class_names = train_dataset.classes
|
58 |
|
59 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
60 |
|
61 |
-
# Visualize a few training images
|
62 |
-
st.markdown("#### Sample Training Images")
|
63 |
-
def imshow(inp, title=None):
|
64 |
-
inp = inp.numpy().transpose((1, 2, 0))
|
65 |
-
mean = np.array([0.485, 0.456, 0.406])
|
66 |
-
std = np.array([0.229, 0.224, 0.225])
|
67 |
-
inp = std * inp + mean
|
68 |
-
inp = np.clip(inp, 0, 1)
|
69 |
-
fig, ax = plt.subplots()
|
70 |
-
ax.imshow(inp)
|
71 |
-
if title is not None:
|
72 |
-
ax.set_title(title)
|
73 |
-
st.pyplot(fig)
|
74 |
-
|
75 |
-
inputs, classes = next(iter(dataloaders['train']))
|
76 |
-
out = torchvision.utils.make_grid(inputs)
|
77 |
-
imshow(out, title=[class_names[x] for x in classes])
|
78 |
-
|
79 |
# Model Preparation Section
|
80 |
st.markdown("""
|
81 |
### Model Preparation
|
82 |
-
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our
|
83 |
""")
|
84 |
|
85 |
# Load Pre-trained ResNet Model
|
@@ -91,18 +63,16 @@ model_ft = model_ft.to(device)
|
|
91 |
|
92 |
# Define Loss Function and Optimizer
|
93 |
criterion = nn.CrossEntropyLoss()
|
94 |
-
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=
|
95 |
-
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
|
96 |
|
97 |
# Training Section
|
98 |
st.markdown("""
|
99 |
### Training
|
100 |
-
We will train the model using stochastic gradient descent (SGD) with
|
101 |
""")
|
102 |
|
103 |
# Train and Evaluate the Model
|
104 |
-
def train_model(model, criterion, optimizer,
|
105 |
-
since = time.time()
|
106 |
best_model_wts = copy.deepcopy(model.state_dict())
|
107 |
best_acc = 0.0
|
108 |
train_loss_history = []
|
@@ -111,7 +81,7 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
111 |
val_acc_history = []
|
112 |
|
113 |
for epoch in range(num_epochs):
|
114 |
-
st.write('Epoch {}/{}'
|
115 |
st.write('-' * 10)
|
116 |
|
117 |
for phase in ['train', 'val']:
|
@@ -141,13 +111,8 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
141 |
running_loss += loss.item() * inputs.size(0)
|
142 |
running_corrects += torch.sum(preds == labels.data)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
epoch_loss = running_loss / dataset_sizes[phase]
|
148 |
-
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
149 |
-
|
150 |
-
st.write('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
|
151 |
|
152 |
if phase == 'train':
|
153 |
train_loss_history.append(epoch_loss)
|
@@ -156,49 +121,33 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
156 |
val_loss_history.append(epoch_loss)
|
157 |
val_acc_history.append(epoch_acc)
|
158 |
|
|
|
|
|
159 |
if phase == 'val' and epoch_acc > best_acc:
|
160 |
best_acc = epoch_acc
|
161 |
best_model_wts = copy.deepcopy(model.state_dict())
|
162 |
|
163 |
-
st.write()
|
164 |
-
|
165 |
-
time_elapsed = time.time() - since
|
166 |
-
st.write('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
167 |
-
st.write('Best val Acc: {:4f}'.format(best_acc))
|
168 |
-
|
169 |
model.load_state_dict(best_model_wts)
|
170 |
-
|
171 |
# Plot training history
|
172 |
-
epochs_range = range(num_epochs)
|
173 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
174 |
-
ax1.plot(
|
175 |
-
ax1.plot(
|
176 |
ax1.legend(loc='upper right')
|
177 |
ax1.set_title('Training and Validation Loss')
|
178 |
|
179 |
-
ax2.plot(
|
180 |
-
ax2.plot(
|
181 |
ax2.legend(loc='lower right')
|
182 |
ax2.set_title('Training and Validation Accuracy')
|
183 |
|
184 |
st.pyplot(fig)
|
185 |
-
|
186 |
return model
|
187 |
|
188 |
if st.button('Train Model'):
|
189 |
-
model_ft = train_model(model_ft, criterion, optimizer_ft,
|
190 |
# Save the Model
|
191 |
torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
|
192 |
st.write("Model saved as 'fine_tuned_resnet.pth'")
|
193 |
|
194 |
-
# Hugging Face Integration Section
|
195 |
-
st.markdown("""
|
196 |
-
### Hugging Face Integration
|
197 |
-
We will use the Hugging Face library to load the dataset and prepare it for training. This integration will allow us to leverage the benefits of Hugging Face's powerful tools and APIs.
|
198 |
-
""")
|
199 |
-
|
200 |
-
# This part is just illustrative since Hugging Face's Trainer does not natively support ResNet.
|
201 |
-
# However, you can still follow a similar approach for transformer models and NLP datasets.
|
202 |
-
if st.button('Train with Hugging Face'):
|
203 |
-
st.write("This section is illustrative and typically used for NLP tasks with Hugging Face transformers.")
|
204 |
-
|
|
|
1 |
# Install required packages
|
2 |
+
# !pip install streamlit torch torchvision matplotlib
|
3 |
|
4 |
# Import Libraries
|
5 |
import streamlit as st
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.optim as optim
|
|
|
9 |
from torchvision import datasets, models, transforms
|
10 |
+
from torch.utils.data import DataLoader, Subset
|
11 |
import numpy as np
|
12 |
import time
|
|
|
|
|
13 |
import matplotlib.pyplot as plt
|
|
|
|
|
14 |
|
15 |
# Streamlit Interface
|
16 |
+
st.title("Simple ResNet Fine-Tuning Example")
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
# User Inputs
|
19 |
st.sidebar.header("Model Parameters")
|
|
|
20 |
batch_size = st.sidebar.number_input("Batch Size", value=32)
|
21 |
+
num_epochs = st.sidebar.number_input("Number of Epochs", value=5)
|
22 |
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
|
|
|
23 |
|
24 |
# Data Preparation Section
|
25 |
st.markdown("""
|
26 |
### Data Preparation
|
27 |
+
We will use a small subset of the CIFAR-10 dataset for quick experimentation. The dataset will be split into training and validation sets, and transformations will be applied to normalize the data.
|
28 |
""")
|
29 |
|
30 |
transform = transforms.Compose([
|
31 |
+
transforms.Resize((224, 224)),
|
32 |
transforms.ToTensor(),
|
33 |
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
34 |
])
|
35 |
|
36 |
+
full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
|
37 |
+
subset_indices = list(range(1000)) # Use only 1000 samples for simplicity
|
38 |
+
subset_dataset = Subset(full_dataset, subset_indices)
|
39 |
+
train_size = int(0.8 * len(subset_dataset))
|
40 |
+
val_size = len(subset_dataset) - train_size
|
41 |
+
train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
|
42 |
|
43 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
44 |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
45 |
|
46 |
dataloaders = {'train': train_loader, 'val': val_loader}
|
47 |
+
class_names = full_dataset.classes
|
|
|
48 |
|
49 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Model Preparation Section
|
52 |
st.markdown("""
|
53 |
### Model Preparation
|
54 |
+
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our dataset.
|
55 |
""")
|
56 |
|
57 |
# Load Pre-trained ResNet Model
|
|
|
63 |
|
64 |
# Define Loss Function and Optimizer
|
65 |
criterion = nn.CrossEntropyLoss()
|
66 |
+
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9)
|
|
|
67 |
|
68 |
# Training Section
|
69 |
st.markdown("""
|
70 |
### Training
|
71 |
+
We will train the model using stochastic gradient descent (SGD) with a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
|
72 |
""")
|
73 |
|
74 |
# Train and Evaluate the Model
|
75 |
+
def train_model(model, criterion, optimizer, num_epochs=5):
|
|
|
76 |
best_model_wts = copy.deepcopy(model.state_dict())
|
77 |
best_acc = 0.0
|
78 |
train_loss_history = []
|
|
|
81 |
val_acc_history = []
|
82 |
|
83 |
for epoch in range(num_epochs):
|
84 |
+
st.write(f'Epoch {epoch+1}/{num_epochs}')
|
85 |
st.write('-' * 10)
|
86 |
|
87 |
for phase in ['train', 'val']:
|
|
|
111 |
running_loss += loss.item() * inputs.size(0)
|
112 |
running_corrects += torch.sum(preds == labels.data)
|
113 |
|
114 |
+
epoch_loss = running_loss / len(dataloaders[phase].dataset)
|
115 |
+
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
if phase == 'train':
|
118 |
train_loss_history.append(epoch_loss)
|
|
|
121 |
val_loss_history.append(epoch_loss)
|
122 |
val_acc_history.append(epoch_acc)
|
123 |
|
124 |
+
st.write(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
|
125 |
+
|
126 |
if phase == 'val' and epoch_acc > best_acc:
|
127 |
best_acc = epoch_acc
|
128 |
best_model_wts = copy.deepcopy(model.state_dict())
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
model.load_state_dict(best_model_wts)
|
131 |
+
|
132 |
# Plot training history
|
|
|
133 |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
134 |
+
ax1.plot(train_loss_history, label='Training Loss')
|
135 |
+
ax1.plot(val_loss_history, label='Validation Loss')
|
136 |
ax1.legend(loc='upper right')
|
137 |
ax1.set_title('Training and Validation Loss')
|
138 |
|
139 |
+
ax2.plot(train_acc_history, label='Training Accuracy')
|
140 |
+
ax2.plot(val_acc_history, label='Validation Accuracy')
|
141 |
ax2.legend(loc='lower right')
|
142 |
ax2.set_title('Training and Validation Accuracy')
|
143 |
|
144 |
st.pyplot(fig)
|
145 |
+
|
146 |
return model
|
147 |
|
148 |
if st.button('Train Model'):
|
149 |
+
model_ft = train_model(model_ft, criterion, optimizer_ft, num_epochs)
|
150 |
# Save the Model
|
151 |
torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
|
152 |
st.write("Model saved as 'fine_tuned_resnet.pth'")
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|