henry000 commited on
Commit
1d2b161
·
1 Parent(s): 240dcb0

🚑️ [Fix] progress bar in DDP or multiGPU env

Browse files
Files changed (1) hide show
  1. yolo/utils/logging_utils.py +10 -5
yolo/utils/logging_utils.py CHANGED
@@ -23,9 +23,10 @@ import wandb
23
  from lightning import LightningModule, Trainer, seed_everything
24
  from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
25
  from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
26
- from lightning.pytorch.loggers import WandbLogger
 
27
  from omegaconf import ListConfig
28
- from rich import reconfigure
29
  from rich.console import Console, Group
30
  from rich.logging import RichHandler
31
  from rich.table import Table
@@ -60,6 +61,7 @@ class YOLOCustomProgress(CustomProgress):
60
 
61
  class YOLORichProgressBar(RichProgressBar):
62
  @override
 
63
  def _init_progress(self, trainer: "Trainer") -> None:
64
  if self.is_enabled and (self.progress is None or self._progress_stopped):
65
  self._reset_progress_bar_ids()
@@ -85,6 +87,7 @@ class YOLORichProgressBar(RichProgressBar):
85
  return Text("[cyan]Train [white]|")
86
 
87
  @override
 
88
  def on_train_start(self, trainer, pl_module):
89
  self._init_progress(trainer)
90
  num_epochs = trainer.max_epochs - 1
@@ -97,6 +100,7 @@ class YOLORichProgressBar(RichProgressBar):
97
  self.progress.update(self.task_epoch, advance=-0.5)
98
 
99
  @override
 
100
  def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
101
  self._update(self.train_progress_bar_id, batch_idx + 1)
102
  self._update_metrics(trainer, pl_module)
@@ -106,7 +110,7 @@ class YOLORichProgressBar(RichProgressBar):
106
  metrics.pop("v_num")
107
  for metrics_name, metrics_val in metrics.items():
108
  if "Loss_step" in metrics_name:
109
- epoch_descript += f"{metrics_name.removesuffix('_step'): ^9}|"
110
  batch_descript += f" {metrics_val:2.2f} |"
111
 
112
  self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
@@ -114,12 +118,14 @@ class YOLORichProgressBar(RichProgressBar):
114
  self.refresh()
115
 
116
  @override
 
117
  def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
118
  self._update_metrics(trainer, pl_module)
119
  self.progress.remove_task(self.train_progress_bar_id)
120
  self.train_progress_bar_id = None
121
 
122
  @override
 
123
  def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
124
  if trainer.state.fn == "fit":
125
  self._update_metrics(trainer, pl_module)
@@ -162,8 +168,6 @@ class YOLORichModelSummary(RichModelSummary):
162
  **summarize_kwargs: Any,
163
  ) -> None:
164
  from lightning.pytorch.utilities.model_summary import get_human_readable_count
165
- from rich import get_console
166
- from rich.table import Table
167
 
168
  console = get_console()
169
 
@@ -223,6 +227,7 @@ class ImageLogger(Callback):
223
 
224
 
225
  def setup(cfg: Config):
 
226
  if hasattr(cfg, "quite"):
227
  logger.removeHandler("YOLO_logger")
228
  return
 
23
  from lightning import LightningModule, Trainer, seed_everything
24
  from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
25
  from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
26
+ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
27
+ from lightning.pytorch.utilities import rank_zero_only
28
  from omegaconf import ListConfig
29
+ from rich import get_console, reconfigure
30
  from rich.console import Console, Group
31
  from rich.logging import RichHandler
32
  from rich.table import Table
 
61
 
62
  class YOLORichProgressBar(RichProgressBar):
63
  @override
64
+ @rank_zero_only
65
  def _init_progress(self, trainer: "Trainer") -> None:
66
  if self.is_enabled and (self.progress is None or self._progress_stopped):
67
  self._reset_progress_bar_ids()
 
87
  return Text("[cyan]Train [white]|")
88
 
89
  @override
90
+ @rank_zero_only
91
  def on_train_start(self, trainer, pl_module):
92
  self._init_progress(trainer)
93
  num_epochs = trainer.max_epochs - 1
 
100
  self.progress.update(self.task_epoch, advance=-0.5)
101
 
102
  @override
103
+ @rank_zero_only
104
  def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
105
  self._update(self.train_progress_bar_id, batch_idx + 1)
106
  self._update_metrics(trainer, pl_module)
 
110
  metrics.pop("v_num")
111
  for metrics_name, metrics_val in metrics.items():
112
  if "Loss_step" in metrics_name:
113
+ epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
114
  batch_descript += f" {metrics_val:2.2f} |"
115
 
116
  self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
 
118
  self.refresh()
119
 
120
  @override
121
+ @rank_zero_only
122
  def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
123
  self._update_metrics(trainer, pl_module)
124
  self.progress.remove_task(self.train_progress_bar_id)
125
  self.train_progress_bar_id = None
126
 
127
  @override
128
+ @rank_zero_only
129
  def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
130
  if trainer.state.fn == "fit":
131
  self._update_metrics(trainer, pl_module)
 
168
  **summarize_kwargs: Any,
169
  ) -> None:
170
  from lightning.pytorch.utilities.model_summary import get_human_readable_count
 
 
171
 
172
  console = get_console()
173
 
 
227
 
228
 
229
  def setup(cfg: Config):
230
+ seed_everything(cfg.lucky_number)
231
  if hasattr(cfg, "quite"):
232
  logger.removeHandler("YOLO_logger")
233
  return