π [Fix] a bug when loading the model
Browse files- yolo/config/general.yaml +1 -1
- yolo/lazy.py +2 -2
- yolo/model/yolo.py +4 -2
- yolo/utils/deploy_utils.py +5 -3
yolo/config/general.yaml
CHANGED
@@ -12,4 +12,4 @@ lucky_number: 10
|
|
12 |
use_wandb: False
|
13 |
use_TensorBoard: False
|
14 |
|
15 |
-
weight:
|
|
|
12 |
use_wandb: False
|
13 |
use_TensorBoard: False
|
14 |
|
15 |
+
weight: True # Path to weight or True for auto, False for no pretrained weight
|
yolo/lazy.py
CHANGED
@@ -22,10 +22,10 @@ def main(cfg: Config):
|
|
22 |
device, use_ddp = get_device(cfg.device)
|
23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
-
model = FastModelLoader(cfg).load_model()
|
26 |
else:
|
27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
-
|
29 |
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
|
|
22 |
device, use_ddp = get_device(cfg.device)
|
23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
|
24 |
if getattr(cfg.task, "fast_inference", False):
|
25 |
+
model = FastModelLoader(cfg).load_model(device)
|
26 |
else:
|
27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
28 |
+
model = model.to(device)
|
29 |
|
30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
31 |
|
yolo/model/yolo.py
CHANGED
@@ -119,7 +119,7 @@ class YOLO(nn.Module):
|
|
119 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
120 |
|
121 |
|
122 |
-
def create_model(model_cfg: ModelConfig, weight_path:
|
123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
124 |
|
125 |
Args:
|
@@ -132,12 +132,14 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num:
|
|
132 |
OmegaConf.set_struct(model_cfg, False)
|
133 |
model = YOLO(model_cfg, class_num)
|
134 |
if weight_path:
|
|
|
|
|
135 |
if not os.path.exists(weight_path):
|
136 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
137 |
prepare_weight(weight_path=weight_path)
|
138 |
if os.path.exists(weight_path):
|
139 |
# TODO: fix map_location
|
140 |
-
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
141 |
logger.info("β
Success load model & weight")
|
142 |
else:
|
143 |
logger.info("β
Success load model")
|
|
|
119 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
120 |
|
121 |
|
122 |
+
def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str], class_num: int = 80) -> YOLO:
|
123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
124 |
|
125 |
Args:
|
|
|
132 |
OmegaConf.set_struct(model_cfg, False)
|
133 |
model = YOLO(model_cfg, class_num)
|
134 |
if weight_path:
|
135 |
+
if weight_path == True:
|
136 |
+
weight_path = os.path.join("weights", f"{model_cfg.name}.pt")
|
137 |
if not os.path.exists(weight_path):
|
138 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
139 |
prepare_weight(weight_path=weight_path)
|
140 |
if os.path.exists(weight_path):
|
141 |
# TODO: fix map_location
|
142 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
|
143 |
logger.info("β
Success load model & weight")
|
144 |
else:
|
145 |
logger.info("β
Success load model")
|
yolo/utils/deploy_utils.py
CHANGED
@@ -13,6 +13,8 @@ class FastModelLoader:
|
|
13 |
self.cfg = cfg
|
14 |
self.compiler = cfg.task.fast_inference
|
15 |
self._validate_compiler()
|
|
|
|
|
16 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
17 |
|
18 |
def _validate_compiler(self):
|
@@ -23,14 +25,14 @@ class FastModelLoader:
|
|
23 |
logger.warning("π TensorRT does not support MPS devices. Using original model.")
|
24 |
self.compiler = None
|
25 |
|
26 |
-
def load_model(self):
|
27 |
if self.compiler == "onnx":
|
28 |
return self._load_onnx_model()
|
29 |
elif self.compiler == "trt":
|
30 |
-
return self._load_trt_model()
|
31 |
elif self.compiler == "deploy":
|
32 |
self.cfg.model.model.auxiliary = {}
|
33 |
-
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
|
34 |
|
35 |
def _load_onnx_model(self):
|
36 |
from onnxruntime import InferenceSession
|
|
|
13 |
self.cfg = cfg
|
14 |
self.compiler = cfg.task.fast_inference
|
15 |
self._validate_compiler()
|
16 |
+
if cfg.weight == True:
|
17 |
+
cfg.weight = os.path.join("weights", f"{cfg.model.name}.pt")
|
18 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
19 |
|
20 |
def _validate_compiler(self):
|
|
|
25 |
logger.warning("π TensorRT does not support MPS devices. Using original model.")
|
26 |
self.compiler = None
|
27 |
|
28 |
+
def load_model(self, device):
|
29 |
if self.compiler == "onnx":
|
30 |
return self._load_onnx_model()
|
31 |
elif self.compiler == "trt":
|
32 |
+
return self._load_trt_model().to(device)
|
33 |
elif self.compiler == "deploy":
|
34 |
self.cfg.model.model.auxiliary = {}
|
35 |
+
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
|
36 |
|
37 |
def _load_onnx_model(self):
|
38 |
from onnxruntime import InferenceSession
|