henry000 commited on
Commit
1132d27
Β·
1 Parent(s): ba6baa6

πŸ’„ [Update] logging function, log only rank0 device

Browse files
Files changed (1) hide show
  1. yolo/utils/logging_utils.py +17 -2
yolo/utils/logging_utils.py CHANGED
@@ -68,8 +68,8 @@ def set_seed(seed):
68
  class ProgressLogger(Progress):
69
  def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
70
  set_seed(cfg.lucky_number)
71
- local_rank = int(os.getenv("LOCAL_RANK", "0"))
72
- self.quite_mode = local_rank or getattr(cfg, "quite", False)
73
  custom_logger(self.quite_mode)
74
  self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
75
 
@@ -93,13 +93,23 @@ class ProgressLogger(Progress):
93
  project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
94
  )
95
 
 
 
 
 
 
 
 
 
96
  def get_renderable(self):
97
  renderable = Group(*self.get_renderables(), self.ap_table)
98
  return renderable
99
 
 
100
  def start_train(self, num_epochs: int):
101
  self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
102
 
 
103
  def start_one_epoch(
104
  self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
105
  ):
@@ -115,6 +125,7 @@ class ProgressLogger(Progress):
115
  self.wandb.log({lr_name: lr_value}, step=epoch_idx)
116
  self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
117
 
 
118
  def one_batch(self, batch_info: Dict[str, Tensor] = None):
119
  epoch_descript = "[cyan]" + self.task + "[white] |"
120
  batch_descript = "|"
@@ -127,6 +138,7 @@ class ProgressLogger(Progress):
127
  if hasattr(self, "task_epoch"):
128
  self.update(self.task_epoch, description=epoch_descript)
129
 
 
130
  def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
131
  if self.task == "Train":
132
  prefix = "Loss/"
@@ -137,9 +149,11 @@ class ProgressLogger(Progress):
137
  self.wandb.log(batch_info, step=epoch_idx)
138
  self.remove_task(self.batch_task)
139
 
 
140
  def start_pycocotools(self):
141
  self.batch_task = self.add_task("[green]Run pycocotools", total=1)
142
 
 
143
  def finish_pycocotools(self, result, epoch_idx=-1):
144
  ap_table, ap_main = make_ap_table(result, self.ap_past_list, self.last_result, epoch_idx)
145
  self.last_result = np.maximum(result, self.last_result)
@@ -152,6 +166,7 @@ class ProgressLogger(Progress):
152
  self.refresh()
153
  self.remove_task(self.batch_task)
154
 
 
155
  def finish_train(self):
156
  self.remove_task(self.task_epoch)
157
  self.stop()
 
68
  class ProgressLogger(Progress):
69
  def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
70
  set_seed(cfg.lucky_number)
71
+ self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
72
+ self.quite_mode = self.local_rank or getattr(cfg, "quite", False)
73
  custom_logger(self.quite_mode)
74
  self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
75
 
 
93
  project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
94
  )
95
 
96
+ def rank_check(logging_function):
97
+ def wrapper(self, *args, **kwargs):
98
+ if getattr(self, "local_rank", 0) != 0:
99
+ return
100
+ return logging_function(self, *args, **kwargs)
101
+
102
+ return wrapper
103
+
104
  def get_renderable(self):
105
  renderable = Group(*self.get_renderables(), self.ap_table)
106
  return renderable
107
 
108
+ @rank_check
109
  def start_train(self, num_epochs: int):
110
  self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
111
 
112
+ @rank_check
113
  def start_one_epoch(
114
  self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
115
  ):
 
125
  self.wandb.log({lr_name: lr_value}, step=epoch_idx)
126
  self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
127
 
128
+ @rank_check
129
  def one_batch(self, batch_info: Dict[str, Tensor] = None):
130
  epoch_descript = "[cyan]" + self.task + "[white] |"
131
  batch_descript = "|"
 
138
  if hasattr(self, "task_epoch"):
139
  self.update(self.task_epoch, description=epoch_descript)
140
 
141
+ @rank_check
142
  def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
143
  if self.task == "Train":
144
  prefix = "Loss/"
 
149
  self.wandb.log(batch_info, step=epoch_idx)
150
  self.remove_task(self.batch_task)
151
 
152
+ @rank_check
153
  def start_pycocotools(self):
154
  self.batch_task = self.add_task("[green]Run pycocotools", total=1)
155
 
156
+ @rank_check
157
  def finish_pycocotools(self, result, epoch_idx=-1):
158
  ap_table, ap_main = make_ap_table(result, self.ap_past_list, self.last_result, epoch_idx)
159
  self.last_result = np.maximum(result, self.last_result)
 
166
  self.refresh()
167
  self.remove_task(self.batch_task)
168
 
169
+ @rank_check
170
  def finish_train(self):
171
  self.remove_task(self.task_epoch)
172
  self.stop()