henry000 commited on
Commit
8fe77d2
Β·
2 Parent(s): a54ff08 0d49177

πŸ”€ [Merge] branch 'MODELv7' into TEST

Browse files
examples/notebook_inference.ipynb CHANGED
@@ -16,8 +16,15 @@
16
  "project_root = Path().resolve().parent\n",
17
  "sys.path.append(str(project_root))\n",
18
  "\n",
19
- "from yolo import AugmentationComposer, Config, create_model, custom_logger, draw_bboxes, Vec2Box, PostProccess\n",
20
- "from yolo.utils.bounding_box_utils import Anc2Box"
 
 
 
 
 
 
 
21
  ]
22
  },
23
  {
@@ -48,8 +55,7 @@
48
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
49
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
50
  " transform = AugmentationComposer([], cfg.image_size)\n",
51
- " converter = Anc2Box(model, cfg.model.anchor, cfg.image_size, device)\n",
52
- " # converter = Vec2Box(model, cfg.model.anchor, cfg.image_size, device)\n",
53
  " post_proccess = PostProccess(converter, cfg.task.nms)"
54
  ]
55
  },
@@ -86,23 +92,6 @@
86
  "\n",
87
  "![image](../demo/images/output/visualize.png)"
88
  ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "metadata": {},
94
- "outputs": [],
95
- "source": [
96
- "%load_ext autoreload\n",
97
- "%autoreload 2"
98
- ]
99
- },
100
- {
101
- "cell_type": "code",
102
- "execution_count": null,
103
- "metadata": {},
104
- "outputs": [],
105
- "source": []
106
  }
107
  ],
108
  "metadata": {
 
16
  "project_root = Path().resolve().parent\n",
17
  "sys.path.append(str(project_root))\n",
18
  "\n",
19
+ "from yolo import (\n",
20
+ " AugmentationComposer,\n",
21
+ " Config,\n",
22
+ " PostProccess,\n",
23
+ " create_converter,\n",
24
+ " create_model,\n",
25
+ " custom_logger,\n",
26
+ " draw_bboxes,\n",
27
+ ")"
28
  ]
29
  },
30
  {
 
55
  " cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
56
  " model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
57
  " transform = AugmentationComposer([], cfg.image_size)\n",
58
+ " converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n",
 
59
  " post_proccess = PostProccess(converter, cfg.task.nms)"
60
  ]
61
  },
 
92
  "\n",
93
  "![image](../demo/images/output/visualize.png)"
94
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  }
96
  ],
97
  "metadata": {
yolo/__init__.py CHANGED
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
- from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
  from yolo.utils.model_utils import PostProccess
@@ -18,6 +18,7 @@ all = [
18
  "Vec2Box",
19
  "Anc2Box",
20
  "bbox_nms",
 
21
  "AugmentationComposer",
22
  "create_dataloader",
23
  "FastModelLoader",
 
3
  from yolo.tools.data_loader import AugmentationComposer, create_dataloader
4
  from yolo.tools.drawer import draw_bboxes
5
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
6
+ from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
7
  from yolo.utils.deploy_utils import FastModelLoader
8
  from yolo.utils.logging_utils import custom_logger
9
  from yolo.utils.model_utils import PostProccess
 
18
  "Vec2Box",
19
  "Anc2Box",
20
  "bbox_nms",
21
+ "create_converter",
22
  "AugmentationComposer",
23
  "create_dataloader",
24
  "FastModelLoader",
yolo/lazy.py CHANGED
@@ -10,7 +10,7 @@ from yolo.config.config import Config
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
- from yolo.utils.bounding_box_utils import Vec2Box
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
16
  from yolo.utils.model_utils import get_device
@@ -27,13 +27,14 @@ def main(cfg: Config):
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
  model = model.to(device)
29
 
30
- vec2box = Vec2Box(model, cfg.model.anchor, cfg.image_size, device)
 
31
  if cfg.task.task == "train":
32
- solver = ModelTrainer(cfg, model, vec2box, progress, device, use_ddp)
33
  if cfg.task.task == "validation":
34
- solver = ModelValidator(cfg.task, cfg.dataset, model, vec2box, progress, device)
35
  if cfg.task.task == "inference":
36
- solver = ModelTester(cfg, model, vec2box, progress, device)
37
  progress.start()
38
  solver.solve(dataloader)
39
 
 
10
  from yolo.model.yolo import create_model
11
  from yolo.tools.data_loader import create_dataloader
12
  from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
13
+ from yolo.utils.bounding_box_utils import create_converter
14
  from yolo.utils.deploy_utils import FastModelLoader
15
  from yolo.utils.logging_utils import ProgressLogger
16
  from yolo.utils.model_utils import get_device
 
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
  model = model.to(device)
29
 
30
+ converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
31
+
32
  if cfg.task.task == "train":
33
+ solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
34
  if cfg.task.task == "validation":
35
+ solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
36
  if cfg.task.task == "inference":
37
+ solver = ModelTester(cfg, model, converter, progress, device)
38
  progress.start()
39
  solver.solve(dataloader)
40
 
yolo/utils/bounding_box_utils.py CHANGED
@@ -364,6 +364,14 @@ class Anc2Box:
364
  return preds_cls, None, preds_box, preds_cnf.sigmoid()
365
 
366
 
 
 
 
 
 
 
 
 
367
  def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
368
  cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
369
 
 
364
  return preds_cls, None, preds_box, preds_cnf.sigmoid()
365
 
366
 
367
+ def create_converter(model_version: str = "v9-c", *args, **kwargs):
368
+ if "v7" in model_version: # check model if v7
369
+ converter = Anc2Box(*args, **kwargs)
370
+ else:
371
+ converter = Vec2Box(*args, **kwargs)
372
+ return converter
373
+
374
+
375
  def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
376
  cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
377