Spaces:
Running
Running
Update pages/19_ResNet.py
Browse files- pages/19_ResNet.py +17 -18
pages/19_ResNet.py
CHANGED
@@ -66,14 +66,15 @@ def imshow(inp, title=None):
|
|
66 |
std = np.array([0.229, 0.224, 0.225])
|
67 |
inp = std * inp + mean
|
68 |
inp = np.clip(inp, 0, 1)
|
69 |
-
plt.
|
|
|
70 |
if title is not None:
|
71 |
-
|
72 |
-
|
73 |
|
74 |
inputs, classes = next(iter(dataloaders['train']))
|
75 |
out = torchvision.utils.make_grid(inputs)
|
76 |
-
|
77 |
|
78 |
# Model Preparation Section
|
79 |
st.markdown("""
|
@@ -169,20 +170,18 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
|
|
169 |
|
170 |
# Plot training history
|
171 |
epochs_range = range(num_epochs)
|
172 |
-
plt.
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
plt.show()
|
185 |
-
st.pyplot(plt)
|
186 |
|
187 |
return model
|
188 |
|
|
|
66 |
std = np.array([0.229, 0.224, 0.225])
|
67 |
inp = std * inp + mean
|
68 |
inp = np.clip(inp, 0, 1)
|
69 |
+
fig, ax = plt.subplots()
|
70 |
+
ax.imshow(inp)
|
71 |
if title is not None:
|
72 |
+
ax.set_title(title)
|
73 |
+
st.pyplot(fig)
|
74 |
|
75 |
inputs, classes = next(iter(dataloaders['train']))
|
76 |
out = torchvision.utils.make_grid(inputs)
|
77 |
+
imshow(out, title=[class_names[x] for x in classes])
|
78 |
|
79 |
# Model Preparation Section
|
80 |
st.markdown("""
|
|
|
170 |
|
171 |
# Plot training history
|
172 |
epochs_range = range(num_epochs)
|
173 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
|
174 |
+
ax1.plot(epochs_range, train_loss_history, label='Training Loss')
|
175 |
+
ax1.plot(epochs_range, val_loss_history, label='Validation Loss')
|
176 |
+
ax1.legend(loc='upper right')
|
177 |
+
ax1.set_title('Training and Validation Loss')
|
178 |
+
|
179 |
+
ax2.plot(epochs_range, train_acc_history, label='Training Accuracy')
|
180 |
+
ax2.plot(epochs_range, val_acc_history, label='Validation Accuracy')
|
181 |
+
ax2.legend(loc='lower right')
|
182 |
+
ax2.set_title('Training and Validation Accuracy')
|
183 |
+
|
184 |
+
st.pyplot(fig)
|
|
|
|
|
185 |
|
186 |
return model
|
187 |
|