glenn-jocher commited on
Commit
84bfa89
·
unverified ·
1 Parent(s): 302a1b0

Consolidate `init_seeds()` (#4849)

Browse files
Files changed (2) hide show
  1. utils/general.py +5 -3
  2. utils/torch_utils.py +0 -10
utils/general.py CHANGED
@@ -29,7 +29,6 @@ import yaml
29
 
30
  from utils.downloads import gsutil_getsize
31
  from utils.metrics import box_iou, fitness
32
- from utils.torch_utils import init_torch_seeds
33
 
34
  # Settings
35
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
@@ -91,10 +90,13 @@ def set_logging(rank=-1, verbose=True):
91
 
92
 
93
  def init_seeds(seed=0):
94
- # Initialize random number generator (RNG) seeds
 
 
95
  random.seed(seed)
96
  np.random.seed(seed)
97
- init_torch_seeds(seed)
 
98
 
99
 
100
  def get_latest_run(search_dir='.'):
 
29
 
30
  from utils.downloads import gsutil_getsize
31
  from utils.metrics import box_iou, fitness
 
32
 
33
  # Settings
34
  torch.set_printoptions(linewidth=320, precision=5, profile='long')
 
90
 
91
 
92
  def init_seeds(seed=0):
93
+ # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
94
+ # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
95
+ import torch.backends.cudnn as cudnn
96
  random.seed(seed)
97
  np.random.seed(seed)
98
+ torch.manual_seed(seed)
99
+ cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
100
 
101
 
102
  def get_latest_run(search_dir='.'):
utils/torch_utils.py CHANGED
@@ -15,7 +15,6 @@ from copy import deepcopy
15
  from pathlib import Path
16
 
17
  import torch
18
- import torch.backends.cudnn as cudnn
19
  import torch.distributed as dist
20
  import torch.nn as nn
21
  import torch.nn.functional as F
@@ -41,15 +40,6 @@ def torch_distributed_zero_first(local_rank: int):
41
  dist.barrier(device_ids=[0])
42
 
43
 
44
- def init_torch_seeds(seed=0):
45
- # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
46
- torch.manual_seed(seed)
47
- if seed == 0: # slower, more reproducible
48
- cudnn.benchmark, cudnn.deterministic = False, True
49
- else: # faster, less reproducible
50
- cudnn.benchmark, cudnn.deterministic = True, False
51
-
52
-
53
  def date_modified(path=__file__):
54
  # return human-readable file modification date, i.e. '2021-3-26'
55
  t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)
 
15
  from pathlib import Path
16
 
17
  import torch
 
18
  import torch.distributed as dist
19
  import torch.nn as nn
20
  import torch.nn.functional as F
 
40
  dist.barrier(device_ids=[0])
41
 
42
 
 
 
 
 
 
 
 
 
 
43
  def date_modified(path=__file__):
44
  # return human-readable file modification date, i.e. '2021-3-26'
45
  t = datetime.datetime.fromtimestamp(Path(path).stat().st_mtime)