Commit
·
f767023
1
Parent(s):
655895a
offset and balance update
Browse files- utils/utils.py +10 -9
utils/utils.py
CHANGED
@@ -438,6 +438,7 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
438 |
|
439 |
# per output
|
440 |
nt = 0 # targets
|
|
|
441 |
for i, pi in enumerate(p): # layer index, layer predictions
|
442 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
443 |
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
@@ -467,11 +468,12 @@ def compute_loss(p, targets, model): # predictions, targets, model
|
|
467 |
# with open('targets.txt', 'a') as file:
|
468 |
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
469 |
|
470 |
-
lobj += BCEobj(pi[..., 4], tobj) # obj loss
|
471 |
|
472 |
-
|
473 |
-
|
474 |
-
|
|
|
475 |
bs = tobj.shape[0] # batch size
|
476 |
if red == 'sum':
|
477 |
g = 3.0 # loss gain
|
@@ -508,16 +510,15 @@ def build_targets(p, targets, model):
|
|
508 |
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
509 |
|
510 |
# overlaps
|
|
|
511 |
gxy = t[:, 2:4] # grid xy
|
512 |
z = torch.zeros_like(gxy)
|
513 |
if style == 'rect2':
|
514 |
-
g = 0.2 # offset
|
515 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
516 |
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
517 |
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
518 |
|
519 |
elif style == 'rect4':
|
520 |
-
g = 0.5 # offset
|
521 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
522 |
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|
523 |
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
|
@@ -764,11 +765,11 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
|
|
764 |
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
|
765 |
|
766 |
# Filter
|
767 |
-
i = (wh0 <
|
768 |
if i:
|
769 |
print('WARNING: Extremely small objects found. '
|
770 |
-
'%g of %g labels are <
|
771 |
-
wh = wh0[(wh0 >=
|
772 |
|
773 |
# Kmeans calculation
|
774 |
from scipy.cluster.vq import kmeans
|
|
|
438 |
|
439 |
# per output
|
440 |
nt = 0 # targets
|
441 |
+
balance = [1.0, 1.0, 1.0]
|
442 |
for i, pi in enumerate(p): # layer index, layer predictions
|
443 |
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
444 |
tobj = torch.zeros_like(pi[..., 0]) # target obj
|
|
|
468 |
# with open('targets.txt', 'a') as file:
|
469 |
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
470 |
|
471 |
+
lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss
|
472 |
|
473 |
+
s = 3 / (i + 1) # output count scaling
|
474 |
+
lbox *= h['giou'] * s
|
475 |
+
lobj *= h['obj'] * s
|
476 |
+
lcls *= h['cls'] * s
|
477 |
bs = tobj.shape[0] # batch size
|
478 |
if red == 'sum':
|
479 |
g = 3.0 # loss gain
|
|
|
510 |
a, t = at[j], t.repeat(na, 1, 1)[j] # filter
|
511 |
|
512 |
# overlaps
|
513 |
+
g = 0.5 # offset
|
514 |
gxy = t[:, 2:4] # grid xy
|
515 |
z = torch.zeros_like(gxy)
|
516 |
if style == 'rect2':
|
|
|
517 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
518 |
a, t = torch.cat((a, a[j], a[k]), 0), torch.cat((t, t[j], t[k]), 0)
|
519 |
offsets = torch.cat((z, z[j] + off[0], z[k] + off[1]), 0) * g
|
520 |
|
521 |
elif style == 'rect4':
|
|
|
522 |
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
523 |
l, m = ((gxy % 1. > (1 - g)) & (gxy < (gain[[2, 3]] - 1.))).T
|
524 |
a, t = torch.cat((a, a[j], a[k], a[l], a[m]), 0), torch.cat((t, t[j], t[k], t[l], t[m]), 0)
|
|
|
765 |
wh0 = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
|
766 |
|
767 |
# Filter
|
768 |
+
i = (wh0 < 3.0).any(1).sum()
|
769 |
if i:
|
770 |
print('WARNING: Extremely small objects found. '
|
771 |
+
'%g of %g labels are < 3 pixels in width or height.' % (i, len(wh0)))
|
772 |
+
wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
|
773 |
|
774 |
# Kmeans calculation
|
775 |
from scipy.cluster.vq import kmeans
|