henry000 commited on
Commit
745aab9
Β·
1 Parent(s): ba3c274

🎨 [Update] logging model place

Browse files
Files changed (2) hide show
  1. examples/example_train.py +0 -2
  2. yolo/model/yolo.py +3 -2
examples/example_train.py CHANGED
@@ -13,7 +13,6 @@ from yolo.model.yolo import get_model
13
  from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
16
- from yolo.utils.drawer import draw_model
17
  from yolo.utils.get_dataset import prepare_dataset
18
 
19
 
@@ -24,7 +23,6 @@ def main(cfg: Config):
24
 
25
  dataloader = get_dataloader(cfg)
26
  model = get_model(cfg)
27
- draw_model(model=model)
28
  # TODO: get_device or rank, for DDP mode
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
 
 
13
  from yolo.tools.log_helper import custom_logger
14
  from yolo.tools.trainer import Trainer
15
  from yolo.utils.dataloader import get_dataloader
 
16
  from yolo.utils.get_dataset import prepare_dataset
17
 
18
 
 
23
 
24
  dataloader = get_dataloader(cfg)
25
  model = get_model(cfg)
 
26
  # TODO: get_device or rank, for DDP mode
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
yolo/model/yolo.py CHANGED
@@ -7,6 +7,7 @@ from omegaconf import ListConfig, OmegaConf
7
  from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
9
  from yolo.tools.log_helper import log_model
 
10
 
11
 
12
  class YOLO(nn.Module):
@@ -24,8 +25,6 @@ class YOLO(nn.Module):
24
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
25
  self.model: List[YOLOLayer] = nn.ModuleList()
26
  self.build_model(model_cfg.model)
27
- # TODO: Move to other position
28
- log_model(self.model)
29
 
30
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
31
  self.layer_index = {}
@@ -126,4 +125,6 @@ def get_model(cfg: Config) -> YOLO:
126
  OmegaConf.set_struct(cfg.model, False)
127
  model = YOLO(cfg.model, cfg.hyper.data.class_num)
128
  logger.info("βœ… Success load model")
 
 
129
  return model
 
7
  from yolo.config.config import Config, Model, YOLOLayer
8
  from yolo.tools.layer_helper import get_layer_map
9
  from yolo.tools.log_helper import log_model
10
+ from yolo.utils.drawer import draw_model
11
 
12
 
13
  class YOLO(nn.Module):
 
25
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
26
  self.model: List[YOLOLayer] = nn.ModuleList()
27
  self.build_model(model_cfg.model)
 
 
28
 
29
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
30
  self.layer_index = {}
 
125
  OmegaConf.set_struct(cfg.model, False)
126
  model = YOLO(cfg.model, cfg.hyper.data.class_num)
127
  logger.info("βœ… Success load model")
128
+ log_model(model.model)
129
+ draw_model(model=model)
130
  return model