♻️ [Refactor] the code of v7 converter, align v9
Browse files
yolo/utils/bounding_box_utils.py
CHANGED
@@ -319,9 +319,8 @@ class Anc2Box:
|
|
319 |
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
|
320 |
self.strides = self.create_auto_anchor(model, image_size)
|
321 |
|
322 |
-
self.generate_anchors(image_size)
|
323 |
-
self.anchor_grid = [anchor_grid.to(device) for anchor_grid in self.anchor_grid]
|
324 |
self.head_num = len(anchor_cfg.anchor)
|
|
|
325 |
self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
|
326 |
self.anchor_num = self.anchor_scale.size(2)
|
327 |
self.class_num = model.num_classes
|
@@ -330,17 +329,22 @@ class Anc2Box:
|
|
330 |
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
|
331 |
dummy_output = model(dummy_input)
|
332 |
strides = []
|
333 |
-
for predict_head in dummy_output:
|
334 |
_, _, *anchor_num = predict_head.shape
|
335 |
strides.append(image_size[1] // anchor_num[1])
|
336 |
return strides
|
337 |
|
338 |
def generate_anchors(self, image_size: List[int]):
|
339 |
-
|
340 |
for stride in self.strides:
|
341 |
W, H = image_size[0] // stride, image_size[1] // stride
|
342 |
anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
|
343 |
-
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
def __call__(self, predicts: List[Tensor]):
|
346 |
preds_box, preds_cls, preds_cnf = [], [], []
|
@@ -348,7 +352,7 @@ class Anc2Box:
|
|
348 |
predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
|
349 |
pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
|
350 |
pred_box = pred_box.sigmoid()
|
351 |
-
pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.
|
352 |
layer_idx
|
353 |
]
|
354 |
pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
|
|
|
319 |
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
|
320 |
self.strides = self.create_auto_anchor(model, image_size)
|
321 |
|
|
|
|
|
322 |
self.head_num = len(anchor_cfg.anchor)
|
323 |
+
self.anchor_grids = self.generate_anchors(image_size)
|
324 |
self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
|
325 |
self.anchor_num = self.anchor_scale.size(2)
|
326 |
self.class_num = model.num_classes
|
|
|
329 |
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
|
330 |
dummy_output = model(dummy_input)
|
331 |
strides = []
|
332 |
+
for predict_head in dummy_output["Main"]:
|
333 |
_, _, *anchor_num = predict_head.shape
|
334 |
strides.append(image_size[1] // anchor_num[1])
|
335 |
return strides
|
336 |
|
337 |
def generate_anchors(self, image_size: List[int]):
|
338 |
+
anchor_grids = []
|
339 |
for stride in self.strides:
|
340 |
W, H = image_size[0] // stride, image_size[1] // stride
|
341 |
anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
|
342 |
+
anchor_grid = torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float().to(self.device)
|
343 |
+
anchor_grids.append(anchor_grid)
|
344 |
+
return anchor_grids
|
345 |
+
|
346 |
+
def update(self, image_size):
|
347 |
+
self.anchor_grid = self.generate_anchors(image_size)
|
348 |
|
349 |
def __call__(self, predicts: List[Tensor]):
|
350 |
preds_box, preds_cls, preds_cnf = [], [], []
|
|
|
352 |
predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
|
353 |
pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
|
354 |
pred_box = pred_box.sigmoid()
|
355 |
+
pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grids[layer_idx]) * self.strides[
|
356 |
layer_idx
|
357 |
]
|
358 |
pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
|