henry000 commited on
Commit
7c6ce21
Β·
1 Parent(s): 787b81d

πŸ› [Fix] Fix the bug of using hydra

Browse files
Files changed (3) hide show
  1. config/config.py +13 -0
  2. config/config.yaml +8 -0
  3. model/yolo.py +5 -5
config/config.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict, Union
3
+
4
+
5
+ @dataclass
6
+ class Model:
7
+ anchor: List[List[int]]
8
+ model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
9
+
10
+
11
+ @dataclass
12
+ class Config:
13
+ model: Model
config/config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: ./runs
4
+
5
+ defaults:
6
+ - data_config: coco.yaml
7
+ - model: v7-base.yaml
8
+ - _self_
model/yolo.py CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union
4
  import torch
5
  import torch.nn as nn
6
  from loguru import logger
 
7
 
8
  from model import module
9
  from utils.tools import load_model_cfg
@@ -35,16 +36,15 @@ class YOLO(nn.Module):
35
  super(YOLO, self).__init__()
36
  self.nc = model_cfg["nc"]
37
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
38
- self.build_model(model_cfg["model"])
39
 
40
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
41
  model_list = nn.ModuleList()
42
  output_dim = [3]
43
  layer_indices_by_tag = {}
44
-
45
- for arch_name, arch in model_arch.items():
46
  logger.info(f"πŸ—οΈ Building model-{arch_name}")
47
- for layer_idx, layer_spec in enumerate(arch, start=1):
48
  layer_type, layer_info = next(iter(layer_spec.items()))
49
  layer_args = layer_info.get("args", {})
50
  source = layer_info.get("source", -1)
@@ -72,7 +72,7 @@ class YOLO(nn.Module):
72
  def forward(self, x):
73
  y = [x]
74
  for layer in self.model:
75
- if isinstance(layer.source, list):
76
  model_input = [y[idx] for idx in layer.source]
77
  else:
78
  model_input = y[layer.source]
 
4
  import torch
5
  import torch.nn as nn
6
  from loguru import logger
7
+ from omegaconf import OmegaConf
8
 
9
  from model import module
10
  from utils.tools import load_model_cfg
 
36
  super(YOLO, self).__init__()
37
  self.nc = model_cfg["nc"]
38
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
39
+ self.build_model(model_cfg.model)
40
 
41
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
42
  model_list = nn.ModuleList()
43
  output_dim = [3]
44
  layer_indices_by_tag = {}
45
+ for arch_name in model_arch:
 
46
  logger.info(f"πŸ—οΈ Building model-{arch_name}")
47
+ for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=1):
48
  layer_type, layer_info = next(iter(layer_spec.items()))
49
  layer_args = layer_info.get("args", {})
50
  source = layer_info.get("source", -1)
 
72
  def forward(self, x):
73
  y = [x]
74
  for layer in self.model:
75
+ if OmegaConf.is_list(layer.source):
76
  model_input = [y[idx] for idx in layer.source]
77
  else:
78
  model_input = y[layer.source]