henry000 commited on
Commit
89c6a27
ยท
1 Parent(s): a0976c9

๐Ÿš€ [Add] deploy option, auto remove aux head

Browse files
examples/notebook_inference.ipynb CHANGED
@@ -43,8 +43,8 @@
43
  "outputs": [],
44
  "source": [
45
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
46
- " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-c-deploy\"])\n",
47
- " model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH).to(device)\n",
48
  " transform = AugmentationComposer([], cfg.image_size)\n",
49
  " vec2box = Vec2Box(model, cfg.image_size, device)"
50
  ]
@@ -70,7 +70,7 @@
70
  " predict = vec2box(predict[\"Main\"])\n",
71
  "\n",
72
  "predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
73
- "draw_bboxes(image, predict_box, save_path='../demo/images/output/', idx2label=cfg.class_list)"
74
  ]
75
  },
76
  {
@@ -81,13 +81,6 @@
81
  "\n",
82
  "![image](../demo/images/output/visualize.png)"
83
  ]
84
- },
85
- {
86
- "cell_type": "code",
87
- "execution_count": null,
88
- "metadata": {},
89
- "outputs": [],
90
- "source": []
91
  }
92
  ],
93
  "metadata": {
 
43
  "outputs": [],
44
  "source": [
45
  "with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
46
+ " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-m\"])\n",
47
+ " model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH, device = device)\n",
48
  " transform = AugmentationComposer([], cfg.image_size)\n",
49
  " vec2box = Vec2Box(model, cfg.image_size, device)"
50
  ]
 
70
  " predict = vec2box(predict[\"Main\"])\n",
71
  "\n",
72
  "predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
73
+ "draw_bboxes(image, predict_box, idx2label=cfg.class_list)"
74
  ]
75
  },
76
  {
 
81
  "\n",
82
  "![image](../demo/images/output/visualize.png)"
83
  ]
 
 
 
 
 
 
 
84
  }
85
  ],
86
  "metadata": {
yolo/utils/deploy_utils.py CHANGED
@@ -9,14 +9,15 @@ from yolo.model.yolo import create_model
9
 
10
 
11
  class FastModelLoader:
12
- def __init__(self, cfg: Config):
13
  self.cfg = cfg
 
14
  self.compiler = cfg.task.fast_inference
15
  self._validate_compiler()
16
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
17
 
18
  def _validate_compiler(self):
19
- if self.compiler not in ["onnx", "trt"]:
20
  logger.warning(f"โš ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
21
  self.compiler = None
22
  if self.cfg.device == "mps" and self.compiler == "trt":
@@ -28,7 +29,11 @@ class FastModelLoader:
28
  return self._load_onnx_model()
29
  elif self.compiler == "trt":
30
  return self._load_trt_model()
31
- return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
 
 
 
 
32
 
33
  def _load_onnx_model(self):
34
  from onnxruntime import InferenceSession
 
9
 
10
 
11
  class FastModelLoader:
12
+ def __init__(self, cfg: Config, device):
13
  self.cfg = cfg
14
+ self.device = device
15
  self.compiler = cfg.task.fast_inference
16
  self._validate_compiler()
17
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
18
 
19
  def _validate_compiler(self):
20
+ if self.compiler not in ["onnx", "trt", "deploy"]:
21
  logger.warning(f"โš ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
22
  self.compiler = None
23
  if self.cfg.device == "mps" and self.compiler == "trt":
 
29
  return self._load_onnx_model()
30
  elif self.compiler == "trt":
31
  return self._load_trt_model()
32
+ elif self.compiler == "deploy":
33
+ self.cfg.model.model.auxiliary = {}
34
+ return create_model(
35
+ self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight, device=self.device
36
+ )
37
 
38
  def _load_onnx_model(self):
39
  from onnxruntime import InferenceSession