henry000 commited on
Commit
323161f
·
1 Parent(s): 8fe77d2

♻️ [Refactor] the code of v7 converter, align v9

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +10 -6
yolo/utils/bounding_box_utils.py CHANGED
@@ -319,9 +319,8 @@ class Anc2Box:
319
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
320
  self.strides = self.create_auto_anchor(model, image_size)
321
 
322
- self.generate_anchors(image_size)
323
- self.anchor_grid = [anchor_grid.to(device) for anchor_grid in self.anchor_grid]
324
  self.head_num = len(anchor_cfg.anchor)
 
325
  self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
326
  self.anchor_num = self.anchor_scale.size(2)
327
  self.class_num = model.num_classes
@@ -330,17 +329,22 @@ class Anc2Box:
330
  dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
331
  dummy_output = model(dummy_input)
332
  strides = []
333
- for predict_head in dummy_output:
334
  _, _, *anchor_num = predict_head.shape
335
  strides.append(image_size[1] // anchor_num[1])
336
  return strides
337
 
338
  def generate_anchors(self, image_size: List[int]):
339
- self.anchor_grid = []
340
  for stride in self.strides:
341
  W, H = image_size[0] // stride, image_size[1] // stride
342
  anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
343
- self.anchor_grid.append(torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float())
 
 
 
 
 
344
 
345
  def __call__(self, predicts: List[Tensor]):
346
  preds_box, preds_cls, preds_cnf = [], [], []
@@ -348,7 +352,7 @@ class Anc2Box:
348
  predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
349
  pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
350
  pred_box = pred_box.sigmoid()
351
- pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grid[layer_idx]) * self.strides[
352
  layer_idx
353
  ]
354
  pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
 
319
  logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
320
  self.strides = self.create_auto_anchor(model, image_size)
321
 
 
 
322
  self.head_num = len(anchor_cfg.anchor)
323
+ self.anchor_grids = self.generate_anchors(image_size)
324
  self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
325
  self.anchor_num = self.anchor_scale.size(2)
326
  self.class_num = model.num_classes
 
329
  dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
330
  dummy_output = model(dummy_input)
331
  strides = []
332
+ for predict_head in dummy_output["Main"]:
333
  _, _, *anchor_num = predict_head.shape
334
  strides.append(image_size[1] // anchor_num[1])
335
  return strides
336
 
337
  def generate_anchors(self, image_size: List[int]):
338
+ anchor_grids = []
339
  for stride in self.strides:
340
  W, H = image_size[0] // stride, image_size[1] // stride
341
  anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
342
+ anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device)
343
+ anchor_grids.append(anchor_grid)
344
+ return anchor_grids
345
+
346
+ def update(self, image_size):
347
+ self.anchor_grid = self.generate_anchors(image_size)
348
 
349
  def __call__(self, predicts: List[Tensor]):
350
  preds_box, preds_cls, preds_cnf = [], [], []
 
352
  predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
353
  pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
354
  pred_box = pred_box.sigmoid()
355
+ pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[
356
  layer_idx
357
  ]
358
  pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]