✨ [Add] automatic load weights
Browse files- yolo/config/config.py +2 -0
- yolo/model/yolo.py +4 -0
yolo/config/config.py
CHANGED
@@ -135,6 +135,8 @@ class Config:
|
|
135 |
use_wandb: bool
|
136 |
use_TensorBoard: bool
|
137 |
|
|
|
|
|
138 |
|
139 |
@dataclass
|
140 |
class YOLOLayer(nn.Module):
|
|
|
135 |
use_wandb: bool
|
136 |
use_TensorBoard: bool
|
137 |
|
138 |
+
weight: Optional[str]
|
139 |
+
|
140 |
|
141 |
@dataclass
|
142 |
class YOLOLayer(nn.Module):
|
yolo/model/yolo.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from typing import Any, Dict, List, Union
|
2 |
|
|
|
3 |
import torch.nn as nn
|
4 |
from loguru import logger
|
5 |
from omegaconf import ListConfig, OmegaConf
|
@@ -125,6 +126,9 @@ def get_model(cfg: Config) -> YOLO:
|
|
125 |
OmegaConf.set_struct(cfg.model, False)
|
126 |
model = YOLO(cfg.model, cfg.class_num)
|
127 |
logger.info("✅ Success load model")
|
|
|
|
|
|
|
128 |
log_model_structure(model.model)
|
129 |
draw_model(model=model)
|
130 |
return model
|
|
|
1 |
from typing import Any, Dict, List, Union
|
2 |
|
3 |
+
import torch
|
4 |
import torch.nn as nn
|
5 |
from loguru import logger
|
6 |
from omegaconf import ListConfig, OmegaConf
|
|
|
126 |
OmegaConf.set_struct(cfg.model, False)
|
127 |
model = YOLO(cfg.model, cfg.class_num)
|
128 |
logger.info("✅ Success load model")
|
129 |
+
if cfg.weight:
|
130 |
+
model.model.load_state_dict(torch.load(cfg.weight))
|
131 |
+
logger.info("✅ Success load model weight")
|
132 |
log_model_structure(model.model)
|
133 |
draw_model(model=model)
|
134 |
return model
|