:bug: [Fix] the loading weight order download first
Browse files- yolo/model/yolo.py +5 -6
yolo/model/yolo.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import
|
3 |
|
4 |
import torch
|
5 |
from loguru import logger
|
@@ -117,9 +117,7 @@ class YOLO(nn.Module):
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
-
def create_model(
|
121 |
-
model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
|
122 |
-
) -> YOLO:
|
123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
124 |
|
125 |
Args:
|
@@ -135,8 +133,9 @@ def create_model(
|
|
135 |
if not os.path.exists(weight_path):
|
136 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
137 |
prepare_weight(weight_path=weight_path)
|
138 |
-
|
139 |
-
|
|
|
140 |
|
141 |
log_model_structure(model.model)
|
142 |
draw_model(model=model)
|
|
|
1 |
import os
|
2 |
+
from typing import Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
from loguru import logger
|
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
+
def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: device, class_num: int = 80) -> YOLO:
|
|
|
|
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
|
|
133 |
if not os.path.exists(weight_path):
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
+
if os.path.exists(weight_path):
|
137 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device))
|
138 |
+
logger.info("β
Success load model weight")
|
139 |
|
140 |
log_model_structure(model.model)
|
141 |
draw_model(model=model)
|