henry000 commited on
Commit
5bbfada
ยท
2 Parent(s): f5a3a55 b86ec3e

๐Ÿ”€ [Merge] branch 'MODELv2' into TEST

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +32 -21
yolo/utils/bounding_box_utils.py CHANGED
@@ -108,12 +108,13 @@ def transform_bbox(bbox: Tensor, indicator="xywh -> xyxy"):
108
  return bbox.to(dtype=data_type)
109
 
110
 
111
- def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
112
  """
113
  Find the anchor maps for each w, h.
114
 
115
  Args:
116
- anchors_list List[[w1, h1], [w2, h2], ...]: the anchor num for each predicted anchor
 
117
 
118
  Returns:
119
  all_anchors [HW x 2]:
@@ -122,15 +123,14 @@ def generate_anchors(image_size: List[int], anchors_list: List[Tuple[int]]):
122
  W, H = image_size
123
  anchors = []
124
  scaler = []
125
- for anchor_wh in anchors_list:
126
- stride = W // anchor_wh[0]
127
- anchor_num = anchor_wh[0] * anchor_wh[1]
128
  scaler.append(torch.full((anchor_num,), stride))
129
  shift = stride // 2
130
- x = torch.arange(0, W, stride) + shift
131
- y = torch.arange(0, H, stride) + shift
132
- anchor_x, anchor_y = torch.meshgrid(x, y, indexing="ij")
133
- anchor = torch.stack([anchor_y.flatten(), anchor_x.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
136
  all_scalers = torch.cat(scaler, dim=0)
@@ -172,6 +172,7 @@ class BoxMatcher:
172
  Returns:
173
  [batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
174
  """
 
175
  target_cls = target_cls.expand(-1, -1, 8400)
176
  predict_cls = predict_cls.transpose(1, 2)
177
  cls_probabilities = torch.gather(predict_cls, 1, target_cls)
@@ -266,24 +267,34 @@ class BoxMatcher:
266
 
267
  class Vec2Box:
268
  def __init__(self, model: YOLO, image_size, device):
269
- if getattr(model, "strides", None) is None:
270
- logger.info("๐Ÿงธ Found no anchor, Make a dummy test for auto-anchor size")
271
- dummy_input = torch.zeros(1, 3, *image_size).to(device)
272
- dummy_output = model(dummy_input)
273
- anchors_num = []
274
- for predict_head in dummy_output["Main"]:
275
- _, _, *anchor_num = predict_head[2].shape
276
- anchors_num.append(anchor_num)
277
  else:
278
- logger.info(f"๐Ÿˆถ Found anchor {model.strides}")
279
- anchors_num = [[image_size[0] // stride, image_size[0] // stride] for stride in model.strides]
280
 
 
281
  if not isinstance(model, YOLO):
282
  device = torch.device("cpu")
283
 
284
- anchor_grid, scaler = generate_anchors(image_size, anchors_num)
285
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
286
- self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  def __call__(self, predicts):
289
  preds_cls, preds_anc, preds_box = [], [], []
 
108
  return bbox.to(dtype=data_type)
109
 
110
 
111
+ def generate_anchors(image_size: List[int], strides: List[int]):
112
  """
113
  Find the anchor maps for each w, h.
114
 
115
  Args:
116
+ image_size List: the image size of augmented image size
117
+ strides List[8, 16, 32, ...]: the stride size for each predicted layer
118
 
119
  Returns:
120
  all_anchors [HW x 2]:
 
123
  W, H = image_size
124
  anchors = []
125
  scaler = []
126
+ for stride in strides:
127
+ anchor_num = W // stride * H // stride
 
128
  scaler.append(torch.full((anchor_num,), stride))
129
  shift = stride // 2
130
+ h = torch.arange(0, H, stride) + shift
131
+ w = torch.arange(0, W, stride) + shift
132
+ anchor_h, anchor_w = torch.meshgrid(h, w, indexing="ij")
133
+ anchor = torch.stack([anchor_w.flatten(), anchor_h.flatten()], dim=-1)
134
  anchors.append(anchor)
135
  all_anchors = torch.cat(anchors, dim=0)
136
  all_scalers = torch.cat(scaler, dim=0)
 
172
  Returns:
173
  [batch x targets x anchors]: The probabilities from `pred_cls` corresponding to the class indices specified in `target_cls`.
174
  """
175
+ # TODO: Turn 8400 to HW
176
  target_cls = target_cls.expand(-1, -1, 8400)
177
  predict_cls = predict_cls.transpose(1, 2)
178
  cls_probabilities = torch.gather(predict_cls, 1, target_cls)
 
267
 
268
  class Vec2Box:
269
  def __init__(self, model: YOLO, image_size, device):
270
+ self.device = device
271
+
272
+ if getattr(model, "strides"):
273
+ logger.info(f"๐Ÿˆถ Found stride of model {model.strides}")
274
+ self.strides = model.strides
 
 
 
275
  else:
276
+ logger.info("๐Ÿงธ Found no stride of model, performed a dummy test for auto-anchor size")
277
+ self.strides = self.create_auto_anchor(model, image_size)
278
 
279
+ # TODO: this is a exception of onnx, remove it when onnx device if fixed
280
  if not isinstance(model, YOLO):
281
  device = torch.device("cpu")
282
 
283
+ anchor_grid, scaler = generate_anchors(image_size, self.strides)
284
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
285
+
286
+ def create_auto_anchor(self, model: YOLO, image_size):
287
+ dummy_input = torch.zeros(1, 3, *image_size).to(self.device)
288
+ dummy_output = model(dummy_input)
289
+ strides = []
290
+ for predict_head in dummy_output["Main"]:
291
+ _, _, *anchor_num = predict_head[2].shape
292
+ strides.append(image_size[1] // anchor_num[1])
293
+ return strides
294
+
295
+ def update(self, image_size):
296
+ anchor_grid, scaler = generate_anchors(image_size, self.strides)
297
+ self.anchor_grid, self.scaler = anchor_grid.to(self.device), scaler.to(self.device)
298
 
299
  def __call__(self, predicts):
300
  preds_cls, preds_anc, preds_box = [], [], []