Spaces:
Running
Running
Update pages/13_FFNN.py
Browse files- pages/13_FFNN.py +20 -5
pages/13_FFNN.py
CHANGED
@@ -59,6 +59,8 @@ def train_network(net, trainloader, criterion, optimizer, epochs):
|
|
59 |
def test_network(net, testloader):
|
60 |
correct = 0
|
61 |
total = 0
|
|
|
|
|
62 |
with torch.no_grad():
|
63 |
for data in testloader:
|
64 |
images, labels = data
|
@@ -66,9 +68,11 @@ def test_network(net, testloader):
|
|
66 |
_, predicted = torch.max(outputs.data, 1)
|
67 |
total += labels.size(0)
|
68 |
correct += (predicted == labels).sum().item()
|
|
|
|
|
69 |
accuracy = 100 * correct / total
|
70 |
st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
|
71 |
-
return accuracy
|
72 |
|
73 |
# Load the data
|
74 |
trainloader, testloader = load_data()
|
@@ -102,11 +106,22 @@ if st.sidebar.button('Train Network'):
|
|
102 |
plt.ylabel('Loss')
|
103 |
plt.grid(True)
|
104 |
st.pyplot(plt)
|
|
|
|
|
|
|
105 |
|
106 |
# Test the network
|
107 |
-
if st.sidebar.button('Test Network'):
|
108 |
-
accuracy = test_network(
|
109 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
# Visualize some test results
|
112 |
def imshow(img):
|
@@ -115,12 +130,12 @@ def imshow(img):
|
|
115 |
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
116 |
plt.show()
|
117 |
|
118 |
-
if st.sidebar.button('Show Test Results'):
|
119 |
dataiter = iter(testloader)
|
120 |
images, labels = next(dataiter) # Use next function
|
121 |
imshow(torchvision.utils.make_grid(images))
|
122 |
|
123 |
-
outputs =
|
124 |
_, predicted = torch.max(outputs, 1)
|
125 |
|
126 |
st.write('GroundTruth vs Predicted')
|
|
|
59 |
def test_network(net, testloader):
|
60 |
correct = 0
|
61 |
total = 0
|
62 |
+
all_labels = []
|
63 |
+
all_predicted = []
|
64 |
with torch.no_grad():
|
65 |
for data in testloader:
|
66 |
images, labels = data
|
|
|
68 |
_, predicted = torch.max(outputs.data, 1)
|
69 |
total += labels.size(0)
|
70 |
correct += (predicted == labels).sum().item()
|
71 |
+
all_labels.extend(labels.numpy())
|
72 |
+
all_predicted.extend(predicted.numpy())
|
73 |
accuracy = 100 * correct / total
|
74 |
st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
|
75 |
+
return accuracy, all_labels, all_predicted
|
76 |
|
77 |
# Load the data
|
78 |
trainloader, testloader = load_data()
|
|
|
106 |
plt.ylabel('Loss')
|
107 |
plt.grid(True)
|
108 |
st.pyplot(plt)
|
109 |
+
|
110 |
+
# Store the trained model in the session state
|
111 |
+
st.session_state['trained_model'] = net
|
112 |
|
113 |
# Test the network
|
114 |
+
if 'trained_model' in st.session_state and st.sidebar.button('Test Network'):
|
115 |
+
accuracy, all_labels, all_predicted = test_network(st.session_state['trained_model'], testloader)
|
116 |
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
117 |
+
|
118 |
+
# Display results in a table
|
119 |
+
st.write('GroundTruth vs Predicted')
|
120 |
+
results = pd.DataFrame({
|
121 |
+
'Ground Truth': all_labels,
|
122 |
+
'Predicted': all_predicted
|
123 |
+
})
|
124 |
+
st.table(results.head(50)) # Display first 50 results for brevity
|
125 |
|
126 |
# Visualize some test results
|
127 |
def imshow(img):
|
|
|
130 |
plt.imshow(np.transpose(npimg, (1, 2, 0)))
|
131 |
plt.show()
|
132 |
|
133 |
+
if 'trained_model' in st.session_state and st.sidebar.button('Show Test Results'):
|
134 |
dataiter = iter(testloader)
|
135 |
images, labels = next(dataiter) # Use next function
|
136 |
imshow(torchvision.utils.make_grid(images))
|
137 |
|
138 |
+
outputs = st.session_state['trained_model'](images)
|
139 |
_, predicted = torch.max(outputs, 1)
|
140 |
|
141 |
st.write('GroundTruth vs Predicted')
|