glenn-jocher commited on
Commit
b0ba101
·
unverified ·
1 Parent(s): 4effd06

`ComputeLoss()` indexing/speed improvements (#7048)

Browse files

* device as class attribute

* Update loss.py

* Update loss.py

* improve zeros

* tensor split

Files changed (1) hide show
  1. utils/loss.py +19 -18
utils/loss.py CHANGED
@@ -89,9 +89,10 @@ class QFocalLoss(nn.Module):
89
 
90
 
91
  class ComputeLoss:
 
 
92
  # Compute losses
93
  def __init__(self, model, autobalance=False):
94
- self.sort_obj_iou = False
95
  device = next(model.parameters()).device # get model device
96
  h = model.hyp # hyperparameters
97
 
@@ -111,26 +112,28 @@ class ComputeLoss:
111
  self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
112
  self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
113
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
 
114
  for k in 'na', 'nc', 'nl', 'anchors':
115
  setattr(self, k, getattr(det, k))
116
 
117
- def __call__(self, p, targets): # predictions, targets, model
118
- device = targets.device
119
- lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
 
120
  tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
121
 
122
  # Losses
123
  for i, pi in enumerate(p): # layer index, layer predictions
124
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
125
- tobj = torch.zeros_like(pi[..., 0], device=device) # target obj
126
 
127
  n = b.shape[0] # number of targets
128
  if n:
129
- ps = pi[b, a, gj, gi] # prediction subset corresponding to targets
130
 
131
  # Regression
132
- pxy = ps[:, :2].sigmoid() * 2 - 0.5
133
- pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
134
  pbox = torch.cat((pxy, pwh), 1) # predicted box
135
  iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
136
  lbox += (1.0 - iou).mean() # iou loss
@@ -144,9 +147,9 @@ class ComputeLoss:
144
 
145
  # Classification
146
  if self.nc > 1: # cls loss (only if multiple classes)
147
- t = torch.full_like(ps[:, 5:], self.cn, device=device) # targets
148
  t[range(n), tcls[i]] = self.cp
149
- lcls += self.BCEcls(ps[:, 5:], t) # BCE
150
 
151
  # Append targets to text file
152
  # with open('targets.txt', 'a') as file:
@@ -170,15 +173,15 @@ class ComputeLoss:
170
  # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
171
  na, nt = self.na, targets.shape[0] # number of anchors, targets
172
  tcls, tbox, indices, anch = [], [], [], []
173
- gain = torch.ones(7, device=targets.device) # normalized to gridspace gain
174
- ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
175
  targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
176
 
177
  g = 0.5 # bias
178
  off = torch.tensor([[0, 0],
179
  [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
180
  # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
181
- ], device=targets.device).float() * g # offsets
182
 
183
  for i in range(self.nl):
184
  anchors = self.anchors[i]
@@ -206,14 +209,12 @@ class ComputeLoss:
206
  offsets = 0
207
 
208
  # Define
209
- b, c = t[:, :2].long().T # image, class
210
- gxy = t[:, 2:4] # grid xy
211
- gwh = t[:, 4:6] # grid wh
212
  gij = (gxy - offsets).long()
213
- gi, gj = gij.T # grid xy indices
214
 
215
  # Append
216
- a = t[:, 6].long() # anchor indices
217
  indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
218
  tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
219
  anch.append(anchors[a]) # anchors
 
89
 
90
 
91
  class ComputeLoss:
92
+ sort_obj_iou = False
93
+
94
  # Compute losses
95
  def __init__(self, model, autobalance=False):
 
96
  device = next(model.parameters()).device # get model device
97
  h = model.hyp # hyperparameters
98
 
 
112
  self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
113
  self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
114
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
115
+ self.device = device
116
  for k in 'na', 'nc', 'nl', 'anchors':
117
  setattr(self, k, getattr(det, k))
118
 
119
+ def __call__(self, p, targets): # predictions, targets
120
+ lcls = torch.zeros(1, device=self.device) # class loss
121
+ lbox = torch.zeros(1, device=self.device) # box loss
122
+ lobj = torch.zeros(1, device=self.device) # object loss
123
  tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets
124
 
125
  # Losses
126
  for i, pi in enumerate(p): # layer index, layer predictions
127
  b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
128
+ tobj = torch.zeros(pi.shape[:4], device=self.device) # target obj
129
 
130
  n = b.shape[0] # number of targets
131
  if n:
132
+ pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # target-subset of predictions
133
 
134
  # Regression
135
+ pxy = pxy.sigmoid() * 2 - 0.5
136
+ pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
137
  pbox = torch.cat((pxy, pwh), 1) # predicted box
138
  iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True) # iou(prediction, target)
139
  lbox += (1.0 - iou).mean() # iou loss
 
147
 
148
  # Classification
149
  if self.nc > 1: # cls loss (only if multiple classes)
150
+ t = torch.full_like(pcls, self.cn, device=self.device) # targets
151
  t[range(n), tcls[i]] = self.cp
152
+ lcls += self.BCEcls(pcls, t) # BCE
153
 
154
  # Append targets to text file
155
  # with open('targets.txt', 'a') as file:
 
173
  # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
174
  na, nt = self.na, targets.shape[0] # number of anchors, targets
175
  tcls, tbox, indices, anch = [], [], [], []
176
+ gain = torch.ones(7, device=self.device) # normalized to gridspace gain
177
+ ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
178
  targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2) # append anchor indices
179
 
180
  g = 0.5 # bias
181
  off = torch.tensor([[0, 0],
182
  [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
183
  # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
184
+ ], device=self.device).float() * g # offsets
185
 
186
  for i in range(self.nl):
187
  anchors = self.anchors[i]
 
209
  offsets = 0
210
 
211
  # Define
212
+ bc, gxy, gwh, a = t.unsafe_chunk(4, dim=1) # (image, class), grid xy, grid wh, anchors
213
+ a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
 
214
  gij = (gxy - offsets).long()
215
+ gi, gj = gij.T # grid indices
216
 
217
  # Append
 
218
  indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices
219
  tbox.append(torch.cat((gxy - gij, gwh), 1)) # box
220
  anch.append(anchors[a]) # anchors