henry000 commited on
Commit
18edc1d
Β·
1 Parent(s): 70a7f92

:bug: [Fix] the loading weight order download first

Browse files
Files changed (1) hide show
  1. yolo/model/yolo.py +5 -6
yolo/model/yolo.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Any, Dict, List, Union
3
 
4
  import torch
5
  from loguru import logger
@@ -117,9 +117,7 @@ class YOLO(nn.Module):
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
- def create_model(
121
- model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
122
- ) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
@@ -135,8 +133,9 @@ def create_model(
135
  if not os.path.exists(weight_path):
136
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
137
  prepare_weight(weight_path=weight_path)
138
- model.model.load_state_dict(torch.load(weight_path, map_location=device))
139
- logger.info("βœ… Success load model weight")
 
140
 
141
  log_model_structure(model.model)
142
  draw_model(model=model)
 
1
  import os
2
+ from typing import Dict, List, Optional, Union
3
 
4
  import torch
5
  from loguru import logger
 
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
+ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: device, class_num: int = 80) -> YOLO:
 
 
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
 
133
  if not os.path.exists(weight_path):
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
+ if os.path.exists(weight_path):
137
+ model.model.load_state_dict(torch.load(weight_path, map_location=device))
138
+ logger.info("βœ… Success load model weight")
139
 
140
  log_model_structure(model.model)
141
  draw_model(model=model)