fix(demo): --fp16 not working in demo.py (#648)
Browse files- tools/demo.py +8 -1
tools/demo.py
CHANGED
@@ -106,6 +106,7 @@ class Predictor(object):
|
|
106 |
trt_file=None,
|
107 |
decoder=None,
|
108 |
device="cpu",
|
|
|
109 |
legacy=False,
|
110 |
):
|
111 |
self.model = model
|
@@ -116,6 +117,7 @@ class Predictor(object):
|
|
116 |
self.nmsthre = exp.nmsthre
|
117 |
self.test_size = exp.test_size
|
118 |
self.device = device
|
|
|
119 |
self.preproc = ValTransform(legacy=legacy)
|
120 |
if trt_file is not None:
|
121 |
from torch2trt import TRTModule
|
@@ -145,8 +147,11 @@ class Predictor(object):
|
|
145 |
|
146 |
img, _ = self.preproc(img, None, self.test_size)
|
147 |
img = torch.from_numpy(img).unsqueeze(0)
|
|
|
148 |
if self.device == "gpu":
|
149 |
img = img.cuda()
|
|
|
|
|
150 |
|
151 |
with torch.no_grad():
|
152 |
t0 = time.time()
|
@@ -261,6 +266,8 @@ def main(exp, args):
|
|
261 |
|
262 |
if args.device == "gpu":
|
263 |
model.cuda()
|
|
|
|
|
264 |
model.eval()
|
265 |
|
266 |
if not args.trt:
|
@@ -291,7 +298,7 @@ def main(exp, args):
|
|
291 |
trt_file = None
|
292 |
decoder = None
|
293 |
|
294 |
-
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.legacy)
|
295 |
current_time = time.localtime()
|
296 |
if args.demo == "image":
|
297 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|
|
|
106 |
trt_file=None,
|
107 |
decoder=None,
|
108 |
device="cpu",
|
109 |
+
fp16=False,
|
110 |
legacy=False,
|
111 |
):
|
112 |
self.model = model
|
|
|
117 |
self.nmsthre = exp.nmsthre
|
118 |
self.test_size = exp.test_size
|
119 |
self.device = device
|
120 |
+
self.fp16 = fp16
|
121 |
self.preproc = ValTransform(legacy=legacy)
|
122 |
if trt_file is not None:
|
123 |
from torch2trt import TRTModule
|
|
|
147 |
|
148 |
img, _ = self.preproc(img, None, self.test_size)
|
149 |
img = torch.from_numpy(img).unsqueeze(0)
|
150 |
+
img = img.float()
|
151 |
if self.device == "gpu":
|
152 |
img = img.cuda()
|
153 |
+
if self.fp16:
|
154 |
+
img = img.half() # to FP16
|
155 |
|
156 |
with torch.no_grad():
|
157 |
t0 = time.time()
|
|
|
266 |
|
267 |
if args.device == "gpu":
|
268 |
model.cuda()
|
269 |
+
if args.fp16:
|
270 |
+
model.half() # to FP16
|
271 |
model.eval()
|
272 |
|
273 |
if not args.trt:
|
|
|
298 |
trt_file = None
|
299 |
decoder = None
|
300 |
|
301 |
+
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device, args.fp16, args.legacy)
|
302 |
current_time = time.localtime()
|
303 |
if args.demo == "image":
|
304 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|