eaglelandsonce commited on
Commit
925cb8a
·
verified ·
1 Parent(s): 5c6143b

Update pages/13_FFNN.py

Browse files
Files changed (1) hide show
  1. pages/13_FFNN.py +20 -5
pages/13_FFNN.py CHANGED
@@ -59,6 +59,8 @@ def train_network(net, trainloader, criterion, optimizer, epochs):
59
  def test_network(net, testloader):
60
  correct = 0
61
  total = 0
 
 
62
  with torch.no_grad():
63
  for data in testloader:
64
  images, labels = data
@@ -66,9 +68,11 @@ def test_network(net, testloader):
66
  _, predicted = torch.max(outputs.data, 1)
67
  total += labels.size(0)
68
  correct += (predicted == labels).sum().item()
 
 
69
  accuracy = 100 * correct / total
70
  st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
71
- return accuracy
72
 
73
  # Load the data
74
  trainloader, testloader = load_data()
@@ -102,11 +106,22 @@ if st.sidebar.button('Train Network'):
102
  plt.ylabel('Loss')
103
  plt.grid(True)
104
  st.pyplot(plt)
 
 
 
105
 
106
  # Test the network
107
- if st.sidebar.button('Test Network'):
108
- accuracy = test_network(net, testloader)
109
  st.write(f'Test Accuracy: {accuracy:.2f}%')
 
 
 
 
 
 
 
 
110
 
111
  # Visualize some test results
112
  def imshow(img):
@@ -115,12 +130,12 @@ def imshow(img):
115
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
116
  plt.show()
117
 
118
- if st.sidebar.button('Show Test Results'):
119
  dataiter = iter(testloader)
120
  images, labels = next(dataiter) # Use next function
121
  imshow(torchvision.utils.make_grid(images))
122
 
123
- outputs = net(images)
124
  _, predicted = torch.max(outputs, 1)
125
 
126
  st.write('GroundTruth vs Predicted')
 
59
  def test_network(net, testloader):
60
  correct = 0
61
  total = 0
62
+ all_labels = []
63
+ all_predicted = []
64
  with torch.no_grad():
65
  for data in testloader:
66
  images, labels = data
 
68
  _, predicted = torch.max(outputs.data, 1)
69
  total += labels.size(0)
70
  correct += (predicted == labels).sum().item()
71
+ all_labels.extend(labels.numpy())
72
+ all_predicted.extend(predicted.numpy())
73
  accuracy = 100 * correct / total
74
  st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
75
+ return accuracy, all_labels, all_predicted
76
 
77
  # Load the data
78
  trainloader, testloader = load_data()
 
106
  plt.ylabel('Loss')
107
  plt.grid(True)
108
  st.pyplot(plt)
109
+
110
+ # Store the trained model in the session state
111
+ st.session_state['trained_model'] = net
112
 
113
  # Test the network
114
+ if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
115
+ accuracy, all_labels, all_predicted = test_network(st.session_state['trained_model'], testloader)
116
  st.write(f'Test Accuracy: {accuracy:.2f}%')
117
+
118
+ # Display results in a table
119
+ st.write('GroundTruth vs Predicted')
120
+ results = pd.DataFrame({
121
+ 'Ground Truth': all_labels,
122
+ 'Predicted': all_predicted
123
+ })
124
+ st.table(results.head(50)) # Display first 50 results for brevity
125
 
126
  # Visualize some test results
127
  def imshow(img):
 
130
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
131
  plt.show()
132
 
133
+ if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
134
  dataiter = iter(testloader)
135
  images, labels = next(dataiter) # Use next function
136
  imshow(torchvision.utils.make_grid(images))
137
 
138
+ outputs = st.session_state['trained_model'](images)
139
  _, predicted = torch.max(outputs, 1)
140
 
141
  st.write('GroundTruth vs Predicted')