glenn-jocher commited on
Commit
7eaf225
·
1 Parent(s): d0d3dd1

zero-target training bug fix (#609)

Browse files
Files changed (1) hide show
  1. utils/general.py +5 -3
utils/general.py CHANGED
@@ -496,8 +496,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
496
  s = 3 / np # output count scaling
497
  lbox *= h['giou'] * s
498
  lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
499
- if model.nc > 1:
500
- lcls *= h['cls'] * s
501
  bs = tobj.shape[0] # batch size
502
 
503
  loss = lbox + lobj + lcls
@@ -524,7 +523,7 @@ def build_targets(p, targets, model):
524
  gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
525
 
526
  # Match targets to anchors
527
- t, offsets = targets * gain, 0
528
  if nt:
529
  # Matches
530
  r = t[:, :, 4:6] / anchors[:, None] # wh ratio
@@ -540,6 +539,9 @@ def build_targets(p, targets, model):
540
  j = torch.stack((torch.ones_like(j), j, k, l, m))
541
  t = t.repeat((5, 1, 1))[j]
542
  offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
 
 
 
543
 
544
  # Define
545
  b, c = t[:, :2].long().T # image, class
 
496
  s = 3 / np # output count scaling
497
  lbox *= h['giou'] * s
498
  lobj *= h['obj'] * s * (1.4 if np == 4 else 1.)
499
+ lcls *= h['cls'] * s
 
500
  bs = tobj.shape[0] # batch size
501
 
502
  loss = lbox + lobj + lcls
 
523
  gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
524
 
525
  # Match targets to anchors
526
+ t = targets * gain
527
  if nt:
528
  # Matches
529
  r = t[:, :, 4:6] / anchors[:, None] # wh ratio
 
539
  j = torch.stack((torch.ones_like(j), j, k, l, m))
540
  t = t.repeat((5, 1, 1))[j]
541
  offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
542
+ else:
543
+ t = targets[0]
544
+ offsets = 0
545
 
546
  # Define
547
  b, c = t[:, :2].long().T # image, class