harry commited on
Commit
f774571
·
1 Parent(s): 55af1cd

feat: enhance training loop with tqdm progress bar and configurable parameters

Browse files
mnist_classifier/train.py CHANGED
@@ -9,6 +9,7 @@ from datetime import datetime
9
  import os
10
  import random
11
  import numpy as np
 
12
 
13
  def set_seed(seed):
14
  torch.manual_seed(seed)
@@ -20,6 +21,11 @@ def set_seed(seed):
20
  torch.backends.cudnn.benchmark = False
21
 
22
  def train():
 
 
 
 
 
23
  # Set seed for reproducibility
24
  set_seed(42)
25
 
@@ -32,18 +38,14 @@ def train():
32
  writer = SummaryWriter(log_dir)
33
 
34
  # Setup data
35
- data_module = MNISTDataModule(batch_size=64, val_batch_size=1000)
36
  train_loader, test_loader = data_module.get_dataloaders()
37
 
38
  # Initialize model, optimizer, and loss function
39
  model = MNISTModel().to(device)
40
- optimizer = optim.Adam(model.parameters())
41
  criterion = nn.CrossEntropyLoss()
42
 
43
- # Training loop
44
- learning_rate = 0.001
45
- batch_size = 64
46
- epochs = 10
47
 
48
  num_epochs = epochs
49
  for epoch in range(num_epochs):
@@ -52,28 +54,44 @@ def train():
52
  correct = 0
53
  total = 0
54
 
55
- for batch_idx, batch in enumerate(train_loader):
56
- images, labels = batch[0].to(device), batch[1].to(device)
57
-
58
- optimizer.zero_grad()
59
- outputs = model(images)
60
- loss = criterion(outputs, labels)
61
- loss.backward()
62
- optimizer.step()
63
-
64
- running_loss += loss.item()
65
- _, predicted = outputs.max(1)
66
- total += labels.size(0)
67
- correct += predicted.eq(labels).sum().item()
68
-
69
- if batch_idx % 100 == 99:
70
- writer.add_scalar('training loss',
71
- running_loss / 100,
72
- epoch * len(train_loader) + batch_idx)
73
- writer.add_scalar('training accuracy',
74
- 100. * correct / total,
75
- epoch * len(train_loader) + batch_idx)
76
- running_loss = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # Validation phase
79
  model.eval()
 
9
  import os
10
  import random
11
  import numpy as np
12
+ from tqdm import tqdm
13
 
14
  def set_seed(seed):
15
  torch.manual_seed(seed)
 
21
  torch.backends.cudnn.benchmark = False
22
 
23
  def train():
24
+ # Training loop
25
+ learning_rate = 0.001
26
+ batch_size = 128
27
+ epochs = 10
28
+
29
  # Set seed for reproducibility
30
  set_seed(42)
31
 
 
38
  writer = SummaryWriter(log_dir)
39
 
40
  # Setup data
41
+ data_module = MNISTDataModule(batch_size=batch_size, val_batch_size=1000)
42
  train_loader, test_loader = data_module.get_dataloaders()
43
 
44
  # Initialize model, optimizer, and loss function
45
  model = MNISTModel().to(device)
46
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
47
  criterion = nn.CrossEntropyLoss()
48
 
 
 
 
 
49
 
50
  num_epochs = epochs
51
  for epoch in range(num_epochs):
 
54
  correct = 0
55
  total = 0
56
 
57
+ with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
58
+ for batch_idx, batch in enumerate(train_loader):
59
+ images, labels = batch[0].to(device), batch[1].to(device)
60
+
61
+ if batch_idx == 0:
62
+ print(f"images shape: {images.shape}")
63
+ print(f"labels shape: {labels.shape}")
64
+
65
+ # print number of images in batch
66
+ print(f"Number of images in batch: {len(images)}")
67
+
68
+ optimizer.zero_grad()
69
+ outputs = model(images)
70
+ loss = criterion(outputs, labels)
71
+ loss.backward()
72
+ optimizer.step()
73
+
74
+ running_loss += loss.item()
75
+ _, predicted = outputs.max(1)
76
+ total += labels.size(0)
77
+ correct += predicted.eq(labels).sum().item()
78
+
79
+ # Update tqdm progress bar
80
+ pbar.set_postfix({
81
+ 'loss': running_loss / (batch_idx + 1),
82
+ 'accuracy': 100. * correct / total,
83
+ 'step': batch_idx + 1
84
+ })
85
+ pbar.update(1)
86
+
87
+ if batch_idx % 100 == 99:
88
+ writer.add_scalar('training loss',
89
+ running_loss / 100,
90
+ epoch * len(train_loader) + batch_idx)
91
+ writer.add_scalar('training accuracy',
92
+ 100. * correct / total,
93
+ epoch * len(train_loader) + batch_idx)
94
+ running_loss = 0.0
95
 
96
  # Validation phase
97
  model.eval()
models/mnist_model_lr0.001_bs128_ep10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f9d6050aca93a46463f77e1a9dd4566da96e07905b9b872b519fa964f6984fc
3
+ size 4803156