eaglelandsonce commited on
Commit
2e0ae2e
·
verified ·
1 Parent(s): 45282e7

Update pages/15_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_CNN.py +117 -96
pages/15_CNN.py CHANGED
@@ -2,109 +2,130 @@ import streamlit as st
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
- import torchvision.transforms as transforms
6
  import torchvision
 
 
7
  from torch.utils.data import DataLoader
8
- from PIL import Image
9
  import numpy as np
10
 
11
- # Define the CNN model
12
- class Net(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def __init__(self):
14
- super(Net, self).__init__()
15
- self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
16
- self.pool = nn.MaxPool2d(2, 2)
17
- self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
18
- self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
19
- self.fc1 = nn.Linear(128 * 4 * 4, 256)
20
- self.fc2 = nn.Linear(256, 128)
21
- self.fc3 = nn.Linear(128, 10)
22
-
 
 
 
 
 
 
 
23
  def forward(self, x):
24
- x = self.pool(F.relu(self.conv1(x)))
25
- x = self.pool(F.relu(self.conv2(x)))
26
- x = self.pool(F.relu(self.conv3(x)))
27
- x = x.view(-1, 128 * 4 * 4)
28
- x = F.relu(self.fc1(x))
29
- x = F.relu(self.fc2(x))
30
- x = self.fc3(x)
31
- return x
32
-
33
- # Load pre-trained model (if available)
34
- def load_model():
35
- net = Net()
36
- try:
37
- net.load_state_dict(torch.load('cnn_model.pth'))
38
- net.eval()
39
- st.write("Model loaded successfully")
40
- except FileNotFoundError:
41
- st.write("No pre-trained model found. Please train the model first.")
42
- return net
43
-
44
- # Function to predict the class of an image
45
- def predict(image, model):
46
- transform = transforms.Compose([
47
- transforms.Resize((32, 32)),
48
- transforms.ToTensor(),
49
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
50
- ])
51
- image = transform(image).unsqueeze(0)
52
- outputs = model(image)
53
- _, predicted = torch.max(outputs.data, 1)
54
- return predicted.item()
55
-
56
- # Training function
57
- def train_model():
58
- transform = transforms.Compose([
59
- transforms.ToTensor(),
60
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
61
- ])
62
-
63
- trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
64
- trainloader = DataLoader(trainset, batch_size=100, shuffle=True, num_workers=2)
65
-
66
- net = Net()
67
- criterion = nn.CrossEntropyLoss()
68
- optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
69
-
70
- st.write("Training the model...")
71
- for epoch in range(10):
72
- running_loss = 0.0
73
- for i, data in enumerate(trainloader, 0):
74
- inputs, labels = data
75
- optimizer.zero_grad()
76
- outputs = net(inputs)
77
  loss = criterion(outputs, labels)
78
- loss.backward()
79
- optimizer.step()
 
 
80
 
81
- running_loss += loss.item()
82
- if i % 100 == 99:
83
- st.write(f'[Epoch {epoch + 1}, Batch {i + 1}] loss: {running_loss / 100:.3f}')
84
- running_loss = 0.0
 
85
 
86
- st.write('Finished Training')
87
- torch.save(net.state_dict(), 'cnn_model.pth')
88
- st.write('Model saved as cnn_model.pth')
89
- return net
 
 
 
 
 
90
 
91
- # Streamlit interface
92
- st.title("CNN Image Classification with CIFAR-10")
93
-
94
- mode = st.sidebar.selectbox("Mode", ["Train", "Predict"])
95
-
96
- if mode == "Train":
97
- if st.button("Train Model"):
98
- model = train_model()
99
-
100
- if mode == "Predict":
101
- model = load_model()
102
- uploaded_file = st.file_uploader("Choose an image...", type="jpg")
103
- if uploaded_file is not None:
104
- image = Image.open(uploaded_file)
105
- st.image(image, caption='Uploaded Image.', use_column_width=True)
106
- st.write("")
107
- st.write("Classifying...")
108
- class_idx = predict(image, model)
109
- classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
110
- st.write(f'Prediction: {classes[class_idx]}')
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
 
5
  import torchvision
6
+ import torchvision.transforms as transforms
7
+ import matplotlib.pyplot as plt
8
  from torch.utils.data import DataLoader
 
9
  import numpy as np
10
 
11
+ # Device configuration
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # Streamlit interface
15
+ st.title("CNN for Image Classification using CIFAR-10")
16
+
17
+ # Hyperparameters
18
+ num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10)
19
+ batch_size = st.sidebar.slider("Batch size", 10, 200, 100, step=10)
20
+ learning_rate = st.sidebar.slider("Learning rate", 0.0001, 0.01, 0.001, step=0.0001)
21
+
22
+ # CIFAR-10 dataset
23
+ transform = transforms.Compose(
24
+ [transforms.ToTensor(),
25
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
26
+
27
+ train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
28
+ download=True, transform=transform)
29
+
30
+ test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
31
+ download=True, transform=transform)
32
+
33
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
34
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
35
+
36
+ # Define a Convolutional Neural Network
37
+ class CNN(nn.Module):
38
  def __init__(self):
39
+ super(CNN, self).__init__()
40
+ self.layer1 = nn.Sequential(
41
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(32),
43
+ nn.ReLU(),
44
+ nn.MaxPool2d(kernel_size=2, stride=2))
45
+ self.layer2 = nn.Sequential(
46
+ nn.Conv2d(32, 64, kernel_size=3),
47
+ nn.BatchNorm2d(64),
48
+ nn.ReLU(),
49
+ nn.MaxPool2d(2))
50
+ self.fc1 = nn.Linear(6*6*64, 600)
51
+ self.drop = nn.Dropout2d(0.25)
52
+ self.fc2 = nn.Linear(600, 100)
53
+ self.fc3 = nn.Linear(100, 10)
54
+
55
  def forward(self, x):
56
+ out = self.layer1(x)
57
+ out = self.layer2(out)
58
+ out = out.view(out.size(0), -1)
59
+ out = self.fc1(out)
60
+ out = self.drop(out)
61
+ out = self.fc2(out)
62
+ out = self.fc3(out)
63
+ return out
64
+
65
+ model = CNN().to(device)
66
+
67
+ # Loss and optimizer
68
+ criterion = nn.CrossEntropyLoss()
69
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
70
+
71
+ # Lists to store losses
72
+ train_losses = []
73
+ test_losses = []
74
+
75
+ # Train the model
76
+ total_step = len(train_loader)
77
+ for epoch in range(num_epochs):
78
+ train_loss = 0
79
+ for i, (images, labels) in enumerate(train_loader):
80
+ images = images.to(device)
81
+ labels = labels.to(device)
82
+
83
+ # Forward pass
84
+ outputs = model(images)
85
+ loss = criterion(outputs, labels)
86
+
87
+ # Backward and optimize
88
+ optimizer.zero_grad()
89
+ loss.backward()
90
+ optimizer.step()
91
+
92
+ train_loss += loss.item()
93
+
94
+ train_loss /= total_step
95
+ train_losses.append(train_loss)
96
+ st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
97
+
98
+ # Test the model
99
+ model.eval()
100
+ with torch.no_grad():
101
+ test_loss = 0
102
+ correct = 0
103
+ total = 0
104
+ for images, labels in test_loader:
105
+ images = images.to(device)
106
+ labels = labels.to(device)
107
+ outputs = model(images)
 
108
  loss = criterion(outputs, labels)
109
+ test_loss += loss.item()
110
+ _, predicted = torch.max(outputs.data, 1)
111
+ total += labels.size(0)
112
+ correct += (predicted == labels).sum().item()
113
 
114
+ test_loss /= len(test_loader)
115
+ test_losses.append(test_loss)
116
+ accuracy = 100 * correct / total
117
+ st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
118
+ model.train()
119
 
120
+ # Plotting the loss
121
+ fig, ax = plt.subplots()
122
+ ax.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
123
+ ax.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
124
+ ax.set_xlabel('Epoch')
125
+ ax.set_ylabel('Loss')
126
+ ax.set_title('Training and Test Loss')
127
+ ax.legend()
128
+ st.pyplot(fig)
129
 
130
+ # Save the model checkpoint
131
+ torch.save(model.state_dict(), 'cnn_model.pth')