henry000 commited on
Commit
d5a73bd
·
1 Parent(s): d13852b

🔨 [Add] Anc2Box conveter, for YOLOv7's output

Browse files
yolo/__init__.py CHANGED
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
- from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
  from yolo.utils.model_utils import PostProccess
@@ -16,6 +16,7 @@ all = [
16
  "validate_log_directory",
17
  "draw_bboxes",
18
  "Vec2Box",
 
19
  "bbox_nms",
20
  "AugmentationComposer",
21
  "create_dataloader",
 
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
+ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
  from yolo.utils.model_utils import PostProccess
 
16
  "validate_log_directory",
17
  "draw_bboxes",
18
  "Vec2Box",
19
+ "Anc2Box",
20
  "bbox_nms",
21
  "AugmentationComposer",
22
  "create_dataloader",
yolo/config/config.py CHANGED
@@ -6,8 +6,10 @@ from torch import nn
6
 
7
  @dataclass
8
  class AnchorConfig:
9
- reg_max: int
10
  strides: List[int]
 
 
 
11
 
12
 
13
  @dataclass
 
6
 
7
  @dataclass
8
  class AnchorConfig:
 
9
  strides: List[int]
10
+ reg_max: Optional[int]
11
+ anchor_num: Optional[int]
12
+ anchor: List[List[int]]
13
 
14
 
15
  @dataclass
yolo/model/yolo.py CHANGED
@@ -26,7 +26,6 @@ class YOLO(nn.Module):
26
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
27
  self.model: List[YOLOLayer] = nn.ModuleList()
28
  self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
29
- self.strides = getattr(model_cfg.anchor, "strides", None)
30
  self.build_model(model_cfg.model)
31
 
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
 
26
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
27
  self.model: List[YOLOLayer] = nn.ModuleList()
28
  self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
 
29
  self.build_model(model_cfg.model)
30
 
31
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
yolo/utils/bounding_box_utils.py CHANGED
@@ -1,14 +1,14 @@
1
  import math
2
- from typing import Dict, List, Tuple
3
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from loguru import logger
8
- from torch import Tensor, arange
9
  from torchvision.ops import batched_nms
10
 
11
- from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
12
  from yolo.model.yolo import YOLO
13
 
14
 
@@ -308,9 +308,64 @@ class Vec2Box:
308
  return preds_cls, preds_anc, preds_box
309
 
310
 
311
- def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig):
312
- # TODO change function to class or set 80 to class_num instead of a number
313
- cls_dist = cls_dist.sigmoid()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
315
  # filter class by confidence
316
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
 
1
  import math
2
+ from typing import Dict, List, Optional, Tuple
3
 
4
  import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from loguru import logger
8
+ from torch import Tensor, arange, tensor
9
  from torchvision.ops import batched_nms
10
 
11
+ from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
12
  from yolo.model.yolo import YOLO
13
 
14
 
 
308
  return preds_cls, preds_anc, preds_box
309
 
310
 
311
+ class Anc2Box:
312
+ def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
313
+ self.device = device
314
+
315
+ if hasattr(anchor_cfg, "strides"):
316
+ logger.info(f"🈶 Found stride of model {anchor_cfg.strides}")
317
+ self.strides = anchor_cfg.strides
318
+ else:
319
+ logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
320
+ self.strides = self.create_auto_anchor(model, image_size)
321
+
322
+ self.generate_anchors(image_size)
323
+ self.anchor_grid = [anchor_grid.to(device) for anchor_grid in self.anchor_grid]
324
+ self.head_num = len(anchor_cfg.anchor)
325
+ self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
326
+ self.anchor_num = self.anchor_scale.size(2)
327
+ self.class_num = model.num_classes
328
+
329
+ def create_auto_anchor(self, model: YOLO, image_size):
330
+ dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
331
+ dummy_output = model(dummy_input)
332
+ strides = []
333
+ for predict_head in dummy_output:
334
+ _, _, *anchor_num = predict_head.shape
335
+ strides.append(image_size[1] // anchor_num[1])
336
+ return strides
337
+
338
+ def generate_anchors(self, image_size: List[int]):
339
+ self.anchor_grid = []
340
+ for stride in self.strides:
341
+ W, H = image_size[0] // stride, image_size[1] // stride
342
+ anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
343
+ self.anchor_grid.append(torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float())
344
+
345
+ def __call__(self, predicts: List[Tensor]):
346
+ preds_box, preds_cls, preds_cnf = [], [], []
347
+ for layer_idx, predict in enumerate(predicts):
348
+ predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
349
+ pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
350
+ pred_box = pred_box.sigmoid()
351
+ pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grid[layer_idx]) * self.strides[
352
+ layer_idx
353
+ ]
354
+ pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
355
+ preds_box.append(rearrange(pred_box, "B L h w A -> B (L h w) A"))
356
+ preds_cls.append(rearrange(pred_cls, "B L h w C -> B (L h w) C"))
357
+ preds_cnf.append(rearrange(pred_cnf, "B L h w C -> B (L h w) C"))
358
+
359
+ preds_box = torch.concat(preds_box, dim=1)
360
+ preds_cls = torch.concat(preds_cls, dim=1)
361
+ preds_cnf = torch.concat(preds_cnf, dim=1)
362
+
363
+ preds_box = transform_bbox(preds_box, "xycwh -> xyxy")
364
+ return preds_cls, None, preds_box, preds_cnf.sigmoid()
365
+
366
+
367
+ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor]):
368
+ cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
369
 
370
  # filter class by confidence
371
  cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
yolo/utils/model_utils.py CHANGED
@@ -103,15 +103,17 @@ class PostProccess:
103
  scale back the prediction and do nms for pred_bbox
104
  """
105
 
106
- def __init__(self, vec2box, nms_cfg: NMSConfig) -> None:
107
- self.vec2box = vec2box
108
  self.nms = nms_cfg
109
 
110
  def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
111
- pred_class, _, pred_bbox = self.vec2box(predict["Main"])
 
 
112
  if rev_tensor is not None:
113
  pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
114
- pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
115
  return pred_bbox
116
 
117
 
 
103
  scale back the prediction and do nms for pred_bbox
104
  """
105
 
106
+ def __init__(self, converter, nms_cfg: NMSConfig) -> None:
107
+ self.converter = converter
108
  self.nms = nms_cfg
109
 
110
  def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
111
+ prediction = self.converter(predict["Main"])
112
+ pred_class, _, pred_bbox = prediction[:3]
113
+ pred_conf = prediction[3] if len(prediction) == 4 else None
114
  if rev_tensor is not None:
115
  pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
116
+ pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
117
  return pred_bbox
118
 
119