glenn-jocher commited on
Commit
305c6a0
·
1 Parent(s): c1a2a7a

compute_loss() leaf variable update

Browse files
Files changed (1) hide show
  1. utils/utils.py +3 -3
utils/utils.py CHANGED
@@ -439,7 +439,7 @@ class BCEBlurWithLogitsLoss(nn.Module):
439
 
440
  def compute_loss(p, targets, model): # predictions, targets, model
441
  device = targets.device
442
- lcls, lbox, lobj = torch.zeros(3, 1, device=device)
443
  tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
444
  h = model.hyp # hyperparameters
445
 
@@ -482,13 +482,13 @@ def compute_loss(p, targets, model): # predictions, targets, model
482
  if model.nc > 1: # cls loss (only if multiple classes)
483
  t = torch.full_like(ps[:, 5:], cn, device=device) # targets
484
  t[range(n), tcls[i]] = cp
485
- lcls = lcls + BCEcls(ps[:, 5:], t) # BCE
486
 
487
  # Append targets to text file
488
  # with open('targets.txt', 'a') as file:
489
  # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
490
 
491
- lobj = lobj + BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
492
 
493
  s = 3 / np # output count scaling
494
  lbox *= h['giou'] * s
 
439
 
440
  def compute_loss(p, targets, model): # predictions, targets, model
441
  device = targets.device
442
+ lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
443
  tcls, tbox, indices, anchors = build_targets(p, targets, model) # targets
444
  h = model.hyp # hyperparameters
445
 
 
482
  if model.nc > 1: # cls loss (only if multiple classes)
483
  t = torch.full_like(ps[:, 5:], cn, device=device) # targets
484
  t[range(n), tcls[i]] = cp
485
+ lcls += BCEcls(ps[:, 5:], t) # BCE
486
 
487
  # Append targets to text file
488
  # with open('targets.txt', 'a') as file:
489
  # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
490
 
491
+ lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
492
 
493
  s = 3 / np # output count scaling
494
  lbox *= h['giou'] * s