Adam Kunák commited on
Commit
f080104
·
unverified ·
1 Parent(s): da4f0bf

✨ [Update] BoxMatcher matching criteria (#125)

Browse files

* ✨ [Update] BoxMatcher matching criteria

Added an additional validity criterium in get_valid_matrix, which masks out anchors from targets, that are too large to predict with the given reg_max and stride values.

Implemented a new function: ensure_one_anchor, which adds a single best suited anchor for valid targets without valid anchors. It is a fallback mechanism, which enables too small or too large targets to be trained to be predicted as well, even if not perfectly.

Fixed the filter_duplicate function to use the topk_masked iou_mat for the selection, which previously sometimes matched invalid targets to anchors with duplicates.

Updated docsstrings across the BoxMatcher functions to match the changes.

* 🔨 [Update] F.one_hot calls in BoxMatcher

to a more efficient solution, without using torch.nn.functional.

torch.nn.functional.one_hot always returns a long tensor, consuming a lot of memory for tensors, which are only used as masks.

yolo/tools/loss_functions.py CHANGED
@@ -75,7 +75,7 @@ class YOLOLoss:
75
  self.dfl = DFLoss(vec2box, reg_max)
76
  self.iou = BoxLoss()
77
 
78
- self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
79
 
80
  def separate_anchor(self, anchors):
81
  """
 
75
  self.dfl = DFLoss(vec2box, reg_max)
76
  self.iou = BoxLoss()
77
 
78
+ self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box, reg_max)
79
 
80
  def separate_anchor(self, anchors):
81
  """
yolo/utils/bounding_box_utils.py CHANGED
@@ -2,7 +2,6 @@ import math
2
  from typing import Dict, List, Optional, Tuple, Union
3
 
4
  import torch
5
- import torch.nn.functional as F
6
  from einops import rearrange
7
  from torch import Tensor, tensor
8
  from torchmetrics.detection import MeanAveragePrecision
@@ -143,28 +142,35 @@ def generate_anchors(image_size: List[int], strides: List[int]):
143
 
144
 
145
  class BoxMatcher:
146
- def __init__(self, cfg: MatcherConfig, class_num: int, anchors: Tensor) -> None:
147
  self.class_num = class_num
148
- self.anchors = anchors
 
149
  for attr_name in cfg:
150
  setattr(self, attr_name, cfg[attr_name])
151
 
152
  def get_valid_matrix(self, target_bbox: Tensor):
153
  """
154
- Get a boolean mask that indicates whether each target bounding box overlaps with each anchor.
 
155
 
156
  Args:
157
- target_bbox [batch x targets x 4]: The bounding box of each targets.
158
  Returns:
159
- [batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps with anchors.
 
160
  """
161
- Xmin, Ymin, Xmax, Ymax = target_bbox[:, :, None].unbind(3)
162
- anchors = self.anchors[None, None] # add a axis at first, second dimension
163
  anchors_x, anchors_y = anchors.unbind(dim=3)
164
- target_in_x = (Xmin < anchors_x) & (anchors_x < Xmax)
165
- target_in_y = (Ymin < anchors_y) & (anchors_y < Ymax)
166
- target_on_anchor = target_in_x & target_in_y
167
- return target_on_anchor
 
 
 
 
168
 
169
  def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
170
  """
@@ -194,40 +200,68 @@ class BoxMatcher:
194
  """
195
  return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
196
 
197
- def filter_topk(self, target_matrix: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
198
  """
199
  Filter the top-k suitability of targets for each anchor.
200
 
201
  Args:
202
  target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
 
203
  topk (int, optional): Number of top scores to retain per anchor.
204
 
205
  Returns:
206
  topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
207
- topk_masks [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
208
  """
209
- values, indices = target_matrix.topk(topk, dim=-1)
 
210
  topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
211
  topk_targets.scatter_(dim=-1, index=indices, src=values)
212
- topk_masks = topk_targets > 0
213
- return topk_targets, topk_masks
214
 
215
- def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor, grid_mask: Tensor):
216
  """
217
- Filter the maximum suitability target index of each anchor.
 
 
218
 
219
  Args:
220
- iou_mat [batch x targets x anchors]: The suitability for each targets-anchors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  Returns:
223
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
 
 
224
  """
225
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
226
- max_idx = F.one_hot(iou_mat.argmax(1), topk_mask.size(1)).permute(0, 2, 1)
227
- topk_mask = torch.where(duplicates, max_idx, topk_mask)
228
- topk_mask &= grid_mask
229
- unique_indices = topk_mask.argmax(dim=1)
230
- return unique_indices[..., None], topk_mask.sum(1), topk_mask
 
 
231
 
232
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
233
  """Matches each target to the most suitable anchor.
@@ -273,17 +307,21 @@ class BoxMatcher:
273
  # get cls matrix (cls prob with each gt class and each predict class)
274
  cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
275
 
276
- target_matrix = grid_mask * (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
277
 
278
  # choose topk
279
- topk_targets, topk_mask = self.filter_topk(target_matrix, topk=self.topk)
 
 
 
280
 
281
  # delete one anchor pred assign to mutliple gts
282
- unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask, grid_mask)
283
 
284
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
285
- align_cls = torch.gather(target_cls, 1, unique_indices).squeeze(-1)
286
- align_cls = F.one_hot(align_cls, self.class_num)
 
287
 
288
  # normalize class ditribution
289
  iou_mat *= topk_mask
@@ -294,7 +332,7 @@ class BoxMatcher:
294
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
295
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
296
  anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
297
- return anchor_matched_targets, valid_mask.bool()
298
 
299
 
300
  class Vec2Box:
@@ -305,7 +343,7 @@ class Vec2Box:
305
  logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
306
  self.strides = anchor_cfg.strides
307
  else:
308
- logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
309
  self.strides = self.create_auto_anchor(model, image_size)
310
 
311
  anchor_grid, scaler = generate_anchors(image_size, self.strides)
@@ -358,7 +396,7 @@ class Anc2Box:
358
  logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
359
  self.strides = anchor_cfg.strides
360
  else:
361
- logger.info("🧸 Found no stride of model, performed a dummy test for auto-anchor size")
362
  self.strides = self.create_auto_anchor(model, image_size)
363
 
364
  self.head_num = len(anchor_cfg.anchor)
 
2
  from typing import Dict, List, Optional, Tuple, Union
3
 
4
  import torch
 
5
  from einops import rearrange
6
  from torch import Tensor, tensor
7
  from torchmetrics.detection import MeanAveragePrecision
 
142
 
143
 
144
  class BoxMatcher:
145
+ def __init__(self, cfg: MatcherConfig, class_num: int, vec2box, reg_max: int) -> None:
146
  self.class_num = class_num
147
+ self.vec2box = vec2box
148
+ self.reg_max = reg_max
149
  for attr_name in cfg:
150
  setattr(self, attr_name, cfg[attr_name])
151
 
152
  def get_valid_matrix(self, target_bbox: Tensor):
153
  """
154
+ Get a boolean mask that indicates whether each target bounding box overlaps with each anchor
155
+ and is able to correctly predict it with the available reg_max value.
156
 
157
  Args:
158
+ target_bbox [batch x targets x 4]: The bounding box of each target.
159
  Returns:
160
+ [batch x targets x anchors]: A boolean tensor indicates if target bounding box overlaps
161
+ with the anchors, and the anchor is able to predict the target.
162
  """
163
+ x_min, y_min, x_max, y_max = target_bbox[:, :, None].unbind(3)
164
+ anchors = self.vec2box.anchor_grid[None, None] # add a axis at first, second dimension
165
  anchors_x, anchors_y = anchors.unbind(dim=3)
166
+ x_min_dist, x_max_dist = anchors_x - x_min, x_max - anchors_x
167
+ y_min_dist, y_max_dist = anchors_y - y_min, y_max - anchors_y
168
+ targets_dist = torch.stack((x_min_dist, y_min_dist, x_max_dist, y_max_dist), dim=-1)
169
+ targets_dist /= self.vec2box.scaler[None, None, :, None] # (1, 1, anchors, 1)
170
+ min_reg_dist, max_reg_dist = targets_dist.amin(dim=-1), targets_dist.amax(dim=-1)
171
+ target_on_anchor = min_reg_dist >= 0
172
+ target_in_reg_max = max_reg_dist <= self.reg_max - 1.01
173
+ return target_on_anchor & target_in_reg_max
174
 
175
  def get_cls_matrix(self, predict_cls: Tensor, target_cls: Tensor) -> Tensor:
176
  """
 
200
  """
201
  return calculate_iou(target_bbox, predict_bbox, self.iou).clamp(0, 1)
202
 
203
+ def filter_topk(self, target_matrix: Tensor, grid_mask: Tensor, topk: int = 10) -> Tuple[Tensor, Tensor]:
204
  """
205
  Filter the top-k suitability of targets for each anchor.
206
 
207
  Args:
208
  target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
209
+ grid_mask [batch x targets x anchors]: The match validity for each target to anchors
210
  topk (int, optional): Number of top scores to retain per anchor.
211
 
212
  Returns:
213
  topk_targets [batch x targets x anchors]: Only leave the topk targets for each anchor
214
+ topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
215
  """
216
+ masked_target_matrix = grid_mask * target_matrix
217
+ values, indices = masked_target_matrix.topk(topk, dim=-1)
218
  topk_targets = torch.zeros_like(target_matrix, device=target_matrix.device)
219
  topk_targets.scatter_(dim=-1, index=indices, src=values)
220
+ topk_mask = topk_targets > 0
221
+ return topk_targets, topk_mask
222
 
223
+ def ensure_one_anchor(self, target_matrix: Tensor, topk_mask: tensor) -> Tensor:
224
  """
225
+ Ensures each valid target gets at least one anchor matched based on the unmasked target matrix,
226
+ which enables an otherwise invalid match. This enables too small or too large targets to be
227
+ learned as well, even if they can't be predicted perfectly.
228
 
229
  Args:
230
+ target_matrix [batch x targets x anchors]: The suitability for each targets-anchors
231
+ topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
232
+
233
+ Returns:
234
+ topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
235
+ """
236
+ values, indices = target_matrix.max(dim=-1)
237
+ best_anchor_mask = torch.zeros_like(target_matrix, dtype=torch.bool)
238
+ best_anchor_mask.scatter_(-1, index=indices[..., None], src=~best_anchor_mask)
239
+ matched_anchor_num = torch.sum(topk_mask, dim=-1)
240
+ target_without_anchor = (matched_anchor_num == 0) & (values > 0)
241
+ topk_mask = torch.where(target_without_anchor[..., None], best_anchor_mask, topk_mask)
242
+ return topk_mask
243
+
244
+ def filter_duplicates(self, iou_mat: Tensor, topk_mask: Tensor):
245
+ """
246
+ Filter the maximum suitability target index of each anchor based on IoU.
247
+
248
+ Args:
249
+ iou_mat [batch x targets x anchors]: The IoU for each targets-anchors
250
+ topk_mask [batch x targets x anchors]: A boolean mask indicating the top-k scores' positions.
251
 
252
  Returns:
253
  unique_indices [batch x anchors x 1]: The index of the best targets for each anchors
254
+ valid_mask [batch x anchors]: Mask indicating the validity of each anchor
255
+ topk_mask [batch x targets x anchors]: A boolean mask indicating the updated top-k scores' positions.
256
  """
257
  duplicates = (topk_mask.sum(1, keepdim=True) > 1).repeat([1, topk_mask.size(1), 1])
258
+ masked_iou_mat = topk_mask * iou_mat
259
+ best_indices = masked_iou_mat.argmax(1)[:, None, :]
260
+ best_target_mask = torch.zeros_like(duplicates, dtype=torch.bool)
261
+ best_target_mask.scatter_(1, index=best_indices, src=~best_target_mask)
262
+ topk_mask = torch.where(duplicates, best_target_mask, topk_mask)
263
+ unique_indices = topk_mask.to(torch.uint8).argmax(dim=1)
264
+ return unique_indices[..., None], topk_mask.any(dim=1), topk_mask
265
 
266
  def __call__(self, target: Tensor, predict: Tuple[Tensor]) -> Tuple[Tensor, Tensor]:
267
  """Matches each target to the most suitable anchor.
 
307
  # get cls matrix (cls prob with each gt class and each predict class)
308
  cls_mat = self.get_cls_matrix(predict_cls.sigmoid(), target_cls)
309
 
310
+ target_matrix = (iou_mat ** self.factor["iou"]) * (cls_mat ** self.factor["cls"])
311
 
312
  # choose topk
313
+ topk_targets, topk_mask = self.filter_topk(target_matrix, grid_mask, topk=self.topk)
314
+
315
+ # match best anchor to valid targets without valid anchors
316
+ topk_mask = self.ensure_one_anchor(target_matrix, topk_mask)
317
 
318
  # delete one anchor pred assign to mutliple gts
319
+ unique_indices, valid_mask, topk_mask = self.filter_duplicates(iou_mat, topk_mask)
320
 
321
  align_bbox = torch.gather(target_bbox, 1, unique_indices.repeat(1, 1, 4))
322
+ align_cls_indices = torch.gather(target_cls, 1, unique_indices)
323
+ align_cls = torch.zeros_like(align_cls_indices, dtype=torch.bool).repeat(1, 1, self.class_num)
324
+ align_cls.scatter_(-1, index=align_cls_indices, src=~align_cls)
325
 
326
  # normalize class ditribution
327
  iou_mat *= topk_mask
 
332
  normalize_term = normalize_term.permute(0, 2, 1).gather(2, unique_indices)
333
  align_cls = align_cls * normalize_term * valid_mask[:, :, None]
334
  anchor_matched_targets = torch.cat([align_cls, align_bbox], dim=-1)
335
+ return anchor_matched_targets, valid_mask
336
 
337
 
338
  class Vec2Box:
 
343
  logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
344
  self.strides = anchor_cfg.strides
345
  else:
346
+ logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
347
  self.strides = self.create_auto_anchor(model, image_size)
348
 
349
  anchor_grid, scaler = generate_anchors(image_size, self.strides)
 
396
  logger.info(f":japanese_not_free_of_charge_button: Found stride of model {anchor_cfg.strides}")
397
  self.strides = anchor_cfg.strides
398
  else:
399
+ logger.info(":teddy_bear: Found no stride of model, performed a dummy test for auto-anchor size")
400
  self.strides = self.create_auto_anchor(model, image_size)
401
 
402
  self.head_num = len(anchor_cfg.anchor)