Commit
·
22d6088
1
Parent(s):
55ca5c7
speed-reproducibility fix #17
Browse files- train.py +1 -1
- utils/torch_utils.py +5 -2
train.py
CHANGED
@@ -63,7 +63,7 @@ def train(hyp):
|
|
63 |
weights = opt.weights # initial training weights
|
64 |
|
65 |
# Configure
|
66 |
-
init_seeds()
|
67 |
with open(opt.data) as f:
|
68 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
69 |
train_path = data_dict['train']
|
|
|
63 |
weights = opt.weights # initial training weights
|
64 |
|
65 |
# Configure
|
66 |
+
init_seeds(1)
|
67 |
with open(opt.data) as f:
|
68 |
data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
|
69 |
train_path = data_dict['train']
|
utils/torch_utils.py
CHANGED
@@ -12,8 +12,11 @@ import torch.nn.functional as F
|
|
12 |
def init_seeds(seed=0):
|
13 |
torch.manual_seed(seed)
|
14 |
|
15 |
-
#
|
16 |
-
if seed == 0:
|
|
|
|
|
|
|
17 |
cudnn.deterministic = False
|
18 |
cudnn.benchmark = True
|
19 |
|
|
|
12 |
def init_seeds(seed=0):
|
13 |
torch.manual_seed(seed)
|
14 |
|
15 |
+
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
|
16 |
+
if seed == 0: # slower, more reproducible
|
17 |
+
cudnn.deterministic = True
|
18 |
+
cudnn.benchmark = False
|
19 |
+
else: # faster, less reproducible
|
20 |
cudnn.deterministic = False
|
21 |
cudnn.benchmark = True
|
22 |
|