f-fl0 commited on
Commit
44527a6
Β·
unverified Β·
1 Parent(s): f1bfd74

πŸ› [Fix] bbox_nms function (#76)

Browse files

* Use valid_con instead of valid_cls in batched_nms

* Update test_bbox_nms

* Improve comments

* Apply black

tests/test_utils/test_bounding_box_utils.py CHANGED
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
146
 
147
 
148
  def test_bbox_nms():
149
- cls_dist = tensor(
150
- [[[0.1, 0.7, 0.2], [0.6, 0.3, 0.1]], [[0.4, 0.4, 0.2], [0.5, 0.4, 0.1]]] # Example class distribution
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
- bbox = tensor(
153
- [[[50, 50, 100, 100], [60, 60, 110, 110]], [[40, 40, 90, 90], [70, 70, 120, 120]]], # Example bounding boxes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  dtype=float32,
155
  )
 
156
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
157
 
158
- expected_output = [
159
- tensor(
 
 
 
 
 
 
 
 
160
  [
161
- [1.0000, 50.0000, 50.0000, 100.0000, 100.0000, 0.6682],
162
- [0.0000, 60.0000, 60.0000, 110.0000, 110.0000, 0.6457],
163
- ]
164
- )
165
- ]
 
 
 
 
166
 
167
  output = bbox_nms(cls_dist, bbox, nms_cfg)
168
 
 
146
 
147
 
148
  def test_bbox_nms():
149
+ cls_dist = torch.tensor(
150
+ [
151
+ [
152
+ [0.7, 0.1, 0.2], # High confidence, class 0
153
+ [0.3, 0.6, 0.1], # High confidence, class 1
154
+ [-3.0, -2.0, -1.0], # low confidence, class 2
155
+ [0.6, 0.2, 0.2], # Medium confidence, class 0
156
+ ],
157
+ [
158
+ [0.55, 0.25, 0.2], # Medium confidence, class 0
159
+ [-4.0, -0.5, -2.0], # low confidence, class 1
160
+ [0.15, 0.2, 0.65], # Medium confidence, class 2
161
+ [0.8, 0.1, 0.1], # High confidence, class 0
162
+ ],
163
+ ],
164
+ dtype=float32,
165
  )
166
+
167
+ bbox = torch.tensor(
168
+ [
169
+ [
170
+ [0, 0, 160, 120], # Overlaps with box 4
171
+ [160, 120, 320, 240],
172
+ [0, 120, 160, 240],
173
+ [16, 12, 176, 132],
174
+ ],
175
+ [
176
+ [0, 0, 160, 120], # Overlaps with box 4
177
+ [160, 120, 320, 240],
178
+ [0, 120, 160, 240],
179
+ [16, 12, 176, 132],
180
+ ],
181
+ ],
182
  dtype=float32,
183
  )
184
+
185
  nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
186
 
187
+ # Batch 1:
188
+ # - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
189
+ # - box 2 is kept with class 1
190
+ # - box 3 is rejected by the confidence filter
191
+ # Batch 2:
192
+ # - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
193
+ # - box 2 is rejected by the confidence filter
194
+ # - box 3 is kept with class 2
195
+ expected_output = torch.tensor(
196
+ [
197
  [
198
+ [0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
199
+ [1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
200
+ ],
201
+ [
202
+ [0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
203
+ [2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
204
+ ],
205
+ ]
206
+ )
207
 
208
  output = bbox_nms(cls_dist, bbox, nms_cfg)
209
 
yolo/utils/bounding_box_utils.py CHANGED
@@ -387,7 +387,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
387
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
388
 
389
  batch_idx, *_ = torch.where(valid_mask)
390
- nms_idx = batched_nms(valid_box, valid_cls, batch_idx, nms_cfg.min_iou)
391
  predicts_nms = []
392
  for idx in range(cls_dist.size(0)):
393
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]
 
387
  valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
388
 
389
  batch_idx, *_ = torch.where(valid_mask)
390
+ nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
391
  predicts_nms = []
392
  for idx in range(cls_dist.size(0)):
393
  instance_idx = nms_idx[idx == batch_idx[nms_idx]]