🔨 [Add] Anc2Box conveter, for YOLOv7's output
Browse files- yolo/__init__.py +2 -1
- yolo/config/config.py +3 -1
- yolo/model/yolo.py +0 -1
- yolo/utils/bounding_box_utils.py +61 -6
- yolo/utils/model_utils.py +6 -4
yolo/__init__.py
CHANGED
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
|
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
6 |
-
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
9 |
from yolo.utils.model_utils import PostProccess
|
@@ -16,6 +16,7 @@ all = [
|
|
16 |
"validate_log_directory",
|
17 |
"draw_bboxes",
|
18 |
"Vec2Box",
|
|
|
19 |
"bbox_nms",
|
20 |
"AugmentationComposer",
|
21 |
"create_dataloader",
|
|
|
3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
4 |
from yolo.tools.drawer import draw_bboxes
|
5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
6 |
+
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms
|
7 |
from yolo.utils.deploy_utils import FastModelLoader
|
8 |
from yolo.utils.logging_utils import custom_logger
|
9 |
from yolo.utils.model_utils import PostProccess
|
|
|
16 |
"validate_log_directory",
|
17 |
"draw_bboxes",
|
18 |
"Vec2Box",
|
19 |
+
"Anc2Box",
|
20 |
"bbox_nms",
|
21 |
"AugmentationComposer",
|
22 |
"create_dataloader",
|
yolo/config/config.py
CHANGED
@@ -6,8 +6,10 @@ from torch import nn
|
|
6 |
|
7 |
@dataclass
|
8 |
class AnchorConfig:
|
9 |
-
reg_max: int
|
10 |
strides: List[int]
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
@dataclass
|
|
|
6 |
|
7 |
@dataclass
|
8 |
class AnchorConfig:
|
|
|
9 |
strides: List[int]
|
10 |
+
reg_max: Optional[int]
|
11 |
+
anchor_num: Optional[int]
|
12 |
+
anchor: List[List[int]]
|
13 |
|
14 |
|
15 |
@dataclass
|
yolo/model/yolo.py
CHANGED
@@ -26,7 +26,6 @@ class YOLO(nn.Module):
|
|
26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
28 |
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
|
29 |
-
self.strides = getattr(model_cfg.anchor, "strides", None)
|
30 |
self.build_model(model_cfg.model)
|
31 |
|
32 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
|
|
26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
28 |
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
|
|
|
29 |
self.build_model(model_cfg.model)
|
30 |
|
31 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1 |
import math
|
2 |
-
from typing import Dict, List, Tuple
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from loguru import logger
|
8 |
-
from torch import Tensor, arange
|
9 |
from torchvision.ops import batched_nms
|
10 |
|
11 |
-
from yolo.config.config import MatcherConfig, ModelConfig, NMSConfig
|
12 |
from yolo.model.yolo import YOLO
|
13 |
|
14 |
|
@@ -308,9 +308,64 @@ class Vec2Box:
|
|
308 |
return preds_cls, preds_anc, preds_box
|
309 |
|
310 |
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
# filter class by confidence
|
316 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
|
|
1 |
import math
|
2 |
+
from typing import Dict, List, Optional, Tuple
|
3 |
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from loguru import logger
|
8 |
+
from torch import Tensor, arange, tensor
|
9 |
from torchvision.ops import batched_nms
|
10 |
|
11 |
+
from yolo.config.config import AnchorConfig, MatcherConfig, ModelConfig, NMSConfig
|
12 |
from yolo.model.yolo import YOLO
|
13 |
|
14 |
|
|
|
308 |
return preds_cls, preds_anc, preds_box
|
309 |
|
310 |
|
311 |
+
class Anc2Box:
|
312 |
+
def __init__(self, model: YOLO, anchor_cfg: AnchorConfig, image_size, device):
|
313 |
+
self.device = device
|
314 |
+
|
315 |
+
if hasattr(anchor_cfg, "strides"):
|
316 |
+
logger.info(f"🈶 Found stride of model {anchor_cfg.strides}")
|
317 |
+
self.strides = anchor_cfg.strides
|
318 |
+
else:
|
319 |
+
logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
|
320 |
+
self.strides = self.create_auto_anchor(model, image_size)
|
321 |
+
|
322 |
+
self.generate_anchors(image_size)
|
323 |
+
self.anchor_grid = [anchor_grid.to(device) for anchor_grid in self.anchor_grid]
|
324 |
+
self.head_num = len(anchor_cfg.anchor)
|
325 |
+
self.anchor_scale = tensor(anchor_cfg.anchor, device=device).view(self.head_num, 1, -1, 1, 1, 2)
|
326 |
+
self.anchor_num = self.anchor_scale.size(2)
|
327 |
+
self.class_num = model.num_classes
|
328 |
+
|
329 |
+
def create_auto_anchor(self, model: YOLO, image_size):
|
330 |
+
dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
|
331 |
+
dummy_output = model(dummy_input)
|
332 |
+
strides = []
|
333 |
+
for predict_head in dummy_output:
|
334 |
+
_, _, *anchor_num = predict_head.shape
|
335 |
+
strides.append(image_size[1] // anchor_num[1])
|
336 |
+
return strides
|
337 |
+
|
338 |
+
def generate_anchors(self, image_size: List[int]):
|
339 |
+
self.anchor_grid = []
|
340 |
+
for stride in self.strides:
|
341 |
+
W, H = image_size[0] // stride, image_size[1] // stride
|
342 |
+
anchor_h, anchor_w = torch.meshgrid([torch.arange(H), torch.arange(W)], indexing="ij")
|
343 |
+
self.anchor_grid.append(torch.stack((anchor_w, anchor_h), 2).view((1, 1, H, W, 2)).float())
|
344 |
+
|
345 |
+
def __call__(self, predicts: List[Tensor]):
|
346 |
+
preds_box, preds_cls, preds_cnf = [], [], []
|
347 |
+
for layer_idx, predict in enumerate(predicts):
|
348 |
+
predict = rearrange(predict, "B (L C) h w -> B L h w C", L=self.anchor_num)
|
349 |
+
pred_box, pred_cnf, pred_cls = predict.split((4, 1, self.class_num), dim=-1)
|
350 |
+
pred_box = pred_box.sigmoid()
|
351 |
+
pred_box[..., 0:2] = (pred_box[..., 0:2] * 2.0 - 0.5 + self.anchor_grid[layer_idx]) * self.strides[
|
352 |
+
layer_idx
|
353 |
+
]
|
354 |
+
pred_box[..., 2:4] = (pred_box[..., 2:4] * 2) ** 2 * self.anchor_scale[layer_idx]
|
355 |
+
preds_box.append(rearrange(pred_box, "B L h w A -> B (L h w) A"))
|
356 |
+
preds_cls.append(rearrange(pred_cls, "B L h w C -> B (L h w) C"))
|
357 |
+
preds_cnf.append(rearrange(pred_cnf, "B L h w C -> B (L h w) C"))
|
358 |
+
|
359 |
+
preds_box = torch.concat(preds_box, dim=1)
|
360 |
+
preds_cls = torch.concat(preds_cls, dim=1)
|
361 |
+
preds_cnf = torch.concat(preds_cnf, dim=1)
|
362 |
+
|
363 |
+
preds_box = transform_bbox(preds_box, "xycwh -> xyxy")
|
364 |
+
return preds_cls, None, preds_box, preds_cnf.sigmoid()
|
365 |
+
|
366 |
+
|
367 |
+
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor]):
|
368 |
+
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
|
369 |
|
370 |
# filter class by confidence
|
371 |
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
yolo/utils/model_utils.py
CHANGED
@@ -103,15 +103,17 @@ class PostProccess:
|
|
103 |
scale back the prediction and do nms for pred_bbox
|
104 |
"""
|
105 |
|
106 |
-
def __init__(self,
|
107 |
-
self.
|
108 |
self.nms = nms_cfg
|
109 |
|
110 |
def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
|
111 |
-
|
|
|
|
|
112 |
if rev_tensor is not None:
|
113 |
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
|
114 |
-
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms)
|
115 |
return pred_bbox
|
116 |
|
117 |
|
|
|
103 |
scale back the prediction and do nms for pred_bbox
|
104 |
"""
|
105 |
|
106 |
+
def __init__(self, converter, nms_cfg: NMSConfig) -> None:
|
107 |
+
self.converter = converter
|
108 |
self.nms = nms_cfg
|
109 |
|
110 |
def __call__(self, predict, rev_tensor: Optional[Tensor] = None):
|
111 |
+
prediction = self.converter(predict["Main"])
|
112 |
+
pred_class, _, pred_bbox = prediction[:3]
|
113 |
+
pred_conf = prediction[3] if len(prediction) == 4 else None
|
114 |
if rev_tensor is not None:
|
115 |
pred_bbox = (pred_bbox - rev_tensor[:, None, 1:]) / rev_tensor[:, 0:1, None]
|
116 |
+
pred_bbox = bbox_nms(pred_class, pred_bbox, self.nms, pred_conf)
|
117 |
return pred_bbox
|
118 |
|
119 |
|