Spaces:
Running
Running
Update pages/19_ResNet.py
Browse files- pages/19_ResNet.py +147 -75
pages/19_ResNet.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
# Install
|
2 |
-
#
|
3 |
-
# You can install them using pip if you haven't already:
|
4 |
-
# pip install torch torchvision streamlit
|
5 |
|
|
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.optim as optim
|
@@ -12,64 +12,118 @@ import numpy as np
|
|
12 |
import time
|
13 |
import os
|
14 |
import copy
|
15 |
-
import streamlit as st
|
16 |
-
from PIL import Image
|
17 |
import matplotlib.pyplot as plt
|
18 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
# Data transformations
|
21 |
data_transforms = {
|
22 |
'train': transforms.Compose([
|
23 |
-
transforms.RandomResizedCrop(
|
24 |
transforms.RandomHorizontalFlip(),
|
25 |
transforms.ToTensor(),
|
26 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
27 |
]),
|
28 |
'val': transforms.Compose([
|
29 |
-
transforms.Resize(
|
30 |
-
transforms.CenterCrop(
|
31 |
transforms.ToTensor(),
|
32 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
33 |
]),
|
34 |
}
|
35 |
|
36 |
-
# Load datasets
|
37 |
-
data_dir = 'path/to/data'
|
38 |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
|
39 |
for x in ['train', 'val']}
|
40 |
-
dataloaders = {x: DataLoader(image_datasets[x], batch_size=
|
41 |
for x in ['train', 'val']}
|
42 |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
43 |
class_names = image_datasets['train'].classes
|
|
|
44 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
45 |
|
46 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
model_ft = models.resnet18(pretrained=True)
|
48 |
num_ftrs = model_ft.fc.in_features
|
49 |
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
|
|
|
50 |
model_ft = model_ft.to(device)
|
51 |
|
52 |
-
# Define
|
53 |
criterion = nn.CrossEntropyLoss()
|
54 |
-
optimizer_ft = optim.SGD(model_ft.parameters(), lr=
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
#
|
58 |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
59 |
since = time.time()
|
60 |
-
|
61 |
best_model_wts = copy.deepcopy(model.state_dict())
|
62 |
best_acc = 0.0
|
|
|
|
|
|
|
|
|
63 |
|
64 |
for epoch in range(num_epochs):
|
65 |
-
|
66 |
-
|
67 |
|
68 |
for phase in ['train', 'val']:
|
69 |
if phase == 'train':
|
70 |
-
model.train()
|
71 |
else:
|
72 |
-
model.eval()
|
73 |
|
74 |
running_loss = 0.0
|
75 |
running_corrects = 0
|
@@ -98,66 +152,84 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
98 |
epoch_loss = running_loss / dataset_sizes[phase]
|
99 |
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
100 |
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
if phase == 'val' and epoch_acc > best_acc:
|
105 |
best_acc = epoch_acc
|
106 |
best_model_wts = copy.deepcopy(model.state_dict())
|
107 |
|
108 |
-
|
109 |
|
110 |
time_elapsed = time.time() - since
|
111 |
-
|
112 |
-
|
113 |
-
print('Best val Acc: {:4f}'.format(best_acc))
|
114 |
|
115 |
model.load_state_dict(best_model_wts)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
return model
|
117 |
|
118 |
-
|
119 |
-
model_ft = train_model(model_ft, criterion, optimizer_ft,
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
-
# Streamlit Interface
|
125 |
-
st.title("Image Classification with Fine-tuned ResNet")
|
126 |
-
|
127 |
-
uploaded_file = st.file_uploader("Choose an image...", type="jpg")
|
128 |
-
|
129 |
-
if uploaded_file is not None:
|
130 |
-
image = Image.open(uploaded_file)
|
131 |
-
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
132 |
-
st.write("")
|
133 |
-
st.write("Classifying...")
|
134 |
-
|
135 |
-
model_ft = models.resnet18(pretrained=True)
|
136 |
-
num_ftrs = model_ft.fc.in_features
|
137 |
-
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
|
138 |
-
model_ft.load_state_dict(torch.load('model_ft.pth'))
|
139 |
-
model_ft = model_ft.to(device)
|
140 |
-
model_ft.eval()
|
141 |
-
|
142 |
-
preprocess = T.Compose([
|
143 |
-
T.Resize(256),
|
144 |
-
T.CenterCrop(224),
|
145 |
-
T.ToTensor(),
|
146 |
-
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
147 |
-
])
|
148 |
-
|
149 |
-
img = preprocess(image).unsqueeze(0)
|
150 |
-
img = img.to(device)
|
151 |
-
|
152 |
-
with torch.no_grad():
|
153 |
-
outputs = model_ft(img)
|
154 |
-
_, preds = torch.max(outputs, 1)
|
155 |
-
predicted_class = class_names[preds[0]]
|
156 |
-
|
157 |
-
st.write(f"Predicted Class: {predicted_class}")
|
158 |
-
|
159 |
-
# Plotting the image with matplotlib
|
160 |
-
fig, ax = plt.subplots()
|
161 |
-
ax.imshow(image)
|
162 |
-
ax.set_title(f"Predicted: {predicted_class}")
|
163 |
-
st.pyplot(fig)
|
|
|
1 |
+
# Install required packages
|
2 |
+
# !pip install streamlit torch torchvision matplotlib datasets transformers
|
|
|
|
|
3 |
|
4 |
+
# Import Libraries
|
5 |
+
import streamlit as st
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.optim as optim
|
|
|
12 |
import time
|
13 |
import os
|
14 |
import copy
|
|
|
|
|
15 |
import matplotlib.pyplot as plt
|
16 |
+
from transformers import Trainer, TrainingArguments
|
17 |
+
from datasets import load_dataset
|
18 |
+
|
19 |
+
# Streamlit Interface
|
20 |
+
st.title("Fine-Tuning ResNet for Custom Image Classification")
|
21 |
+
|
22 |
+
# Introduction Section
|
23 |
+
st.markdown("""
|
24 |
+
### Introduction
|
25 |
+
In this exercise, we will fine-tune a pre-trained ResNet model on a custom image classification task using PyTorch. The ResNet (Residual Network) architecture helps in training very deep neural networks by using skip connections to mitigate the vanishing gradient problem.
|
26 |
+
""")
|
27 |
+
|
28 |
+
# User Inputs
|
29 |
+
st.sidebar.header("Model Parameters")
|
30 |
+
data_dir = st.sidebar.text_input("Path to Dataset Directory", 'path_to_caltech101_dataset')
|
31 |
+
input_size = st.sidebar.number_input("Input Size", value=224)
|
32 |
+
batch_size = st.sidebar.number_input("Batch Size", value=32)
|
33 |
+
num_epochs = st.sidebar.number_input("Number of Epochs", value=25)
|
34 |
+
learning_rate = st.sidebar.number_input("Learning Rate", value=0.001)
|
35 |
+
momentum = st.sidebar.number_input("Momentum", value=0.9)
|
36 |
+
|
37 |
+
# Data Preparation Section
|
38 |
+
st.markdown("""
|
39 |
+
### Data Preparation
|
40 |
+
We will use the Caltech 101 dataset, which contains images from 101 object categories. The dataset will be split into training and validation sets, and transformations will be applied to augment the data and normalize it.
|
41 |
+
""")
|
42 |
|
|
|
43 |
data_transforms = {
|
44 |
'train': transforms.Compose([
|
45 |
+
transforms.RandomResizedCrop(input_size),
|
46 |
transforms.RandomHorizontalFlip(),
|
47 |
transforms.ToTensor(),
|
48 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
49 |
]),
|
50 |
'val': transforms.Compose([
|
51 |
+
transforms.Resize(input_size),
|
52 |
+
transforms.CenterCrop(input_size),
|
53 |
transforms.ToTensor(),
|
54 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
55 |
]),
|
56 |
}
|
57 |
|
|
|
|
|
58 |
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
|
59 |
for x in ['train', 'val']}
|
60 |
+
dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
|
61 |
for x in ['train', 'val']}
|
62 |
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
63 |
class_names = image_datasets['train'].classes
|
64 |
+
|
65 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
66 |
|
67 |
+
# Visualize a few training images
|
68 |
+
st.markdown("#### Sample Training Images")
|
69 |
+
def imshow(inp, title=None):
|
70 |
+
inp = inp.numpy().transpose((1, 2, 0))
|
71 |
+
mean = np.array([0.485, 0.456, 0.406])
|
72 |
+
std = np.array([0.229, 0.224, 0.225])
|
73 |
+
inp = std * inp + mean
|
74 |
+
inp = np.clip(inp, 0, 1)
|
75 |
+
plt.imshow(inp)
|
76 |
+
if title is not None:
|
77 |
+
plt.title(title)
|
78 |
+
plt.pause(0.001)
|
79 |
+
|
80 |
+
inputs, classes = next(iter(dataloaders['train']))
|
81 |
+
out = torchvision.utils.make_grid(inputs)
|
82 |
+
st.pyplot(imshow(out, title=[class_names[x] for x in classes]))
|
83 |
+
|
84 |
+
# Model Preparation Section
|
85 |
+
st.markdown("""
|
86 |
+
### Model Preparation
|
87 |
+
We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our custom dataset.
|
88 |
+
""")
|
89 |
+
|
90 |
+
# Load Pre-trained ResNet Model
|
91 |
model_ft = models.resnet18(pretrained=True)
|
92 |
num_ftrs = model_ft.fc.in_features
|
93 |
model_ft.fc = nn.Linear(num_ftrs, len(class_names))
|
94 |
+
|
95 |
model_ft = model_ft.to(device)
|
96 |
|
97 |
+
# Define Loss Function and Optimizer
|
98 |
criterion = nn.CrossEntropyLoss()
|
99 |
+
optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=momentum)
|
100 |
+
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
|
101 |
+
|
102 |
+
# Training Section
|
103 |
+
st.markdown("""
|
104 |
+
### Training
|
105 |
+
We will train the model using stochastic gradient descent (SGD) with momentum and a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
|
106 |
+
""")
|
107 |
|
108 |
+
# Train and Evaluate the Model
|
109 |
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
110 |
since = time.time()
|
|
|
111 |
best_model_wts = copy.deepcopy(model.state_dict())
|
112 |
best_acc = 0.0
|
113 |
+
train_loss_history = []
|
114 |
+
val_loss_history = []
|
115 |
+
train_acc_history = []
|
116 |
+
val_acc_history = []
|
117 |
|
118 |
for epoch in range(num_epochs):
|
119 |
+
st.write('Epoch {}/{}'.format(epoch, num_epochs - 1))
|
120 |
+
st.write('-' * 10)
|
121 |
|
122 |
for phase in ['train', 'val']:
|
123 |
if phase == 'train':
|
124 |
+
model.train()
|
125 |
else:
|
126 |
+
model.eval()
|
127 |
|
128 |
running_loss = 0.0
|
129 |
running_corrects = 0
|
|
|
152 |
epoch_loss = running_loss / dataset_sizes[phase]
|
153 |
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
154 |
|
155 |
+
st.write('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
|
156 |
+
|
157 |
+
if phase == 'train':
|
158 |
+
train_loss_history.append(epoch_loss)
|
159 |
+
train_acc_history.append(epoch_acc)
|
160 |
+
else:
|
161 |
+
val_loss_history.append(epoch_loss)
|
162 |
+
val_acc_history.append(epoch_acc)
|
163 |
|
164 |
if phase == 'val' and epoch_acc > best_acc:
|
165 |
best_acc = epoch_acc
|
166 |
best_model_wts = copy.deepcopy(model.state_dict())
|
167 |
|
168 |
+
st.write()
|
169 |
|
170 |
time_elapsed = time.time() - since
|
171 |
+
st.write('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
|
172 |
+
st.write('Best val Acc: {:4f}'.format(best_acc))
|
|
|
173 |
|
174 |
model.load_state_dict(best_model_wts)
|
175 |
+
|
176 |
+
# Plot training history
|
177 |
+
epochs_range = range(num_epochs)
|
178 |
+
plt.figure(figsize=(10, 5))
|
179 |
+
plt.subplot(1, 2, 1)
|
180 |
+
plt.plot(epochs_range, train_loss_history, label='Training Loss')
|
181 |
+
plt.plot(epochs_range, val_loss_history, label='Validation Loss')
|
182 |
+
plt.legend(loc='upper right')
|
183 |
+
plt.title('Training and Validation Loss')
|
184 |
+
|
185 |
+
plt.subplot(1, 2, 2)
|
186 |
+
plt.plot(epochs_range, train_acc_history, label='Training Accuracy')
|
187 |
+
plt.plot(epochs_range, val_acc_history, label='Validation Accuracy')
|
188 |
+
plt.legend(loc='lower right')
|
189 |
+
plt.title('Training and Validation Accuracy')
|
190 |
+
plt.show()
|
191 |
+
st.pyplot(plt)
|
192 |
+
|
193 |
return model
|
194 |
|
195 |
+
if st.button('Train Model'):
|
196 |
+
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs)
|
197 |
+
# Save the Model
|
198 |
+
torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
|
199 |
+
st.write("Model saved as 'fine_tuned_resnet.pth'")
|
200 |
+
|
201 |
+
# Hugging Face Integration Section
|
202 |
+
st.markdown("""
|
203 |
+
### Hugging Face Integration
|
204 |
+
We will use the Hugging Face library to load the dataset and prepare it for training. This integration will allow us to leverage the benefits of Hugging Face's powerful tools and APIs.
|
205 |
+
""")
|
206 |
+
|
207 |
+
dataset = load_dataset('caltech101', split='train')
|
208 |
+
|
209 |
+
def preprocess_function(examples):
|
210 |
+
return {'pixel_values': [data_transforms['train'](image) for image in examples['image']], 'labels': examples['label']}
|
211 |
+
|
212 |
+
dataset = dataset.map(preprocess_function, batched=True)
|
213 |
+
|
214 |
+
training_args = TrainingArguments(
|
215 |
+
output_dir='./results',
|
216 |
+
evaluation_strategy="epoch",
|
217 |
+
per_device_train_batch_size=8,
|
218 |
+
per_device_eval_batch_size=8,
|
219 |
+
num_train_epochs=3,
|
220 |
+
save_strategy="epoch",
|
221 |
+
logging_dir='./logs',
|
222 |
+
)
|
223 |
+
|
224 |
+
trainer = Trainer(
|
225 |
+
model=model_ft,
|
226 |
+
args=training_args,
|
227 |
+
train_dataset=dataset['train'],
|
228 |
+
eval_dataset=dataset['val'],
|
229 |
+
tokenizer=None,
|
230 |
+
)
|
231 |
+
|
232 |
+
if st.button('Train with Hugging Face'):
|
233 |
+
trainer.train()
|
234 |
+
st.write("Model trained using Hugging Face")
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|