henry000 commited on
Commit
40afe67
·
1 Parent(s): 7daf6f0

♻️ [Refactor] the create_model func, clearify input

Browse files
Files changed (2) hide show
  1. yolo/lazy.py +1 -1
  2. yolo/model/yolo.py +8 -8
yolo/lazy.py CHANGED
@@ -25,7 +25,7 @@ def main(cfg: Config):
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
- model = create_model(cfg).to(device)
29
 
30
  if cfg.task.task == "train":
31
  trainer = ModelTrainer(cfg, model, save_path, device)
 
25
  model = FastModelLoader(cfg).load_model()
26
  device = torch.device(cfg.device)
27
  else:
28
+ model = create_model(cfg.model, cfg.weight).to(device)
29
 
30
  if cfg.task.task == "train":
31
  trainer = ModelTrainer(cfg, model, save_path, device)
yolo/model/yolo.py CHANGED
@@ -116,7 +116,7 @@ class YOLO(nn.Module):
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
 
119
- def create_model(cfg: Config) -> YOLO:
120
  """Constructs and returns a model from a Dictionary configuration file.
121
 
122
  Args:
@@ -125,16 +125,16 @@ def create_model(cfg: Config) -> YOLO:
125
  Returns:
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
128
- OmegaConf.set_struct(cfg.model, False)
129
- model = YOLO(cfg.model)
130
  logger.info("✅ Success load model")
131
- if cfg.weight:
132
- if os.path.exists(cfg.weight):
133
- model.model.load_state_dict(torch.load(cfg.weight), strict=False)
134
  logger.info("✅ Success load model weight")
135
  else:
136
- logger.info(f"🌐 Weight {cfg.weight} not found, try downloading")
137
- prepare_weight(weight_path=cfg.weight)
138
 
139
  log_model_structure(model.model)
140
  draw_model(model=model)
 
116
  raise ValueError(f"Unsupported layer type: {layer_type}")
117
 
118
 
119
+ def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
120
  """Constructs and returns a model from a Dictionary configuration file.
121
 
122
  Args:
 
125
  Returns:
126
  YOLO: An instance of the model defined by the given configuration.
127
  """
128
+ OmegaConf.set_struct(model_cfg, False)
129
+ model = YOLO(model_cfg)
130
  logger.info("✅ Success load model")
131
+ if weight_path:
132
+ if os.path.exists(weight_path):
133
+ model.model.load_state_dict(torch.load(weight_path), strict=False)
134
  logger.info("✅ Success load model weight")
135
  else:
136
+ logger.info(f"🌐 Weight {weight_path} not found, try downloading")
137
+ prepare_weight(weight_path=weight_path)
138
 
139
  log_model_structure(model.model)
140
  draw_model(model=model)