π [Update] logging function, log only rank0 device
Browse files- yolo/utils/logging_utils.py +17 -2
yolo/utils/logging_utils.py
CHANGED
@@ -68,8 +68,8 @@ def set_seed(seed):
|
|
68 |
class ProgressLogger(Progress):
|
69 |
def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
|
70 |
set_seed(cfg.lucky_number)
|
71 |
-
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
72 |
-
self.quite_mode = local_rank or getattr(cfg, "quite", False)
|
73 |
custom_logger(self.quite_mode)
|
74 |
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
75 |
|
@@ -93,13 +93,23 @@ class ProgressLogger(Progress):
|
|
93 |
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
94 |
)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
def get_renderable(self):
|
97 |
renderable = Group(*self.get_renderables(), self.ap_table)
|
98 |
return renderable
|
99 |
|
|
|
100 |
def start_train(self, num_epochs: int):
|
101 |
self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
|
102 |
|
|
|
103 |
def start_one_epoch(
|
104 |
self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
|
105 |
):
|
@@ -115,6 +125,7 @@ class ProgressLogger(Progress):
|
|
115 |
self.wandb.log({lr_name: lr_value}, step=epoch_idx)
|
116 |
self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
|
117 |
|
|
|
118 |
def one_batch(self, batch_info: Dict[str, Tensor] = None):
|
119 |
epoch_descript = "[cyan]" + self.task + "[white] |"
|
120 |
batch_descript = "|"
|
@@ -127,6 +138,7 @@ class ProgressLogger(Progress):
|
|
127 |
if hasattr(self, "task_epoch"):
|
128 |
self.update(self.task_epoch, description=epoch_descript)
|
129 |
|
|
|
130 |
def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
|
131 |
if self.task == "Train":
|
132 |
prefix = "Loss/"
|
@@ -137,9 +149,11 @@ class ProgressLogger(Progress):
|
|
137 |
self.wandb.log(batch_info, step=epoch_idx)
|
138 |
self.remove_task(self.batch_task)
|
139 |
|
|
|
140 |
def start_pycocotools(self):
|
141 |
self.batch_task = self.add_task("[green]Run pycocotools", total=1)
|
142 |
|
|
|
143 |
def finish_pycocotools(self, result, epoch_idx=-1):
|
144 |
ap_table, ap_main = make_ap_table(result, self.ap_past_list, self.last_result, epoch_idx)
|
145 |
self.last_result = np.maximum(result, self.last_result)
|
@@ -152,6 +166,7 @@ class ProgressLogger(Progress):
|
|
152 |
self.refresh()
|
153 |
self.remove_task(self.batch_task)
|
154 |
|
|
|
155 |
def finish_train(self):
|
156 |
self.remove_task(self.task_epoch)
|
157 |
self.stop()
|
|
|
68 |
class ProgressLogger(Progress):
|
69 |
def __init__(self, cfg: Config, exp_name: str, *args, **kwargs):
|
70 |
set_seed(cfg.lucky_number)
|
71 |
+
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
72 |
+
self.quite_mode = self.local_rank or getattr(cfg, "quite", False)
|
73 |
custom_logger(self.quite_mode)
|
74 |
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
75 |
|
|
|
93 |
project="YOLO", resume="allow", mode="online", dir=self.save_path, id=None, name=exp_name
|
94 |
)
|
95 |
|
96 |
+
def rank_check(logging_function):
|
97 |
+
def wrapper(self, *args, **kwargs):
|
98 |
+
if getattr(self, "local_rank", 0) != 0:
|
99 |
+
return
|
100 |
+
return logging_function(self, *args, **kwargs)
|
101 |
+
|
102 |
+
return wrapper
|
103 |
+
|
104 |
def get_renderable(self):
|
105 |
renderable = Group(*self.get_renderables(), self.ap_table)
|
106 |
return renderable
|
107 |
|
108 |
+
@rank_check
|
109 |
def start_train(self, num_epochs: int):
|
110 |
self.task_epoch = self.add_task(f"[cyan]Start Training {num_epochs} epochs", total=num_epochs)
|
111 |
|
112 |
+
@rank_check
|
113 |
def start_one_epoch(
|
114 |
self, num_batches: int, task: str = "Train", optimizer: Optimizer = None, epoch_idx: int = None
|
115 |
):
|
|
|
125 |
self.wandb.log({lr_name: lr_value}, step=epoch_idx)
|
126 |
self.batch_task = self.add_task(f"[green] Phase: {task}", total=num_batches)
|
127 |
|
128 |
+
@rank_check
|
129 |
def one_batch(self, batch_info: Dict[str, Tensor] = None):
|
130 |
epoch_descript = "[cyan]" + self.task + "[white] |"
|
131 |
batch_descript = "|"
|
|
|
138 |
if hasattr(self, "task_epoch"):
|
139 |
self.update(self.task_epoch, description=epoch_descript)
|
140 |
|
141 |
+
@rank_check
|
142 |
def finish_one_epoch(self, batch_info: Dict[str, Any] = None, epoch_idx: int = -1):
|
143 |
if self.task == "Train":
|
144 |
prefix = "Loss/"
|
|
|
149 |
self.wandb.log(batch_info, step=epoch_idx)
|
150 |
self.remove_task(self.batch_task)
|
151 |
|
152 |
+
@rank_check
|
153 |
def start_pycocotools(self):
|
154 |
self.batch_task = self.add_task("[green]Run pycocotools", total=1)
|
155 |
|
156 |
+
@rank_check
|
157 |
def finish_pycocotools(self, result, epoch_idx=-1):
|
158 |
ap_table, ap_main = make_ap_table(result, self.ap_past_list, self.last_result, epoch_idx)
|
159 |
self.last_result = np.maximum(result, self.last_result)
|
|
|
166 |
self.refresh()
|
167 |
self.remove_task(self.batch_task)
|
168 |
|
169 |
+
@rank_check
|
170 |
def finish_train(self):
|
171 |
self.remove_task(self.task_epoch)
|
172 |
self.stop()
|