eaglelandsonce commited on
Commit
63668c4
·
verified ·
1 Parent(s): 7125d94

Update pages/19_ResNet.py

Browse files
Files changed (1) hide show
  1. pages/19_ResNet.py +104 -16
pages/19_ResNet.py CHANGED
@@ -35,25 +35,18 @@ transform = transforms.Compose([
35
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
36
  ])
37
 
38
- train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
39
- val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
40
-
41
- # Using only 1000 samples for simplicity
42
- subset_indices = list(range(1000))
43
- train_size = int(0.8 * len(subset_indices))
44
- val_size = len(subset_indices) - train_size
45
-
46
- train_indices = subset_indices[:train_size]
47
- val_indices = subset_indices[train_size:]
48
-
49
- train_dataset = Subset(train_dataset, train_indices)
50
- val_dataset = Subset(val_dataset, val_indices)
51
 
52
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
53
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
54
 
55
  dataloaders = {'train': train_loader, 'val': val_loader}
56
- class_names = datasets.CIFAR10(root='./data', download=False).classes
57
 
58
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
59
 
@@ -78,8 +71,103 @@ imshow(out, title=[class_names[x] for x in classes])
78
  # Model Preparation Section
79
  st.markdown("""
80
  ### Model Preparation
81
- We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our dataset.
82
  """)
83
 
84
  # Load Pre-trained ResNet Model
85
- model_ft = models.resnet
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
36
  ])
37
 
38
+ full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
39
+ subset_indices = list(range(1000)) # Use only 1000 samples for simplicity
40
+ subset_dataset = Subset(full_dataset, subset_indices)
41
+ train_size = int(0.8 * len(subset_dataset))
42
+ val_size = len(subset_dataset) - train_size
43
+ train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
 
 
 
 
 
 
 
44
 
45
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
46
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
47
 
48
  dataloaders = {'train': train_loader, 'val': val_loader}
49
+ class_names = full_dataset.classes
50
 
51
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
52
 
 
71
  # Model Preparation Section
72
  st.markdown("""
73
  ### Model Preparation
74
+ We will use a pre-trained ResNet-18 model and fine-tune the final fully connected layer to match the number of classes in our custom dataset.
75
  """)
76
 
77
  # Load Pre-trained ResNet Model
78
+ model_ft = models.resnet18(pretrained=True)
79
+ num_ftrs = model_ft.fc.in_features
80
+ model_ft.fc = nn.Linear(num_ftrs, len(class_names))
81
+
82
+ model_ft = model_ft.to(device)
83
+
84
+ # Define Loss Function and Optimizer
85
+ criterion = nn.CrossEntropyLoss()
86
+ optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9)
87
+
88
+ # Training Section
89
+ st.markdown("""
90
+ ### Training
91
+ We will train the model using stochastic gradient descent (SGD) with a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
92
+ """)
93
+
94
+ # Train and Evaluate the Model
95
+ def train_model(model, criterion, optimizer, num_epochs=5):
96
+ best_model_wts = copy.deepcopy(model.state_dict())
97
+ best_acc = 0.0
98
+ train_loss_history = []
99
+ val_loss_history = []
100
+ train_acc_history = []
101
+ val_acc_history = []
102
+
103
+ for epoch in range(num_epochs):
104
+ st.write(f'Epoch {epoch+1}/{num_epochs}')
105
+ st.write('-' * 10)
106
+
107
+ for phase in ['train', 'val']:
108
+ if phase == 'train':
109
+ model.train()
110
+ else:
111
+ model.eval()
112
+
113
+ running_loss = 0.0
114
+ running_corrects = 0
115
+
116
+ for inputs, labels in dataloaders[phase]:
117
+ inputs = inputs.to(device)
118
+ labels = labels.to(device)
119
+
120
+ optimizer.zero_grad()
121
+
122
+ with torch.set_grad_enabled(phase == 'train'):
123
+ outputs = model(inputs)
124
+ _, preds = torch.max(outputs, 1)
125
+ loss = criterion(outputs, labels)
126
+
127
+ if phase == 'train':
128
+ loss.backward()
129
+ optimizer.step()
130
+
131
+ running_loss += loss.item() * inputs.size(0)
132
+ running_corrects += torch.sum(preds == labels.data)
133
+
134
+ epoch_loss = running_loss / len(dataloaders[phase].dataset)
135
+ epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
136
+
137
+ if phase == 'train':
138
+ train_loss_history.append(epoch_loss)
139
+ train_acc_history.append(epoch_acc)
140
+ else:
141
+ val_loss_history.append(epoch_loss)
142
+ val_acc_history.append(epoch_acc)
143
+
144
+ st.write(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
145
+
146
+ if phase == 'val' and epoch_acc > best_acc:
147
+ best_acc = epoch_acc
148
+ best_model_wts = copy.deepcopy(model.state_dict())
149
+
150
+ model.load_state_dict(best_model_wts)
151
+
152
+ # Plot training history
153
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
154
+ ax1.plot(train_loss_history, label='Training Loss')
155
+ ax1.plot(val_loss_history, label='Validation Loss')
156
+ ax1.legend(loc='upper right')
157
+ ax1.set_title('Training and Validation Loss')
158
+
159
+ ax2.plot(train_acc_history, label='Training Accuracy')
160
+ ax2.plot(val_acc_history, label='Validation Accuracy')
161
+ ax2.legend(loc='lower right')
162
+ ax2.set_title('Training and Validation Accuracy')
163
+
164
+ st.pyplot(fig)
165
+
166
+ return model
167
+
168
+ if st.button('Train Model'):
169
+ model_ft = train_model(model_ft, criterion, optimizer_ft, num_epochs)
170
+ # Save the Model
171
+ torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
172
+ st.write("Model saved as 'fine_tuned_resnet.pth'")
173
+