glenn-jocher commited on
Commit
a2d617e
·
unverified ·
1 Parent(s): 6f12803

Update loss for FP16 `tobj` (#7088)

Browse files
Files changed (1) hide show
  1. utils/loss.py +1 -1
utils/loss.py CHANGED
@@ -125,7 +125,7 @@ class ComputeLoss:
125
  # Losses
126
  for i, pi in enumerate(p): # layer index, layer predictions
127
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
128
- tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj
129
 
130
  n = b.shape[0] # number of targets
131
  if n:
 
125
  # Losses
126
  for i, pi in enumerate(p): # layer index, layer predictions
127
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
128
+ tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
129
 
130
  n = b.shape[0] # number of targets
131
  if n: