henry000 commited on
Commit
1c110af
Β·
1 Parent(s): 635f41a

πŸ”€ [Merge] branch 'MODELv2' into INFERENCE

Browse files
Files changed (2) hide show
  1. yolo/lazy.py +1 -1
  2. yolo/model/yolo.py +6 -4
yolo/lazy.py CHANGED
@@ -26,7 +26,7 @@ def main(cfg: Config):
26
  model = FastModelLoader(cfg).load_model()
27
  device = torch.device(cfg.device)
28
  else:
29
- model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
30
 
31
  vec2box = Vec2Box(model, cfg.image_size, device)
32
 
 
26
  model = FastModelLoader(cfg).load_model()
27
  device = torch.device(cfg.device)
28
  else:
29
+ model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
30
 
31
  vec2box = Vec2Box(model, cfg.image_size, device)
32
 
yolo/model/yolo.py CHANGED
@@ -2,9 +2,9 @@ import os
2
  from typing import Any, Dict, List, Union
3
 
4
  import torch
5
- import torch.nn as nn
6
  from loguru import logger
7
  from omegaconf import ListConfig, OmegaConf
 
8
 
9
  from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
@@ -117,7 +117,9 @@ class YOLO(nn.Module):
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
- def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt") -> YOLO:
 
 
121
  """Constructs and returns a model from a Dictionary configuration file.
122
 
123
  Args:
@@ -133,9 +135,9 @@ def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str =
133
  if not os.path.exists(weight_path):
134
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
135
  prepare_weight(weight_path=weight_path)
136
- model.model.load_state_dict(torch.load(weight_path))
137
  logger.info("βœ… Success load model weight")
138
 
139
  log_model_structure(model.model)
140
  draw_model(model=model)
141
- return model
 
2
  from typing import Any, Dict, List, Union
3
 
4
  import torch
 
5
  from loguru import logger
6
  from omegaconf import ListConfig, OmegaConf
7
+ from torch import device, nn
8
 
9
  from yolo.config.config import Config, ModelConfig, YOLOLayer
10
  from yolo.tools.dataset_preparation import prepare_weight
 
117
  raise ValueError(f"Unsupported layer type: {layer_type}")
118
 
119
 
120
+ def create_model(
121
+ model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt", device: device = device("cuda")
122
+ ) -> YOLO:
123
  """Constructs and returns a model from a Dictionary configuration file.
124
 
125
  Args:
 
135
  if not os.path.exists(weight_path):
136
  logger.info(f"🌐 Weight {weight_path} not found, try downloading")
137
  prepare_weight(weight_path=weight_path)
138
+ model.model.load_state_dict(torch.load(weight_path, map_location=device))
139
  logger.info("βœ… Success load model weight")
140
 
141
  log_model_structure(model.model)
142
  draw_model(model=model)
143
+ return model.to(device)