henry000 commited on
Commit
c8710f3
ยท
1 Parent(s): f5518c0

๐Ÿ› [Fix] a bug of initialize vec2box

Browse files
Files changed (1) hide show
  1. yolo/utils/bounding_box_utils.py +5 -1
yolo/utils/bounding_box_utils.py CHANGED
@@ -266,7 +266,7 @@ class BoxMatcher:
266
 
267
  class Vec2Box:
268
  def __init__(self, model: YOLO, image_size, device):
269
- if model.strides is None:
270
  logger.info("๐Ÿงธ Found no anchor, Make a dummy test for auto-anchor size")
271
  dummy_input = torch.zeros(1, 3, *image_size).to(device)
272
  dummy_output = model(dummy_input)
@@ -277,6 +277,10 @@ class Vec2Box:
277
  else:
278
  logger.info(f"๐Ÿˆถ Found anchor {model.strides}")
279
  anchors_num = [[image_size[0] // stride, image_size[0] // stride] for stride in model.strides]
 
 
 
 
280
  anchor_grid, scaler = generate_anchors(image_size, anchors_num)
281
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
282
  self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
 
266
 
267
  class Vec2Box:
268
  def __init__(self, model: YOLO, image_size, device):
269
+ if getattr(model, "strides", None) is None:
270
  logger.info("๐Ÿงธ Found no anchor, Make a dummy test for auto-anchor size")
271
  dummy_input = torch.zeros(1, 3, *image_size).to(device)
272
  dummy_output = model(dummy_input)
 
277
  else:
278
  logger.info(f"๐Ÿˆถ Found anchor {model.strides}")
279
  anchors_num = [[image_size[0] // stride, image_size[0] // stride] for stride in model.strides]
280
+
281
+ if not isinstance(model, YOLO):
282
+ device = torch.device("cpu")
283
+
284
  anchor_grid, scaler = generate_anchors(image_size, anchors_num)
285
  self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
286
  self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)