Sreekanth Tangirala commited on
Commit
ae63f95
·
1 Parent(s): c773c40
Files changed (1) hide show
  1. train.py +44 -22
train.py CHANGED
@@ -18,39 +18,59 @@ def get_transforms():
18
  std=[0.229, 0.224, 0.225])
19
  ])
20
 
21
- def get_data(subset_size=None):
22
  """
23
  Load and prepare the dataset
24
  Args:
25
  subset_size (int): If provided, return only a subset of data
 
26
  """
27
  transform = get_transforms()
28
- trainset = torchvision.datasets.CIFAR10(
29
  root='./data',
30
- train=True,
31
  download=True,
32
  transform=transform
33
  )
34
 
35
  if subset_size:
36
- indices = torch.randperm(len(trainset))[:subset_size]
37
- trainset = Subset(trainset, indices)
38
 
39
- trainloader = DataLoader(
40
- trainset,
41
  batch_size=32,
42
- shuffle=True,
43
  num_workers=2
44
  )
45
 
46
- return trainloader
47
 
48
- def train_model(model, trainloader, epochs=100, device='cuda'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  """
50
  Train the model
51
  Args:
52
  model: The ResNet50 model
53
  trainloader: DataLoader for training data
 
54
  epochs (int): Number of epochs to train
55
  device (str): Device to train on ('cuda' or 'cpu')
56
  """
@@ -100,18 +120,19 @@ def train_model(model, trainloader, epochs=100, device='cuda'):
100
  epoch_acc = 100. * correct / total
101
  avg_loss = running_loss/len(trainloader)
102
 
103
- # Update epoch status with more detailed format
104
- epoch_pbar.write(f'Epoch {epoch+1}: Loss: {avg_loss:.3f} | Accuracy: {epoch_acc:.2f}%')
 
105
 
106
- scheduler.step(epoch_acc)
107
 
108
- if epoch_acc > best_acc:
109
- best_acc = epoch_acc
110
  save_model(model, 'best_model.pth')
111
- epoch_pbar.write(f'New best accuracy: {epoch_acc:.2f}%')
112
-
113
- if epoch_acc > 70:
114
- epoch_pbar.write(f"\nReached target accuracy of 70%!")
115
  break
116
 
117
  if __name__ == "__main__":
@@ -119,11 +140,12 @@ if __name__ == "__main__":
119
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
120
  print(f"Using device: {device}")
121
 
122
- # Get data
123
- trainloader = get_data(subset_size=5000) # Using subset for initial testing
 
124
 
125
  # Initialize model
126
  model = get_model(num_classes=10)
127
 
128
  # Train model
129
- train_model(model, trainloader, epochs=20, device=device)
 
18
  std=[0.229, 0.224, 0.225])
19
  ])
20
 
21
+ def get_data(subset_size=None, train=True):
22
  """
23
  Load and prepare the dataset
24
  Args:
25
  subset_size (int): If provided, return only a subset of data
26
+ train (bool): If True, return training data, else test data
27
  """
28
  transform = get_transforms()
29
+ dataset = torchvision.datasets.CIFAR10(
30
  root='./data',
31
+ train=train,
32
  download=True,
33
  transform=transform
34
  )
35
 
36
  if subset_size:
37
+ indices = torch.randperm(len(dataset))[:subset_size]
38
+ dataset = Subset(dataset, indices)
39
 
40
+ dataloader = DataLoader(
41
+ dataset,
42
  batch_size=32,
43
+ shuffle=True if train else False,
44
  num_workers=2
45
  )
46
 
47
+ return dataloader
48
 
49
+ def evaluate_model(model, testloader, device):
50
+ """
51
+ Evaluate the model on test data
52
+ """
53
+ model.eval()
54
+ correct = 0
55
+ total = 0
56
+
57
+ with torch.no_grad():
58
+ for inputs, labels in testloader:
59
+ inputs, labels = inputs.to(device), labels.to(device)
60
+ outputs = model(inputs)
61
+ _, predicted = outputs.max(1)
62
+ total += labels.size(0)
63
+ correct += predicted.eq(labels).sum().item()
64
+
65
+ return 100. * correct / total
66
+
67
+ def train_model(model, trainloader, testloader, epochs=100, device='cuda'):
68
  """
69
  Train the model
70
  Args:
71
  model: The ResNet50 model
72
  trainloader: DataLoader for training data
73
+ testloader: DataLoader for test data
74
  epochs (int): Number of epochs to train
75
  device (str): Device to train on ('cuda' or 'cpu')
76
  """
 
120
  epoch_acc = 100. * correct / total
121
  avg_loss = running_loss/len(trainloader)
122
 
123
+ # Evaluate on test data
124
+ test_acc = evaluate_model(model, testloader, device)
125
+ epoch_pbar.write(f'Epoch {epoch+1}: Train Loss: {avg_loss:.3f} | Train Acc: {epoch_acc:.2f}% | Test Acc: {test_acc:.2f}%')
126
 
127
+ scheduler.step(test_acc) # Using test accuracy for scheduler
128
 
129
+ if test_acc > best_acc:
130
+ best_acc = test_acc
131
  save_model(model, 'best_model.pth')
132
+ epoch_pbar.write(f'New best test accuracy: {test_acc:.2f}%')
133
+
134
+ if test_acc > 70:
135
+ epoch_pbar.write(f"\nReached target accuracy of 70% on test data!")
136
  break
137
 
138
  if __name__ == "__main__":
 
140
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
  print(f"Using device: {device}")
142
 
143
+ # Get train and test data
144
+ trainloader = get_data(subset_size=5000, train=True)
145
+ testloader = get_data(subset_size=1000, train=False)
146
 
147
  # Initialize model
148
  model = get_model(num_classes=10)
149
 
150
  # Train model
151
+ train_model(model, trainloader, testloader, epochs=20, device=device)