CuriousDolphin commited on
Commit
90d8dcc
Β·
1 Parent(s): b6e1550

add onnx export and inference

Browse files
.gitignore CHANGED
@@ -1,3 +1,5 @@
1
  .venv
2
  __pycache__
3
- data/cache
 
 
 
1
  .venv
2
  __pycache__
3
+ data/cache
4
+ data/onnx
5
+ .DS_Store
README.md CHANGED
@@ -6,4 +6,4 @@ app_port: 7000
6
  pinned: true
7
  ---
8
 
9
- # Simple DETR gradio implementation (object detections & panoptic segmentation)
 
6
  pinned: true
7
  ---
8
 
9
+ # Simple DETR gradio implementation (object detection & panoptic segmentation)
data/assets/detr_architecture.png ADDED
data/{assets β†’ images}/000000039769.jpg RENAMED
File without changes
data/images/MOTO_GP_landing_page-Hero_image_Medium.jpeg ADDED
data/{assets β†’ images}/dog_bike_car.jpeg RENAMED
File without changes
data/{assets β†’ images}/download.png RENAMED
File without changes
data/images/sample1.png ADDED
detr/{detr.py β†’ detr_models.py} RENAMED
@@ -10,7 +10,10 @@ from torch import nn
10
  from torchvision.models import resnet50
11
  from panopticapi.utils import id2rgb, rgb2id
12
  from supervision import Detections, BoxAnnotator, MaskAnnotator
 
 
13
  from PIL import Image
 
14
 
15
  torch.set_grad_enabled(False)
16
 
@@ -18,11 +21,14 @@ torch.set_grad_enabled(False)
18
  # https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=cfCcEYjg7y46
19
 
20
  DETR_DEMO_WEIGHTS_URI = "https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth"
21
-
22
  TORCH_HOME = os.path.abspath(os.curdir) + "/data/cache"
23
-
24
  os.environ["TORCH_HOME"] = TORCH_HOME
25
 
 
 
 
 
26
  print("Torch home:", TORCH_HOME)
27
 
28
 
@@ -40,6 +46,17 @@ def normalize_img(image):
40
  return transform(image).unsqueeze(0)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
43
  # for output bounding box post-processing
44
  def box_cxcywh_to_xyxy(x):
45
  x_c, y_c, w, h = x.unbind(1)
@@ -199,6 +216,100 @@ class SimpleDetr:
199
  )
200
  return annotated
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  class PanopticDetrResenet101:
204
  @cache
@@ -235,6 +346,32 @@ class PanopticDetrResenet101:
235
  panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
236
  return panoptic_seg
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  # COCO classes
240
  CLASSES = [
@@ -330,3 +467,6 @@ CLASSES = [
330
  "hair drier",
331
  "toothbrush",
332
  ]
 
 
 
 
10
  from torchvision.models import resnet50
11
  from panopticapi.utils import id2rgb, rgb2id
12
  from supervision import Detections, BoxAnnotator, MaskAnnotator
13
+ import onnx
14
+ import onnxruntime
15
  from PIL import Image
16
+ from pathlib import Path
17
 
18
  torch.set_grad_enabled(False)
19
 
 
21
  # https://colab.research.google.com/github/facebookresearch/detr/blob/colab/notebooks/detr_demo.ipynb#scrollTo=cfCcEYjg7y46
22
 
23
  DETR_DEMO_WEIGHTS_URI = "https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth"
 
24
  TORCH_HOME = os.path.abspath(os.curdir) + "/data/cache"
25
+ ONNX_DIR = os.path.abspath(os.curdir) + "/data/onnx"
26
  os.environ["TORCH_HOME"] = TORCH_HOME
27
 
28
+ Path(TORCH_HOME).mkdir(exist_ok=True)
29
+ Path(ONNX_DIR).mkdir(exist_ok=True)
30
+
31
+
32
  print("Torch home:", TORCH_HOME)
33
 
34
 
 
46
  return transform(image).unsqueeze(0)
47
 
48
 
49
+ def normalize_img_800_800(image):
50
+ transform = T.Compose(
51
+ [
52
+ T.Resize((800, 800)),
53
+ T.ToTensor(),
54
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
55
+ ]
56
+ )
57
+ return transform(image).unsqueeze(0)
58
+
59
+
60
  # for output bounding box post-processing
61
  def box_cxcywh_to_xyxy(x):
62
  x_c, y_c, w, h = x.unbind(1)
 
216
  )
217
  return annotated
218
 
219
+ def export(self):
220
+ model_path = f"{ONNX_DIR}/detr_simple_demo_onnx.onnx"
221
+ dummy_image = torch.ones(1, 3, 800, 800, device="cpu")
222
+ input_names = ["inputs"]
223
+ output_names = ["pred_logits", "pred_boxes"]
224
+ torch.onnx.export(
225
+ self.model,
226
+ dummy_image,
227
+ model_path,
228
+ input_names=input_names,
229
+ output_names=output_names,
230
+ # dynamic_axes={input_names[0]: {0: "batch_size", 2: "height", 3: "width"}}, #!TODO
231
+ export_params=True,
232
+ training=torch.onnx.TrainingMode.EVAL,
233
+ opset_version=14,
234
+ )
235
+ onnx_model = onnx.load(model_path)
236
+
237
+ # Check the model
238
+ try:
239
+ onnx.checker.check_model(onnx_model)
240
+ except onnx.checker.ValidationError as e:
241
+ print(f"The model is invalid: {e}")
242
+ else:
243
+ print("The model is valid!")
244
+ return model_path
245
+
246
+
247
+ class SimpleDetrOnnx:
248
+ @cache
249
+ def __init__(self):
250
+ self.box_annotator: BoxAnnotator = BoxAnnotator()
251
+ onnx_sess_opts = onnxruntime.SessionOptions()
252
+ onnx_sess_opts.graph_optimization_level = (
253
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
254
+ # onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
255
+ )
256
+ onnx_sess_opts.enable_mem_pattern = True
257
+ onnx_sess_opts.enable_cpu_mem_arena = True
258
+ self.ort_session = onnxruntime.InferenceSession(
259
+ f"{ONNX_DIR}/detr_simple_demo.onnx",
260
+ sess_options=onnx_sess_opts,
261
+ providers=[
262
+ "CUDAExecutionProvider",
263
+ "CoreMLExecutionProvider",
264
+ "CPUExecutionProvider",
265
+ ],
266
+ )
267
+ self.classes = {}
268
+ self.metadata = self.ort_session.get_modelmeta()
269
+ self.providers = self.ort_session.get_providers()
270
+ print(f"[OnnxRuntime] Providers:{self.providers}")
271
+ print(
272
+ f"[OnnxRuntime] medatadata: {self.metadata.custom_metadata_map} {type(self.metadata.custom_metadata_map)}"
273
+ )
274
+
275
+ def detect(self, image, conf):
276
+ # dummy_image = np.ones((1, 3, 600, 800), dtype=np.float32)
277
+ im = normalize_img_800_800(image).numpy()
278
+ print("SHAPE", im.shape)
279
+ ort_inputs = {self.ort_session.get_inputs()[0].name: im}
280
+ outputs = self.ort_session.run(None, ort_inputs)
281
+ pred_logits = torch.tensor(
282
+ outputs[0]
283
+ ) # conversion to torch for simplicity (softmax etc)
284
+ pred_boxes = torch.tensor(outputs[1])
285
+ scores = pred_logits.softmax(-1)[0, :, :-1]
286
+ keep = scores.max(-1).values > conf
287
+ bboxes_scaled = rescale_bboxes(pred_boxes[0, keep], image.size)
288
+ probas = scores[keep]
289
+ class_id = []
290
+ confidence = []
291
+ for prob in probas:
292
+ cls_id = prob.argmax()
293
+ c = prob[cls_id]
294
+ class_id.append(int(cls_id))
295
+ confidence.append(float(c))
296
+ print(class_id, confidence)
297
+ detections = Detections(
298
+ xyxy=bboxes_scaled.cpu().detach().numpy(),
299
+ class_id=np.array(class_id),
300
+ confidence=np.array(confidence),
301
+ )
302
+ annotated = self.box_annotator.annotate(
303
+ scene=np.array(image),
304
+ skip_label=False,
305
+ detections=detections,
306
+ labels=[
307
+ f"{CLASSES[cls_id]} {conf:.2f}"
308
+ for cls_id, conf in zip(detections.class_id, detections.confidence)
309
+ ],
310
+ )
311
+ return annotated
312
+
313
 
314
  class PanopticDetrResenet101:
315
  @cache
 
346
  panoptic_seg[panoptic_seg_id == id] = np.asarray(next(palette)) * 255
347
  return panoptic_seg
348
 
349
+ def export(self):
350
+ model_path = f"{ONNX_DIR}/detr_resnet101_panoptic.onnx"
351
+ dummy_image = torch.ones(1, 3, 800, 800, device="cpu")
352
+ input_names = ["inputs"]
353
+ output_names = ["pred_logits", "pred_boxes", "pred_masks"]
354
+ torch.onnx.export(
355
+ self.model,
356
+ dummy_image,
357
+ model_path,
358
+ input_names=input_names,
359
+ output_names=output_names,
360
+ export_params=True,
361
+ training=torch.onnx.TrainingMode.EVAL,
362
+ opset_version=14,
363
+ )
364
+ onnx_model = onnx.load(model_path)
365
+
366
+ # Check the model
367
+ try:
368
+ onnx.checker.check_model(onnx_model)
369
+ except onnx.checker.ValidationError as e:
370
+ print(f"The model is invalid: {e}")
371
+ else:
372
+ print("The model is valid!")
373
+ return model_path
374
+
375
 
376
  # COCO classes
377
  CLASSES = [
 
467
  "hair drier",
468
  "toothbrush",
469
  ]
470
+
471
+ # model = SimpleDetr()
472
+ # model.export()
detr/main_gradio.py CHANGED
@@ -3,20 +3,26 @@ import supervision as sv
3
  import os
4
  from time import perf_counter
5
 
6
- from detr import SimpleDetr, PanopticDetrResenet101
7
 
 
8
  ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
9
-
10
- print("Assets:", ASSETS_DIR)
11
 
12
 
13
  def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
14
  progress(0.1, "loading model..")
 
 
15
  t0 = perf_counter()
16
- if model_name == "detr_demo_boxes":
17
  model = SimpleDetr()
18
- else:
19
  model = PanopticDetrResenet101()
 
 
 
 
20
  t1 = perf_counter()
21
  progress(0.1, "Inference..")
22
 
@@ -25,43 +31,111 @@ def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm
25
  return annotated_img, {"load_model": t1 - t0, "inference": t2 - t1}, None
26
 
27
 
28
- with gr.Blocks() as inference_gradio:
29
- gr.Markdown("# DETR inference")
30
- with gr.Row():
31
- with gr.Column():
32
- img_file = gr.Image(type="pil")
33
- # with gr.Row():
34
- model_name = gr.Dropdown(
35
- label="Model",
36
- scale=3,
37
- choices=["detr_demo_boxes", "detr_resnet101_panoptic"],
38
- value="detr_demo_boxes",
39
- )
40
 
41
- conf = gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.5)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- with gr.Row():
44
- start_btn = gr.Button("Start", variant="primary")
45
- with gr.Column():
46
- annotated_img = gr.Image(label="Annotated Image")
47
- speed = gr.JSON(label="speed")
48
- examples = gr.Examples(
49
- examples=[
50
- [path]
51
- for path in sv.list_files_with_extensions(
52
- directory=ASSETS_DIR, extensions=["jpeg", "jpg", "png"]
53
- )
54
- ],
55
- inputs=[img_file],
56
- )
57
- start_btn.click(
58
- fn=run_inference,
59
- inputs=[img_file, conf, model_name],
60
- outputs=[annotated_img, speed],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- inference_gradio.queue(2).launch(
65
  debug=True,
66
  server_name="0.0.0.0",
67
  server_port=7000,
 
3
  import os
4
  from time import perf_counter
5
 
6
+ from detr_models import SimpleDetr, PanopticDetrResenet101, SimpleDetrOnnx, ONNX_DIR
7
 
8
+ IMAGES_DIR = os.path.abspath(os.curdir) + "/data/images"
9
  ASSETS_DIR = os.path.abspath(os.curdir) + "/data/assets"
10
+ print("images:", IMAGES_DIR)
 
11
 
12
 
13
  def run_inference(image, confidence, model_name, progress=gr.Progress(track_tqdm=True)):
14
  progress(0.1, "loading model..")
15
+ if not image:
16
+ raise gr.Error("Provide image.")
17
  t0 = perf_counter()
18
+ if model_name == "detr_simple_demo":
19
  model = SimpleDetr()
20
+ elif model_name == "detr_resnet101_panoptic":
21
  model = PanopticDetrResenet101()
22
+ elif model_name == "detr_simple_demo_onnx":
23
+ if not os.path.exists(f"{ONNX_DIR}/detr_simple_demo_onnx.onnx"):
24
+ raise gr.Error("ONNX model not found, please export it first!")
25
+ model = SimpleDetrOnnx()
26
  t1 = perf_counter()
27
  progress(0.1, "Inference..")
28
 
 
31
  return annotated_img, {"load_model": t1 - t0, "inference": t2 - t1}, None
32
 
33
 
34
+ def export_model(model_name, progress=gr.Progress(track_tqdm=True)):
35
+ progress(0.1, "Conversion..")
36
+ t0 = perf_counter()
37
+ if model_name == "detr_simple_demo":
38
+ model = SimpleDetr()
39
+ elif model_name == "detr_resnet101_panoptic":
40
+ model = PanopticDetrResenet101()
41
+
42
+ model_path = model.export()
43
+ t1 = perf_counter()
44
+ return model_path, {"export_time": t1 - t0}
45
+
46
 
47
+ with gr.Blocks() as demo:
48
+ gr.Markdown("# DETR: Detection Transformer")
49
+ # gr.Image(value=f"{ASSETS_DIR}/detr_architecture.png")
50
+ with gr.Tab("Torch Inference"):
51
+ with gr.Row():
52
+ with gr.Column():
53
+ img_file = gr.Image(type="pil")
54
+ model_name = gr.Dropdown(
55
+ label="Model",
56
+ choices=[
57
+ "detr_simple_demo",
58
+ "detr_resnet101_panoptic",
59
+ ],
60
+ value="detr_simple_demo",
61
+ )
62
 
63
+ conf = gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.5)
64
+
65
+ with gr.Row():
66
+ start_btn = gr.Button("Start", variant="primary")
67
+ with gr.Column():
68
+ annotated_img = gr.Image(label="Annotated Image")
69
+ speed = gr.JSON(label="speed")
70
+ examples = gr.Examples(
71
+ examples=[
72
+ [path]
73
+ for path in sv.list_files_with_extensions(
74
+ directory=IMAGES_DIR, extensions=["jpeg", "jpg", "png"]
75
+ )
76
+ ],
77
+ inputs=[img_file],
78
+ )
79
+ start_btn.click(
80
+ fn=run_inference,
81
+ inputs=[img_file, conf, model_name],
82
+ outputs=[annotated_img, speed],
83
+ )
84
+ with gr.Tab("ONNX Inference"):
85
+ with gr.Row():
86
+ with gr.Column():
87
+ img_file = gr.Image(type="pil")
88
+ model_name = gr.Dropdown(
89
+ label="Model",
90
+ choices=[
91
+ "detr_simple_demo_onnx",
92
+ ],
93
+ value="detr_simple_demo_onnx",
94
+ )
95
+ conf = gr.Slider(label="Confidence", minimum=0, maximum=0.99, value=0.7)
96
+ with gr.Row():
97
+ start_btn = gr.Button("Start", variant="primary")
98
+ with gr.Column():
99
+ annotated_img = gr.Image(label="Annotated Image")
100
+ speed = gr.JSON(label="speed")
101
+ examples = gr.Examples(
102
+ examples=[
103
+ [path]
104
+ for path in sv.list_files_with_extensions(
105
+ directory=IMAGES_DIR, extensions=["jpeg", "jpg", "png"]
106
+ )
107
+ ],
108
+ inputs=[img_file],
109
+ )
110
+ start_btn.click(
111
+ fn=run_inference,
112
+ inputs=[img_file, conf, model_name],
113
+ outputs=[annotated_img, speed],
114
+ )
115
+ with gr.Tab("ONNX export"):
116
+ with gr.Row():
117
+ with gr.Column():
118
+ model_name = gr.Dropdown(
119
+ label="Model",
120
+ choices=[
121
+ "detr_simple_demo",
122
+ "detr_resnet101_panoptic",
123
+ ],
124
+ value="detr_simple_demo",
125
+ )
126
+ with gr.Row():
127
+ export_btn = gr.Button("Export", variant="primary")
128
+ with gr.Column():
129
+ onnx_file = gr.File()
130
+ result = gr.JSON(label="result")
131
+ export_btn.click(
132
+ fn=export_model,
133
+ inputs=[model_name],
134
+ outputs=[onnx_file, result],
135
+ )
136
 
137
  if __name__ == "__main__":
138
+ demo.queue(2).launch(
139
  debug=True,
140
  server_name="0.0.0.0",
141
  server_port=7000,
requirements.txt CHANGED
@@ -5,4 +5,6 @@ matplotlib
5
  torchvision
6
  supervision==0.17.1
7
  git+https://github.com/cocodataset/panopticapi.git
8
- seaborn
 
 
 
5
  torchvision
6
  supervision==0.17.1
7
  git+https://github.com/cocodataset/panopticapi.git
8
+ seaborn
9
+ onnx
10
+ onnxruntime