Clip TTA Augmented Tails (#5028)
Browse files* Clip TTA Augmented Tails
Experimental TTA update.
* Update yolo.py
* Update yolo.py
* Update yolo.py
* Update yolo.py
- models/yolo.py +12 -0
models/yolo.py
CHANGED
@@ -134,6 +134,7 @@ class Model(nn.Module):
|
|
134 |
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
135 |
yi = self._descale_pred(yi, fi, si, img_size)
|
136 |
y.append(yi)
|
|
|
137 |
return torch.cat(y, 1), None # augmented inference, train
|
138 |
|
139 |
def _forward_once(self, x, profile=False, visualize=False):
|
@@ -166,6 +167,17 @@ class Model(nn.Module):
|
|
166 |
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
167 |
return p
|
168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
def _profile_one_layer(self, m, x, dt):
|
170 |
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
171 |
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
|
|
134 |
# cv2.imwrite(f'img_{si}.jpg', 255 * xi[0].cpu().numpy().transpose((1, 2, 0))[:, :, ::-1]) # save
|
135 |
yi = self._descale_pred(yi, fi, si, img_size)
|
136 |
y.append(yi)
|
137 |
+
y = self._clip_augmented(y) # clip augmented tails
|
138 |
return torch.cat(y, 1), None # augmented inference, train
|
139 |
|
140 |
def _forward_once(self, x, profile=False, visualize=False):
|
|
|
167 |
p = torch.cat((x, y, wh, p[..., 4:]), -1)
|
168 |
return p
|
169 |
|
170 |
+
def _clip_augmented(self, y):
|
171 |
+
# Clip YOLOv5 augmented inference tails
|
172 |
+
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
173 |
+
g = sum(4 ** x for x in range(nl)) # grid points
|
174 |
+
e = 1 # exclude layer count
|
175 |
+
i = (y[0].shape[1] // g) * sum(4 ** x for x in range(e)) # indices
|
176 |
+
y[0] = y[0][:, :-i] # large
|
177 |
+
i = (y[-1].shape[1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
178 |
+
y[-1] = y[-1][:, i:] # small
|
179 |
+
return y
|
180 |
+
|
181 |
def _profile_one_layer(self, m, x, dt):
|
182 |
c = isinstance(m, Detect) # is final layer, copy input as inplace fix
|
183 |
o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|