Spaces:
Running
Running
Update pages/13_FFNN.py
Browse files- pages/13_FFNN.py +20 -4
pages/13_FFNN.py
CHANGED
@@ -37,6 +37,7 @@ def load_data():
|
|
37 |
|
38 |
# Function to train the network
|
39 |
def train_network(net, trainloader, criterion, optimizer, epochs):
|
|
|
40 |
for epoch in range(epochs):
|
41 |
running_loss = 0.0
|
42 |
for i, data in enumerate(trainloader, 0):
|
@@ -47,8 +48,11 @@ def train_network(net, trainloader, criterion, optimizer, epochs):
|
|
47 |
loss.backward()
|
48 |
optimizer.step()
|
49 |
running_loss += loss.item()
|
50 |
-
|
|
|
|
|
51 |
st.write('Finished Training')
|
|
|
52 |
|
53 |
# Function to test the network
|
54 |
def test_network(net, testloader):
|
@@ -61,7 +65,9 @@ def test_network(net, testloader):
|
|
61 |
_, predicted = torch.max(outputs.data, 1)
|
62 |
total += labels.size(0)
|
63 |
correct += (predicted == labels).sum().item()
|
64 |
-
|
|
|
|
|
65 |
|
66 |
# Load the data
|
67 |
trainloader, testloader = load_data()
|
@@ -85,11 +91,21 @@ st.write('\n' * 10)
|
|
85 |
|
86 |
# Train the network
|
87 |
if st.sidebar.button('Train Network'):
|
88 |
-
train_network(net, trainloader, criterion, optimizer, epochs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Test the network
|
91 |
if st.sidebar.button('Test Network'):
|
92 |
-
test_network(net, testloader)
|
|
|
93 |
|
94 |
# Visualize some test results
|
95 |
def imshow(img):
|
|
|
37 |
|
38 |
# Function to train the network
|
39 |
def train_network(net, trainloader, criterion, optimizer, epochs):
|
40 |
+
loss_values = []
|
41 |
for epoch in range(epochs):
|
42 |
running_loss = 0.0
|
43 |
for i, data in enumerate(trainloader, 0):
|
|
|
48 |
loss.backward()
|
49 |
optimizer.step()
|
50 |
running_loss += loss.item()
|
51 |
+
epoch_loss = running_loss / len(trainloader)
|
52 |
+
loss_values.append(epoch_loss)
|
53 |
+
st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}')
|
54 |
st.write('Finished Training')
|
55 |
+
return loss_values
|
56 |
|
57 |
# Function to test the network
|
58 |
def test_network(net, testloader):
|
|
|
65 |
_, predicted = torch.max(outputs.data, 1)
|
66 |
total += labels.size(0)
|
67 |
correct += (predicted == labels).sum().item()
|
68 |
+
accuracy = 100 * correct / total
|
69 |
+
st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')
|
70 |
+
return accuracy
|
71 |
|
72 |
# Load the data
|
73 |
trainloader, testloader = load_data()
|
|
|
91 |
|
92 |
# Train the network
|
93 |
if st.sidebar.button('Train Network'):
|
94 |
+
loss_values = train_network(net, trainloader, criterion, optimizer, epochs)
|
95 |
+
|
96 |
+
# Plot the loss values
|
97 |
+
plt.figure(figsize=(10, 5))
|
98 |
+
plt.plot(range(1, epochs + 1), loss_values, marker='o')
|
99 |
+
plt.title('Training Loss Over Epochs')
|
100 |
+
plt.xlabel('Epoch')
|
101 |
+
plt.ylabel('Loss')
|
102 |
+
plt.grid(True)
|
103 |
+
st.pyplot(plt)
|
104 |
|
105 |
# Test the network
|
106 |
if st.sidebar.button('Test Network'):
|
107 |
+
accuracy = test_network(net, testloader)
|
108 |
+
st.write(f'Test Accuracy: {accuracy:.2f}%')
|
109 |
|
110 |
# Visualize some test results
|
111 |
def imshow(img):
|