tidalove commited on
Commit
a292bf5
·
verified ·
1 Parent(s): 833e9a7

Create demo_api.py

Browse files
Files changed (1) hide show
  1. tools/demo_api.py +180 -0
tools/demo_api.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+
4
+ import os
5
+ import json
6
+ from loguru import logger
7
+
8
+ import cv2
9
+
10
+ import torch
11
+
12
+ from yolox.data.datasets import COCO_CLASSES
13
+ from yolox.exp import get_exp
14
+ from yolox.utils import fuse_model, get_model_info, postprocess, vis
15
+
16
+ IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]
17
+
18
+ def get_image_list(path):
19
+ image_names = []
20
+ for maindir, subdir, file_name_list in os.walk(path):
21
+ for filename in file_name_list:
22
+ apath = os.path.join(maindir, filename)
23
+ ext = os.path.splitext(apath)[1]
24
+ if ext in IMAGE_EXT:
25
+ image_names.append(apath)
26
+ return image_names
27
+
28
+
29
+ class Predictor(object):
30
+ def __init__(
31
+ self,
32
+ model,
33
+ exp,
34
+ cls_names=COCO_CLASSES,
35
+ trt_file=None,
36
+ decoder=None,
37
+ device="cpu",
38
+ fp16=False,
39
+ legacy=False,
40
+ ):
41
+ self.model = model
42
+ self.cls_names = cls_names
43
+ self.decoder = decoder
44
+ self.num_classes = exp.num_classes
45
+ self.confthre = exp.test_conf
46
+ self.nmsthre = exp.nmsthre
47
+ self.test_size = exp.test_size
48
+ self.device = device
49
+ self.fp16 = fp16
50
+ self.preproc = ValTransform(legacy=legacy)
51
+ if trt_file is not None:
52
+ from torch2trt import TRTModule
53
+
54
+ model_trt = TRTModule()
55
+ model_trt.load_state_dict(torch.load(trt_file))
56
+
57
+ x = torch.ones(1, 3, exp.test_size[0], exp.test_size[1]).cuda()
58
+ self.model(x)
59
+ self.model = model_trt
60
+
61
+ def inference(self, img):
62
+ img_info = {"id": 0}
63
+ if isinstance(img, str):
64
+ img_info["file_name"] = os.path.basename(img)
65
+ img = cv2.imread(img)
66
+ else:
67
+ img_info["file_name"] = None
68
+
69
+ height, width = img.shape[:2]
70
+ img_info["height"] = height
71
+ img_info["width"] = width
72
+ img_info["raw_img"] = img
73
+
74
+ ratio = min(self.test_size[0] / img.shape[0], self.test_size[1] / img.shape[1])
75
+ img_info["ratio"] = ratio
76
+
77
+ img, _ = self.preproc(img, None, self.test_size)
78
+ img = torch.from_numpy(img).unsqueeze(0)
79
+ img = img.float()
80
+ if self.device == "gpu":
81
+ img = img.cuda()
82
+ if self.fp16:
83
+ img = img.half() # to FP16
84
+
85
+ with torch.no_grad():
86
+ t0 = time.time()
87
+ outputs = self.model(img)
88
+ if self.decoder is not None:
89
+ outputs = self.decoder(outputs, dtype=outputs.type())
90
+ outputs = postprocess(
91
+ outputs, self.num_classes, self.confthre,
92
+ self.nmsthre, class_agnostic=True
93
+ )
94
+ logger.info("Infer time: {:.4f}s".format(time.time() - t0))
95
+ return outputs, img_info
96
+
97
+ def visual(self, output, img_info, cls_conf=0.35):
98
+ ratio = img_info["ratio"]
99
+ img = img_info["raw_img"]
100
+ if output is None:
101
+ return img
102
+ output = output.cpu()
103
+
104
+ bboxes = output[:, 0:4]
105
+
106
+ # preprocessing: resize
107
+ bboxes /= ratio
108
+
109
+ cls = output[:, 6]
110
+ scores = output[:, 4] * output[:, 5]
111
+
112
+ vis_res = vis(img, bboxes, scores, cls, cls_conf, self.cls_names)
113
+ return vis_res
114
+
115
+ def build_predictor(
116
+ exp_file, model_name, ckpt_path, device="cpu", fp16=False, fuse=False, trt=False, conf=0.3, nms=0.3, tsize=None
117
+ ):
118
+ # load experiment
119
+ exp = get_exp(exp_file, model_name)
120
+ if conf is not None:
121
+ exp.test_conf = conf
122
+ if nms is not None:
123
+ exp.nmsthre = nms
124
+ if tsize is not None:
125
+ exp.test_size = (tsize, tsize)
126
+
127
+ # create & initialize model
128
+ model = exp.get_model()
129
+ if device == "gpu":
130
+ model.cuda()
131
+ if fp16:
132
+ model.half()
133
+ model.eval()
134
+ logger.info("Model Summary: {}".format(get_model_info(model, exp.test_size)))
135
+
136
+ predictor = Predictor(
137
+ model, exp, COCO_CLASSES,
138
+ None, decoder=None,
139
+ device=device, fp16=fp16, legacy=False
140
+ )
141
+
142
+ return predictor
143
+
144
+ def run_detection(predictor, path):
145
+ # COCO output format: { images: [{id: 0, filename: "x.jpg"}, ...],
146
+ # annotations: [{id: 0, image_id: 0, bbox: [0 0 0 0], score: 0.35, class: 1}, ... ] }
147
+ if os.path.isdir(path):
148
+ files = get_image_list(path)
149
+ else:
150
+ files = [path]
151
+ files.sort()
152
+
153
+ img_list = []
154
+ ann_list = []
155
+
156
+ for img_id, image_name in enumerate(files):
157
+
158
+ outputs, img_info = predictor.inference(image_name)
159
+ ratio = img_info["ratio"]
160
+
161
+ img_entry = {"id": img_id,
162
+ "filename": image_name }
163
+ img_list.append(img_entry)
164
+
165
+ for id, output in enumerate(outputs):
166
+ ann_entry = {"id": id,
167
+ "image_id": img_id,
168
+ "bbox": output[:4] / ratio,
169
+ "cls": output[6],
170
+ "score": output[4] * output[5] }
171
+ ann_list.append(ann_entry)
172
+
173
+ data_dict = { "images": img_list,
174
+ "annotations": ann_list
175
+ }
176
+
177
+ with open(f"{path}/results.json", w) as f:
178
+ json.dump(data_dict, f)
179
+
180
+ return f"{path}/results.json"