harry commited on
Commit
55af1cd
·
1 Parent(s): 5144b79

feat: add seed setting for deterministic training

Browse files
mnist_classifier/train.py CHANGED
@@ -7,8 +7,22 @@ from mnist_classifier.dataset import MNISTDataModule
7
  from mnist_classifier.model import MNISTModel
8
  from datetime import datetime
9
  import os
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def train():
 
 
 
12
  # Set device
13
  device = torch.device('cuda')
14
  print(f"Using device: {device}")
 
7
  from mnist_classifier.model import MNISTModel
8
  from datetime import datetime
9
  import os
10
+ import random
11
+ import numpy as np
12
+
13
+ def set_seed(seed):
14
+ torch.manual_seed(seed)
15
+ torch.cuda.manual_seed(seed)
16
+ torch.cuda.manual_seed_all(seed)
17
+ np.random.seed(seed)
18
+ random.seed(seed)
19
+ torch.backends.cudnn.deterministic = True
20
+ torch.backends.cudnn.benchmark = False
21
 
22
  def train():
23
+ # Set seed for reproducibility
24
+ set_seed(42)
25
+
26
  # Set device
27
  device = torch.device('cuda')
28
  print(f"Using device: {device}")
models/mnist_model_lr0.001_bs64_ep10.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f00fa1ee4fd08e6a5c41d3952b64e27b8bb122182f432332e18c9ee2af67609
3
  size 4803144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d1b474acf8a447dea4e3aaaf0371346ee7a7055d1c716fb371c059b9a1799bab
3
  size 4803144