`torch.split()` 1.7.0 compatibility fix (#7102)
Browse files* Update loss.py
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Update loss.py
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- utils/loss.py +9 -6
utils/loss.py
CHANGED
@@ -108,13 +108,15 @@ class ComputeLoss:
|
|
108 |
if g > 0:
|
109 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
110 |
|
111 |
-
|
112 |
-
self.balance = {3: [4.0, 1.0, 0.4]}.get(
|
113 |
-
self.ssi = list(
|
114 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
|
|
|
|
|
|
|
|
|
115 |
self.device = device
|
116 |
-
for k in 'na', 'nc', 'nl', 'anchors':
|
117 |
-
setattr(self, k, getattr(det, k))
|
118 |
|
119 |
def __call__(self, p, targets): # predictions, targets
|
120 |
lcls = torch.zeros(1, device=self.device) # class loss
|
@@ -129,7 +131,8 @@ class ComputeLoss:
|
|
129 |
|
130 |
n = b.shape[0] # number of targets
|
131 |
if n:
|
132 |
-
pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) #
|
|
|
133 |
|
134 |
# Regression
|
135 |
pxy = pxy.sigmoid() * 2 - 0.5
|
|
|
108 |
if g > 0:
|
109 |
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
|
110 |
|
111 |
+
m = de_parallel(model).model[-1] # Detect() module
|
112 |
+
self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
|
113 |
+
self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
|
114 |
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
|
115 |
+
self.na = m.na # number of anchors
|
116 |
+
self.nc = m.nc # number of classes
|
117 |
+
self.nl = m.nl # number of layers
|
118 |
+
self.anchors = m.anchors
|
119 |
self.device = device
|
|
|
|
|
120 |
|
121 |
def __call__(self, p, targets): # predictions, targets
|
122 |
lcls = torch.zeros(1, device=self.device) # class loss
|
|
|
131 |
|
132 |
n = b.shape[0] # number of targets
|
133 |
if n:
|
134 |
+
# pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
|
135 |
+
pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions
|
136 |
|
137 |
# Regression
|
138 |
pxy = pxy.sigmoid() * 2 - 0.5
|