Ge
commited on
Commit
·
c61fc24
1
Parent(s):
7fd430f
fix bugs in demo
Browse files- tools/demo.py +4 -5
tools/demo.py
CHANGED
@@ -81,13 +81,12 @@ def get_image_list(path):
|
|
81 |
|
82 |
|
83 |
class Predictor(object):
|
84 |
-
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None,
|
85 |
self.model = model
|
86 |
self.cls_names = cls_names
|
87 |
self.decoder = decoder
|
88 |
self.num_classes = exp.num_classes
|
89 |
self.confthre = exp.test_conf
|
90 |
-
self.conv_vis = conf_vis
|
91 |
self.nmsthre = exp.nmsthre
|
92 |
self.test_size = exp.test_size
|
93 |
self.device = device
|
@@ -159,7 +158,7 @@ def image_demo(predictor, vis_folder, path, current_time, save_result):
|
|
159 |
files.sort()
|
160 |
for image_name in files:
|
161 |
outputs, img_info = predictor.inference(image_name)
|
162 |
-
result_image = predictor.visual(outputs[0], img_info, predictor.
|
163 |
if save_result:
|
164 |
save_folder = os.path.join(
|
165 |
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
|
@@ -192,7 +191,7 @@ def imageflow_demo(predictor, vis_folder, current_time, args):
|
|
192 |
ret_val, frame = cap.read()
|
193 |
if ret_val:
|
194 |
outputs, img_info = predictor.inference(frame)
|
195 |
-
result_frame = predictor.visual(outputs[0], img_info, predictor.
|
196 |
if args.save_result:
|
197 |
vid_writer.write(result_frame)
|
198 |
ch = cv2.waitKey(1)
|
@@ -261,7 +260,7 @@ def main(exp, args):
|
|
261 |
trt_file = None
|
262 |
decoder = None
|
263 |
|
264 |
-
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.
|
265 |
current_time = time.localtime()
|
266 |
if args.demo == 'image':
|
267 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|
|
|
81 |
|
82 |
|
83 |
class Predictor(object):
|
84 |
+
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu"):
|
85 |
self.model = model
|
86 |
self.cls_names = cls_names
|
87 |
self.decoder = decoder
|
88 |
self.num_classes = exp.num_classes
|
89 |
self.confthre = exp.test_conf
|
|
|
90 |
self.nmsthre = exp.nmsthre
|
91 |
self.test_size = exp.test_size
|
92 |
self.device = device
|
|
|
158 |
files.sort()
|
159 |
for image_name in files:
|
160 |
outputs, img_info = predictor.inference(image_name)
|
161 |
+
result_image = predictor.visual(outputs[0], img_info, predictor.confthre)
|
162 |
if save_result:
|
163 |
save_folder = os.path.join(
|
164 |
vis_folder, time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
|
|
|
191 |
ret_val, frame = cap.read()
|
192 |
if ret_val:
|
193 |
outputs, img_info = predictor.inference(frame)
|
194 |
+
result_frame = predictor.visual(outputs[0], img_info, predictor.confthre)
|
195 |
if args.save_result:
|
196 |
vid_writer.write(result_frame)
|
197 |
ch = cv2.waitKey(1)
|
|
|
260 |
trt_file = None
|
261 |
decoder = None
|
262 |
|
263 |
+
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)
|
264 |
current_time = time.localtime()
|
265 |
if args.demo == 'image':
|
266 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|