glenn-jocher pre-commit-ci[bot] commited on
Commit
05aae17
·
unverified ·
1 Parent(s): a600bae

`torch.split()` 1.7.0 compatibility fix (#7102)

Browse files

* Update loss.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update loss.py

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (1) hide show
  1. utils/loss.py +9 -6
utils/loss.py CHANGED
@@ -108,13 +108,15 @@ class ComputeLoss:
108
  if g > 0:
109
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
110
 
111
- det = de_parallel(model).model[-1] # Detect() module
112
- self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
113
- self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
114
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
 
 
 
 
115
  self.device = device
116
- for k in 'na', 'nc', 'nl', 'anchors':
117
- setattr(self, k, getattr(det, k))
118
 
119
  def __call__(self, p, targets): # predictions, targets
120
  lcls = torch.zeros(1, device=self.device) # class loss
@@ -129,7 +131,8 @@ class ComputeLoss:
129
 
130
  n = b.shape[0] # number of targets
131
  if n:
132
- pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # target-subset of predictions
 
133
 
134
  # Regression
135
  pxy = pxy.sigmoid() * 2 - 0.5
 
108
  if g > 0:
109
  BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
110
 
111
+ m = de_parallel(model).model[-1] # Detect() module
112
+ self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
113
+ self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
114
  self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
115
+ self.na = m.na # number of anchors
116
+ self.nc = m.nc # number of classes
117
+ self.nl = m.nl # number of layers
118
+ self.anchors = m.anchors
119
  self.device = device
 
 
120
 
121
  def __call__(self, p, targets): # predictions, targets
122
  lcls = torch.zeros(1, device=self.device) # class loss
 
131
 
132
  n = b.shape[0] # number of targets
133
  if n:
134
+ # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
135
+ pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
136
 
137
  # Regression
138
  pxy = pxy.sigmoid() * 2 - 0.5