harry
commited on
Commit
·
55af1cd
1
Parent(s):
5144b79
feat: add seed setting for deterministic training
Browse files
mnist_classifier/train.py
CHANGED
@@ -7,8 +7,22 @@ from mnist_classifier.dataset import MNISTDataModule
|
|
7 |
from mnist_classifier.model import MNISTModel
|
8 |
from datetime import datetime
|
9 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def train():
|
|
|
|
|
|
|
12 |
# Set device
|
13 |
device = torch.device('cuda')
|
14 |
print(f"Using device: {device}")
|
|
|
7 |
from mnist_classifier.model import MNISTModel
|
8 |
from datetime import datetime
|
9 |
import os
|
10 |
+
import random
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
def set_seed(seed):
|
14 |
+
torch.manual_seed(seed)
|
15 |
+
torch.cuda.manual_seed(seed)
|
16 |
+
torch.cuda.manual_seed_all(seed)
|
17 |
+
np.random.seed(seed)
|
18 |
+
random.seed(seed)
|
19 |
+
torch.backends.cudnn.deterministic = True
|
20 |
+
torch.backends.cudnn.benchmark = False
|
21 |
|
22 |
def train():
|
23 |
+
# Set seed for reproducibility
|
24 |
+
set_seed(42)
|
25 |
+
|
26 |
# Set device
|
27 |
device = torch.device('cuda')
|
28 |
print(f"Using device: {device}")
|
models/mnist_model_lr0.001_bs64_ep10.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4803144
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1b474acf8a447dea4e3aaaf0371346ee7a7055d1c716fb371c059b9a1799bab
|
3 |
size 4803144
|