Manan Goel commited on
Commit
261bf27
·
1 Parent(s): 2600527

chore(logger): log predictions during training to wandb tables (#1181)

Browse files
README.md CHANGED
@@ -150,6 +150,19 @@ On the second machine, run
150
  python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
151
  ```
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  **Others**
154
  See more information with the following command:
155
  ```shell
 
150
  python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
151
  ```
152
 
153
+ **Logging to Weights & Biases**
154
+
155
+ To log metrics, predictions and model checkpoints to [W&B](https://docs.wandb.ai/guides/integrations/other/yolox) use the command line argument `--logger wandb` and use the prefix "wandb-" to specify arguments for initializing the wandb run.
156
+
157
+ ```shell
158
+ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
159
+ yolox-m
160
+ yolox-l
161
+ yolox-x
162
+ ```
163
+
164
+ An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
165
+
166
  **Others**
167
  See more information with the following command:
168
  ```shell
docs/quick_run.md CHANGED
@@ -76,6 +76,19 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
76
  yolox-x
77
  ```
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  **Multi Machine Training**
80
 
81
  We also support multi-nodes training. Just add the following args:
 
76
  yolox-x
77
  ```
78
 
79
+ More WandbLogger arguments include
80
+
81
+ ```shell
82
+ python tools/train.py .... --logger wandb wandb-project <project-name> \
83
+ wandb-name <run-name> \
84
+ wandb-id <run-id> \
85
+ wandb-save_dir <save-dir> \
86
+ wandb-num_eval_images <num-images> \
87
+ wandb-log_checkpoints <bool>
88
+ ```
89
+
90
+ More information available [here](https://docs.wandb.ai/guides/integrations/other/yolox).
91
+
92
  **Multi Machine Training**
93
 
94
  We also support multi-nodes training. Just add the following args:
tools/train.py CHANGED
@@ -84,7 +84,8 @@ def make_parser():
84
  "-l",
85
  "--logger",
86
  type=str,
87
- help="Logger to be used for metrics",
 
88
  default="tensorboard"
89
  )
90
  parser.add_argument(
 
84
  "-l",
85
  "--logger",
86
  type=str,
87
+ help="Logger to be used for metrics. \
88
+ Implemented loggers include `tensorboard` and `wandb`.",
89
  default="tensorboard"
90
  )
91
  parser.add_argument(
yolox/core/trainer.py CHANGED
@@ -180,11 +180,11 @@ class Trainer:
180
  if self.args.logger == "tensorboard":
181
  self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
182
  elif self.args.logger == "wandb":
183
- wandb_params = dict()
184
- for k, v in zip(self.args.opts[0::2], self.args.opts[1::2]):
185
- if k.startswith("wandb-"):
186
- wandb_params.update({k[len("wandb-"):]: v})
187
- self.wandb_logger = WandbLogger(config=vars(self.exp), **wandb_params)
188
  else:
189
  raise ValueError("logger must be either 'tensorboard' or 'wandb'")
190
 
@@ -263,8 +263,11 @@ class Trainer:
263
 
264
  if self.rank == 0:
265
  if self.args.logger == "wandb":
266
- self.wandb_logger.log_metrics({k: v.latest for k, v in loss_meter.items()})
267
- self.wandb_logger.log_metrics({"lr": self.meter["lr"].latest})
 
 
 
268
 
269
  self.meter.clear_meters()
270
 
@@ -322,8 +325,8 @@ class Trainer:
322
  evalmodel = evalmodel.module
323
 
324
  with adjust_status(evalmodel, training=False):
325
- ap50_95, ap50, summary = self.exp.eval(
326
- evalmodel, self.evaluator, self.is_distributed
327
  )
328
 
329
  update_best_ckpt = ap50_95 > self.best_ap
@@ -337,16 +340,17 @@ class Trainer:
337
  self.wandb_logger.log_metrics({
338
  "val/COCOAP50": ap50,
339
  "val/COCOAP50_95": ap50_95,
340
- "epoch": self.epoch + 1,
341
  })
 
342
  logger.info("\n" + summary)
343
  synchronize()
344
 
345
- self.save_ckpt("last_epoch", update_best_ckpt)
346
  if self.save_history_ckpt:
347
- self.save_ckpt(f"epoch_{self.epoch + 1}")
348
 
349
- def save_ckpt(self, ckpt_name, update_best_ckpt=False):
350
  if self.rank == 0:
351
  save_model = self.ema_model.ema if self.use_model_ema else self.model
352
  logger.info("Save weights to {}".format(self.file_name))
@@ -355,6 +359,7 @@ class Trainer:
355
  "model": save_model.state_dict(),
356
  "optimizer": self.optimizer.state_dict(),
357
  "best_ap": self.best_ap,
 
358
  }
359
  save_checkpoint(
360
  ckpt_state,
@@ -364,4 +369,14 @@ class Trainer:
364
  )
365
 
366
  if self.args.logger == "wandb":
367
- self.wandb_logger.save_checkpoint(self.file_name, ckpt_name, update_best_ckpt)
 
 
 
 
 
 
 
 
 
 
 
180
  if self.args.logger == "tensorboard":
181
  self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
182
  elif self.args.logger == "wandb":
183
+ self.wandb_logger = WandbLogger.initialize_wandb_logger(
184
+ self.args,
185
+ self.exp,
186
+ self.evaluator.dataloader.dataset
187
+ )
188
  else:
189
  raise ValueError("logger must be either 'tensorboard' or 'wandb'")
190
 
 
263
 
264
  if self.rank == 0:
265
  if self.args.logger == "wandb":
266
+ metrics = {"train/" + k: v.latest for k, v in loss_meter.items()}
267
+ metrics.update({
268
+ "train/lr": self.meter["lr"].latest
269
+ })
270
+ self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter)
271
 
272
  self.meter.clear_meters()
273
 
 
325
  evalmodel = evalmodel.module
326
 
327
  with adjust_status(evalmodel, training=False):
328
+ (ap50_95, ap50, summary), predictions = self.exp.eval(
329
+ evalmodel, self.evaluator, self.is_distributed, return_outputs=True
330
  )
331
 
332
  update_best_ckpt = ap50_95 > self.best_ap
 
340
  self.wandb_logger.log_metrics({
341
  "val/COCOAP50": ap50,
342
  "val/COCOAP50_95": ap50_95,
343
+ "train/epoch": self.epoch + 1,
344
  })
345
+ self.wandb_logger.log_images(predictions)
346
  logger.info("\n" + summary)
347
  synchronize()
348
 
349
+ self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95)
350
  if self.save_history_ckpt:
351
+ self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95)
352
 
353
+ def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
354
  if self.rank == 0:
355
  save_model = self.ema_model.ema if self.use_model_ema else self.model
356
  logger.info("Save weights to {}".format(self.file_name))
 
359
  "model": save_model.state_dict(),
360
  "optimizer": self.optimizer.state_dict(),
361
  "best_ap": self.best_ap,
362
+ "curr_ap": ap,
363
  }
364
  save_checkpoint(
365
  ckpt_state,
 
369
  )
370
 
371
  if self.args.logger == "wandb":
372
+ self.wandb_logger.save_checkpoint(
373
+ self.file_name,
374
+ ckpt_name,
375
+ update_best_ckpt,
376
+ metadata={
377
+ "epoch": self.epoch + 1,
378
+ "optimizer": self.optimizer.state_dict(),
379
+ "best_ap": self.best_ap,
380
+ "curr_ap": ap
381
+ }
382
+ )
yolox/data/datasets/coco.py CHANGED
@@ -65,8 +65,8 @@ class COCODataset(Dataset):
65
  remove_useless_info(self.coco)
66
  self.ids = self.coco.getImgIds()
67
  self.class_ids = sorted(self.coco.getCatIds())
68
- cats = self.coco.loadCats(self.coco.getCatIds())
69
- self._classes = tuple([c["name"] for c in cats])
70
  self.imgs = None
71
  self.name = name
72
  self.img_size = img_size
 
65
  remove_useless_info(self.coco)
66
  self.ids = self.coco.getImgIds()
67
  self.class_ids = sorted(self.coco.getCatIds())
68
+ self.cats = self.coco.loadCats(self.coco.getCatIds())
69
+ self._classes = tuple([c["name"] for c in self.cats])
70
  self.imgs = None
71
  self.name = name
72
  self.img_size = img_size
yolox/evaluators/coco_evaluator.py CHANGED
@@ -8,6 +8,7 @@ import itertools
8
  import json
9
  import tempfile
10
  import time
 
11
  from loguru import logger
12
  from tabulate import tabulate
13
  from tqdm import tqdm
@@ -120,6 +121,7 @@ class COCOEvaluator:
120
  trt_file=None,
121
  decoder=None,
122
  test_size=None,
 
123
  ):
124
  """
125
  COCO average precision (AP) Evaluation. Iterate inference on the test dataset
@@ -142,6 +144,7 @@ class COCOEvaluator:
142
  model = model.half()
143
  ids = []
144
  data_list = []
 
145
  progress_bar = tqdm if is_main_process() else iter
146
 
147
  inference_time = 0
@@ -184,20 +187,29 @@ class COCOEvaluator:
184
  nms_end = time_synchronized()
185
  nms_time += nms_end - infer_end
186
 
187
- data_list.extend(self.convert_to_coco_format(outputs, info_imgs, ids))
 
 
 
188
 
189
  statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
190
  if distributed:
191
  data_list = gather(data_list, dst=0)
 
192
  data_list = list(itertools.chain(*data_list))
 
193
  torch.distributed.reduce(statistics, dst=0)
194
 
195
  eval_results = self.evaluate_prediction(data_list, statistics)
196
  synchronize()
 
 
 
197
  return eval_results
198
 
199
- def convert_to_coco_format(self, outputs, info_imgs, ids):
200
  data_list = []
 
201
  for (output, img_h, img_w, img_id) in zip(
202
  outputs, info_imgs[0], info_imgs[1], ids
203
  ):
@@ -212,10 +224,22 @@ class COCOEvaluator:
212
  self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
213
  )
214
  bboxes /= scale
215
- bboxes = xyxy2xywh(bboxes)
216
-
217
  cls = output[:, 6]
218
  scores = output[:, 4] * output[:, 5]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  for ind in range(bboxes.shape[0]):
220
  label = self.dataloader.dataset.class_ids[int(cls[ind])]
221
  pred_data = {
@@ -226,6 +250,9 @@ class COCOEvaluator:
226
  "segmentation": [],
227
  } # COCO json format
228
  data_list.append(pred_data)
 
 
 
229
  return data_list
230
 
231
  def evaluate_prediction(self, data_dict, statistics):
 
8
  import json
9
  import tempfile
10
  import time
11
+ from collections import ChainMap, defaultdict
12
  from loguru import logger
13
  from tabulate import tabulate
14
  from tqdm import tqdm
 
121
  trt_file=None,
122
  decoder=None,
123
  test_size=None,
124
+ return_outputs=False
125
  ):
126
  """
127
  COCO average precision (AP) Evaluation. Iterate inference on the test dataset
 
144
  model = model.half()
145
  ids = []
146
  data_list = []
147
+ output_data = defaultdict()
148
  progress_bar = tqdm if is_main_process() else iter
149
 
150
  inference_time = 0
 
187
  nms_end = time_synchronized()
188
  nms_time += nms_end - infer_end
189
 
190
+ data_list_elem, image_wise_data = self.convert_to_coco_format(
191
+ outputs, info_imgs, ids, return_outputs=True)
192
+ data_list.extend(data_list_elem)
193
+ output_data.update(image_wise_data)
194
 
195
  statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
196
  if distributed:
197
  data_list = gather(data_list, dst=0)
198
+ output_data = gather(output_data, dst=0)
199
  data_list = list(itertools.chain(*data_list))
200
+ output_data = dict(ChainMap(*output_data))
201
  torch.distributed.reduce(statistics, dst=0)
202
 
203
  eval_results = self.evaluate_prediction(data_list, statistics)
204
  synchronize()
205
+
206
+ if return_outputs:
207
+ return eval_results, output_data
208
  return eval_results
209
 
210
+ def convert_to_coco_format(self, outputs, info_imgs, ids, return_outputs=False):
211
  data_list = []
212
+ image_wise_data = defaultdict(dict)
213
  for (output, img_h, img_w, img_id) in zip(
214
  outputs, info_imgs[0], info_imgs[1], ids
215
  ):
 
224
  self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
225
  )
226
  bboxes /= scale
 
 
227
  cls = output[:, 6]
228
  scores = output[:, 4] * output[:, 5]
229
+
230
+ image_wise_data.update({
231
+ int(img_id): {
232
+ "bboxes": [box.numpy().tolist() for box in bboxes],
233
+ "scores": [score.numpy().item() for score in scores],
234
+ "categories": [
235
+ self.dataloader.dataset.class_ids[int(cls[ind])]
236
+ for ind in range(bboxes.shape[0])
237
+ ],
238
+ }
239
+ })
240
+
241
+ bboxes = xyxy2xywh(bboxes)
242
+
243
  for ind in range(bboxes.shape[0]):
244
  label = self.dataloader.dataset.class_ids[int(cls[ind])]
245
  pred_data = {
 
250
  "segmentation": [],
251
  } # COCO json format
252
  data_list.append(pred_data)
253
+
254
+ if return_outputs:
255
+ return data_list, image_wise_data
256
  return data_list
257
 
258
  def evaluate_prediction(self, data_dict, statistics):
yolox/exp/yolox_base.py CHANGED
@@ -318,5 +318,5 @@ class Exp(BaseExp):
318
  # NOTE: trainer shouldn't be an attribute of exp object
319
  return trainer
320
 
321
- def eval(self, model, evaluator, is_distributed, half=False):
322
- return evaluator.evaluate(model, is_distributed, half)
 
318
  # NOTE: trainer shouldn't be an attribute of exp object
319
  return trainer
320
 
321
+ def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
322
+ return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
yolox/utils/logger.py CHANGED
@@ -5,8 +5,12 @@
5
  import inspect
6
  import os
7
  import sys
 
8
  from loguru import logger
9
 
 
 
 
10
  import torch
11
 
12
 
@@ -108,6 +112,7 @@ class WandbLogger(object):
108
 
109
  For more information, please refer to:
110
  https://docs.wandb.ai/guides/track
 
111
  """
112
  def __init__(self,
113
  project=None,
@@ -116,6 +121,9 @@ class WandbLogger(object):
116
  entity=None,
117
  save_dir=None,
118
  config=None,
 
 
 
119
  **kwargs):
120
  """
121
  Args:
@@ -125,7 +133,24 @@ class WandbLogger(object):
125
  entity (str): wandb entity name.
126
  save_dir (str): save directory.
127
  config (dict): config dict.
 
 
 
128
  **kwargs: other kwargs.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  """
130
  try:
131
  import wandb
@@ -144,6 +169,12 @@ class WandbLogger(object):
144
  self.kwargs = kwargs
145
  self.entity = entity
146
  self._run = None
 
 
 
 
 
 
147
  self._wandb_init = dict(
148
  project=self.project,
149
  name=self.name,
@@ -158,8 +189,17 @@ class WandbLogger(object):
158
 
159
  if self.config:
160
  self.run.config.update(self.config)
161
- self.run.define_metric("epoch")
162
- self.run.define_metric("val/", step_metric="epoch")
 
 
 
 
 
 
 
 
 
163
 
164
  @property
165
  def run(self):
@@ -176,6 +216,32 @@ class WandbLogger(object):
176
  self._run = self.wandb.init(**self._wandb_init)
177
  return self._run
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def log_metrics(self, metrics, step=None):
180
  """
181
  Args:
@@ -188,21 +254,98 @@ class WandbLogger(object):
188
  metrics[k] = v.item()
189
 
190
  if step is not None:
191
- self.run.log(metrics, step=step)
 
192
  else:
193
  self.run.log(metrics)
194
 
195
- def save_checkpoint(self, save_dir, model_name, is_best):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  """
197
  Args:
198
  save_dir (str): save directory.
199
  model_name (str): model name.
200
  is_best (bool): whether the model is the best model.
 
201
  """
 
 
 
 
 
 
 
 
 
202
  filename = os.path.join(save_dir, model_name + "_ckpt.pth")
203
  artifact = self.wandb.Artifact(
204
- name=f"model-{self.run.id}",
205
- type="model"
 
206
  )
207
  artifact.add_file(filename, name="model_ckpt.pth")
208
 
@@ -211,7 +354,23 @@ class WandbLogger(object):
211
  if is_best:
212
  aliases.append("best")
213
 
 
 
 
214
  self.run.log_artifact(artifact, aliases=aliases)
215
 
216
  def finish(self):
217
  self.run.finish()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import inspect
6
  import os
7
  import sys
8
+ from collections import defaultdict
9
  from loguru import logger
10
 
11
+ import cv2
12
+ import numpy as np
13
+
14
  import torch
15
 
16
 
 
112
 
113
  For more information, please refer to:
114
  https://docs.wandb.ai/guides/track
115
+ https://docs.wandb.ai/guides/integrations/other/yolox
116
  """
117
  def __init__(self,
118
  project=None,
 
121
  entity=None,
122
  save_dir=None,
123
  config=None,
124
+ val_dataset=None,
125
+ num_eval_images=100,
126
+ log_checkpoints=False,
127
  **kwargs):
128
  """
129
  Args:
 
133
  entity (str): wandb entity name.
134
  save_dir (str): save directory.
135
  config (dict): config dict.
136
+ val_dataset (Dataset): validation dataset.
137
+ num_eval_images (int): number of images from the validation set to log.
138
+ log_checkpoints (bool): log checkpoints
139
  **kwargs: other kwargs.
140
+
141
+ Usage:
142
+ Any arguments for wandb.init can be provided on the command line using
143
+ the prefix `wandb-`.
144
+ Example
145
+ ```
146
+ python tools/train.py .... --logger wandb wandb-project <project-name> \
147
+ wandb-name <run-name> \
148
+ wandb-id <run-id> \
149
+ wandb-save_dir <save-dir> \
150
+ wandb-num_eval_imges <num-images> \
151
+ wandb-log_checkpoints <bool>
152
+ ```
153
+ The val_dataset argument is not open to the command line.
154
  """
155
  try:
156
  import wandb
 
169
  self.kwargs = kwargs
170
  self.entity = entity
171
  self._run = None
172
+ self.val_artifact = None
173
+ if num_eval_images == -1:
174
+ self.num_log_images = len(val_dataset)
175
+ else:
176
+ self.num_log_images = min(num_eval_images, len(val_dataset))
177
+ self.log_checkpoints = (log_checkpoints == "True" or log_checkpoints == "true")
178
  self._wandb_init = dict(
179
  project=self.project,
180
  name=self.name,
 
189
 
190
  if self.config:
191
  self.run.config.update(self.config)
192
+ self.run.define_metric("train/epoch")
193
+ self.run.define_metric("val/*", step_metric="train/epoch")
194
+ self.run.define_metric("train/step")
195
+ self.run.define_metric("train/*", step_metric="train/step")
196
+
197
+ if val_dataset and self.num_log_images != 0:
198
+ self.cats = val_dataset.cats
199
+ self.id_to_class = {
200
+ cls['id']: cls['name'] for cls in self.cats
201
+ }
202
+ self._log_validation_set(val_dataset)
203
 
204
  @property
205
  def run(self):
 
216
  self._run = self.wandb.init(**self._wandb_init)
217
  return self._run
218
 
219
+ def _log_validation_set(self, val_dataset):
220
+ """
221
+ Log validation set to wandb.
222
+
223
+ Args:
224
+ val_dataset (Dataset): validation dataset.
225
+ """
226
+ if self.val_artifact is None:
227
+ self.val_artifact = self.wandb.Artifact(name="validation_images", type="dataset")
228
+ self.val_table = self.wandb.Table(columns=["id", "input"])
229
+
230
+ for i in range(self.num_log_images):
231
+ data_point = val_dataset[i]
232
+ img = data_point[0]
233
+ id = data_point[3]
234
+ img = np.transpose(img, (1, 2, 0))
235
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
236
+ self.val_table.add_data(
237
+ id.item(),
238
+ self.wandb.Image(img)
239
+ )
240
+
241
+ self.val_artifact.add(self.val_table, "validation_images_table")
242
+ self.run.use_artifact(self.val_artifact)
243
+ self.val_artifact.wait()
244
+
245
  def log_metrics(self, metrics, step=None):
246
  """
247
  Args:
 
254
  metrics[k] = v.item()
255
 
256
  if step is not None:
257
+ metrics.update({"train/step": step})
258
+ self.run.log(metrics)
259
  else:
260
  self.run.log(metrics)
261
 
262
+ def log_images(self, predictions):
263
+ if len(predictions) == 0 or self.val_artifact is None or self.num_log_images == 0:
264
+ return
265
+
266
+ table_ref = self.val_artifact.get("validation_images_table")
267
+
268
+ columns = ["id", "predicted"]
269
+ for cls in self.cats:
270
+ columns.append(cls["name"])
271
+
272
+ result_table = self.wandb.Table(columns=columns)
273
+ for idx, val in table_ref.iterrows():
274
+
275
+ avg_scores = defaultdict(int)
276
+ num_occurrences = defaultdict(int)
277
+
278
+ if val[0] in predictions:
279
+ prediction = predictions[val[0]]
280
+ boxes = []
281
+
282
+ for i in range(len(prediction["bboxes"])):
283
+ bbox = prediction["bboxes"][i]
284
+ x0 = bbox[0]
285
+ y0 = bbox[1]
286
+ x1 = bbox[2]
287
+ y1 = bbox[3]
288
+ box = {
289
+ "position": {
290
+ "minX": min(x0, x1),
291
+ "minY": min(y0, y1),
292
+ "maxX": max(x0, x1),
293
+ "maxY": max(y0, y1)
294
+ },
295
+ "class_id": prediction["categories"][i],
296
+ "domain": "pixel"
297
+ }
298
+ avg_scores[
299
+ self.id_to_class[prediction["categories"][i]]
300
+ ] += prediction["scores"][i]
301
+ num_occurrences[self.id_to_class[prediction["categories"][i]]] += 1
302
+ boxes.append(box)
303
+ else:
304
+ boxes = []
305
+
306
+ average_class_score = []
307
+ for cls in self.cats:
308
+ if cls["name"] not in num_occurrences:
309
+ score = 0
310
+ else:
311
+ score = avg_scores[cls["name"]] / num_occurrences[cls["name"]]
312
+ average_class_score.append(score)
313
+ result_table.add_data(
314
+ idx,
315
+ self.wandb.Image(val[1], boxes={
316
+ "prediction": {
317
+ "box_data": boxes,
318
+ "class_labels": self.id_to_class
319
+ }
320
+ }
321
+ ),
322
+ *average_class_score
323
+ )
324
+
325
+ self.wandb.log({"val_results/result_table": result_table})
326
+
327
+ def save_checkpoint(self, save_dir, model_name, is_best, metadata=None):
328
  """
329
  Args:
330
  save_dir (str): save directory.
331
  model_name (str): model name.
332
  is_best (bool): whether the model is the best model.
333
+ metadata (dict): metadata to save corresponding to the checkpoint.
334
  """
335
+
336
+ if not self.log_checkpoints:
337
+ return
338
+
339
+ if "epoch" in metadata:
340
+ epoch = metadata["epoch"]
341
+ else:
342
+ epoch = None
343
+
344
  filename = os.path.join(save_dir, model_name + "_ckpt.pth")
345
  artifact = self.wandb.Artifact(
346
+ name=f"run_{self.run.id}_model",
347
+ type="model",
348
+ metadata=metadata
349
  )
350
  artifact.add_file(filename, name="model_ckpt.pth")
351
 
 
354
  if is_best:
355
  aliases.append("best")
356
 
357
+ if epoch:
358
+ aliases.append(f"epoch-{epoch}")
359
+
360
  self.run.log_artifact(artifact, aliases=aliases)
361
 
362
  def finish(self):
363
  self.run.finish()
364
+
365
+ @classmethod
366
+ def initialize_wandb_logger(cls, args, exp, val_dataset):
367
+ wandb_params = dict()
368
+ prefix = "wandb-"
369
+ for k, v in zip(args.opts[0::2], args.opts[1::2]):
370
+ if k.startswith("wandb-"):
371
+ try:
372
+ wandb_params.update({k[len(prefix):]: int(v)})
373
+ except ValueError:
374
+ wandb_params.update({k[len(prefix):]: v})
375
+
376
+ return cls(config=vars(exp), val_dataset=val_dataset, **wandb_params)