Commit
·
7eaf225
1
Parent(s):
d0d3dd1
zero-target training bug fix (#609)
Browse files- utils/general.py +5 -3
utils/general.py
CHANGED
@@ -496,8 +496,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
496 |
s = 3 / np # output count scaling
|
497 |
lbox *= h['giou'] * s
|
498 |
lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
|
499 |
-
|
500 |
-
lcls *= h['cls'] * s
|
501 |
bs = tobj.shape[0] # batch size
|
502 |
|
503 |
loss = lbox + lobj + lcls
|
@@ -524,7 +523,7 @@ def build_targets(p, targets, model):
|
|
524 |
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
525 |
|
526 |
# Match targets to anchors
|
527 |
-
t
|
528 |
if nt:
|
529 |
# Matches
|
530 |
r = t[:, :, 4:6] / anchors[:, None] # wh ratio
|
@@ -540,6 +539,9 @@ def build_targets(p, targets, model):
|
|
540 |
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
541 |
t = t.repeat((5, 1, 1))[j]
|
542 |
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
|
|
|
|
|
|
543 |
|
544 |
# Define
|
545 |
b, c = t[:, :2].long().T # image, class
|
|
|
496 |
s = 3 / np # output count scaling
|
497 |
lbox *= h['giou'] * s
|
498 |
lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
|
499 |
+
lcls *= h['cls'] * s
|
|
|
500 |
bs = tobj.shape[0] # batch size
|
501 |
|
502 |
loss = lbox + lobj + lcls
|
|
|
523 |
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
524 |
|
525 |
# Match targets to anchors
|
526 |
+
t = targets * gain
|
527 |
if nt:
|
528 |
# Matches
|
529 |
r = t[:, :, 4:6] / anchors[:, None] # wh ratio
|
|
|
539 |
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
540 |
t = t.repeat((5, 1, 1))[j]
|
541 |
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
542 |
+
else:
|
543 |
+
t = targets[0]
|
544 |
+
offsets = 0
|
545 |
|
546 |
# Define
|
547 |
b, c = t[:, :2].long().T # image, class
|