🚑️ [Fix] progress bar in DDP or multiGPU env
Browse files- yolo/utils/logging_utils.py +10 -5
yolo/utils/logging_utils.py
CHANGED
@@ -23,9 +23,10 @@ import wandb
|
|
23 |
from lightning import LightningModule, Trainer, seed_everything
|
24 |
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
|
25 |
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
|
26 |
-
from lightning.pytorch.loggers import WandbLogger
|
|
|
27 |
from omegaconf import ListConfig
|
28 |
-
from rich import reconfigure
|
29 |
from rich.console import Console, Group
|
30 |
from rich.logging import RichHandler
|
31 |
from rich.table import Table
|
@@ -60,6 +61,7 @@ class YOLOCustomProgress(CustomProgress):
|
|
60 |
|
61 |
class YOLORichProgressBar(RichProgressBar):
|
62 |
@override
|
|
|
63 |
def _init_progress(self, trainer: "Trainer") -> None:
|
64 |
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
65 |
self._reset_progress_bar_ids()
|
@@ -85,6 +87,7 @@ class YOLORichProgressBar(RichProgressBar):
|
|
85 |
return Text("[cyan]Train [white]|")
|
86 |
|
87 |
@override
|
|
|
88 |
def on_train_start(self, trainer, pl_module):
|
89 |
self._init_progress(trainer)
|
90 |
num_epochs = trainer.max_epochs - 1
|
@@ -97,6 +100,7 @@ class YOLORichProgressBar(RichProgressBar):
|
|
97 |
self.progress.update(self.task_epoch, advance=-0.5)
|
98 |
|
99 |
@override
|
|
|
100 |
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
|
101 |
self._update(self.train_progress_bar_id, batch_idx + 1)
|
102 |
self._update_metrics(trainer, pl_module)
|
@@ -106,7 +110,7 @@ class YOLORichProgressBar(RichProgressBar):
|
|
106 |
metrics.pop("v_num")
|
107 |
for metrics_name, metrics_val in metrics.items():
|
108 |
if "Loss_step" in metrics_name:
|
109 |
-
epoch_descript += f"{metrics_name.removesuffix('_step'): ^9}|"
|
110 |
batch_descript += f" {metrics_val:2.2f} |"
|
111 |
|
112 |
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
|
@@ -114,12 +118,14 @@ class YOLORichProgressBar(RichProgressBar):
|
|
114 |
self.refresh()
|
115 |
|
116 |
@override
|
|
|
117 |
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
118 |
self._update_metrics(trainer, pl_module)
|
119 |
self.progress.remove_task(self.train_progress_bar_id)
|
120 |
self.train_progress_bar_id = None
|
121 |
|
122 |
@override
|
|
|
123 |
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
124 |
if trainer.state.fn == "fit":
|
125 |
self._update_metrics(trainer, pl_module)
|
@@ -162,8 +168,6 @@ class YOLORichModelSummary(RichModelSummary):
|
|
162 |
**summarize_kwargs: Any,
|
163 |
) -> None:
|
164 |
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
165 |
-
from rich import get_console
|
166 |
-
from rich.table import Table
|
167 |
|
168 |
console = get_console()
|
169 |
|
@@ -223,6 +227,7 @@ class ImageLogger(Callback):
|
|
223 |
|
224 |
|
225 |
def setup(cfg: Config):
|
|
|
226 |
if hasattr(cfg, "quite"):
|
227 |
logger.removeHandler("YOLO_logger")
|
228 |
return
|
|
|
23 |
from lightning import LightningModule, Trainer, seed_everything
|
24 |
from lightning.pytorch.callbacks import Callback, RichModelSummary, RichProgressBar
|
25 |
from lightning.pytorch.callbacks.progress.rich_progress import CustomProgress
|
26 |
+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
|
27 |
+
from lightning.pytorch.utilities import rank_zero_only
|
28 |
from omegaconf import ListConfig
|
29 |
+
from rich import get_console, reconfigure
|
30 |
from rich.console import Console, Group
|
31 |
from rich.logging import RichHandler
|
32 |
from rich.table import Table
|
|
|
61 |
|
62 |
class YOLORichProgressBar(RichProgressBar):
|
63 |
@override
|
64 |
+
@rank_zero_only
|
65 |
def _init_progress(self, trainer: "Trainer") -> None:
|
66 |
if self.is_enabled and (self.progress is None or self._progress_stopped):
|
67 |
self._reset_progress_bar_ids()
|
|
|
87 |
return Text("[cyan]Train [white]|")
|
88 |
|
89 |
@override
|
90 |
+
@rank_zero_only
|
91 |
def on_train_start(self, trainer, pl_module):
|
92 |
self._init_progress(trainer)
|
93 |
num_epochs = trainer.max_epochs - 1
|
|
|
100 |
self.progress.update(self.task_epoch, advance=-0.5)
|
101 |
|
102 |
@override
|
103 |
+
@rank_zero_only
|
104 |
def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: int):
|
105 |
self._update(self.train_progress_bar_id, batch_idx + 1)
|
106 |
self._update_metrics(trainer, pl_module)
|
|
|
110 |
metrics.pop("v_num")
|
111 |
for metrics_name, metrics_val in metrics.items():
|
112 |
if "Loss_step" in metrics_name:
|
113 |
+
epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"
|
114 |
batch_descript += f" {metrics_val:2.2f} |"
|
115 |
|
116 |
self.progress.update(self.task_epoch, advance=1 / self.total_train_batches, description=epoch_descript)
|
|
|
118 |
self.refresh()
|
119 |
|
120 |
@override
|
121 |
+
@rank_zero_only
|
122 |
def on_train_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
123 |
self._update_metrics(trainer, pl_module)
|
124 |
self.progress.remove_task(self.train_progress_bar_id)
|
125 |
self.train_progress_bar_id = None
|
126 |
|
127 |
@override
|
128 |
+
@rank_zero_only
|
129 |
def on_validation_end(self, trainer: "Trainer", pl_module: "LightningModule") -> None:
|
130 |
if trainer.state.fn == "fit":
|
131 |
self._update_metrics(trainer, pl_module)
|
|
|
168 |
**summarize_kwargs: Any,
|
169 |
) -> None:
|
170 |
from lightning.pytorch.utilities.model_summary import get_human_readable_count
|
|
|
|
|
171 |
|
172 |
console = get_console()
|
173 |
|
|
|
227 |
|
228 |
|
229 |
def setup(cfg: Config):
|
230 |
+
seed_everything(cfg.lucky_number)
|
231 |
if hasattr(cfg, "quite"):
|
232 |
logger.removeHandler("YOLO_logger")
|
233 |
return
|