henry000 commited on
Commit
f5518c0
Β·
1 Parent(s): f99f89b

πŸ› [Fix] a bug when loading the model

Browse files
yolo/config/general.yaml CHANGED
@@ -12,4 +12,4 @@ lucky_number: 10
12
  use_wandb: False
13
  use_TensorBoard: False
14
 
15
- weight: weights/v9-c.pt
 
12
  use_wandb: False
13
  use_TensorBoard: False
14
 
15
+ weight: True # Path to weight or True for auto, False for no pretrained weight
yolo/lazy.py CHANGED
@@ -22,10 +22,10 @@ def main(cfg: Config):
22
  device, use_ddp = get_device(cfg.device)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
  if getattr(cfg.task, "fast_inference", False):
25
- model = FastModelLoader(cfg).load_model()
26
  else:
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
- model = model.to(device)
29
 
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
 
22
  device, use_ddp = get_device(cfg.device)
23
  dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task, use_ddp)
24
  if getattr(cfg.task, "fast_inference", False):
25
+ model = FastModelLoader(cfg).load_model(device)
26
  else:
27
  model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
28
+ model = model.to(device)
29
 
30
  vec2box = Vec2Box(model, cfg.image_size, device)
31
 
yolo/model/yolo.py CHANGED
@@ -119,7 +119,7 @@ class YOLO(nn.Module):
119
  raise ValueError(f"Unsupported layer type: {layer_type}")
120
 
121
 
122
- def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num: int = 80) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
@@ -132,12 +132,14 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num:
132
  OmegaConf.set_struct(model_cfg, False)
133
  model = YOLO(model_cfg, class_num)
134
  if weight_path:
 
 
135
  if not os.path.exists(weight_path):
136
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
137
  prepare_weight(weight_path=weight_path)
138
  if os.path.exists(weight_path):
139
  # TODO: fix map_location
140
- model.model.load_state_dict(torch.load(weight_path), strict=False)
141
  logger.info("βœ… Success load model & weight")
142
  else:
143
  logger.info("βœ… Success load model")
 
119
  raise ValueError(f"Unsupported layer type: {layer_type}")
120
 
121
 
122
+ def create_model(model_cfg: ModelConfig, weight_path: Union[bool, str], class_num: int = 80) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
 
132
  OmegaConf.set_struct(model_cfg, False)
133
  model = YOLO(model_cfg, class_num)
134
  if weight_path:
135
+ if weight_path == True:
136
+ weight_path = os.path.join("weights", f"{model_cfg.name}.pt")
137
  if not os.path.exists(weight_path):
138
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
139
  prepare_weight(weight_path=weight_path)
140
  if os.path.exists(weight_path):
141
  # TODO: fix map_location
142
+ model.model.load_state_dict(torch.load(weight_path, map_location=torch.device("cpu")), strict=False)
143
  logger.info("βœ… Success load model & weight")
144
  else:
145
  logger.info("βœ… Success load model")
yolo/utils/deploy_utils.py CHANGED
@@ -13,6 +13,8 @@ class FastModelLoader:
13
  self.cfg = cfg
14
  self.compiler = cfg.task.fast_inference
15
  self._validate_compiler()
 
 
16
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
17
 
18
  def _validate_compiler(self):
@@ -23,14 +25,14 @@ class FastModelLoader:
23
  logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
24
  self.compiler = None
25
 
26
- def load_model(self):
27
  if self.compiler == "onnx":
28
  return self._load_onnx_model()
29
  elif self.compiler == "trt":
30
- return self._load_trt_model()
31
  elif self.compiler == "deploy":
32
  self.cfg.model.model.auxiliary = {}
33
- return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
34
 
35
  def _load_onnx_model(self):
36
  from onnxruntime import InferenceSession
 
13
  self.cfg = cfg
14
  self.compiler = cfg.task.fast_inference
15
  self._validate_compiler()
16
+ if cfg.weight == True:
17
+ cfg.weight = os.path.join("weights", f"{cfg.model.name}.pt")
18
  self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
19
 
20
  def _validate_compiler(self):
 
25
  logger.warning("🍎 TensorRT does not support MPS devices. Using original model.")
26
  self.compiler = None
27
 
28
+ def load_model(self, device):
29
  if self.compiler == "onnx":
30
  return self._load_onnx_model()
31
  elif self.compiler == "trt":
32
+ return self._load_trt_model().to(device)
33
  elif self.compiler == "deploy":
34
  self.cfg.model.model.auxiliary = {}
35
+ return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).to(device)
36
 
37
  def _load_onnx_model(self):
38
  from onnxruntime import InferenceSession