Feng Wang commited on
Commit
8efa55f
Β·
1 Parent(s): 5c838b8

feat(tools): add assignment visualizer (#1616)

Browse files
README.md CHANGED
@@ -10,7 +10,8 @@ This repo is an implementation of PyTorch version YOLOX, there is also a [MegEng
10
  <img src="assets/git_fig.png" width="1000" >
11
 
12
  ## Updates!!
13
- * 【2022/04/14】 We suport jit compile op.
 
14
  * 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
15
  * 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
16
  * 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
@@ -206,6 +207,7 @@ python -m yolox.tools.eval -n yolox-s -c yolox_s.pth -b 1 -d 1 --conf 0.001 --f
206
  * [Training on custom data](docs/train_custom_data.md)
207
  * [Caching for custom data](docs/cache.md)
208
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
 
209
  * [Freezing model](docs/freeze_module.md)
210
 
211
  </details>
@@ -243,8 +245,8 @@ If you use YOLOX in your research, please cite our work by using the following B
243
  }
244
  ```
245
  ## In memory of Dr. Jian Sun
246
- Without the guidance of [Dr. Sun Jian](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
247
- The passing away of Dr. Sun Jian is a great loss to the Computer Vision field. We have added this section here to express our remembrance and condolences to our captain Dr. Sun.
248
  It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.
249
 
250
  <div align="center"><img src="assets/sunjian.png" width="200"></div>
 
10
  <img src="assets/git_fig.png" width="1000" >
11
 
12
  ## Updates!!
13
+ * 【2023/02/28】 We support assignment visualization tool, see doc [here](./docs/assignment_visualization.md).
14
+ * 【2022/04/14】 We support jit compile op.
15
  * 【2021/08/19】 We optimize the training process with **2x** faster training and **~1%** higher performance! See [notes](docs/updates_note.md) for more details.
16
  * 【2021/08/05】 We release [MegEngine version YOLOX](https://github.com/MegEngine/YOLOX).
17
  * 【2021/07/28】 We fix the fatal error of [memory leak](https://github.com/Megvii-BaseDetection/YOLOX/issues/103)
 
207
  * [Training on custom data](docs/train_custom_data.md)
208
  * [Caching for custom data](docs/cache.md)
209
  * [Manipulating training image size](docs/manipulate_training_image_size.md)
210
+ * [Assignment visualization](docs/assignment_visualization.md)
211
  * [Freezing model](docs/freeze_module.md)
212
 
213
  </details>
 
245
  }
246
  ```
247
  ## In memory of Dr. Jian Sun
248
+ Without the guidance of [Dr. Jian Sun](http://www.jiansun.org/), YOLOX would not have been released and open sourced to the community.
249
+ The passing away of Dr. Jian is a huge loss to the Computer Vision field. We add this section here to express our remembrance and condolences to our captain Dr. Jian.
250
  It is hoped that every AI practitioner in the world will stick to the concept of "continuous innovation to expand cognitive boundaries, and extraordinary technology to achieve product value" and move forward all the way.
251
 
252
  <div align="center"><img src="assets/sunjian.png" width="200"></div>
docs/assignment_visualization.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Visualize label assignment
2
+
3
+ This tutorial explains how to visualize your label asssignment result when training with YOLOX.
4
+
5
+ ## 1. Visualization command
6
+
7
+ We provide a visualization tool to help you visualize your label assignment result. You can find it in [`tools/visualize_assignment.py`](../tools/visualize_assign.py).
8
+
9
+ Here is an example of command to visualize your label assignment result:
10
+
11
+ ```shell
12
+ python3 tools/visualize_assign.py -f /path/to/your/exp.py yolox-s -d 1 -b 8 --max-batch 2
13
+ ```
14
+
15
+ `max-batch` here means the maximum number of batches to visualize. The default value is 1, which the tool means only visualize the first batch.
16
+
17
+ By the way, the mosaic augmentation is used in default dataloader, so you can also see the mosaic result here.
18
+
19
+ After running the command, the logger will show you where the visualization result is saved, let's open it and into the step 2.
20
+
21
+ ## 2. Check the visualization result
22
+
23
+ Here is an example of visualization result:
24
+ <div align="center"><img src="../assets/assignment.png" width="640"></div>
25
+
26
+ Those dots in one box is the matched anchor of gt box. **The color of dots is the same as the color of the box** to help you determine which object is assigned to the anchor. Note the box and dots are **instance level** visualization, which means the same class may have different colors.
27
+ **If the gt box doesn't match any anchor, the box will be marked as red and the red text "unmatched" will be drawn over the box**.
28
+
29
+ Please feel free to open an issue if you have any questions.
tools/visualize_assign.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Megvii, Inc. and its affiliates.
3
+
4
+ import os
5
+ import sys
6
+ import random
7
+ import time
8
+ import warnings
9
+ from loguru import logger
10
+
11
+ import torch
12
+ import torch.backends.cudnn as cudnn
13
+
14
+ from yolox.exp import Exp, get_exp
15
+ from yolox.core import Trainer
16
+ from yolox.utils import configure_module, configure_omp
17
+ from yolox.tools.train import make_parser
18
+
19
+
20
+ class AssignVisualizer(Trainer):
21
+
22
+ def __init__(self, exp: Exp, args):
23
+ super().__init__(exp, args)
24
+ self.batch_cnt = 0
25
+ self.vis_dir = os.path.join(self.file_name, "vis")
26
+ os.makedirs(self.vis_dir, exist_ok=True)
27
+
28
+ def train_one_iter(self):
29
+ iter_start_time = time.time()
30
+
31
+ inps, targets = self.prefetcher.next()
32
+ inps = inps.to(self.data_type)
33
+ targets = targets.to(self.data_type)
34
+ targets.requires_grad = False
35
+ inps, targets = self.exp.preprocess(inps, targets, self.input_size)
36
+ data_end_time = time.time()
37
+
38
+ with torch.cuda.amp.autocast(enabled=self.amp_training):
39
+ path_prefix = os.path.join(self.vis_dir, f"assign_vis_{self.batch_cnt}_")
40
+ self.model.visualize(inps, targets, path_prefix)
41
+
42
+ if self.use_model_ema:
43
+ self.ema_model.update(self.model)
44
+
45
+ iter_end_time = time.time()
46
+ self.meter.update(
47
+ iter_time=iter_end_time - iter_start_time,
48
+ data_time=data_end_time - iter_start_time,
49
+ )
50
+ self.batch_cnt += 1
51
+ if self.batch_cnt >= self.args.max_batch:
52
+ sys.exit(0)
53
+
54
+ def after_train(self):
55
+ logger.info("Finish visualize assignment, exit...")
56
+
57
+
58
+ def assign_vis_parser():
59
+ parser = make_parser()
60
+ parser.add_argument("--max-batch", type=int, default=1, help="max batch of images to visualize")
61
+ return parser
62
+
63
+
64
+ @logger.catch
65
+ def main(exp: Exp, args):
66
+ if exp.seed is not None:
67
+ random.seed(exp.seed)
68
+ torch.manual_seed(exp.seed)
69
+ cudnn.deterministic = True
70
+ warnings.warn(
71
+ "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
72
+ "which can slow down your training considerably! You may see unexpected behavior "
73
+ "when restarting from checkpoints."
74
+ )
75
+
76
+ # set environment variables for distributed training
77
+ configure_omp()
78
+ cudnn.benchmark = True
79
+
80
+ visualizer = AssignVisualizer(exp, args)
81
+ visualizer.train()
82
+
83
+
84
+ if __name__ == "__main__":
85
+ configure_module()
86
+ args = assign_vis_parser().parse_args()
87
+ exp = get_exp(args.exp_file, args.name)
88
+ exp.merge(args.opts)
89
+
90
+ if not args.experiment_name:
91
+ args.experiment_name = exp.exp_name
92
+
93
+ main(exp, args)
yolox/core/trainer.py CHANGED
@@ -1,5 +1,4 @@
1
  #!/usr/bin/env python3
2
- # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii, Inc. and its affiliates.
4
 
5
  import datetime
 
1
  #!/usr/bin/env python3
 
2
  # Copyright (c) Megvii, Inc. and its affiliates.
3
 
4
  import datetime
yolox/models/yolo_head.py CHANGED
@@ -9,7 +9,7 @@ import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
12
- from yolox.utils import bboxes_iou, meshgrid
13
 
14
  from .losses import IOUloss
15
  from .network_blocks import BaseConv, DWConv
@@ -511,11 +511,7 @@ class YOLOXHead(nn.Module):
511
  )
512
 
513
  def get_geometry_constraint(
514
- self,
515
- gt_bboxes_per_image,
516
- expanded_strides,
517
- x_shifts,
518
- y_shifts,
519
  ):
520
  """
521
  Calculate whether the center of an object is located in a fixed range of
@@ -546,8 +542,6 @@ class YOLOXHead(nn.Module):
546
  return anchor_filter, geometry_relation
547
 
548
  def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
549
- # Dynamic K
550
- # ---------------------------------------------------------------
551
  matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
552
 
553
  n_candidate_k = min(10, pair_wise_ious.size(1))
@@ -580,3 +574,68 @@ class YOLOXHead(nn.Module):
580
  fg_mask_inboxes
581
  ]
582
  return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
 
12
+ from yolox.utils import bboxes_iou, cxcywh2xyxy, meshgrid, visualize_assign
13
 
14
  from .losses import IOUloss
15
  from .network_blocks import BaseConv, DWConv
 
511
  )
512
 
513
  def get_geometry_constraint(
514
+ self, gt_bboxes_per_image, expanded_strides, x_shifts, y_shifts,
 
 
 
 
515
  ):
516
  """
517
  Calculate whether the center of an object is located in a fixed range of
 
542
  return anchor_filter, geometry_relation
543
 
544
  def simota_matching(self, cost, pair_wise_ious, gt_classes, num_gt, fg_mask):
 
 
545
  matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
546
 
547
  n_candidate_k = min(10, pair_wise_ious.size(1))
 
574
  fg_mask_inboxes
575
  ]
576
  return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
577
+
578
+ def visualize_assign_result(self, xin, labels=None, imgs=None, save_prefix="assign_vis_"):
579
+ # original forward logic
580
+ outputs, x_shifts, y_shifts, expanded_strides = [], [], [], []
581
+ # TODO: use forward logic here.
582
+
583
+ for k, (cls_conv, reg_conv, stride_this_level, x) in enumerate(
584
+ zip(self.cls_convs, self.reg_convs, self.strides, xin)
585
+ ):
586
+ x = self.stems[k](x)
587
+ cls_x = x
588
+ reg_x = x
589
+
590
+ cls_feat = cls_conv(cls_x)
591
+ cls_output = self.cls_preds[k](cls_feat)
592
+ reg_feat = reg_conv(reg_x)
593
+ reg_output = self.reg_preds[k](reg_feat)
594
+ obj_output = self.obj_preds[k](reg_feat)
595
+
596
+ output = torch.cat([reg_output, obj_output, cls_output], 1)
597
+ output, grid = self.get_output_and_grid(output, k, stride_this_level, xin[0].type())
598
+ x_shifts.append(grid[:, :, 0])
599
+ y_shifts.append(grid[:, :, 1])
600
+ expanded_strides.append(
601
+ torch.full((1, grid.shape[1]), stride_this_level).type_as(xin[0])
602
+ )
603
+ outputs.append(output)
604
+
605
+ outputs = torch.cat(outputs, 1)
606
+ bbox_preds = outputs[:, :, :4] # [batch, n_anchors_all, 4]
607
+ obj_preds = outputs[:, :, 4:5] # [batch, n_anchors_all, 1]
608
+ cls_preds = outputs[:, :, 5:] # [batch, n_anchors_all, n_cls]
609
+
610
+ # calculate targets
611
+ total_num_anchors = outputs.shape[1]
612
+ x_shifts = torch.cat(x_shifts, 1) # [1, n_anchors_all]
613
+ y_shifts = torch.cat(y_shifts, 1) # [1, n_anchors_all]
614
+ expanded_strides = torch.cat(expanded_strides, 1)
615
+
616
+ nlabel = (labels.sum(dim=2) > 0).sum(dim=1) # number of objects
617
+ for batch_idx, (img, num_gt, label) in enumerate(zip(imgs, nlabel, labels)):
618
+ img = imgs[batch_idx].permute(1, 2, 0).to(torch.uint8)
619
+ num_gt = int(num_gt)
620
+ if num_gt == 0:
621
+ fg_mask = outputs.new_zeros(total_num_anchors).bool()
622
+ else:
623
+ gt_bboxes_per_image = label[:num_gt, 1:5]
624
+ gt_classes = label[:num_gt, 0]
625
+ bboxes_preds_per_image = bbox_preds[batch_idx]
626
+ _, fg_mask, _, matched_gt_inds, _ = self.get_assignments( # noqa
627
+ batch_idx, num_gt, gt_bboxes_per_image, gt_classes,
628
+ bboxes_preds_per_image, expanded_strides, x_shifts,
629
+ y_shifts, cls_preds, obj_preds,
630
+ )
631
+
632
+ img = img.cpu().numpy().copy() # copy is crucial here
633
+ coords = torch.stack([
634
+ ((x_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
635
+ ((y_shifts + 0.5) * expanded_strides).flatten()[fg_mask],
636
+ ], 1)
637
+
638
+ xyxy_boxes = cxcywh2xyxy(gt_bboxes_per_image)
639
+ save_name = save_prefix + str(batch_idx) + ".png"
640
+ img = visualize_assign(img, xyxy_boxes, coords, matched_gt_inds, save_name)
641
+ logger.info(f"save img to {save_name}")
yolox/models/yolox.py CHANGED
@@ -46,3 +46,7 @@ class YOLOX(nn.Module):
46
  outputs = self.head(fpn_outs)
47
 
48
  return outputs
 
 
 
 
 
46
  outputs = self.head(fpn_outs)
47
 
48
  return outputs
49
+
50
+ def visualize(self, x, targets, save_prefix="assign_vis_"):
51
+ fpn_outs = self.backbone(x)
52
+ self.head.visualize_assign_result(fpn_outs, targets, x, save_prefix)
yolox/utils/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
  #!/usr/bin/env python3
2
- # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
5
  from .allreduce_norm import *
 
1
  #!/usr/bin/env python3
 
2
  # Copyright (c) Megvii Inc. All rights reserved.
3
 
4
  from .allreduce_norm import *
yolox/utils/boxes.py CHANGED
@@ -1,5 +1,4 @@
1
  #!/usr/bin/env python3
2
- # -*- coding:utf-8 -*-
3
  # Copyright (c) Megvii Inc. All rights reserved.
4
 
5
  import numpy as np
@@ -15,6 +14,7 @@ __all__ = [
15
  "adjust_box_anns",
16
  "xyxy2xywh",
17
  "xyxy2cxcywh",
 
18
  ]
19
 
20
 
@@ -133,3 +133,11 @@ def xyxy2cxcywh(bboxes):
133
  bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
134
  bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
135
  return bboxes
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
 
2
  # Copyright (c) Megvii Inc. All rights reserved.
3
 
4
  import numpy as np
 
14
  "adjust_box_anns",
15
  "xyxy2xywh",
16
  "xyxy2cxcywh",
17
+ "cxcywh2xyxy",
18
  ]
19
 
20
 
 
133
  bboxes[:, 0] = bboxes[:, 0] + bboxes[:, 2] * 0.5
134
  bboxes[:, 1] = bboxes[:, 1] + bboxes[:, 3] * 0.5
135
  return bboxes
136
+
137
+
138
+ def cxcywh2xyxy(bboxes):
139
+ bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] * 0.5
140
+ bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] * 0.5
141
+ bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
142
+ bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
143
+ return bboxes
yolox/utils/demo_utils.py CHANGED
@@ -2,10 +2,51 @@
2
  # Copyright (c) Megvii Inc. All rights reserved.
3
 
4
  import os
 
5
 
 
6
  import numpy as np
7
 
8
- __all__ = ["mkdir", "nms", "multiclass_nms", "demo_postprocess"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def mkdir(path):
 
2
  # Copyright (c) Megvii Inc. All rights reserved.
3
 
4
  import os
5
+ import random
6
 
7
+ import cv2
8
  import numpy as np
9
 
10
+ __all__ = [
11
+ "mkdir", "nms", "multiclass_nms", "demo_postprocess", "random_color", "visualize_assign"
12
+ ]
13
+
14
+
15
+ def random_color():
16
+ return random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)
17
+
18
+
19
+ def visualize_assign(img, boxes, coords, match_results, save_name=None) -> np.ndarray:
20
+ """visualize label assign result.
21
+
22
+ Args:
23
+ img: img to visualize
24
+ boxes: gt boxes in xyxy format
25
+ coords: coords of matched anchors
26
+ match_results: match results of each gt box and coord.
27
+ save_name: name of save image, if None, image will not be saved. Default: None.
28
+ """
29
+ for box_id, box in enumerate(boxes):
30
+ x1, y1, x2, y2 = box
31
+ color = random_color()
32
+ assign_coords = coords[match_results == box_id]
33
+ if assign_coords.numel() == 0:
34
+ # unmatched boxes are red
35
+ color = (0, 0, 255)
36
+ cv2.putText(
37
+ img, "unmatched", (int(x1), int(y1) - 5),
38
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1
39
+ )
40
+ else:
41
+ for coord in assign_coords:
42
+ # draw assigned anchor
43
+ cv2.circle(img, (int(coord[0]), int(coord[1])), 3, color, -1)
44
+ cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
45
+
46
+ if save_name is not None:
47
+ cv2.imwrite(save_name, img)
48
+
49
+ return img
50
 
51
 
52
  def mkdir(path):