Manan Goel
commited on
Commit
·
261bf27
1
Parent(s):
2600527
chore(logger): log predictions during training to wandb tables (#1181)
Browse files- README.md +13 -0
- docs/quick_run.md +13 -0
- tools/train.py +2 -1
- yolox/core/trainer.py +29 -14
- yolox/data/datasets/coco.py +2 -2
- yolox/evaluators/coco_evaluator.py +31 -4
- yolox/exp/yolox_base.py +2 -2
- yolox/utils/logger.py +165 -6
README.md
CHANGED
@@ -150,6 +150,19 @@ On the second machine, run
|
|
150 |
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
|
151 |
```
|
152 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
**Others**
|
154 |
See more information with the following command:
|
155 |
```shell
|
|
|
150 |
python tools/train.py -n yolox-s -b 128 --dist-url tcp://123.123.123.123:12312 --num_machines 2 --machine_rank 1
|
151 |
```
|
152 |
|
153 |
+
**Logging to Weights & Biases**
|
154 |
+
|
155 |
+
To log metrics, predictions and model checkpoints to [W&B](https://docs.wandb.ai/guides/integrations/other/yolox) use the command line argument `--logger wandb` and use the prefix "wandb-" to specify arguments for initializing the wandb run.
|
156 |
+
|
157 |
+
```shell
|
158 |
+
python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb wandb-project <project name>
|
159 |
+
yolox-m
|
160 |
+
yolox-l
|
161 |
+
yolox-x
|
162 |
+
```
|
163 |
+
|
164 |
+
An example wandb dashboard is available [here](https://wandb.ai/manan-goel/yolox-nano/runs/3pzfeom0)
|
165 |
+
|
166 |
**Others**
|
167 |
See more information with the following command:
|
168 |
```shell
|
docs/quick_run.md
CHANGED
@@ -76,6 +76,19 @@ python tools/train.py -n yolox-s -d 8 -b 64 --fp16 -o [--cache] --logger wandb w
|
|
76 |
yolox-x
|
77 |
```
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
**Multi Machine Training**
|
80 |
|
81 |
We also support multi-nodes training. Just add the following args:
|
|
|
76 |
yolox-x
|
77 |
```
|
78 |
|
79 |
+
More WandbLogger arguments include
|
80 |
+
|
81 |
+
```shell
|
82 |
+
python tools/train.py .... --logger wandb wandb-project <project-name> \
|
83 |
+
wandb-name <run-name> \
|
84 |
+
wandb-id <run-id> \
|
85 |
+
wandb-save_dir <save-dir> \
|
86 |
+
wandb-num_eval_images <num-images> \
|
87 |
+
wandb-log_checkpoints <bool>
|
88 |
+
```
|
89 |
+
|
90 |
+
More information available [here](https://docs.wandb.ai/guides/integrations/other/yolox).
|
91 |
+
|
92 |
**Multi Machine Training**
|
93 |
|
94 |
We also support multi-nodes training. Just add the following args:
|
tools/train.py
CHANGED
@@ -84,7 +84,8 @@ def make_parser():
|
|
84 |
"-l",
|
85 |
"--logger",
|
86 |
type=str,
|
87 |
-
help="Logger to be used for metrics
|
|
|
88 |
default="tensorboard"
|
89 |
)
|
90 |
parser.add_argument(
|
|
|
84 |
"-l",
|
85 |
"--logger",
|
86 |
type=str,
|
87 |
+
help="Logger to be used for metrics. \
|
88 |
+
Implemented loggers include `tensorboard` and `wandb`.",
|
89 |
default="tensorboard"
|
90 |
)
|
91 |
parser.add_argument(
|
yolox/core/trainer.py
CHANGED
@@ -180,11 +180,11 @@ class Trainer:
|
|
180 |
if self.args.logger == "tensorboard":
|
181 |
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
|
182 |
elif self.args.logger == "wandb":
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
else:
|
189 |
raise ValueError("logger must be either 'tensorboard' or 'wandb'")
|
190 |
|
@@ -263,8 +263,11 @@ class Trainer:
|
|
263 |
|
264 |
if self.rank == 0:
|
265 |
if self.args.logger == "wandb":
|
266 |
-
|
267 |
-
|
|
|
|
|
|
|
268 |
|
269 |
self.meter.clear_meters()
|
270 |
|
@@ -322,8 +325,8 @@ class Trainer:
|
|
322 |
evalmodel = evalmodel.module
|
323 |
|
324 |
with adjust_status(evalmodel, training=False):
|
325 |
-
ap50_95, ap50, summary = self.exp.eval(
|
326 |
-
evalmodel, self.evaluator, self.is_distributed
|
327 |
)
|
328 |
|
329 |
update_best_ckpt = ap50_95 > self.best_ap
|
@@ -337,16 +340,17 @@ class Trainer:
|
|
337 |
self.wandb_logger.log_metrics({
|
338 |
"val/COCOAP50": ap50,
|
339 |
"val/COCOAP50_95": ap50_95,
|
340 |
-
"epoch": self.epoch + 1,
|
341 |
})
|
|
|
342 |
logger.info("\n" + summary)
|
343 |
synchronize()
|
344 |
|
345 |
-
self.save_ckpt("last_epoch", update_best_ckpt)
|
346 |
if self.save_history_ckpt:
|
347 |
-
self.save_ckpt(f"epoch_{self.epoch + 1}")
|
348 |
|
349 |
-
def save_ckpt(self, ckpt_name, update_best_ckpt=False):
|
350 |
if self.rank == 0:
|
351 |
save_model = self.ema_model.ema if self.use_model_ema else self.model
|
352 |
logger.info("Save weights to {}".format(self.file_name))
|
@@ -355,6 +359,7 @@ class Trainer:
|
|
355 |
"model": save_model.state_dict(),
|
356 |
"optimizer": self.optimizer.state_dict(),
|
357 |
"best_ap": self.best_ap,
|
|
|
358 |
}
|
359 |
save_checkpoint(
|
360 |
ckpt_state,
|
@@ -364,4 +369,14 @@ class Trainer:
|
|
364 |
)
|
365 |
|
366 |
if self.args.logger == "wandb":
|
367 |
-
self.wandb_logger.save_checkpoint(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
if self.args.logger == "tensorboard":
|
181 |
self.tblogger = SummaryWriter(os.path.join(self.file_name, "tensorboard"))
|
182 |
elif self.args.logger == "wandb":
|
183 |
+
self.wandb_logger = WandbLogger.initialize_wandb_logger(
|
184 |
+
self.args,
|
185 |
+
self.exp,
|
186 |
+
self.evaluator.dataloader.dataset
|
187 |
+
)
|
188 |
else:
|
189 |
raise ValueError("logger must be either 'tensorboard' or 'wandb'")
|
190 |
|
|
|
263 |
|
264 |
if self.rank == 0:
|
265 |
if self.args.logger == "wandb":
|
266 |
+
metrics = {"train/" + k: v.latest for k, v in loss_meter.items()}
|
267 |
+
metrics.update({
|
268 |
+
"train/lr": self.meter["lr"].latest
|
269 |
+
})
|
270 |
+
self.wandb_logger.log_metrics(metrics, step=self.progress_in_iter)
|
271 |
|
272 |
self.meter.clear_meters()
|
273 |
|
|
|
325 |
evalmodel = evalmodel.module
|
326 |
|
327 |
with adjust_status(evalmodel, training=False):
|
328 |
+
(ap50_95, ap50, summary), predictions = self.exp.eval(
|
329 |
+
evalmodel, self.evaluator, self.is_distributed, return_outputs=True
|
330 |
)
|
331 |
|
332 |
update_best_ckpt = ap50_95 > self.best_ap
|
|
|
340 |
self.wandb_logger.log_metrics({
|
341 |
"val/COCOAP50": ap50,
|
342 |
"val/COCOAP50_95": ap50_95,
|
343 |
+
"train/epoch": self.epoch + 1,
|
344 |
})
|
345 |
+
self.wandb_logger.log_images(predictions)
|
346 |
logger.info("\n" + summary)
|
347 |
synchronize()
|
348 |
|
349 |
+
self.save_ckpt("last_epoch", update_best_ckpt, ap=ap50_95)
|
350 |
if self.save_history_ckpt:
|
351 |
+
self.save_ckpt(f"epoch_{self.epoch + 1}", ap=ap50_95)
|
352 |
|
353 |
+
def save_ckpt(self, ckpt_name, update_best_ckpt=False, ap=None):
|
354 |
if self.rank == 0:
|
355 |
save_model = self.ema_model.ema if self.use_model_ema else self.model
|
356 |
logger.info("Save weights to {}".format(self.file_name))
|
|
|
359 |
"model": save_model.state_dict(),
|
360 |
"optimizer": self.optimizer.state_dict(),
|
361 |
"best_ap": self.best_ap,
|
362 |
+
"curr_ap": ap,
|
363 |
}
|
364 |
save_checkpoint(
|
365 |
ckpt_state,
|
|
|
369 |
)
|
370 |
|
371 |
if self.args.logger == "wandb":
|
372 |
+
self.wandb_logger.save_checkpoint(
|
373 |
+
self.file_name,
|
374 |
+
ckpt_name,
|
375 |
+
update_best_ckpt,
|
376 |
+
metadata={
|
377 |
+
"epoch": self.epoch + 1,
|
378 |
+
"optimizer": self.optimizer.state_dict(),
|
379 |
+
"best_ap": self.best_ap,
|
380 |
+
"curr_ap": ap
|
381 |
+
}
|
382 |
+
)
|
yolox/data/datasets/coco.py
CHANGED
@@ -65,8 +65,8 @@ class COCODataset(Dataset):
|
|
65 |
remove_useless_info(self.coco)
|
66 |
self.ids = self.coco.getImgIds()
|
67 |
self.class_ids = sorted(self.coco.getCatIds())
|
68 |
-
cats = self.coco.loadCats(self.coco.getCatIds())
|
69 |
-
self._classes = tuple([c["name"] for c in cats])
|
70 |
self.imgs = None
|
71 |
self.name = name
|
72 |
self.img_size = img_size
|
|
|
65 |
remove_useless_info(self.coco)
|
66 |
self.ids = self.coco.getImgIds()
|
67 |
self.class_ids = sorted(self.coco.getCatIds())
|
68 |
+
self.cats = self.coco.loadCats(self.coco.getCatIds())
|
69 |
+
self._classes = tuple([c["name"] for c in self.cats])
|
70 |
self.imgs = None
|
71 |
self.name = name
|
72 |
self.img_size = img_size
|
yolox/evaluators/coco_evaluator.py
CHANGED
@@ -8,6 +8,7 @@ import itertools
|
|
8 |
import json
|
9 |
import tempfile
|
10 |
import time
|
|
|
11 |
from loguru import logger
|
12 |
from tabulate import tabulate
|
13 |
from tqdm import tqdm
|
@@ -120,6 +121,7 @@ class COCOEvaluator:
|
|
120 |
trt_file=None,
|
121 |
decoder=None,
|
122 |
test_size=None,
|
|
|
123 |
):
|
124 |
"""
|
125 |
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
|
@@ -142,6 +144,7 @@ class COCOEvaluator:
|
|
142 |
model = model.half()
|
143 |
ids = []
|
144 |
data_list = []
|
|
|
145 |
progress_bar = tqdm if is_main_process() else iter
|
146 |
|
147 |
inference_time = 0
|
@@ -184,20 +187,29 @@ class COCOEvaluator:
|
|
184 |
nms_end = time_synchronized()
|
185 |
nms_time += nms_end - infer_end
|
186 |
|
187 |
-
|
|
|
|
|
|
|
188 |
|
189 |
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
|
190 |
if distributed:
|
191 |
data_list = gather(data_list, dst=0)
|
|
|
192 |
data_list = list(itertools.chain(*data_list))
|
|
|
193 |
torch.distributed.reduce(statistics, dst=0)
|
194 |
|
195 |
eval_results = self.evaluate_prediction(data_list, statistics)
|
196 |
synchronize()
|
|
|
|
|
|
|
197 |
return eval_results
|
198 |
|
199 |
-
def convert_to_coco_format(self, outputs, info_imgs, ids):
|
200 |
data_list = []
|
|
|
201 |
for (output, img_h, img_w, img_id) in zip(
|
202 |
outputs, info_imgs[0], info_imgs[1], ids
|
203 |
):
|
@@ -212,10 +224,22 @@ class COCOEvaluator:
|
|
212 |
self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
|
213 |
)
|
214 |
bboxes /= scale
|
215 |
-
bboxes = xyxy2xywh(bboxes)
|
216 |
-
|
217 |
cls = output[:, 6]
|
218 |
scores = output[:, 4] * output[:, 5]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
for ind in range(bboxes.shape[0]):
|
220 |
label = self.dataloader.dataset.class_ids[int(cls[ind])]
|
221 |
pred_data = {
|
@@ -226,6 +250,9 @@ class COCOEvaluator:
|
|
226 |
"segmentation": [],
|
227 |
} # COCO json format
|
228 |
data_list.append(pred_data)
|
|
|
|
|
|
|
229 |
return data_list
|
230 |
|
231 |
def evaluate_prediction(self, data_dict, statistics):
|
|
|
8 |
import json
|
9 |
import tempfile
|
10 |
import time
|
11 |
+
from collections import ChainMap, defaultdict
|
12 |
from loguru import logger
|
13 |
from tabulate import tabulate
|
14 |
from tqdm import tqdm
|
|
|
121 |
trt_file=None,
|
122 |
decoder=None,
|
123 |
test_size=None,
|
124 |
+
return_outputs=False
|
125 |
):
|
126 |
"""
|
127 |
COCO average precision (AP) Evaluation. Iterate inference on the test dataset
|
|
|
144 |
model = model.half()
|
145 |
ids = []
|
146 |
data_list = []
|
147 |
+
output_data = defaultdict()
|
148 |
progress_bar = tqdm if is_main_process() else iter
|
149 |
|
150 |
inference_time = 0
|
|
|
187 |
nms_end = time_synchronized()
|
188 |
nms_time += nms_end - infer_end
|
189 |
|
190 |
+
data_list_elem, image_wise_data = self.convert_to_coco_format(
|
191 |
+
outputs, info_imgs, ids, return_outputs=True)
|
192 |
+
data_list.extend(data_list_elem)
|
193 |
+
output_data.update(image_wise_data)
|
194 |
|
195 |
statistics = torch.cuda.FloatTensor([inference_time, nms_time, n_samples])
|
196 |
if distributed:
|
197 |
data_list = gather(data_list, dst=0)
|
198 |
+
output_data = gather(output_data, dst=0)
|
199 |
data_list = list(itertools.chain(*data_list))
|
200 |
+
output_data = dict(ChainMap(*output_data))
|
201 |
torch.distributed.reduce(statistics, dst=0)
|
202 |
|
203 |
eval_results = self.evaluate_prediction(data_list, statistics)
|
204 |
synchronize()
|
205 |
+
|
206 |
+
if return_outputs:
|
207 |
+
return eval_results, output_data
|
208 |
return eval_results
|
209 |
|
210 |
+
def convert_to_coco_format(self, outputs, info_imgs, ids, return_outputs=False):
|
211 |
data_list = []
|
212 |
+
image_wise_data = defaultdict(dict)
|
213 |
for (output, img_h, img_w, img_id) in zip(
|
214 |
outputs, info_imgs[0], info_imgs[1], ids
|
215 |
):
|
|
|
224 |
self.img_size[0] / float(img_h), self.img_size[1] / float(img_w)
|
225 |
)
|
226 |
bboxes /= scale
|
|
|
|
|
227 |
cls = output[:, 6]
|
228 |
scores = output[:, 4] * output[:, 5]
|
229 |
+
|
230 |
+
image_wise_data.update({
|
231 |
+
int(img_id): {
|
232 |
+
"bboxes": [box.numpy().tolist() for box in bboxes],
|
233 |
+
"scores": [score.numpy().item() for score in scores],
|
234 |
+
"categories": [
|
235 |
+
self.dataloader.dataset.class_ids[int(cls[ind])]
|
236 |
+
for ind in range(bboxes.shape[0])
|
237 |
+
],
|
238 |
+
}
|
239 |
+
})
|
240 |
+
|
241 |
+
bboxes = xyxy2xywh(bboxes)
|
242 |
+
|
243 |
for ind in range(bboxes.shape[0]):
|
244 |
label = self.dataloader.dataset.class_ids[int(cls[ind])]
|
245 |
pred_data = {
|
|
|
250 |
"segmentation": [],
|
251 |
} # COCO json format
|
252 |
data_list.append(pred_data)
|
253 |
+
|
254 |
+
if return_outputs:
|
255 |
+
return data_list, image_wise_data
|
256 |
return data_list
|
257 |
|
258 |
def evaluate_prediction(self, data_dict, statistics):
|
yolox/exp/yolox_base.py
CHANGED
@@ -318,5 +318,5 @@ class Exp(BaseExp):
|
|
318 |
# NOTE: trainer shouldn't be an attribute of exp object
|
319 |
return trainer
|
320 |
|
321 |
-
def eval(self, model, evaluator, is_distributed, half=False):
|
322 |
-
return evaluator.evaluate(model, is_distributed, half)
|
|
|
318 |
# NOTE: trainer shouldn't be an attribute of exp object
|
319 |
return trainer
|
320 |
|
321 |
+
def eval(self, model, evaluator, is_distributed, half=False, return_outputs=False):
|
322 |
+
return evaluator.evaluate(model, is_distributed, half, return_outputs=return_outputs)
|
yolox/utils/logger.py
CHANGED
@@ -5,8 +5,12 @@
|
|
5 |
import inspect
|
6 |
import os
|
7 |
import sys
|
|
|
8 |
from loguru import logger
|
9 |
|
|
|
|
|
|
|
10 |
import torch
|
11 |
|
12 |
|
@@ -108,6 +112,7 @@ class WandbLogger(object):
|
|
108 |
|
109 |
For more information, please refer to:
|
110 |
https://docs.wandb.ai/guides/track
|
|
|
111 |
"""
|
112 |
def __init__(self,
|
113 |
project=None,
|
@@ -116,6 +121,9 @@ class WandbLogger(object):
|
|
116 |
entity=None,
|
117 |
save_dir=None,
|
118 |
config=None,
|
|
|
|
|
|
|
119 |
**kwargs):
|
120 |
"""
|
121 |
Args:
|
@@ -125,7 +133,24 @@ class WandbLogger(object):
|
|
125 |
entity (str): wandb entity name.
|
126 |
save_dir (str): save directory.
|
127 |
config (dict): config dict.
|
|
|
|
|
|
|
128 |
**kwargs: other kwargs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
"""
|
130 |
try:
|
131 |
import wandb
|
@@ -144,6 +169,12 @@ class WandbLogger(object):
|
|
144 |
self.kwargs = kwargs
|
145 |
self.entity = entity
|
146 |
self._run = None
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
self._wandb_init = dict(
|
148 |
project=self.project,
|
149 |
name=self.name,
|
@@ -158,8 +189,17 @@ class WandbLogger(object):
|
|
158 |
|
159 |
if self.config:
|
160 |
self.run.config.update(self.config)
|
161 |
-
self.run.define_metric("epoch")
|
162 |
-
self.run.define_metric("val
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
@property
|
165 |
def run(self):
|
@@ -176,6 +216,32 @@ class WandbLogger(object):
|
|
176 |
self._run = self.wandb.init(**self._wandb_init)
|
177 |
return self._run
|
178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
179 |
def log_metrics(self, metrics, step=None):
|
180 |
"""
|
181 |
Args:
|
@@ -188,21 +254,98 @@ class WandbLogger(object):
|
|
188 |
metrics[k] = v.item()
|
189 |
|
190 |
if step is not None:
|
191 |
-
|
|
|
192 |
else:
|
193 |
self.run.log(metrics)
|
194 |
|
195 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
"""
|
197 |
Args:
|
198 |
save_dir (str): save directory.
|
199 |
model_name (str): model name.
|
200 |
is_best (bool): whether the model is the best model.
|
|
|
201 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
|
203 |
artifact = self.wandb.Artifact(
|
204 |
-
name=f"
|
205 |
-
type="model"
|
|
|
206 |
)
|
207 |
artifact.add_file(filename, name="model_ckpt.pth")
|
208 |
|
@@ -211,7 +354,23 @@ class WandbLogger(object):
|
|
211 |
if is_best:
|
212 |
aliases.append("best")
|
213 |
|
|
|
|
|
|
|
214 |
self.run.log_artifact(artifact, aliases=aliases)
|
215 |
|
216 |
def finish(self):
|
217 |
self.run.finish()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import inspect
|
6 |
import os
|
7 |
import sys
|
8 |
+
from collections import defaultdict
|
9 |
from loguru import logger
|
10 |
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
import torch
|
15 |
|
16 |
|
|
|
112 |
|
113 |
For more information, please refer to:
|
114 |
https://docs.wandb.ai/guides/track
|
115 |
+
https://docs.wandb.ai/guides/integrations/other/yolox
|
116 |
"""
|
117 |
def __init__(self,
|
118 |
project=None,
|
|
|
121 |
entity=None,
|
122 |
save_dir=None,
|
123 |
config=None,
|
124 |
+
val_dataset=None,
|
125 |
+
num_eval_images=100,
|
126 |
+
log_checkpoints=False,
|
127 |
**kwargs):
|
128 |
"""
|
129 |
Args:
|
|
|
133 |
entity (str): wandb entity name.
|
134 |
save_dir (str): save directory.
|
135 |
config (dict): config dict.
|
136 |
+
val_dataset (Dataset): validation dataset.
|
137 |
+
num_eval_images (int): number of images from the validation set to log.
|
138 |
+
log_checkpoints (bool): log checkpoints
|
139 |
**kwargs: other kwargs.
|
140 |
+
|
141 |
+
Usage:
|
142 |
+
Any arguments for wandb.init can be provided on the command line using
|
143 |
+
the prefix `wandb-`.
|
144 |
+
Example
|
145 |
+
```
|
146 |
+
python tools/train.py .... --logger wandb wandb-project <project-name> \
|
147 |
+
wandb-name <run-name> \
|
148 |
+
wandb-id <run-id> \
|
149 |
+
wandb-save_dir <save-dir> \
|
150 |
+
wandb-num_eval_imges <num-images> \
|
151 |
+
wandb-log_checkpoints <bool>
|
152 |
+
```
|
153 |
+
The val_dataset argument is not open to the command line.
|
154 |
"""
|
155 |
try:
|
156 |
import wandb
|
|
|
169 |
self.kwargs = kwargs
|
170 |
self.entity = entity
|
171 |
self._run = None
|
172 |
+
self.val_artifact = None
|
173 |
+
if num_eval_images == -1:
|
174 |
+
self.num_log_images = len(val_dataset)
|
175 |
+
else:
|
176 |
+
self.num_log_images = min(num_eval_images, len(val_dataset))
|
177 |
+
self.log_checkpoints = (log_checkpoints == "True" or log_checkpoints == "true")
|
178 |
self._wandb_init = dict(
|
179 |
project=self.project,
|
180 |
name=self.name,
|
|
|
189 |
|
190 |
if self.config:
|
191 |
self.run.config.update(self.config)
|
192 |
+
self.run.define_metric("train/epoch")
|
193 |
+
self.run.define_metric("val/*", step_metric="train/epoch")
|
194 |
+
self.run.define_metric("train/step")
|
195 |
+
self.run.define_metric("train/*", step_metric="train/step")
|
196 |
+
|
197 |
+
if val_dataset and self.num_log_images != 0:
|
198 |
+
self.cats = val_dataset.cats
|
199 |
+
self.id_to_class = {
|
200 |
+
cls['id']: cls['name'] for cls in self.cats
|
201 |
+
}
|
202 |
+
self._log_validation_set(val_dataset)
|
203 |
|
204 |
@property
|
205 |
def run(self):
|
|
|
216 |
self._run = self.wandb.init(**self._wandb_init)
|
217 |
return self._run
|
218 |
|
219 |
+
def _log_validation_set(self, val_dataset):
|
220 |
+
"""
|
221 |
+
Log validation set to wandb.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
val_dataset (Dataset): validation dataset.
|
225 |
+
"""
|
226 |
+
if self.val_artifact is None:
|
227 |
+
self.val_artifact = self.wandb.Artifact(name="validation_images", type="dataset")
|
228 |
+
self.val_table = self.wandb.Table(columns=["id", "input"])
|
229 |
+
|
230 |
+
for i in range(self.num_log_images):
|
231 |
+
data_point = val_dataset[i]
|
232 |
+
img = data_point[0]
|
233 |
+
id = data_point[3]
|
234 |
+
img = np.transpose(img, (1, 2, 0))
|
235 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
236 |
+
self.val_table.add_data(
|
237 |
+
id.item(),
|
238 |
+
self.wandb.Image(img)
|
239 |
+
)
|
240 |
+
|
241 |
+
self.val_artifact.add(self.val_table, "validation_images_table")
|
242 |
+
self.run.use_artifact(self.val_artifact)
|
243 |
+
self.val_artifact.wait()
|
244 |
+
|
245 |
def log_metrics(self, metrics, step=None):
|
246 |
"""
|
247 |
Args:
|
|
|
254 |
metrics[k] = v.item()
|
255 |
|
256 |
if step is not None:
|
257 |
+
metrics.update({"train/step": step})
|
258 |
+
self.run.log(metrics)
|
259 |
else:
|
260 |
self.run.log(metrics)
|
261 |
|
262 |
+
def log_images(self, predictions):
|
263 |
+
if len(predictions) == 0 or self.val_artifact is None or self.num_log_images == 0:
|
264 |
+
return
|
265 |
+
|
266 |
+
table_ref = self.val_artifact.get("validation_images_table")
|
267 |
+
|
268 |
+
columns = ["id", "predicted"]
|
269 |
+
for cls in self.cats:
|
270 |
+
columns.append(cls["name"])
|
271 |
+
|
272 |
+
result_table = self.wandb.Table(columns=columns)
|
273 |
+
for idx, val in table_ref.iterrows():
|
274 |
+
|
275 |
+
avg_scores = defaultdict(int)
|
276 |
+
num_occurrences = defaultdict(int)
|
277 |
+
|
278 |
+
if val[0] in predictions:
|
279 |
+
prediction = predictions[val[0]]
|
280 |
+
boxes = []
|
281 |
+
|
282 |
+
for i in range(len(prediction["bboxes"])):
|
283 |
+
bbox = prediction["bboxes"][i]
|
284 |
+
x0 = bbox[0]
|
285 |
+
y0 = bbox[1]
|
286 |
+
x1 = bbox[2]
|
287 |
+
y1 = bbox[3]
|
288 |
+
box = {
|
289 |
+
"position": {
|
290 |
+
"minX": min(x0, x1),
|
291 |
+
"minY": min(y0, y1),
|
292 |
+
"maxX": max(x0, x1),
|
293 |
+
"maxY": max(y0, y1)
|
294 |
+
},
|
295 |
+
"class_id": prediction["categories"][i],
|
296 |
+
"domain": "pixel"
|
297 |
+
}
|
298 |
+
avg_scores[
|
299 |
+
self.id_to_class[prediction["categories"][i]]
|
300 |
+
] += prediction["scores"][i]
|
301 |
+
num_occurrences[self.id_to_class[prediction["categories"][i]]] += 1
|
302 |
+
boxes.append(box)
|
303 |
+
else:
|
304 |
+
boxes = []
|
305 |
+
|
306 |
+
average_class_score = []
|
307 |
+
for cls in self.cats:
|
308 |
+
if cls["name"] not in num_occurrences:
|
309 |
+
score = 0
|
310 |
+
else:
|
311 |
+
score = avg_scores[cls["name"]] / num_occurrences[cls["name"]]
|
312 |
+
average_class_score.append(score)
|
313 |
+
result_table.add_data(
|
314 |
+
idx,
|
315 |
+
self.wandb.Image(val[1], boxes={
|
316 |
+
"prediction": {
|
317 |
+
"box_data": boxes,
|
318 |
+
"class_labels": self.id_to_class
|
319 |
+
}
|
320 |
+
}
|
321 |
+
),
|
322 |
+
*average_class_score
|
323 |
+
)
|
324 |
+
|
325 |
+
self.wandb.log({"val_results/result_table": result_table})
|
326 |
+
|
327 |
+
def save_checkpoint(self, save_dir, model_name, is_best, metadata=None):
|
328 |
"""
|
329 |
Args:
|
330 |
save_dir (str): save directory.
|
331 |
model_name (str): model name.
|
332 |
is_best (bool): whether the model is the best model.
|
333 |
+
metadata (dict): metadata to save corresponding to the checkpoint.
|
334 |
"""
|
335 |
+
|
336 |
+
if not self.log_checkpoints:
|
337 |
+
return
|
338 |
+
|
339 |
+
if "epoch" in metadata:
|
340 |
+
epoch = metadata["epoch"]
|
341 |
+
else:
|
342 |
+
epoch = None
|
343 |
+
|
344 |
filename = os.path.join(save_dir, model_name + "_ckpt.pth")
|
345 |
artifact = self.wandb.Artifact(
|
346 |
+
name=f"run_{self.run.id}_model",
|
347 |
+
type="model",
|
348 |
+
metadata=metadata
|
349 |
)
|
350 |
artifact.add_file(filename, name="model_ckpt.pth")
|
351 |
|
|
|
354 |
if is_best:
|
355 |
aliases.append("best")
|
356 |
|
357 |
+
if epoch:
|
358 |
+
aliases.append(f"epoch-{epoch}")
|
359 |
+
|
360 |
self.run.log_artifact(artifact, aliases=aliases)
|
361 |
|
362 |
def finish(self):
|
363 |
self.run.finish()
|
364 |
+
|
365 |
+
@classmethod
|
366 |
+
def initialize_wandb_logger(cls, args, exp, val_dataset):
|
367 |
+
wandb_params = dict()
|
368 |
+
prefix = "wandb-"
|
369 |
+
for k, v in zip(args.opts[0::2], args.opts[1::2]):
|
370 |
+
if k.startswith("wandb-"):
|
371 |
+
try:
|
372 |
+
wandb_params.update({k[len(prefix):]: int(v)})
|
373 |
+
except ValueError:
|
374 |
+
wandb_params.update({k[len(prefix):]: v})
|
375 |
+
|
376 |
+
return cls(config=vars(exp), val_dataset=val_dataset, **wandb_params)
|