Spaces:
Sleeping
Sleeping
Sreekanth Tangirala
commited on
Commit
·
ae63f95
1
Parent(s):
c773c40
test loop
Browse files
train.py
CHANGED
@@ -18,39 +18,59 @@ def get_transforms():
|
|
18 |
std=[0.229, 0.224, 0.225])
|
19 |
])
|
20 |
|
21 |
-
def get_data(subset_size=None):
|
22 |
"""
|
23 |
Load and prepare the dataset
|
24 |
Args:
|
25 |
subset_size (int): If provided, return only a subset of data
|
|
|
26 |
"""
|
27 |
transform = get_transforms()
|
28 |
-
|
29 |
root='./data',
|
30 |
-
train=
|
31 |
download=True,
|
32 |
transform=transform
|
33 |
)
|
34 |
|
35 |
if subset_size:
|
36 |
-
indices = torch.randperm(len(
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
41 |
batch_size=32,
|
42 |
-
shuffle=True,
|
43 |
num_workers=2
|
44 |
)
|
45 |
|
46 |
-
return
|
47 |
|
48 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
"""
|
50 |
Train the model
|
51 |
Args:
|
52 |
model: The ResNet50 model
|
53 |
trainloader: DataLoader for training data
|
|
|
54 |
epochs (int): Number of epochs to train
|
55 |
device (str): Device to train on ('cuda' or 'cpu')
|
56 |
"""
|
@@ -100,18 +120,19 @@ def train_model(model, trainloader, epochs=100, device='cuda'):
|
|
100 |
epoch_acc = 100. * correct / total
|
101 |
avg_loss = running_loss/len(trainloader)
|
102 |
|
103 |
-
#
|
104 |
-
|
|
|
105 |
|
106 |
-
scheduler.step(
|
107 |
|
108 |
-
if
|
109 |
-
best_acc =
|
110 |
save_model(model, 'best_model.pth')
|
111 |
-
epoch_pbar.write(f'New best accuracy: {
|
112 |
-
|
113 |
-
if
|
114 |
-
epoch_pbar.write(f"\nReached target accuracy of 70
|
115 |
break
|
116 |
|
117 |
if __name__ == "__main__":
|
@@ -119,11 +140,12 @@ if __name__ == "__main__":
|
|
119 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
120 |
print(f"Using device: {device}")
|
121 |
|
122 |
-
# Get data
|
123 |
-
trainloader = get_data(subset_size=5000)
|
|
|
124 |
|
125 |
# Initialize model
|
126 |
model = get_model(num_classes=10)
|
127 |
|
128 |
# Train model
|
129 |
-
train_model(model, trainloader, epochs=20, device=device)
|
|
|
18 |
std=[0.229, 0.224, 0.225])
|
19 |
])
|
20 |
|
21 |
+
def get_data(subset_size=None, train=True):
|
22 |
"""
|
23 |
Load and prepare the dataset
|
24 |
Args:
|
25 |
subset_size (int): If provided, return only a subset of data
|
26 |
+
train (bool): If True, return training data, else test data
|
27 |
"""
|
28 |
transform = get_transforms()
|
29 |
+
dataset = torchvision.datasets.CIFAR10(
|
30 |
root='./data',
|
31 |
+
train=train,
|
32 |
download=True,
|
33 |
transform=transform
|
34 |
)
|
35 |
|
36 |
if subset_size:
|
37 |
+
indices = torch.randperm(len(dataset))[:subset_size]
|
38 |
+
dataset = Subset(dataset, indices)
|
39 |
|
40 |
+
dataloader = DataLoader(
|
41 |
+
dataset,
|
42 |
batch_size=32,
|
43 |
+
shuffle=True if train else False,
|
44 |
num_workers=2
|
45 |
)
|
46 |
|
47 |
+
return dataloader
|
48 |
|
49 |
+
def evaluate_model(model, testloader, device):
|
50 |
+
"""
|
51 |
+
Evaluate the model on test data
|
52 |
+
"""
|
53 |
+
model.eval()
|
54 |
+
correct = 0
|
55 |
+
total = 0
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
for inputs, labels in testloader:
|
59 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
60 |
+
outputs = model(inputs)
|
61 |
+
_, predicted = outputs.max(1)
|
62 |
+
total += labels.size(0)
|
63 |
+
correct += predicted.eq(labels).sum().item()
|
64 |
+
|
65 |
+
return 100. * correct / total
|
66 |
+
|
67 |
+
def train_model(model, trainloader, testloader, epochs=100, device='cuda'):
|
68 |
"""
|
69 |
Train the model
|
70 |
Args:
|
71 |
model: The ResNet50 model
|
72 |
trainloader: DataLoader for training data
|
73 |
+
testloader: DataLoader for test data
|
74 |
epochs (int): Number of epochs to train
|
75 |
device (str): Device to train on ('cuda' or 'cpu')
|
76 |
"""
|
|
|
120 |
epoch_acc = 100. * correct / total
|
121 |
avg_loss = running_loss/len(trainloader)
|
122 |
|
123 |
+
# Evaluate on test data
|
124 |
+
test_acc = evaluate_model(model, testloader, device)
|
125 |
+
epoch_pbar.write(f'Epoch {epoch+1}: Train Loss: {avg_loss:.3f} | Train Acc: {epoch_acc:.2f}% | Test Acc: {test_acc:.2f}%')
|
126 |
|
127 |
+
scheduler.step(test_acc) # Using test accuracy for scheduler
|
128 |
|
129 |
+
if test_acc > best_acc:
|
130 |
+
best_acc = test_acc
|
131 |
save_model(model, 'best_model.pth')
|
132 |
+
epoch_pbar.write(f'New best test accuracy: {test_acc:.2f}%')
|
133 |
+
|
134 |
+
if test_acc > 70:
|
135 |
+
epoch_pbar.write(f"\nReached target accuracy of 70% on test data!")
|
136 |
break
|
137 |
|
138 |
if __name__ == "__main__":
|
|
|
140 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
141 |
print(f"Using device: {device}")
|
142 |
|
143 |
+
# Get train and test data
|
144 |
+
trainloader = get_data(subset_size=5000, train=True)
|
145 |
+
testloader = get_data(subset_size=1000, train=False)
|
146 |
|
147 |
# Initialize model
|
148 |
model = get_model(num_classes=10)
|
149 |
|
150 |
# Train model
|
151 |
+
train_model(model, trainloader, testloader, epochs=20, device=device)
|