henry000 commited on
Commit
a87899e
·
1 Parent(s): b23f927

✨ [Add] NMS proccess when inference image!

Browse files
Files changed (1) hide show
  1. yolo/tools/bbox_helper.py +22 -0
yolo/tools/bbox_helper.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from torch import Tensor
 
8
 
9
  from yolo.config.config import Config, MatcherConfig
10
 
@@ -288,3 +289,24 @@ class BoxMatcher:
288
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
289
 
290
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from torch import Tensor
8
+ from torchvision.ops import batched_nms
9
 
10
  from yolo.config.config import Config, MatcherConfig
11
 
 
289
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
290
 
291
  return torch.cat([align_cls, align_bbox], dim=-1), valid_mask.bool()
292
+
293
+
294
+ def bbox_nms(predicts: Tensor, min_conf: float = 0, min_iou: float = 0.5):
295
+ cls_dist, bbox = predicts.split([80, 4], dim=-1)
296
+
297
+ # filter class by confidence
298
+ cls_val, cls_idx = cls_dist.max(dim=-1, keepdim=True)
299
+ valid_mask = cls_val > min_conf
300
+ valid_cls = cls_idx[valid_mask]
301
+ valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
302
+
303
+ batch_idx, *_ = torch.where(valid_mask)
304
+ nms_idx = batched_nms(valid_box, valid_cls, batch_idx, min_iou)
305
+ predicts_nms = []
306
+ for idx in range(batch_idx.max() + 1):
307
+ instance_idx = nms_idx[idx == batch_idx[nms_idx]]
308
+
309
+ predict_nms = torch.cat([valid_cls[instance_idx][:, None], valid_box[instance_idx]], dim=-1)
310
+
311
+ predicts_nms.append(predict_nms)
312
+ return predicts_nms