eaglelandsonce commited on
Commit
9373ca7
·
verified ·
1 Parent(s): f6cff54

Update pages/15_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_CNN.py +61 -9
pages/15_CNN.py CHANGED
@@ -5,7 +5,9 @@ import torch.optim as optim
5
  import torchvision
6
  import torchvision.transforms as transforms
7
  import matplotlib.pyplot as plt
 
8
  from torch.utils.data import DataLoader
 
9
  import numpy as np
10
 
11
  # Device configuration
@@ -13,6 +15,9 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
 
14
  # Streamlit interface
15
  st.title("CNN for Image Classification using CIFAR-10")
 
 
 
16
 
17
  # Hyperparameters
18
  num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10)
@@ -33,6 +38,26 @@ test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
33
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
34
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Define a Convolutional Neural Network
37
  class CNN(nn.Module):
38
  def __init__(self):
@@ -81,9 +106,13 @@ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
81
 
82
  # Button to start training
83
  if st.button("Start Training"):
84
- # Lists to store losses
85
  train_losses = []
86
  test_losses = []
 
 
 
 
87
 
88
  # Train the model
89
  total_step = len(train_loader)
@@ -114,6 +143,8 @@ if st.button("Start Training"):
114
  test_loss = 0
115
  correct = 0
116
  total = 0
 
 
117
  for images, labels in test_loader:
118
  images = images.to(device)
119
  labels = labels.to(device)
@@ -123,21 +154,42 @@ if st.button("Start Training"):
123
  _, predicted = torch.max(outputs.data, 1)
124
  total += labels.size(0)
125
  correct += (predicted == labels).sum().item()
 
 
126
 
127
  test_loss /= len(test_loader)
128
  test_losses.append(test_loss)
129
  accuracy = 100 * correct / total
 
130
  st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
131
  model.train()
132
 
133
- # Plotting the loss
134
- fig, ax = plt.subplots()
135
- ax.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
136
- ax.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
137
- ax.set_xlabel('Epoch')
138
- ax.set_ylabel('Loss')
139
- ax.set_title('Training and Test Loss')
140
- ax.legend()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  st.pyplot(fig)
142
 
143
  # Save the model checkpoint
 
5
  import torchvision
6
  import torchvision.transforms as transforms
7
  import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
  from torch.utils.data import DataLoader
10
+ from sklearn.metrics import confusion_matrix
11
  import numpy as np
12
 
13
  # Device configuration
 
15
 
16
  # Streamlit interface
17
  st.title("CNN for Image Classification using CIFAR-10")
18
+ st.write("""
19
+ This application demonstrates how to build and train a Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset. You can adjust hyperparameters, visualize sample images, and see the model's performance.
20
+ """)
21
 
22
  # Hyperparameters
23
  num_epochs = st.sidebar.slider("Number of epochs", 1, 20, 10)
 
38
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
39
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
40
 
41
+ # Display some sample images
42
+ st.write("## Sample Images from CIFAR-10 Dataset")
43
+ sample_images, sample_labels = iter(train_loader).next()
44
+ fig, axes = plt.subplots(1, 6, figsize=(15, 5))
45
+ for i in range(6):
46
+ axes[i].imshow(np.transpose(sample_images[i].numpy(), (1, 2, 0)))
47
+ axes[i].set_title(f'Label: {sample_labels[i].item()}')
48
+ axes[i].axis('off')
49
+ st.pyplot(fig)
50
+
51
+ # Class distribution
52
+ st.write("## Class Distribution in CIFAR-10 Dataset")
53
+ class_names = train_dataset.classes
54
+ class_counts = np.bincount([sample_labels[i].item() for i in range(len(sample_labels))])
55
+ fig, ax = plt.subplots()
56
+ sns.barplot(x=class_names, y=class_counts, ax=ax)
57
+ ax.set_ylabel('Count')
58
+ ax.set_title('Class Distribution')
59
+ st.pyplot(fig)
60
+
61
  # Define a Convolutional Neural Network
62
  class CNN(nn.Module):
63
  def __init__(self):
 
106
 
107
  # Button to start training
108
  if st.button("Start Training"):
109
+ # Lists to store losses and accuracy
110
  train_losses = []
111
  test_losses = []
112
+ test_accuracies = []
113
+
114
+ # Progress bar
115
+ progress_bar = st.progress(0)
116
 
117
  # Train the model
118
  total_step = len(train_loader)
 
143
  test_loss = 0
144
  correct = 0
145
  total = 0
146
+ all_labels = []
147
+ all_predictions = []
148
  for images, labels in test_loader:
149
  images = images.to(device)
150
  labels = labels.to(device)
 
154
  _, predicted = torch.max(outputs.data, 1)
155
  total += labels.size(0)
156
  correct += (predicted == labels).sum().item()
157
+ all_labels.extend(labels.cpu().numpy())
158
+ all_predictions.extend(predicted.cpu().numpy())
159
 
160
  test_loss /= len(test_loader)
161
  test_losses.append(test_loss)
162
  accuracy = 100 * correct / total
163
+ test_accuracies.append(accuracy)
164
  st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
165
  model.train()
166
 
167
+ # Update progress bar
168
+ progress_bar.progress((epoch + 1) / num_epochs)
169
+
170
+ # Plotting the loss and accuracy
171
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
172
+ ax1.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
173
+ ax1.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
174
+ ax1.set_xlabel('Epoch')
175
+ ax1.set_ylabel('Loss')
176
+ ax1.set_title('Training and Test Loss')
177
+ ax1.legend()
178
+
179
+ ax2.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')
180
+ ax2.set_xlabel('Epoch')
181
+ ax2.set_ylabel('Accuracy (%)')
182
+ ax2.set_title('Test Accuracy')
183
+ ax2.legend()
184
+ st.pyplot(fig)
185
+
186
+ # Confusion Matrix
187
+ cm = confusion_matrix(all_labels, all_predictions)
188
+ fig, ax = plt.subplots(figsize=(10, 10))
189
+ sns.heatmap(cm, annot=True, fmt="d", xticklabels=class_names, yticklabels=class_names, cmap='Blues')
190
+ ax.set_xlabel('Predicted')
191
+ ax.set_ylabel('True')
192
+ ax.set_title('Confusion Matrix')
193
  st.pyplot(fig)
194
 
195
  # Save the model checkpoint