f-fl0
commited on
π [Fix] bbox_nms function (#76)
Browse files* Use valid_con instead of valid_cls in batched_nms
* Update test_bbox_nms
* Improve comments
* Apply black
tests/test_utils/test_bounding_box_utils.py
CHANGED
@@ -146,23 +146,64 @@ def test_anc2box_autoanchor(inference_v7_cfg: Config):
|
|
146 |
|
147 |
|
148 |
def test_bbox_nms():
|
149 |
-
cls_dist = tensor(
|
150 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
dtype=float32,
|
155 |
)
|
|
|
156 |
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
|
157 |
|
158 |
-
|
159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
[
|
161 |
-
[
|
162 |
-
[0
|
163 |
-
]
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
166 |
|
167 |
output = bbox_nms(cls_dist, bbox, nms_cfg)
|
168 |
|
|
|
146 |
|
147 |
|
148 |
def test_bbox_nms():
|
149 |
+
cls_dist = torch.tensor(
|
150 |
+
[
|
151 |
+
[
|
152 |
+
[0.7, 0.1, 0.2], # High confidence, class 0
|
153 |
+
[0.3, 0.6, 0.1], # High confidence, class 1
|
154 |
+
[-3.0, -2.0, -1.0], # low confidence, class 2
|
155 |
+
[0.6, 0.2, 0.2], # Medium confidence, class 0
|
156 |
+
],
|
157 |
+
[
|
158 |
+
[0.55, 0.25, 0.2], # Medium confidence, class 0
|
159 |
+
[-4.0, -0.5, -2.0], # low confidence, class 1
|
160 |
+
[0.15, 0.2, 0.65], # Medium confidence, class 2
|
161 |
+
[0.8, 0.1, 0.1], # High confidence, class 0
|
162 |
+
],
|
163 |
+
],
|
164 |
+
dtype=float32,
|
165 |
)
|
166 |
+
|
167 |
+
bbox = torch.tensor(
|
168 |
+
[
|
169 |
+
[
|
170 |
+
[0, 0, 160, 120], # Overlaps with box 4
|
171 |
+
[160, 120, 320, 240],
|
172 |
+
[0, 120, 160, 240],
|
173 |
+
[16, 12, 176, 132],
|
174 |
+
],
|
175 |
+
[
|
176 |
+
[0, 0, 160, 120], # Overlaps with box 4
|
177 |
+
[160, 120, 320, 240],
|
178 |
+
[0, 120, 160, 240],
|
179 |
+
[16, 12, 176, 132],
|
180 |
+
],
|
181 |
+
],
|
182 |
dtype=float32,
|
183 |
)
|
184 |
+
|
185 |
nms_cfg = NMSConfig(min_confidence=0.5, min_iou=0.5)
|
186 |
|
187 |
+
# Batch 1:
|
188 |
+
# - box 1 is kept with class 0 as it has a higher confidence than box 4 i.e. box 4 is filtered out
|
189 |
+
# - box 2 is kept with class 1
|
190 |
+
# - box 3 is rejected by the confidence filter
|
191 |
+
# Batch 2:
|
192 |
+
# - box 4 is kept with class 0 as it has a higher confidence than box 1 i.e. box 1 is filtered out
|
193 |
+
# - box 2 is rejected by the confidence filter
|
194 |
+
# - box 3 is kept with class 2
|
195 |
+
expected_output = torch.tensor(
|
196 |
+
[
|
197 |
[
|
198 |
+
[0.0, 0.0, 0.0, 160.0, 120.0, 0.6682],
|
199 |
+
[1.0, 160.0, 120.0, 320.0, 240.0, 0.6457],
|
200 |
+
],
|
201 |
+
[
|
202 |
+
[0.0, 16.0, 12.0, 176.0, 132.0, 0.6900],
|
203 |
+
[2.0, 0.0, 120.0, 160.0, 240.0, 0.6570],
|
204 |
+
],
|
205 |
+
]
|
206 |
+
)
|
207 |
|
208 |
output = bbox_nms(cls_dist, bbox, nms_cfg)
|
209 |
|
yolo/utils/bounding_box_utils.py
CHANGED
@@ -387,7 +387,7 @@ def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Opt
|
|
387 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
388 |
|
389 |
batch_idx, *_ = torch.where(valid_mask)
|
390 |
-
nms_idx = batched_nms(valid_box,
|
391 |
predicts_nms = []
|
392 |
for idx in range(cls_dist.size(0)):
|
393 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|
|
|
387 |
valid_box = bbox[valid_mask.repeat(1, 1, 4)].view(-1, 4)
|
388 |
|
389 |
batch_idx, *_ = torch.where(valid_mask)
|
390 |
+
nms_idx = batched_nms(valid_box, valid_con, batch_idx, nms_cfg.min_iou)
|
391 |
predicts_nms = []
|
392 |
for idx in range(cls_dist.size(0)):
|
393 |
instance_idx = nms_idx[idx == batch_idx[nms_idx]]
|