eaglelandsonce commited on
Commit
bc8ae18
·
verified ·
1 Parent(s): c12dbb5

Delete pages/15_Simple_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_Simple_CNN.py +0 -111
pages/15_Simple_CNN.py DELETED
@@ -1,111 +0,0 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- import torchvision
6
- import torchvision.transforms as transforms
7
- import matplotlib.pyplot as plt
8
- import numpy as np
9
-
10
- # Define the CNN
11
- class SimpleCNN(nn.Module):
12
- def __init__(self):
13
- super(SimpleCNN, self).__init__()
14
- self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
15
- self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
16
- self.pool = nn.MaxPool2d(2, 2)
17
- self.fc1 = nn.Linear(32 * 8 * 8, 128)
18
- self.fc2 = nn.Linear(128, 10)
19
-
20
- def forward(self, x):
21
- x = self.pool(torch.relu(self.conv1(x)))
22
- x = self.pool(torch.relu(self.conv2(x)))
23
- x = x.view(-1, 32 * 8 * 8)
24
- x = torch.relu(self.fc1(x))
25
- x = self.fc2(x)
26
- return x
27
-
28
- # Function to train the model
29
- def train_model(num_epochs):
30
- transform = transforms.Compose([
31
- transforms.ToTensor(),
32
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
33
- ])
34
-
35
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
36
- trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
37
-
38
- testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
39
- testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
40
-
41
- CIFAR10_CLASSES = [
42
- 'plane', 'car', 'bird', 'cat', 'deer',
43
- 'dog', 'frog', 'horse', 'ship', 'truck'
44
- ]
45
-
46
- net = SimpleCNN()
47
- criterion = nn.CrossEntropyLoss()
48
- optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
49
-
50
- loss_values = []
51
- st.write("Training the model...")
52
-
53
- for epoch in range(num_epochs):
54
- running_loss = 0.0
55
- for i, data in enumerate(trainloader, 0):
56
- inputs, labels = data
57
- optimizer.zero_grad()
58
- outputs = net(inputs)
59
- loss = criterion(outputs, labels)
60
- loss.backward()
61
- optimizer.step()
62
- running_loss += loss.item()
63
- loss_values.append(running_loss / len(trainloader))
64
- st.write(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.3f}')
65
- st.write('Finished Training')
66
-
67
- # Plot the loss values
68
- plt.figure(figsize=(10, 5))
69
- plt.plot(range(1, num_epochs + 1), loss_values, marker='o')
70
- plt.title('Training Loss over Epochs')
71
- plt.xlabel('Epoch')
72
- plt.ylabel('Loss')
73
- st.pyplot(plt)
74
-
75
- correct = 0
76
- total = 0
77
- with torch.no_grad():
78
- for data in testloader:
79
- images, labels = data
80
- outputs = net(images)
81
- _, predicted = torch.max(outputs, 1)
82
- total += labels.size(0)
83
- correct += (predicted == labels).sum().item()
84
-
85
- st.write(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
86
-
87
- # Visualize some test images and their predictions
88
- def imshow(img):
89
- img = img / 2 + 0.5 # Unnormalize
90
- npimg = img.numpy()
91
- plt.imshow(np.transpose(npimg, (1, 2, 0)))
92
- plt.show()
93
-
94
- dataiter = iter(testloader)
95
- images, labels = next(dataiter)
96
-
97
- imshow(torchvision.utils.make_grid(images))
98
-
99
- outputs = net(images)
100
- _, predicted = torch.max(outputs, 1)
101
-
102
- st.write('Predicted: ', ' '.join(f'{CIFAR10_CLASSES[predicted[j]]:5s}' for j in range(8)))
103
- st.write('Actual: ', ' '.join(f'{CIFAR10_CLASSES[labels[j]]:5s}' for j in range(8)))
104
- st.pyplot()
105
-
106
- # Streamlit interface
107
- st.title('CIFAR-10 Classification with PyTorch')
108
- num_epochs = st.number_input('Enter number of epochs:', min_value=1, max_value=100, value=10)
109
- if st.button('Run'):
110
- train_model(num_epochs)
111
-