✨ [Add] NMS proccess when inference image!
Browse files- yolo/tools/bbox_helper.py +22 -0
yolo/tools/bbox_helper.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from torch import Tensor
|
|
|
8 |
|
9 |
from yolo.config.config import Config, MatcherConfig
|
10 |
|
@@ -288,3 +289,24 @@ class BoxMatcher:
|
|
288 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
289 |
|
290 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from torch import Tensor
|
8 |
+
from torchvision.ops import batched_nms
|
9 |
|
10 |
from yolo.config.config import Config, MatcherConfig
|
11 |
|
|
|
289 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
290 |
|
291 |
return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
|
292 |
+
|
293 |
+
|
294 |
+
def bbox_nms(predicts: Tensor, min_conf: float = 0, min_iou: float = 0.5):
|
295 |
+
cls_dist, bbox = predicts.split([80, 4], dim=-1)
|
296 |
+
|
297 |
+
# filter class by confidence
|
298 |
+
cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
|
299 |
+
valid_mask = cls_val > min_conf
|
300 |
+
valid_cls = cls_idx[valid_mask]
|
301 |
+
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
302 |
+
|
303 |
+
batch_idx, *_ = torch.where(valid_mask)
|
304 |
+
nms_idx = batched_nms(valid_box, valid_cls, batch_idx, min_iou)
|
305 |
+
predicts_nms = []
|
306 |
+
for idx in range(batch_idx.max() + 1):
|
307 |
+
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
308 |
+
|
309 |
+
predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
|
310 |
+
|
311 |
+
predicts_nms.append(predict_nms)
|
312 |
+
return predicts_nms
|