Manan Goel commited on
Commit
e578635
·
1 Parent(s): 1ee5e0e

feat(engine): support wandb logger (#1144)

Browse files
docs/quick_run.md CHANGED
@@ -56,6 +56,26 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache]
56
  * --fp16: mixed precision training
57
  * --cache: caching imgs into RAM to accelarate training, which need large system RAM.
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  **Multi Machine Training**
60
 
61
  We also support multi-nodes training. Just add the following args:
 
56
  * --fp16: mixed precision training
57
  * --cache: caching imgs into RAM to accelarate training, which need large system RAM.
58
 
59
+ **Weights & Biases for Logging**
60
+
61
+ To use W&B for logging, install wandb in your environment and log in to your W&B account using
62
+
63
+ ```shell
64
+ pip install wandb
65
+ wandb login
66
+ ```
67
+
68
+ Log in to your W&B account
69
+
70
+ To start logging metrics to W&B during training add the flag `--logger` to the previous command and use the prefix "wandb-" to specify arguments for initializing the wandb run.
71
+
72
+ ```shell
73
+ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
74
+ yolox-m
75
+ yolox-l
76
+ yolox-x
77
+ ```
78
+
79
  **Multi Machine Training**
80
 
81
  We also support multi-nodes training. Just add the following args:
tools/train.py CHANGED
@@ -80,6 +80,13 @@ def make_parser():
80
  action="store_true",
81
  help="occupy GPU memory first for training.",
82
  )
 
 
 
 
 
 
 
83
  parser.add_argument(
84
  "opts",
85
  help="Modify config options using the command-line",
 
80
  action="store_true",
81
  help="occupy GPU memory first for training.",
82
  )
83
+ parser.add_argument(
84
+ "-l",
85
+ "--logger",
86
+ type=str,
87
+ help="Logger to be used for metrics",
88
+ default="tensorboard"
89
+ )
90
  parser.add_argument(
91
  "opts",
92
  help="Modify config options using the command-line",
yolox/core/trainer.py CHANGED
@@ -15,6 +15,7 @@ from yolox.data import DataPrefetcher
15
  from yolox.utils import (
16
  MeterBuffer,
17
  ModelEMA,
 
18
  all_reduce_norm,
19
  get_local_rank,
20
  get_model_info,
@@ -173,9 +174,18 @@ class Trainer:
173
  self.evaluator = self.exp.get_evaluator(
174
  batch_size=self.args.batch_size, is_distributed=self.is_distributed
175
  )
176
- # Tensorboard logger
177
  if self.rank == 0:
178
- self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
 
 
 
 
 
 
 
 
 
179
 
180
  logger.info("Training start...")
181
  logger.info("\n{}".format(model))
@@ -184,6 +194,9 @@ class Trainer:
184
  logger.info(
185
  "Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
186
  )
 
 
 
187
 
188
  def before_epoch(self):
189
  logger.info("---> start train epoch{}".format(self.epoch + 1))
@@ -246,6 +259,12 @@ class Trainer:
246
  )
247
  + (", size: {:d}, {}".format(self.input_size[0], eta_str))
248
  )
 
 
 
 
 
 
249
  self.meter.clear_meters()
250
 
251
  # random resizing
@@ -309,8 +328,15 @@ class Trainer:
309
 
310
  self.model.train()
311
  if self.rank == 0:
312
- self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
313
- self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
 
 
 
 
 
 
 
314
  logger.info("\n" + summary)
315
  synchronize()
316
 
@@ -334,3 +360,6 @@ class Trainer:
334
  self.file_name,
335
  ckpt_name,
336
  )
 
 
 
 
15
  from yolox.utils import (
16
  MeterBuffer,
17
  ModelEMA,
18
+ WandbLogger,
19
  all_reduce_norm,
20
  get_local_rank,
21
  get_model_info,
 
174
  self.evaluator = self.exp.get_evaluator(
175
  batch_size=self.args.batch_size, is_distributed=self.is_distributed
176
  )
177
+ # Tensorboard and Wandb loggers
178
  if self.rank == 0:
179
+ if self.args.logger == "tensorboard":
180
+ self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
181
+ elif self.args.logger == "wandb":
182
+ wandb_params = dict()
183
+ for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
184
+ if k.startswith("wandb-"):
185
+ wandb_params.update({k.lstrip("wandb-"): v})
186
+ self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
187
+ else:
188
+ raise ValueError("logger must be either 'tensorboard' or 'wandb'")
189
 
190
  logger.info("Training start...")
191
  logger.info("\n{}".format(model))
 
194
  logger.info(
195
  "Training of experiment is done and the best AP is {:.2f}".format(self.best_ap * 100)
196
  )
197
+ if self.rank == 0:
198
+ if self.args.logger == "wandb":
199
+ self.wandb_logger.finish()
200
 
201
  def before_epoch(self):
202
  logger.info("---> start train epoch{}".format(self.epoch + 1))
 
259
  )
260
  + (", size: {:d}, {}".format(self.input_size[0], eta_str))
261
  )
262
+
263
+ if self.rank == 0:
264
+ if self.args.logger == "wandb":
265
+ self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
266
+ self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})
267
+
268
  self.meter.clear_meters()
269
 
270
  # random resizing
 
328
 
329
  self.model.train()
330
  if self.rank == 0:
331
+ if self.args.logger == "tensorboard":
332
+ self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
333
+ self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
334
+ if self.args.logger == "wandb":
335
+ self.wandb_logger.log_metrics({
336
+ "val/COCOAP50": ap50,
337
+ "val/COCOAP50_95": ap50_95,
338
+ "epoch": self.epoch + 1,
339
+ })
340
  logger.info("\n" + summary)
341
  synchronize()
342
 
 
360
  self.file_name,
361
  ckpt_name,
362
  )
363
+
364
+ if self.args.logger == "wandb":
365
+ self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
yolox/utils/__init__.py CHANGED
@@ -8,7 +8,7 @@ from .checkpoint import load_ckpt, save_checkpoint
8
  from .demo_utils import *
9
  from .dist import *
10
  from .ema import *
11
- from .logger import setup_logger
12
  from .lr_scheduler import LRScheduler
13
  from .metric import *
14
  from .model_utils import *
 
8
  from .demo_utils import *
9
  from .dist import *
10
  from .ema import *
11
+ from .logger import WandbLogger, setup_logger
12
  from .lr_scheduler import LRScheduler
13
  from .metric import *
14
  from .model_utils import *
yolox/utils/logger.py CHANGED
@@ -7,11 +7,14 @@ import os
7
  import sys
8
  from loguru import logger
9
 
 
 
10
 
11
  def get_caller_name(depth=0):
12
  """
13
  Args:
14
- depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
 
15
 
16
  Returns:
17
  str: module name of the caller
@@ -93,3 +96,122 @@ def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
93
 
94
  # redirect stdout/stderr to loguru
95
  redirect_sys_output("INFO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import sys
8
  from loguru import logger
9
 
10
+ import torch
11
+
12
 
13
  def get_caller_name(depth=0):
14
  """
15
  Args:
16
+ depth (int): Depth of caller conext, use 0 for caller depth.
17
+ Default value: 0.
18
 
19
  Returns:
20
  str: module name of the caller
 
96
 
97
  # redirect stdout/stderr to loguru
98
  redirect_sys_output("INFO")
99
+
100
+
101
+ class WandbLogger(object):
102
+ """
103
+ Log training runs, datasets, models, and predictions to Weights & Biases.
104
+ This logger sends information to W&B at wandb.ai.
105
+ By default, this information includes hyperparameters,
106
+ system configuration and metrics, model metrics,
107
+ and basic data metrics and analyses.
108
+
109
+ For more information, please refer to:
110
+ https://docs.wandb.ai/guides/track
111
+ """
112
+ def __init__(self,
113
+ project=None,
114
+ name=None,
115
+ id=None,
116
+ entity=None,
117
+ save_dir=None,
118
+ config=None,
119
+ **kwargs):
120
+ """
121
+ Args:
122
+ project (str): wandb project name.
123
+ name (str): wandb run name.
124
+ id (str): wandb run id.
125
+ entity (str): wandb entity name.
126
+ save_dir (str): save directory.
127
+ config (dict): config dict.
128
+ **kwargs: other kwargs.
129
+ """
130
+ try:
131
+ import wandb
132
+ self.wandb = wandb
133
+ except ModuleNotFoundError:
134
+ raise ModuleNotFoundError(
135
+ "wandb is not installed."
136
+ "Please install wandb using pip install wandb"
137
+ )
138
+
139
+ self.project = project
140
+ self.name = name
141
+ self.id = id
142
+ self.save_dir = save_dir
143
+ self.config = config
144
+ self.kwargs = kwargs
145
+ self.entity = entity
146
+ self._run = None
147
+ self._wandb_init = dict(
148
+ project=self.project,
149
+ name=self.name,
150
+ id=self.id,
151
+ entity=self.entity,
152
+ dir=self.save_dir,
153
+ resume="allow"
154
+ )
155
+ self._wandb_init.update(**kwargs)
156
+
157
+ _ = self.run
158
+
159
+ if self.config:
160
+ self.run.config.update(self.config)
161
+ self.run.define_metric("epoch")
162
+ self.run.define_metric("val/", step_metric="epoch")
163
+
164
+ @property
165
+ def run(self):
166
+ if self._run is None:
167
+ if self.wandb.run is not None:
168
+ logger.info(
169
+ "There is a wandb run already in progress "
170
+ "and newly created instances of `WandbLogger` will reuse"
171
+ " this run. If this is not desired, call `wandb.finish()`"
172
+ "before instantiating `WandbLogger`."
173
+ )
174
+ self._run = self.wandb.run
175
+ else:
176
+ self._run = self.wandb.init(**self._wandb_init)
177
+ return self._run
178
+
179
+ def log_metrics(self, metrics, step=None):
180
+ """
181
+ Args:
182
+ metrics (dict): metrics dict.
183
+ step (int): step number.
184
+ """
185
+
186
+ for k, v in metrics.items():
187
+ if isinstance(v, torch.Tensor):
188
+ metrics[k] = v.item()
189
+
190
+ if step is not None:
191
+ self.run.log(metrics, step=step)
192
+ else:
193
+ self.run.log(metrics)
194
+
195
+ def save_checkpoint(self, save_dir, model_name, is_best):
196
+ """
197
+ Args:
198
+ save_dir (str): save directory.
199
+ model_name (str): model name.
200
+ is_best (bool): whether the model is the best model.
201
+ """
202
+ filename = os.path.join(save_dir, model_name + "_ckpt.pth")
203
+ artifact = self.wandb.Artifact(
204
+ name=f"model-{self.run.id}",
205
+ type="model"
206
+ )
207
+ artifact.add_file(filename, name="model_ckpt.pth")
208
+
209
+ aliases = ["latest"]
210
+
211
+ if is_best:
212
+ aliases.append("best")
213
+
214
+ self.run.log_artifact(artifact, aliases=aliases)
215
+
216
+ def finish(self):
217
+ self.run.finish()