Update loss for FP16 `tobj` (#7088)
Browse files- utils/loss.py +1 -1
utils/loss.py
CHANGED
@@ -125,7 +125,7 @@ class ComputeLoss:
|
|
125 |
# Losses
|
126 |
for i, pi in enumerate(p): # layer index, layer predictions
|
127 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
128 |
-
tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj
|
129 |
|
130 |
n = b.shape[0] # number of targets
|
131 |
if n:
|
|
|
125 |
# Losses
|
126 |
for i, pi in enumerate(p): # layer index, layer predictions
|
127 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
128 |
+
tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
|
129 |
|
130 |
n = b.shape[0] # number of targets
|
131 |
if n:
|