wwqgtxx commited on
Commit
88f2356
·
1 Parent(s): 9b3f86b

fix(demo): --fp16 not working in demo.py (#648)

Browse files
Files changed (1) hide show
  1. tools/demo.py +8 -1
tools/demo.py CHANGED
@@ -106,6 +106,7 @@ class Predictor(object):
106
  trt_file=None,
107
  decoder=None,
108
  device="cpu",
 
109
  legacy=False,
110
  ):
111
  self.model = model
@@ -116,6 +117,7 @@ class Predictor(object):
116
  self.nmsthre = exp.nmsthre
117
  self.test_size = exp.test_size
118
  self.device = device
 
119
  self.preproc = ValTransform(legacy=legacy)
120
  if trt_file is not None:
121
  from torch2trt import TRTModule
@@ -145,8 +147,11 @@ class Predictor(object):
145
 
146
  img, _ = self.preproc(img, None, self.test_size)
147
  img = torch.from_numpy(img).unsqueeze(0)
 
148
  if self.device == "gpu":
149
  img = img.cuda()
 
 
150
 
151
  with torch.no_grad():
152
  t0 = time.time()
@@ -261,6 +266,8 @@ def main(exp, args):
261
 
262
  if args.device == "gpu":
263
  model.cuda()
 
 
264
  model.eval()
265
 
266
  if not args.trt:
@@ -291,7 +298,7 @@ def main(exp, args):
291
  trt_file = None
292
  decoder = None
293
 
294
- predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.legacy)
295
  current_time = time.localtime()
296
  if args.demo == "image":
297
  image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
 
106
  trt_file=None,
107
  decoder=None,
108
  device="cpu",
109
+ fp16=False,
110
  legacy=False,
111
  ):
112
  self.model = model
 
117
  self.nmsthre = exp.nmsthre
118
  self.test_size = exp.test_size
119
  self.device = device
120
+ self.fp16 = fp16
121
  self.preproc = ValTransform(legacy=legacy)
122
  if trt_file is not None:
123
  from torch2trt import TRTModule
 
147
 
148
  img, _ = self.preproc(img, None, self.test_size)
149
  img = torch.from_numpy(img).unsqueeze(0)
150
+ img = img.float()
151
  if self.device == "gpu":
152
  img = img.cuda()
153
+ if self.fp16:
154
+ img = img.half() # to FP16
155
 
156
  with torch.no_grad():
157
  t0 = time.time()
 
266
 
267
  if args.device == "gpu":
268
  model.cuda()
269
+ if args.fp16:
270
+ model.half() # to FP16
271
  model.eval()
272
 
273
  if not args.trt:
 
298
  trt_file = None
299
  decoder = None
300
 
301
+ predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy)
302
  current_time = time.localtime()
303
  if args.demo == "image":
304
  image_demo(predictor, vis_folder, args.path, current_time, args.save_result)