eaglelandsonce commited on
Commit
f6cff54
·
verified ·
1 Parent(s): d98f068

Update pages/15_CNN.py

Browse files
Files changed (1) hide show
  1. pages/15_CNN.py +60 -57
pages/15_CNN.py CHANGED
@@ -79,64 +79,67 @@ model = CNN().to(device)
79
  criterion = nn.CrossEntropyLoss()
80
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
81
 
82
- # Lists to store losses
83
- train_losses = []
84
- test_losses = []
85
-
86
- # Train the model
87
- total_step = len(train_loader)
88
- for epoch in range(num_epochs):
89
- train_loss = 0
90
- for i, (images, labels) in enumerate(train_loader):
91
- images = images.to(device)
92
- labels = labels.to(device)
93
-
94
- # Forward pass
95
- outputs = model(images)
96
- loss = criterion(outputs, labels)
97
-
98
- # Backward and optimize
99
- optimizer.zero_grad()
100
- loss.backward()
101
- optimizer.step()
102
-
103
- train_loss += loss.item()
104
-
105
- train_loss /= total_step
106
- train_losses.append(train_loss)
107
- st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
108
-
109
- # Test the model
110
- model.eval()
111
- with torch.no_grad():
112
- test_loss = 0
113
- correct = 0
114
- total = 0
115
- for images, labels in test_loader:
116
  images = images.to(device)
117
  labels = labels.to(device)
 
 
118
  outputs = model(images)
119
  loss = criterion(outputs, labels)
120
- test_loss += loss.item()
121
- _, predicted = torch.max(outputs.data, 1)
122
- total += labels.size(0)
123
- correct += (predicted == labels).sum().item()
124
-
125
- test_loss /= len(test_loader)
126
- test_losses.append(test_loss)
127
- accuracy = 100 * correct / total
128
- st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
129
- model.train()
130
-
131
- # Plotting the loss
132
- fig, ax = plt.subplots()
133
- ax.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
134
- ax.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
135
- ax.set_xlabel('Epoch')
136
- ax.set_ylabel('Loss')
137
- ax.set_title('Training and Test Loss')
138
- ax.legend()
139
- st.pyplot(fig)
140
-
141
- # Save the model checkpoint
142
- torch.save(model.state_dict(), 'cnn_model.pth')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  criterion = nn.CrossEntropyLoss()
80
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
81
 
82
+ # Button to start training
83
+ if st.button("Start Training"):
84
+ # Lists to store losses
85
+ train_losses = []
86
+ test_losses = []
87
+
88
+ # Train the model
89
+ total_step = len(train_loader)
90
+ for epoch in range(num_epochs):
91
+ train_loss = 0
92
+ for i, (images, labels) in enumerate(train_loader):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  images = images.to(device)
94
  labels = labels.to(device)
95
+
96
+ # Forward pass
97
  outputs = model(images)
98
  loss = criterion(outputs, labels)
99
+
100
+ # Backward and optimize
101
+ optimizer.zero_grad()
102
+ loss.backward()
103
+ optimizer.step()
104
+
105
+ train_loss += loss.item()
106
+
107
+ train_loss /= total_step
108
+ train_losses.append(train_loss)
109
+ st.write(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}')
110
+
111
+ # Test the model
112
+ model.eval()
113
+ with torch.no_grad():
114
+ test_loss = 0
115
+ correct = 0
116
+ total = 0
117
+ for images, labels in test_loader:
118
+ images = images.to(device)
119
+ labels = labels.to(device)
120
+ outputs = model(images)
121
+ loss = criterion(outputs, labels)
122
+ test_loss += loss.item()
123
+ _, predicted = torch.max(outputs.data, 1)
124
+ total += labels.size(0)
125
+ correct += (predicted == labels).sum().item()
126
+
127
+ test_loss /= len(test_loader)
128
+ test_losses.append(test_loss)
129
+ accuracy = 100 * correct / total
130
+ st.write(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')
131
+ model.train()
132
+
133
+ # Plotting the loss
134
+ fig, ax = plt.subplots()
135
+ ax.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
136
+ ax.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
137
+ ax.set_xlabel('Epoch')
138
+ ax.set_ylabel('Loss')
139
+ ax.set_title('Training and Test Loss')
140
+ ax.legend()
141
+ st.pyplot(fig)
142
+
143
+ # Save the model checkpoint
144
+ torch.save(model.state_dict(), 'cnn_model.pth')
145
+ st.write("Model training completed and saved as 'cnn_model.pth'")