π [Merge] branch 'MODELv2' into INFERENCE
Browse files- yolo/lazy.py +1 -1
- yolo/model/yolo.py +6 -4
yolo/lazy.py
CHANGED
@@ -26,7 +26,7 @@ def main(cfg: Config):
|
|
26 |
model = FastModelLoader(cfg).load_model()
|
27 |
device = torch.device(cfg.device)
|
28 |
else:
|
29 |
-
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight
|
30 |
|
31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
32 |
|
|
|
26 |
model = FastModelLoader(cfg).load_model()
|
27 |
device = torch.device(cfg.device)
|
28 |
else:
|
29 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
30 |
|
31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
32 |
|
yolo/model/yolo.py
CHANGED
@@ -2,9 +2,9 @@ import os
|
|
2 |
from typing import Any, Dict, List, Union
|
3 |
|
4 |
import torch
|
5 |
-
import torch.nn as nn
|
6 |
from loguru import logger
|
7 |
from omegaconf import ListConfig, OmegaConf
|
|
|
8 |
|
9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
@@ -117,7 +117,9 @@ class YOLO(nn.Module):
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
-
def create_model(
|
|
|
|
|
121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
122 |
|
123 |
Args:
|
@@ -133,9 +135,9 @@ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str =
|
|
133 |
if not os.path.exists(weight_path):
|
134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
135 |
prepare_weight(weight_path=weight_path)
|
136 |
-
model.model.load_state_dict(torch.load(weight_path))
|
137 |
logger.info("β
Success load model weight")
|
138 |
|
139 |
log_model_structure(model.model)
|
140 |
draw_model(model=model)
|
141 |
-
return model
|
|
|
2 |
from typing import Any, Dict, List, Union
|
3 |
|
4 |
import torch
|
|
|
5 |
from loguru import logger
|
6 |
from omegaconf import ListConfig, OmegaConf
|
7 |
+
from torch import device, nn
|
8 |
|
9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
118 |
|
119 |
|
120 |
+
def create_model(
|
121 |
+
model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
|
122 |
+
) -> YOLO:
|
123 |
"""Constructs and returns a model from a Dictionary configuration file.
|
124 |
|
125 |
Args:
|
|
|
135 |
if not os.path.exists(weight_path):
|
136 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
137 |
prepare_weight(weight_path=weight_path)
|
138 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device))
|
139 |
logger.info("β
Success load model weight")
|
140 |
|
141 |
log_model_structure(model.model)
|
142 |
draw_model(model=model)
|
143 |
+
return model.to(device)
|