♻️ [Refactor] the create_model func, clearify input
Browse files- yolo/lazy.py +1 -1
- yolo/model/yolo.py +8 -8
yolo/lazy.py
CHANGED
@@ -25,7 +25,7 @@ def main(cfg: Config):
|
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
-
model = create_model(cfg).to(device)
|
29 |
|
30 |
if cfg.task.task == "train":
|
31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
|
|
25 |
model = FastModelLoader(cfg).load_model()
|
26 |
device = torch.device(cfg.device)
|
27 |
else:
|
28 |
+
model = create_model(cfg.model, cfg.weight).to(device)
|
29 |
|
30 |
if cfg.task.task == "train":
|
31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
yolo/model/yolo.py
CHANGED
@@ -116,7 +116,7 @@ class YOLO(nn.Module):
|
|
116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
117 |
|
118 |
|
119 |
-
def create_model(
|
120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
121 |
|
122 |
Args:
|
@@ -125,16 +125,16 @@ def create_model(cfg: Config) -> YOLO:
|
|
125 |
Returns:
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
128 |
-
OmegaConf.set_struct(
|
129 |
-
model = YOLO(
|
130 |
logger.info("✅ Success load model")
|
131 |
-
if
|
132 |
-
if os.path.exists(
|
133 |
-
model.model.load_state_dict(torch.load(
|
134 |
logger.info("✅ Success load model weight")
|
135 |
else:
|
136 |
-
logger.info(f"🌐 Weight {
|
137 |
-
prepare_weight(weight_path=
|
138 |
|
139 |
log_model_structure(model.model)
|
140 |
draw_model(model=model)
|
|
|
116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
117 |
|
118 |
|
119 |
+
def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
|
120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
121 |
|
122 |
Args:
|
|
|
125 |
Returns:
|
126 |
YOLO: An instance of the model defined by the given configuration.
|
127 |
"""
|
128 |
+
OmegaConf.set_struct(model_cfg, False)
|
129 |
+
model = YOLO(model_cfg)
|
130 |
logger.info("✅ Success load model")
|
131 |
+
if weight_path:
|
132 |
+
if os.path.exists(weight_path):
|
133 |
+
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
134 |
logger.info("✅ Success load model weight")
|
135 |
else:
|
136 |
+
logger.info(f"🌐 Weight {weight_path} not found, try downloading")
|
137 |
+
prepare_weight(weight_path=weight_path)
|
138 |
|
139 |
log_model_structure(model.model)
|
140 |
draw_model(model=model)
|