Ge Zheng
commited on
Commit
·
6bd763f
1
Parent(s):
499b0a0
fix bugs in demo, support device parser for demo. (#77)
Browse files- README.md +3 -3
- tools/demo.py +13 -16
README.md
CHANGED
@@ -68,15 +68,15 @@ Step1. Download a pretrained model from the benchmark table.
|
|
68 |
Step2. Use either -n or -f to specify your detector's config. For example:
|
69 |
|
70 |
```shell
|
71 |
-
python tools/demo.py image -n yolox-s -c /path/to/your/yolox_s.pth.tar --path assets/dog.jpg --conf 0.3 --nms 0.65 --tsize 640 --save_result
|
72 |
```
|
73 |
or
|
74 |
```shell
|
75 |
-
python tools/demo.py image -f exps/default/yolox_s.py -c /path/to/your/yolox_s.pth.tar --path assets/dog.jpg --conf 0.3 --nms 0.65 --tsize 640 --save_result
|
76 |
```
|
77 |
Demo for video:
|
78 |
```shell
|
79 |
-
python tools/demo.py video -n yolox-s -c /path/to/your/yolox_s.pth.tar --path /path/to/your/video --conf 0.3 --nms 0.65 --tsize 640 --save_result
|
80 |
```
|
81 |
|
82 |
|
|
|
68 |
Step2. Use either -n or -f to specify your detector's config. For example:
|
69 |
|
70 |
```shell
|
71 |
+
python tools/demo.py image -n yolox-s -c /path/to/your/yolox_s.pth.tar --path assets/dog.jpg --conf 0.3 --nms 0.65 --tsize 640 --save_result --device [cpu/gpu]
|
72 |
```
|
73 |
or
|
74 |
```shell
|
75 |
+
python tools/demo.py image -f exps/default/yolox_s.py -c /path/to/your/yolox_s.pth.tar --path assets/dog.jpg --conf 0.3 --nms 0.65 --tsize 640 --save_result --device [cpu/gpu]
|
76 |
```
|
77 |
Demo for video:
|
78 |
```shell
|
79 |
+
python tools/demo.py video -n yolox-s -c /path/to/your/yolox_s.pth.tar --path /path/to/your/video --conf 0.3 --nms 0.65 --tsize 640 --save_result --device [cpu/gpu]
|
80 |
```
|
81 |
|
82 |
|
tools/demo.py
CHANGED
@@ -10,12 +10,11 @@ from loguru import logger
|
|
10 |
import cv2
|
11 |
|
12 |
import torch
|
13 |
-
import torch.backends.cudnn as cudnn
|
14 |
|
15 |
from yolox.data.data_augment import preproc
|
16 |
from yolox.data.datasets import COCO_CLASSES
|
17 |
from yolox.exp import get_exp
|
18 |
-
from yolox.utils import fuse_model, get_model_info, postprocess,
|
19 |
|
20 |
IMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
|
21 |
|
@@ -42,6 +41,7 @@ def make_parser():
|
|
42 |
help="pls input your expriment description file",
|
43 |
)
|
44 |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
|
|
|
45 |
parser.add_argument("--conf", default=None, type=float, help="test conf")
|
46 |
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
|
47 |
parser.add_argument("--tsize", default=None, type=int, help="test img size")
|
@@ -81,7 +81,7 @@ def get_image_list(path):
|
|
81 |
|
82 |
|
83 |
class Predictor(object):
|
84 |
-
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None):
|
85 |
self.model = model
|
86 |
self.cls_names = cls_names
|
87 |
self.decoder = decoder
|
@@ -89,6 +89,7 @@ class Predictor(object):
|
|
89 |
self.confthre = exp.test_conf
|
90 |
self.nmsthre = exp.nmsthre
|
91 |
self.test_size = exp.test_size
|
|
|
92 |
if trt_file is not None:
|
93 |
from torch2trt import TRTModule
|
94 |
model_trt = TRTModule()
|
@@ -115,13 +116,17 @@ class Predictor(object):
|
|
115 |
|
116 |
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
|
117 |
img_info['ratio'] = ratio
|
118 |
-
img = torch.from_numpy(img).unsqueeze(0)
|
|
|
|
|
119 |
|
120 |
with torch.no_grad():
|
121 |
t0 = time.time()
|
122 |
outputs = self.model(img)
|
123 |
if self.decoder is not None:
|
124 |
outputs = self.decoder(outputs, dtype=outputs.type())
|
|
|
|
|
125 |
outputs = postprocess(
|
126 |
outputs, self.num_classes, self.confthre, self.nmsthre
|
127 |
)
|
@@ -202,10 +207,6 @@ def main(exp, args):
|
|
202 |
if not args.experiment_name:
|
203 |
args.experiment_name = exp.exp_name
|
204 |
|
205 |
-
# set environment variables for distributed training
|
206 |
-
cudnn.benchmark = True
|
207 |
-
rank = 0
|
208 |
-
|
209 |
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
210 |
os.makedirs(file_name, exist_ok=True)
|
211 |
|
@@ -213,9 +214,6 @@ def main(exp, args):
|
|
213 |
vis_folder = os.path.join(file_name, 'vis_res')
|
214 |
os.makedirs(vis_folder, exist_ok=True)
|
215 |
|
216 |
-
setup_logger(
|
217 |
-
file_name, distributed_rank=rank, filename="demo_log.txt", mode="a"
|
218 |
-
)
|
219 |
logger.info("Args: {}".format(args))
|
220 |
|
221 |
if args.conf is not None:
|
@@ -228,8 +226,8 @@ def main(exp, args):
|
|
228 |
model = exp.get_model()
|
229 |
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
230 |
|
231 |
-
|
232 |
-
|
233 |
model.eval()
|
234 |
|
235 |
if not args.trt:
|
@@ -238,8 +236,7 @@ def main(exp, args):
|
|
238 |
else:
|
239 |
ckpt_file = args.ckpt
|
240 |
logger.info("loading checkpoint")
|
241 |
-
|
242 |
-
ckpt = torch.load(ckpt_file, map_location=loc)
|
243 |
# load the model state dict
|
244 |
model.load_state_dict(ckpt["model"])
|
245 |
logger.info("loaded checkpoint done.")
|
@@ -262,7 +259,7 @@ def main(exp, args):
|
|
262 |
trt_file = None
|
263 |
decoder = None
|
264 |
|
265 |
-
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder)
|
266 |
current_time = time.localtime()
|
267 |
if args.demo == 'image':
|
268 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|
|
|
10 |
import cv2
|
11 |
|
12 |
import torch
|
|
|
13 |
|
14 |
from yolox.data.data_augment import preproc
|
15 |
from yolox.data.datasets import COCO_CLASSES
|
16 |
from yolox.exp import get_exp
|
17 |
+
from yolox.utils import fuse_model, get_model_info, postprocess, vis
|
18 |
|
19 |
IMAGE_EXT = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
|
20 |
|
|
|
41 |
help="pls input your expriment description file",
|
42 |
)
|
43 |
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt for eval")
|
44 |
+
parser.add_argument("--device", default="cpu", type=str, help="device to run our model, can either be cpu or gpu")
|
45 |
parser.add_argument("--conf", default=None, type=float, help="test conf")
|
46 |
parser.add_argument("--nms", default=None, type=float, help="test nms threshold")
|
47 |
parser.add_argument("--tsize", default=None, type=int, help="test img size")
|
|
|
81 |
|
82 |
|
83 |
class Predictor(object):
|
84 |
+
def __init__(self, model, exp, cls_names=COCO_CLASSES, trt_file=None, decoder=None, device="cpu"):
|
85 |
self.model = model
|
86 |
self.cls_names = cls_names
|
87 |
self.decoder = decoder
|
|
|
89 |
self.confthre = exp.test_conf
|
90 |
self.nmsthre = exp.nmsthre
|
91 |
self.test_size = exp.test_size
|
92 |
+
self.device = device
|
93 |
if trt_file is not None:
|
94 |
from torch2trt import TRTModule
|
95 |
model_trt = TRTModule()
|
|
|
116 |
|
117 |
img, ratio = preproc(img, self.test_size, self.rgb_means, self.std)
|
118 |
img_info['ratio'] = ratio
|
119 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
120 |
+
if self.device == "gpu":
|
121 |
+
img = img.cuda()
|
122 |
|
123 |
with torch.no_grad():
|
124 |
t0 = time.time()
|
125 |
outputs = self.model(img)
|
126 |
if self.decoder is not None:
|
127 |
outputs = self.decoder(outputs, dtype=outputs.type())
|
128 |
+
if self.device == "gpu":
|
129 |
+
outputs = outputs.cpu().numpy()
|
130 |
outputs = postprocess(
|
131 |
outputs, self.num_classes, self.confthre, self.nmsthre
|
132 |
)
|
|
|
207 |
if not args.experiment_name:
|
208 |
args.experiment_name = exp.exp_name
|
209 |
|
|
|
|
|
|
|
|
|
210 |
file_name = os.path.join(exp.output_dir, args.experiment_name)
|
211 |
os.makedirs(file_name, exist_ok=True)
|
212 |
|
|
|
214 |
vis_folder = os.path.join(file_name, 'vis_res')
|
215 |
os.makedirs(vis_folder, exist_ok=True)
|
216 |
|
|
|
|
|
|
|
217 |
logger.info("Args: {}".format(args))
|
218 |
|
219 |
if args.conf is not None:
|
|
|
226 |
model = exp.get_model()
|
227 |
logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
|
228 |
|
229 |
+
if args.device == "gpu":
|
230 |
+
model.cuda()
|
231 |
model.eval()
|
232 |
|
233 |
if not args.trt:
|
|
|
236 |
else:
|
237 |
ckpt_file = args.ckpt
|
238 |
logger.info("loading checkpoint")
|
239 |
+
ckpt = torch.load(ckpt_file, map_location="cpu")
|
|
|
240 |
# load the model state dict
|
241 |
model.load_state_dict(ckpt["model"])
|
242 |
logger.info("loaded checkpoint done.")
|
|
|
259 |
trt_file = None
|
260 |
decoder = None
|
261 |
|
262 |
+
predictor = Predictor(model, exp, COCO_CLASSES, trt_file, decoder, args.device)
|
263 |
current_time = time.localtime()
|
264 |
if args.demo == 'image':
|
265 |
image_demo(predictor, vis_folder, args.path, current_time, args.save_result)
|