glenn-jocher commited on
Commit
2bf34f5
·
unverified ·
1 Parent(s): ee16983

PyTorch Hub amp.autocast() inference (#2641)

Browse files

I think this should help speed up CUDA inference, as currently models may be running in FP32 inference mode on CUDA devices unnecesarily.

Files changed (1) hide show
  1. models/common.py +9 -8
models/common.py CHANGED
@@ -8,6 +8,7 @@ import requests
8
  import torch
9
  import torch.nn as nn
10
  from PIL import Image
 
11
 
12
  from utils.datasets import letterbox
13
  from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
@@ -219,17 +220,17 @@ class autoShape(nn.Module):
219
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
220
  t.append(time_synchronized())
221
 
222
- # Inference
223
- with torch.no_grad():
224
  y = self.model(x, augment, profile)[0] # forward
225
- t.append(time_synchronized())
226
 
227
- # Post-process
228
- y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
229
- for i in range(n):
230
- scale_coords(shape1, y[i][:, :4], shape0[i])
231
- t.append(time_synchronized())
232
 
 
233
  return Detections(imgs, y, files, t, self.names, x.shape)
234
 
235
 
 
8
  import torch
9
  import torch.nn as nn
10
  from PIL import Image
11
+ from torch.cuda import amp
12
 
13
  from utils.datasets import letterbox
14
  from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh
 
220
  x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
221
  t.append(time_synchronized())
222
 
223
+ with torch.no_grad(), amp.autocast(enabled=p.device.type != 'cpu'):
224
+ # Inference
225
  y = self.model(x, augment, profile)[0] # forward
226
+ t.append(time_synchronized())
227
 
228
+ # Post-process
229
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
230
+ for i in range(n):
231
+ scale_coords(shape1, y[i][:, :4], shape0[i])
 
232
 
233
+ t.append(time_synchronized())
234
  return Detections(imgs, y, files, t, self.names, x.shape)
235
 
236