import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as np # Define the Feedforward Neural Network class FeedforwardNeuralNetwork(nn.Module): def __init__(self, input_size, hidden1_size, hidden2_size, hidden3_size, output_size): super(FeedforwardNeuralNetwork, self).__init__() self.fc1 = nn.Linear(input_size, hidden1_size) self.fc2 = nn.Linear(hidden1_size, hidden2_size) self.fc3 = nn.Linear(hidden2_size, hidden3_size) self.fc4 = nn.Linear(hidden3_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = x.view(-1, 28 * 28) x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.relu(self.fc3(x)) x = self.fc4(x) return x # Function to load the data @st.cache_data def load_data(): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) return trainloader, testloader # Function to train the network def train_network(net, trainloader, criterion, optimizer, epochs): loss_values = [] for epoch in range(epochs): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() epoch_loss = running_loss / len(trainloader) loss_values.append(epoch_loss) st.write(f'Epoch {epoch + 1}: loss {epoch_loss:.3f}') st.write('Finished Training') return loss_values # Function to test the network def test_network(net, testloader): correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total st.write(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%') return accuracy # Load the data trainloader, testloader = load_data() # Streamlit sidebar for input parameters st.sidebar.header('Model Hyperparameters') hidden1_size = st.sidebar.slider('Hidden Layer 1 Size', 128, 1024, 512) hidden2_size = st.sidebar.slider('Hidden Layer 2 Size', 128, 1024, 256) hidden3_size = st.sidebar.slider('Hidden Layer 3 Size', 128, 1024, 128) learning_rate = st.sidebar.slider('Learning Rate', 0.001, 0.1, 0.01, step=0.001) momentum = st.sidebar.slider('Momentum', 0.0, 1.0, 0.9, step=0.1) epochs = st.sidebar.slider('Epochs', 1, 20, 5) # Create the network net = FeedforwardNeuralNetwork(784, hidden1_size, hidden2_size, hidden3_size, 10) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) # Add vertical space st.write('\n' * 10) # Train the network if st.sidebar.button('Train Network'): loss_values = train_network(net, trainloader, criterion, optimizer, epochs) # Plot the loss values plt.figure(figsize=(10, 5)) plt.plot(range(1, epochs + 1), loss_values, marker='o') plt.title('Training Loss Over Epochs') plt.xlabel('Epoch') plt.ylabel('Loss') plt.grid(True) st.pyplot(plt) # Test the network if st.sidebar.button('Test Network'): accuracy = test_network(net, testloader) st.write(f'Test Accuracy: {accuracy:.2f}%') # Visualize some test results def imshow(img): img = img / 2 + 0.5 # unnormalize npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() if st.sidebar.button('Show Test Results'): dataiter = iter(testloader) images, labels = next(dataiter) # Use next function imshow(torchvision.utils.make_grid(images)) st.write('GroundTruth: ', ' '.join(f'{labels[j]}' for j in range(8))) outputs = net(images) _, predicted = torch.max(outputs, 1) st.write('Predicted: ', ' '.join(f'{predicted[j]}' for j in range(8)))