glenn-jocher commited on
Commit
52c1399
·
unverified ·
1 Parent(s): c84dd27

DetectMultiBackend() return `device` update (#6958)

Browse files
Files changed (1) hide show
  1. models/common.py +2 -1
models/common.py CHANGED
@@ -458,7 +458,8 @@ class DetectMultiBackend(nn.Module):
458
  y = (y.astype(np.float32) - zero_point) * scale # re-scale
459
  y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
460
 
461
- y = torch.tensor(y) if isinstance(y, np.ndarray) else y
 
462
  return (y, []) if val else y
463
 
464
  def warmup(self, imgsz=(1, 3, 640, 640)):
 
458
  y = (y.astype(np.float32) - zero_point) * scale # re-scale
459
  y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
460
 
461
+ if isinstance(y, np.ndarray):
462
+ y = torch.tensor(y, device=self.device)
463
  return (y, []) if val else y
464
 
465
  def warmup(self, imgsz=(1, 3, 640, 640)):