henry000 commited on
Commit
67c056a
·
1 Parent(s): c134940

🌱 [Add] a random seed for training

Browse files
Files changed (1) hide show
  1. yolo/utils/logging_utils.py +14 -0
yolo/utils/logging_utils.py CHANGED
@@ -12,12 +12,14 @@ Example:
12
  """
13
 
14
  import os
 
15
  import sys
16
  from collections import deque
17
  from pathlib import Path
18
  from typing import Any, Dict, List
19
 
20
  import numpy as np
 
21
  import wandb
22
  import wandb.errors.term
23
  from loguru import logger
@@ -49,6 +51,18 @@ def custom_logger(quite: bool = False):
49
  )
50
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  class ProgressLogger(Progress):
53
  def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
54
  local_rank = int(os.getenv("LOCAL_RANK", "0"))
 
12
  """
13
 
14
  import os
15
+ import random
16
  import sys
17
  from collections import deque
18
  from pathlib import Path
19
  from typing import Any, Dict, List
20
 
21
  import numpy as np
22
+ import torch
23
  import wandb
24
  import wandb.errors.term
25
  from loguru import logger
 
51
  )
52
 
53
 
54
+ # TODO: should be moved to correct position
55
+ def set_seed(seed):
56
+ random.seed(seed)
57
+ np.random.seed(seed)
58
+ torch.manual_seed(seed)
59
+ if torch.cuda.is_available():
60
+ torch.cuda.manual_seed(seed)
61
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
62
+ torch.backends.cudnn.deterministic = True
63
+ torch.backends.cudnn.benchmark = False
64
+
65
+
66
  class ProgressLogger(Progress):
67
  def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
68
  local_rank = int(os.getenv("LOCAL_RANK", "0"))