glenn-jocher commited on
Commit
22d6088
·
1 Parent(s): 55ca5c7

speed-reproducibility fix #17

Browse files
Files changed (2) hide show
  1. train.py +1 -1
  2. utils/torch_utils.py +5 -2
train.py CHANGED
@@ -63,7 +63,7 @@ def train(hyp):
63
  weights = opt.weights # initial training weights
64
 
65
  # Configure
66
- init_seeds()
67
  with open(opt.data) as f:
68
  data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
69
  train_path = data_dict['train']
 
63
  weights = opt.weights # initial training weights
64
 
65
  # Configure
66
+ init_seeds(1)
67
  with open(opt.data) as f:
68
  data_dict = yaml.load(f, Loader=yaml.FullLoader) # model dict
69
  train_path = data_dict['train']
utils/torch_utils.py CHANGED
@@ -12,8 +12,11 @@ import torch.nn.functional as F
12
  def init_seeds(seed=0):
13
  torch.manual_seed(seed)
14
 
15
- # Reduce randomness (may be slower on Tesla GPUs) # https://pytorch.org/docs/stable/notes/randomness.html
16
- if seed == 0:
 
 
 
17
  cudnn.deterministic = False
18
  cudnn.benchmark = True
19
 
 
12
  def init_seeds(seed=0):
13
  torch.manual_seed(seed)
14
 
15
+ # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
16
+ if seed == 0: # slower, more reproducible
17
+ cudnn.deterministic = True
18
+ cudnn.benchmark = False
19
+ else: # faster, less reproducible
20
  cudnn.deterministic = False
21
  cudnn.benchmark = True
22