eaglelandsonce commited on
Commit
7125d94
·
verified ·
1 Parent(s): f6317cd

Update pages/19_ResNet.py

Browse files
Files changed (1) hide show
  1. pages/19_ResNet.py +35 -103
pages/19_ResNet.py CHANGED
@@ -6,10 +6,12 @@ import streamlit as st
6
  import torch
7
  import torch.nn as nn
8
  import torch.optim as optim
 
9
  from torchvision import datasets, models, transforms
10
  from torch.utils.data import DataLoader, Subset
11
  import numpy as np
12
  import time
 
13
  import matplotlib.pyplot as plt
14
 
15
  # Streamlit Interface
@@ -33,21 +35,46 @@ transform = transforms.Compose([
33
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
34
  ])
35
 
36
- full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
37
- subset_indices = list(range(1000)) # Use only 1000 samples for simplicity
38
- subset_dataset = Subset(full_dataset, subset_indices)
39
- train_size = int(0.8 * len(subset_dataset))
40
- val_size = len(subset_dataset) - train_size
41
- train_dataset, val_dataset = torch.utils.data.random_split(subset_dataset, [train_size, val_size])
 
 
 
 
 
 
 
42
 
43
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
44
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
45
 
46
  dataloaders = {'train': train_loader, 'val': val_loader}
47
- class_names = full_dataset.classes
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Model Preparation Section
52
  st.markdown("""
53
  ### Model Preparation
@@ -55,99 +82,4 @@ We will use a pre-trained ResNet-18 model and fine-tune the final fully connecte
55
  """)
56
 
57
  # Load Pre-trained ResNet Model
58
- model_ft = models.resnet18(pretrained=True)
59
- num_ftrs = model_ft.fc.in_features
60
- model_ft.fc = nn.Linear(num_ftrs, len(class_names))
61
-
62
- model_ft = model_ft.to(device)
63
-
64
- # Define Loss Function and Optimizer
65
- criterion = nn.CrossEntropyLoss()
66
- optimizer_ft = optim.SGD(model_ft.parameters(), lr=learning_rate, momentum=0.9)
67
-
68
- # Training Section
69
- st.markdown("""
70
- ### Training
71
- We will train the model using stochastic gradient descent (SGD) with a learning rate scheduler. The training and validation loss and accuracy will be plotted to monitor the training process.
72
- """)
73
-
74
- # Train and Evaluate the Model
75
- def train_model(model, criterion, optimizer, num_epochs=5):
76
- best_model_wts = copy.deepcopy(model.state_dict())
77
- best_acc = 0.0
78
- train_loss_history = []
79
- val_loss_history = []
80
- train_acc_history = []
81
- val_acc_history = []
82
-
83
- for epoch in range(num_epochs):
84
- st.write(f'Epoch {epoch+1}/{num_epochs}')
85
- st.write('-' * 10)
86
-
87
- for phase in ['train', 'val']:
88
- if phase == 'train':
89
- model.train()
90
- else:
91
- model.eval()
92
-
93
- running_loss = 0.0
94
- running_corrects = 0
95
-
96
- for inputs, labels in dataloaders[phase]:
97
- inputs = inputs.to(device)
98
- labels = labels.to(device)
99
-
100
- optimizer.zero_grad()
101
-
102
- with torch.set_grad_enabled(phase == 'train'):
103
- outputs = model(inputs)
104
- _, preds = torch.max(outputs, 1)
105
- loss = criterion(outputs, labels)
106
-
107
- if phase == 'train':
108
- loss.backward()
109
- optimizer.step()
110
-
111
- running_loss += loss.item() * inputs.size(0)
112
- running_corrects += torch.sum(preds == labels.data)
113
-
114
- epoch_loss = running_loss / len(dataloaders[phase].dataset)
115
- epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
116
-
117
- if phase == 'train':
118
- train_loss_history.append(epoch_loss)
119
- train_acc_history.append(epoch_acc)
120
- else:
121
- val_loss_history.append(epoch_loss)
122
- val_acc_history.append(epoch_acc)
123
-
124
- st.write(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
125
-
126
- if phase == 'val' and epoch_acc > best_acc:
127
- best_acc = epoch_acc
128
- best_model_wts = copy.deepcopy(model.state_dict())
129
-
130
- model.load_state_dict(best_model_wts)
131
-
132
- # Plot training history
133
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
134
- ax1.plot(train_loss_history, label='Training Loss')
135
- ax1.plot(val_loss_history, label='Validation Loss')
136
- ax1.legend(loc='upper right')
137
- ax1.set_title('Training and Validation Loss')
138
-
139
- ax2.plot(train_acc_history, label='Training Accuracy')
140
- ax2.plot(val_acc_history, label='Validation Accuracy')
141
- ax2.legend(loc='lower right')
142
- ax2.set_title('Training and Validation Accuracy')
143
-
144
- st.pyplot(fig)
145
-
146
- return model
147
-
148
- if st.button('Train Model'):
149
- model_ft = train_model(model_ft, criterion, optimizer_ft, num_epochs)
150
- # Save the Model
151
- torch.save(model_ft.state_dict(), 'fine_tuned_resnet.pth')
152
- st.write("Model saved as 'fine_tuned_resnet.pth'")
153
-
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.optim as optim
9
+ import torchvision # Add this import
10
  from torchvision import datasets, models, transforms
11
  from torch.utils.data import DataLoader, Subset
12
  import numpy as np
13
  import time
14
+ import copy # Add this import
15
  import matplotlib.pyplot as plt
16
 
17
  # Streamlit Interface
 
35
  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
36
  ])
37
 
38
+ train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
39
+ val_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
40
+
41
+ # Using only 1000 samples for simplicity
42
+ subset_indices = list(range(1000))
43
+ train_size = int(0.8 * len(subset_indices))
44
+ val_size = len(subset_indices) - train_size
45
+
46
+ train_indices = subset_indices[:train_size]
47
+ val_indices = subset_indices[train_size:]
48
+
49
+ train_dataset = Subset(train_dataset, train_indices)
50
+ val_dataset = Subset(val_dataset, val_indices)
51
 
52
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
53
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
54
 
55
  dataloaders = {'train': train_loader, 'val': val_loader}
56
+ class_names = datasets.CIFAR10(root='./data', download=False).classes
57
 
58
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
59
 
60
+ # Visualize a few training images
61
+ st.markdown("#### Sample Training Images")
62
+ def imshow(inp, title=None):
63
+ inp = inp.numpy().transpose((1, 2, 0))
64
+ mean = np.array([0.485, 0.456, 0.406])
65
+ std = np.array([0.229, 0.224, 0.225])
66
+ inp = std * inp + mean
67
+ inp = np.clip(inp, 0, 1)
68
+ fig, ax = plt.subplots()
69
+ ax.imshow(inp)
70
+ if title is not None:
71
+ ax.set_title(title)
72
+ st.pyplot(fig)
73
+
74
+ inputs, classes = next(iter(dataloaders['train']))
75
+ out = torchvision.utils.make_grid(inputs)
76
+ imshow(out, title=[class_names[x] for x in classes])
77
+
78
  # Model Preparation Section
79
  st.markdown("""
80
  ### Model Preparation
 
82
  """)
83
 
84
  # Load Pre-trained ResNet Model
85
+ model_ft = models.resnet