✨ [Update] BoxMatcher matching criteria (#125)
Browse files* ✨ [Update] BoxMatcher matching criteria
Added an additional validity criterium in get_valid_matrix, which masks out anchors from targets, that are too large to predict with the given reg_max and stride values.
Implemented a new function: ensure_one_anchor, which adds a single best suited anchor for valid targets without valid anchors. It is a fallback mechanism, which enables too small or too large targets to be trained to be predicted as well, even if not perfectly.
Fixed the filter_duplicate function to use the topk_masked iou_mat for the selection, which previously sometimes matched invalid targets to anchors with duplicates.
Updated docsstrings across the BoxMatcher functions to match the changes.
* 🔨 [Update] F.one_hot calls in BoxMatcher
to a more efficient solution, without using torch.nn.functional.
torch.nn.functional.one_hot always returns a long tensor, consuming a lot of memory for tensors, which are only used as masks.
- yolo/tools/loss_functions.py +1 -1
- yolo/utils/bounding_box_utils.py +71 -33
@@ -75,7 +75,7 @@ class YOLOLoss:
|
|
75 |
self.dfl = DFLoss(vec2box, reg_max)
|
76 |
self.iou = BoxLoss()
|
77 |
|
78 |
-
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box
|
79 |
|
80 |
def separate_anchor(self, anchors):
|
81 |
"""
|
|
|
75 |
self.dfl = DFLoss(vec2box, reg_max)
|
76 |
self.iou = BoxLoss()
|
77 |
|
78 |
+
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max)
|
79 |
|
80 |
def separate_anchor(self, anchors):
|
81 |
"""
|
@@ -2,7 +2,6 @@ import math
|
|
2 |
from typing import Dict, List, Optional, Tuple, Union
|
3 |
|
4 |
import torch
|
5 |
-
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from torch import Tensor, tensor
|
8 |
from torchmetrics.detection import MeanAveragePrecision
|
@@ -143,28 +142,35 @@ def generate_anchors(image_size: List[int], strides: List[int]):
|
|
143 |
|
144 |
|
145 |
class BoxMatcher:
|
146 |
-
def __init__(self, cfg: MatcherConfig, class_num: int,
|
147 |
self.class_num = class_num
|
148 |
-
self.
|
|
|
149 |
for attr_name in cfg:
|
150 |
setattr(self, attr_name, cfg[attr_name])
|
151 |
|
152 |
def get_valid_matrix(self, target_bbox: Tensor):
|
153 |
"""
|
154 |
-
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
|
|
|
155 |
|
156 |
Args:
|
157 |
-
target_bbox [batch x targets x 4]: The bounding box of each
|
158 |
Returns:
|
159 |
-
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
|
|
|
160 |
"""
|
161 |
-
|
162 |
-
anchors = self.
|
163 |
anchors_x, anchors_y = anchors.unbind(dim=3)
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
168 |
|
169 |
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
|
170 |
"""
|
@@ -194,40 +200,68 @@ class BoxMatcher:
|
|
194 |
"""
|
195 |
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
|
196 |
|
197 |
-
def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
|
198 |
"""
|
199 |
Filter the top-k suitability of targets for each anchor.
|
200 |
|
201 |
Args:
|
202 |
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
|
|
203 |
topk (int, optional): Number of top scores to retain per anchor.
|
204 |
|
205 |
Returns:
|
206 |
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
|
207 |
-
|
208 |
"""
|
209 |
-
|
|
|
210 |
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
|
211 |
topk_targets.scatter_(dim=-1, index=indices, src=values)
|
212 |
-
|
213 |
-
return topk_targets,
|
214 |
|
215 |
-
def
|
216 |
"""
|
217 |
-
|
|
|
|
|
218 |
|
219 |
Args:
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
|
222 |
Returns:
|
223 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
|
|
|
|
224 |
"""
|
225 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
|
|
|
|
231 |
|
232 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
233 |
"""Matches each target to the most suitable anchor.
|
@@ -273,17 +307,21 @@ class BoxMatcher:
|
|
273 |
# get cls matrix (cls prob with each gt class and each predict class)
|
274 |
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
275 |
|
276 |
-
target_matrix =
|
277 |
|
278 |
# choose topk
|
279 |
-
topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
|
|
|
|
|
|
|
280 |
|
281 |
# delete one anchor pred assign to mutliple gts
|
282 |
-
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask
|
283 |
|
284 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
285 |
-
|
286 |
-
align_cls =
|
|
|
287 |
|
288 |
# normalize class ditribution
|
289 |
iou_mat *= topk_mask
|
@@ -294,7 +332,7 @@ class BoxMatcher:
|
|
294 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
295 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
296 |
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
297 |
-
return anchor_matched_targets, valid_mask
|
298 |
|
299 |
|
300 |
class Vec2Box:
|
@@ -305,7 +343,7 @@ class Vec2Box:
|
|
305 |
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
306 |
self.strides = anchor_cfg.strides
|
307 |
else:
|
308 |
-
logger.info("
|
309 |
self.strides = self.create_auto_anchor(model, image_size)
|
310 |
|
311 |
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
@@ -358,7 +396,7 @@ class Anc2Box:
|
|
358 |
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
359 |
self.strides = anchor_cfg.strides
|
360 |
else:
|
361 |
-
logger.info("
|
362 |
self.strides = self.create_auto_anchor(model, image_size)
|
363 |
|
364 |
self.head_num = len(anchor_cfg.anchor)
|
|
|
2 |
from typing import Dict, List, Optional, Tuple, Union
|
3 |
|
4 |
import torch
|
|
|
5 |
from einops import rearrange
|
6 |
from torch import Tensor, tensor
|
7 |
from torchmetrics.detection import MeanAveragePrecision
|
|
|
142 |
|
143 |
|
144 |
class BoxMatcher:
|
145 |
+
def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None:
|
146 |
self.class_num = class_num
|
147 |
+
self.vec2box = vec2box
|
148 |
+
self.reg_max = reg_max
|
149 |
for attr_name in cfg:
|
150 |
setattr(self, attr_name, cfg[attr_name])
|
151 |
|
152 |
def get_valid_matrix(self, target_bbox: Tensor):
|
153 |
"""
|
154 |
+
Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
|
155 |
+
and is able to correctly predict it with the available reg_max value.
|
156 |
|
157 |
Args:
|
158 |
+
target_bbox [batch x targets x 4]: The bounding box of each target.
|
159 |
Returns:
|
160 |
+
[batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
|
161 |
+
with the anchors, and the anchor is able to predict the target.
|
162 |
"""
|
163 |
+
x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3)
|
164 |
+
anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension
|
165 |
anchors_x, anchors_y = anchors.unbind(dim=3)
|
166 |
+
x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x
|
167 |
+
y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y
|
168 |
+
targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1)
|
169 |
+
targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1)
|
170 |
+
min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1)
|
171 |
+
target_on_anchor = min_reg_dist >= 0
|
172 |
+
target_in_reg_max = max_reg_dist <= self.reg_max - 1.01
|
173 |
+
return target_on_anchor & target_in_reg_max
|
174 |
|
175 |
def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
|
176 |
"""
|
|
|
200 |
"""
|
201 |
return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
|
202 |
|
203 |
+
def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
|
204 |
"""
|
205 |
Filter the top-k suitability of targets for each anchor.
|
206 |
|
207 |
Args:
|
208 |
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
209 |
+
grid_mask [batch x targets x anchors]: The match validity for each target to anchors
|
210 |
topk (int, optional): Number of top scores to retain per anchor.
|
211 |
|
212 |
Returns:
|
213 |
topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
|
214 |
+
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
|
215 |
"""
|
216 |
+
masked_target_matrix = grid_mask * target_matrix
|
217 |
+
values, indices = masked_target_matrix.topk(topk, dim=-1)
|
218 |
topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
|
219 |
topk_targets.scatter_(dim=-1, index=indices, src=values)
|
220 |
+
topk_mask = topk_targets > 0
|
221 |
+
return topk_targets, topk_mask
|
222 |
|
223 |
+
def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
|
224 |
"""
|
225 |
+
Ensures each valid target gets at least one anchor matched based on the unmasked target matrix,
|
226 |
+
which enables an otherwise invalid match. This enables too small or too large targets to be
|
227 |
+
learned as well, even if they can't be predicted perfectly.
|
228 |
|
229 |
Args:
|
230 |
+
target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
|
231 |
+
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
|
235 |
+
"""
|
236 |
+
values, indices = target_matrix.max(dim=-1)
|
237 |
+
best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool)
|
238 |
+
best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask)
|
239 |
+
matched_anchor_num = torch.sum(topk_mask, dim=-1)
|
240 |
+
target_without_anchor = (matched_anchor_num == 0) & (values > 0)
|
241 |
+
topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
|
242 |
+
return topk_mask
|
243 |
+
|
244 |
+
def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
|
245 |
+
"""
|
246 |
+
Filter the maximum suitability target index of each anchor based on IoU.
|
247 |
+
|
248 |
+
Args:
|
249 |
+
iou_mat [batch x targets x anchors]: The IoU for each targets-anchors
|
250 |
+
topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
|
251 |
|
252 |
Returns:
|
253 |
unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
|
254 |
+
valid_mask [batch x anchors]: Mask indicating the validity of each anchor
|
255 |
+
topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
|
256 |
"""
|
257 |
duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
|
258 |
+
masked_iou_mat = topk_mask * iou_mat
|
259 |
+
best_indices = masked_iou_mat.argmax(1)[:, None, :]
|
260 |
+
best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool)
|
261 |
+
best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask)
|
262 |
+
topk_mask = torch.where(duplicates, best_target_mask, topk_mask)
|
263 |
+
unique_indices = topk_mask.to(torch.uint8).argmax(dim=1)
|
264 |
+
return unique_indices[..., None], topk_mask.any(dim=1), topk_mask
|
265 |
|
266 |
def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
|
267 |
"""Matches each target to the most suitable anchor.
|
|
|
307 |
# get cls matrix (cls prob with each gt class and each predict class)
|
308 |
cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
|
309 |
|
310 |
+
target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
|
311 |
|
312 |
# choose topk
|
313 |
+
topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk)
|
314 |
+
|
315 |
+
# match best anchor to valid targets without valid anchors
|
316 |
+
topk_mask = self.ensure_one_anchor(target_matrix, topk_mask)
|
317 |
|
318 |
# delete one anchor pred assign to mutliple gts
|
319 |
+
unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
|
320 |
|
321 |
align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
|
322 |
+
align_cls_indices = torch.gather(target_cls, 1, unique_indices)
|
323 |
+
align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num)
|
324 |
+
align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls)
|
325 |
|
326 |
# normalize class ditribution
|
327 |
iou_mat *= topk_mask
|
|
|
332 |
normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
|
333 |
align_cls = align_cls * normalize_term * valid_mask[:, :, None]
|
334 |
anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
|
335 |
+
return anchor_matched_targets, valid_mask
|
336 |
|
337 |
|
338 |
class Vec2Box:
|
|
|
343 |
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
344 |
self.strides = anchor_cfg.strides
|
345 |
else:
|
346 |
+
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
|
347 |
self.strides = self.create_auto_anchor(model, image_size)
|
348 |
|
349 |
anchor_grid, scaler = generate_anchors(image_size, self.strides)
|
|
|
396 |
logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
|
397 |
self.strides = anchor_cfg.strides
|
398 |
else:
|
399 |
+
logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
|
400 |
self.strides = self.create_auto_anchor(model, image_size)
|
401 |
|
402 |
self.head_num = len(anchor_cfg.anchor)
|