henry000 commited on
Commit
306fc38
·
1 Parent(s): 9eb2d4e

✨ [Add] automatic load weights

Browse files
Files changed (2) hide show
  1. yolo/config/config.py +2 -0
  2. yolo/model/yolo.py +4 -0
yolo/config/config.py CHANGED
@@ -135,6 +135,8 @@ class Config:
135
  use_wandb: bool
136
  use_TensorBoard: bool
137
 
 
 
138
 
139
  @dataclass
140
  class YOLOLayer(nn.Module):
 
135
  use_wandb: bool
136
  use_TensorBoard: bool
137
 
138
+ weight: Optional[str]
139
+
140
 
141
  @dataclass
142
  class YOLOLayer(nn.Module):
yolo/model/yolo.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Any, Dict, List, Union
2
 
 
3
  import torch.nn as nn
4
  from loguru import logger
5
  from omegaconf import ListConfig, OmegaConf
@@ -125,6 +126,9 @@ def get_model(cfg: Config) -> YOLO:
125
  OmegaConf.set_struct(cfg.model, False)
126
  model = YOLO(cfg.model, cfg.class_num)
127
  logger.info("✅ Success load model")
 
 
 
128
  log_model_structure(model.model)
129
  draw_model(model=model)
130
  return model
 
1
  from typing import Any, Dict, List, Union
2
 
3
+ import torch
4
  import torch.nn as nn
5
  from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
 
126
  OmegaConf.set_struct(cfg.model, False)
127
  model = YOLO(cfg.model, cfg.class_num)
128
  logger.info("✅ Success load model")
129
+ if cfg.weight:
130
+ model.model.load_state_dict(torch.load(cfg.weight))
131
+ logger.info("✅ Success load model weight")
132
  log_model_structure(model.model)
133
  draw_model(model=model)
134
  return model