๐ [Add] deploy option, auto remove aux head
Browse files
examples/notebook_inference.ipynb
CHANGED
@@ -43,8 +43,8 @@
|
|
43 |
"outputs": [],
|
44 |
"source": [
|
45 |
"with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
|
46 |
-
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-
|
47 |
-
" model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH
|
48 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
49 |
" vec2box = Vec2Box(model, cfg.image_size, device)"
|
50 |
]
|
@@ -70,7 +70,7 @@
|
|
70 |
" predict = vec2box(predict[\"Main\"])\n",
|
71 |
"\n",
|
72 |
"predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
|
73 |
-
"draw_bboxes(image, predict_box,
|
74 |
]
|
75 |
},
|
76 |
{
|
@@ -81,13 +81,6 @@
|
|
81 |
"\n",
|
82 |
""
|
83 |
]
|
84 |
-
},
|
85 |
-
{
|
86 |
-
"cell_type": "code",
|
87 |
-
"execution_count": null,
|
88 |
-
"metadata": {},
|
89 |
-
"outputs": [],
|
90 |
-
"source": []
|
91 |
}
|
92 |
],
|
93 |
"metadata": {
|
|
|
43 |
"outputs": [],
|
44 |
"source": [
|
45 |
"with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
|
46 |
+
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-m\"])\n",
|
47 |
+
" model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH, device = device)\n",
|
48 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
49 |
" vec2box = Vec2Box(model, cfg.image_size, device)"
|
50 |
]
|
|
|
70 |
" predict = vec2box(predict[\"Main\"])\n",
|
71 |
"\n",
|
72 |
"predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
|
73 |
+
"draw_bboxes(image, predict_box, idx2label=cfg.class_list)"
|
74 |
]
|
75 |
},
|
76 |
{
|
|
|
81 |
"\n",
|
82 |
""
|
83 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
}
|
85 |
],
|
86 |
"metadata": {
|
yolo/utils/deploy_utils.py
CHANGED
@@ -9,14 +9,15 @@ from yolo.model.yolo import create_model
|
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
12 |
-
def __init__(self, cfg: Config):
|
13 |
self.cfg = cfg
|
|
|
14 |
self.compiler = cfg.task.fast_inference
|
15 |
self._validate_compiler()
|
16 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
17 |
|
18 |
def _validate_compiler(self):
|
19 |
-
if self.compiler not in ["onnx", "trt"]:
|
20 |
logger.warning(f"โ ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
|
21 |
self.compiler = None
|
22 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
@@ -28,7 +29,11 @@ class FastModelLoader:
|
|
28 |
return self._load_onnx_model()
|
29 |
elif self.compiler == "trt":
|
30 |
return self._load_trt_model()
|
31 |
-
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def _load_onnx_model(self):
|
34 |
from onnxruntime import InferenceSession
|
|
|
9 |
|
10 |
|
11 |
class FastModelLoader:
|
12 |
+
def __init__(self, cfg: Config, device):
|
13 |
self.cfg = cfg
|
14 |
+
self.device = device
|
15 |
self.compiler = cfg.task.fast_inference
|
16 |
self._validate_compiler()
|
17 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
18 |
|
19 |
def _validate_compiler(self):
|
20 |
+
if self.compiler not in ["onnx", "trt", "deploy"]:
|
21 |
logger.warning(f"โ ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
|
22 |
self.compiler = None
|
23 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
|
|
29 |
return self._load_onnx_model()
|
30 |
elif self.compiler == "trt":
|
31 |
return self._load_trt_model()
|
32 |
+
elif self.compiler == "deploy":
|
33 |
+
self.cfg.model.model.auxiliary = {}
|
34 |
+
return create_model(
|
35 |
+
self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight, device=self.device
|
36 |
+
)
|
37 |
|
38 |
def _load_onnx_model(self):
|
39 |
from onnxruntime import InferenceSession
|