Add ComputeLoss() class (#1950)
Browse files- test.py +4 -4
- train.py +5 -3
- utils/loss.py +129 -115
test.py
CHANGED
@@ -13,7 +13,6 @@ from models.experimental import attempt_load
|
|
13 |
from utils.datasets import create_dataloader
|
14 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
|
15 |
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
|
16 |
-
from utils.loss import compute_loss
|
17 |
from utils.metrics import ap_per_class, ConfusionMatrix
|
18 |
from utils.plots import plot_images, output_to_target, plot_study_txt
|
19 |
from utils.torch_utils import select_device, time_synchronized
|
@@ -36,7 +35,8 @@ def test(data,
|
|
36 |
save_hybrid=False, # for hybrid auto-labelling
|
37 |
save_conf=False, # save auto-label confidences
|
38 |
plots=True,
|
39 |
-
log_imgs=0
|
|
|
40 |
|
41 |
# Initialize/load model and set device
|
42 |
training = model is not None
|
@@ -111,8 +111,8 @@ def test(data,
|
|
111 |
t0 += time_synchronized() - t
|
112 |
|
113 |
# Compute loss
|
114 |
-
if
|
115 |
-
loss += compute_loss([x.float() for x in train_out], targets
|
116 |
|
117 |
# Run NMS
|
118 |
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
|
|
|
13 |
from utils.datasets import create_dataloader
|
14 |
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, check_requirements, \
|
15 |
box_iou, non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path, colorstr
|
|
|
16 |
from utils.metrics import ap_per_class, ConfusionMatrix
|
17 |
from utils.plots import plot_images, output_to_target, plot_study_txt
|
18 |
from utils.torch_utils import select_device, time_synchronized
|
|
|
35 |
save_hybrid=False, # for hybrid auto-labelling
|
36 |
save_conf=False, # save auto-label confidences
|
37 |
plots=True,
|
38 |
+
log_imgs=0, # number of logged images
|
39 |
+
compute_loss=None):
|
40 |
|
41 |
# Initialize/load model and set device
|
42 |
training = model is not None
|
|
|
111 |
t0 += time_synchronized() - t
|
112 |
|
113 |
# Compute loss
|
114 |
+
if compute_loss:
|
115 |
+
loss += compute_loss([x.float() for x in train_out], targets)[1][:3] # box, obj, cls
|
116 |
|
117 |
# Run NMS
|
118 |
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
|
train.py
CHANGED
@@ -29,7 +29,7 @@ from utils.general import labels_to_class_weights, increment_path, labels_to_ima
|
|
29 |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
30 |
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
31 |
from utils.google_utils import attempt_download
|
32 |
-
from utils.loss import
|
33 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
34 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
|
35 |
|
@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
227 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
228 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
229 |
scaler = amp.GradScaler(enabled=cuda)
|
|
|
230 |
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
231 |
f'Using {dataloader.num_workers} dataloader workers\n'
|
232 |
f'Logging results to {save_dir}\n'
|
@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
286 |
# Forward
|
287 |
with amp.autocast(enabled=cuda):
|
288 |
pred = model(imgs) # forward
|
289 |
-
loss, loss_items = compute_loss(pred, targets.to(device)
|
290 |
if rank != -1:
|
291 |
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
292 |
if opt.quad:
|
@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
|
|
344 |
dataloader=testloader,
|
345 |
save_dir=save_dir,
|
346 |
plots=plots and final_epoch,
|
347 |
-
log_imgs=opt.log_imgs if wandb else 0
|
|
|
348 |
|
349 |
# Write
|
350 |
with open(results_file, 'a') as f:
|
|
|
29 |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
|
30 |
check_requirements, print_mutation, set_logging, one_cycle, colorstr
|
31 |
from utils.google_utils import attempt_download
|
32 |
+
from utils.loss import ComputeLoss
|
33 |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution
|
34 |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first
|
35 |
|
|
|
227 |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
|
228 |
scheduler.last_epoch = start_epoch - 1 # do not move
|
229 |
scaler = amp.GradScaler(enabled=cuda)
|
230 |
+
compute_loss = ComputeLoss(model) # init loss class
|
231 |
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n'
|
232 |
f'Using {dataloader.num_workers} dataloader workers\n'
|
233 |
f'Logging results to {save_dir}\n'
|
|
|
287 |
# Forward
|
288 |
with amp.autocast(enabled=cuda):
|
289 |
pred = model(imgs) # forward
|
290 |
+
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
|
291 |
if rank != -1:
|
292 |
loss *= opt.world_size # gradient averaged between devices in DDP mode
|
293 |
if opt.quad:
|
|
|
345 |
dataloader=testloader,
|
346 |
save_dir=save_dir,
|
347 |
plots=plots and final_epoch,
|
348 |
+
log_imgs=opt.log_imgs if wandb else 0,
|
349 |
+
compute_loss=compute_loss)
|
350 |
|
351 |
# Write
|
352 |
with open(results_file, 'a') as f:
|
utils/loss.py
CHANGED
@@ -85,119 +85,133 @@ class QFocalLoss(nn.Module):
|
|
85 |
return loss
|
86 |
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
#
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
gxy = t[:, 2:4] # grid xy
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
gxy = t[:, 2:4] # grid xy
|
192 |
-
gwh = t[:, 4:6] # grid wh
|
193 |
-
gij = (gxy - offsets).long()
|
194 |
-
gi, gj = gij.T # grid xy indices
|
195 |
-
|
196 |
-
# Append
|
197 |
-
a = t[:, 6].long() # anchor indices
|
198 |
-
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
|
199 |
-
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
200 |
-
anch.append(anchors[a]) # anchors
|
201 |
-
tcls.append(c) # class
|
202 |
-
|
203 |
-
return tcls, tbox, indices, anch
|
|
|
85 |
return loss
|
86 |
|
87 |
|
88 |
+
class ComputeLoss:
|
89 |
+
# Compute losses
|
90 |
+
def __init__(self, model, autobalance=False):
|
91 |
+
super(ComputeLoss, self).__init__()
|
92 |
+
device = next(model.parameters()).device # get model device
|
93 |
+
h = model.hyp # hyperparameters
|
94 |
+
|
95 |
+
# Define criteria
|
96 |
+
BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
|
97 |
+
BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
|
98 |
+
|
99 |
+
# Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
|
100 |
+
self.cp, self.cn = smooth_BCE(eps=0.0)
|
101 |
+
|
102 |
+
# Focal loss
|
103 |
+
g = h['fl_gamma'] # focal loss gamma
|
104 |
+
if g > 0:
|
105 |
+
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
106 |
+
|
107 |
+
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module
|
108 |
+
self.balance = {3: [3.67, 1.0, 0.43], 4: [3.78, 1.0, 0.39, 0.22], 5: [3.88, 1.0, 0.37, 0.17, 0.10]}[det.nl]
|
109 |
+
# self.balance = [1.0] * det.nl
|
110 |
+
self.ssi = (det.stride == 16).nonzero(as_tuple=False).item() # stride 16 index
|
111 |
+
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, model.gr, h, autobalance
|
112 |
+
for k in 'na', 'nc', 'nl', 'anchors':
|
113 |
+
setattr(self, k, getattr(det, k))
|
114 |
+
|
115 |
+
def __call__(self, p, targets): # predictions, targets, model
|
116 |
+
device = targets.device
|
117 |
+
lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
|
118 |
+
tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
|
119 |
+
|
120 |
+
# Losses
|
121 |
+
for i, pi in enumerate(p): # layer index, layer predictions
|
122 |
+
b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
|
123 |
+
tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
|
124 |
+
|
125 |
+
n = b.shape[0] # number of targets
|
126 |
+
if n:
|
127 |
+
ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
|
128 |
+
|
129 |
+
# Regression
|
130 |
+
pxy = ps[:, :2].sigmoid() * 2. - 0.5
|
131 |
+
pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
|
132 |
+
pbox = torch.cat((pxy, pwh), 1) # predicted box
|
133 |
+
iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
|
134 |
+
lbox += (1.0 - iou).mean() # iou loss
|
135 |
+
|
136 |
+
# Objectness
|
137 |
+
tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio
|
138 |
+
|
139 |
+
# Classification
|
140 |
+
if self.nc > 1: # cls loss (only if multiple classes)
|
141 |
+
t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
|
142 |
+
t[range(n), tcls[i]] = self.cp
|
143 |
+
lcls += self.BCEcls(ps[:, 5:], t) # BCE
|
144 |
+
|
145 |
+
# Append targets to text file
|
146 |
+
# with open('targets.txt', 'a') as file:
|
147 |
+
# [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
|
148 |
+
|
149 |
+
obji = self.BCEobj(pi[..., 4], tobj)
|
150 |
+
lobj += obji * self.balance[i] # obj loss
|
151 |
+
if self.autobalance:
|
152 |
+
self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
|
153 |
+
|
154 |
+
if self.autobalance:
|
155 |
+
self.balance = [x / self.balance[self.ssi] for x in self.balance]
|
156 |
+
lbox *= self.hyp['box']
|
157 |
+
lobj *= self.hyp['obj']
|
158 |
+
lcls *= self.hyp['cls']
|
159 |
+
bs = tobj.shape[0] # batch size
|
160 |
+
|
161 |
+
loss = lbox + lobj + lcls
|
162 |
+
return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
|
163 |
+
|
164 |
+
def build_targets(self, p, targets):
|
165 |
+
# Build targets for compute_loss(), input targets(image,class,x,y,w,h)
|
166 |
+
na, nt = self.na, targets.shape[0] # number of anchors, targets
|
167 |
+
tcls, tbox, indices, anch = [], [], [], []
|
168 |
+
gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
|
169 |
+
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
|
170 |
+
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
|
171 |
+
|
172 |
+
g = 0.5 # bias
|
173 |
+
off = torch.tensor([[0, 0],
|
174 |
+
[1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
|
175 |
+
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
|
176 |
+
], device=targets.device).float() * g # offsets
|
177 |
+
|
178 |
+
for i in range(self.nl):
|
179 |
+
anchors = self.anchors[i]
|
180 |
+
gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain
|
181 |
+
|
182 |
+
# Match targets to anchors
|
183 |
+
t = targets * gain
|
184 |
+
if nt:
|
185 |
+
# Matches
|
186 |
+
r = t[:, :, 4:6] / anchors[:, None] # wh ratio
|
187 |
+
j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t'] # compare
|
188 |
+
# j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
|
189 |
+
t = t[j] # filter
|
190 |
+
|
191 |
+
# Offsets
|
192 |
+
gxy = t[:, 2:4] # grid xy
|
193 |
+
gxi = gain[[2, 3]] - gxy # inverse
|
194 |
+
j, k = ((gxy % 1. < g) & (gxy > 1.)).T
|
195 |
+
l, m = ((gxi % 1. < g) & (gxi > 1.)).T
|
196 |
+
j = torch.stack((torch.ones_like(j), j, k, l, m))
|
197 |
+
t = t.repeat((5, 1, 1))[j]
|
198 |
+
offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
|
199 |
+
else:
|
200 |
+
t = targets[0]
|
201 |
+
offsets = 0
|
202 |
+
|
203 |
+
# Define
|
204 |
+
b, c = t[:, :2].long().T # image, class
|
205 |
gxy = t[:, 2:4] # grid xy
|
206 |
+
gwh = t[:, 4:6] # grid wh
|
207 |
+
gij = (gxy - offsets).long()
|
208 |
+
gi, gj = gij.T # grid xy indices
|
209 |
+
|
210 |
+
# Append
|
211 |
+
a = t[:, 6].long() # anchor indices
|
212 |
+
indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
|
213 |
+
tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
|
214 |
+
anch.append(anchors[a]) # anchors
|
215 |
+
tcls.append(c) # class
|
216 |
+
|
217 |
+
return tcls, tbox, indices, anch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|