eaglelandsonce commited on
Commit
eb2cc92
·
verified ·
1 Parent(s): 600d0d3

Update pages/19_ResNet.py

Browse files
Files changed (1) hide show
  1. pages/19_ResNet.py +17 -18
pages/19_ResNet.py CHANGED
@@ -66,14 +66,15 @@ def imshow(inp, title=None):
66
  std = np.array([0.229, 0.224, 0.225])
67
  inp = std * inp + mean
68
  inp = np.clip(inp, 0, 1)
69
- plt.imshow(inp)
 
70
  if title is not None:
71
- plt.title(title)
72
- plt.pause(0.001)
73
 
74
  inputs, classes = next(iter(dataloaders['train']))
75
  out = torchvision.utils.make_grid(inputs)
76
- st.pyplot(imshow(out, title=[class_names[x] for x in classes]))
77
 
78
  # Model Preparation Section
79
  st.markdown("""
@@ -169,20 +170,18 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
169
 
170
  # Plot training history
171
  epochs_range = range(num_epochs)
172
- plt.figure(figsize=(10, 5))
173
- plt.subplot(1, 2, 1)
174
- plt.plot(epochs_range, train_loss_history, label='Training Loss')
175
- plt.plot(epochs_range, val_loss_history, label='Validation Loss')
176
- plt.legend(loc='upper right')
177
- plt.title('Training and Validation Loss')
178
-
179
- plt.subplot(1, 2, 2)
180
- plt.plot(epochs_range, train_acc_history, label='Training Accuracy')
181
- plt.plot(epochs_range, val_acc_history, label='Validation Accuracy')
182
- plt.legend(loc='lower right')
183
- plt.title('Training and Validation Accuracy')
184
- plt.show()
185
- st.pyplot(plt)
186
 
187
  return model
188
 
 
66
  std = np.array([0.229, 0.224, 0.225])
67
  inp = std * inp + mean
68
  inp = np.clip(inp, 0, 1)
69
+ fig, ax = plt.subplots()
70
+ ax.imshow(inp)
71
  if title is not None:
72
+ ax.set_title(title)
73
+ st.pyplot(fig)
74
 
75
  inputs, classes = next(iter(dataloaders['train']))
76
  out = torchvision.utils.make_grid(inputs)
77
+ imshow(out, title=[class_names[x] for x in classes])
78
 
79
  # Model Preparation Section
80
  st.markdown("""
 
170
 
171
  # Plot training history
172
  epochs_range = range(num_epochs)
173
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
174
+ ax1.plot(epochs_range, train_loss_history, label='Training Loss')
175
+ ax1.plot(epochs_range, val_loss_history, label='Validation Loss')
176
+ ax1.legend(loc='upper right')
177
+ ax1.set_title('Training and Validation Loss')
178
+
179
+ ax2.plot(epochs_range, train_acc_history, label='Training Accuracy')
180
+ ax2.plot(epochs_range, val_acc_history, label='Validation Accuracy')
181
+ ax2.legend(loc='lower right')
182
+ ax2.set_title('Training and Validation Accuracy')
183
+
184
+ st.pyplot(fig)
 
 
185
 
186
  return model
187