glenn-jocher commited on
Commit
f767023
·
1 Parent(s): 655895a

offset and balance update

Browse files
Files changed (1) hide show
  1. utils/utils.py +10 -9
utils/utils.py CHANGED
@@ -438,6 +438,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
438
 
439
  # per output
440
  nt = 0 # targets
 
441
  for i, pi in enumerate(p): # layer index, layer predictions
442
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
443
  tobj = torch.zeros_like(pi[..., 0]) # target obj
@@ -467,11 +468,12 @@ def compute_loss(p, targets, model): # predictions, targets, model
467
  # with open('targets.txt', 'a') as file:
468
  # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
469
 
470
- lobj += BCEobj(pi[..., 4], tobj) # obj loss
471
 
472
- lbox *= h['giou']
473
- lobj *= h['obj']
474
- lcls *= h['cls']
 
475
  bs = tobj.shape[0] # batch size
476
  if red == 'sum':
477
  g = 3.0 # loss gain
@@ -508,16 +510,15 @@ def build_targets(p, targets, model):
508
  a, t = at[j], t.repeat(na, 1, 1)[j] # filter
509
 
510
  # overlaps
 
511
  gxy = t[:, 2:4] # grid xy
512
  z = torch.zeros_like(gxy)
513
  if style == 'rect2':
514
- g = 0.2 # offset
515
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
516
  a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
517
  offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
518
 
519
  elif style == 'rect4':
520
- g = 0.5 # offset
521
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
522
  l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
523
  a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
@@ -764,11 +765,11 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
764
  wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
765
 
766
  # Filter
767
- i = (wh0 < 4.0).any(1).sum()
768
  if i:
769
  print('WARNING: Extremely small objects found. '
770
- '%g of %g labels are < 4 pixels in width or height.' % (i, len(wh0)))
771
- wh = wh0[(wh0 >= 4.0).any(1)] # filter > 2 pixels
772
 
773
  # Kmeans calculation
774
  from scipy.cluster.vq import kmeans
 
438
 
439
  # per output
440
  nt = 0 # targets
441
+ balance = [1.0, 1.0, 1.0]
442
  for i, pi in enumerate(p): # layer index, layer predictions
443
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
444
  tobj = torch.zeros_like(pi[..., 0]) # target obj
 
468
  # with open('targets.txt', 'a') as file:
469
  # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
470
 
471
+ lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
472
 
473
+ s = 3 / (i + 1) # output count scaling
474
+ lbox *= h['giou'] * s
475
+ lobj *= h['obj'] * s
476
+ lcls *= h['cls'] * s
477
  bs = tobj.shape[0] # batch size
478
  if red == 'sum':
479
  g = 3.0 # loss gain
 
510
  a, t = at[j], t.repeat(na, 1, 1)[j] # filter
511
 
512
  # overlaps
513
+ g = 0.5 # offset
514
  gxy = t[:, 2:4] # grid xy
515
  z = torch.zeros_like(gxy)
516
  if style == 'rect2':
 
517
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
518
  a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
519
  offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
520
 
521
  elif style == 'rect4':
 
522
  j, k = ((gxy % 1. < g) & (gxy > 1.)).T
523
  l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
524
  a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
 
765
  wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
766
 
767
  # Filter
768
+ i = (wh0 < 3.0).any(1).sum()
769
  if i:
770
  print('WARNING: Extremely small objects found. '
771
+ '%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
772
+ wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
773
 
774
  # Kmeans calculation
775
  from scipy.cluster.vq import kmeans