π [Fix] Fix the bug of using hydra
Browse files- config/config.py +13 -0
- config/config.yaml +8 -0
- model/yolo.py +5 -5
config/config.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Dict, Union
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class Model:
|
7 |
+
anchor: List[List[int]]
|
8 |
+
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
9 |
+
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class Config:
|
13 |
+
model: Model
|
config/config.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: ./runs
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- data_config: coco.yaml
|
7 |
+
- model: v7-base.yaml
|
8 |
+
- _self_
|
model/yolo.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from loguru import logger
|
|
|
7 |
|
8 |
from model import module
|
9 |
from utils.tools import load_model_cfg
|
@@ -35,16 +36,15 @@ class YOLO(nn.Module):
|
|
35 |
super(YOLO, self).__init__()
|
36 |
self.nc = model_cfg["nc"]
|
37 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
38 |
-
self.build_model(model_cfg
|
39 |
|
40 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
41 |
model_list = nn.ModuleList()
|
42 |
output_dim = [3]
|
43 |
layer_indices_by_tag = {}
|
44 |
-
|
45 |
-
for arch_name, arch in model_arch.items():
|
46 |
logger.info(f"ποΈ Building model-{arch_name}")
|
47 |
-
for layer_idx, layer_spec in enumerate(
|
48 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
49 |
layer_args = layer_info.get("args", {})
|
50 |
source = layer_info.get("source", -1)
|
@@ -72,7 +72,7 @@ class YOLO(nn.Module):
|
|
72 |
def forward(self, x):
|
73 |
y = [x]
|
74 |
for layer in self.model:
|
75 |
-
if
|
76 |
model_input = [y[idx] for idx in layer.source]
|
77 |
else:
|
78 |
model_input = y[layer.source]
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from loguru import logger
|
7 |
+
from omegaconf import OmegaConf
|
8 |
|
9 |
from model import module
|
10 |
from utils.tools import load_model_cfg
|
|
|
36 |
super(YOLO, self).__init__()
|
37 |
self.nc = model_cfg["nc"]
|
38 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
39 |
+
self.build_model(model_cfg.model)
|
40 |
|
41 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
42 |
model_list = nn.ModuleList()
|
43 |
output_dim = [3]
|
44 |
layer_indices_by_tag = {}
|
45 |
+
for arch_name in model_arch:
|
|
|
46 |
logger.info(f"ποΈ Building model-{arch_name}")
|
47 |
+
for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=1):
|
48 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
49 |
layer_args = layer_info.get("args", {})
|
50 |
source = layer_info.get("source", -1)
|
|
|
72 |
def forward(self, x):
|
73 |
y = [x]
|
74 |
for layer in self.model:
|
75 |
+
if OmegaConf.is_list(layer.source):
|
76 |
model_input = [y[idx] for idx in layer.source]
|
77 |
else:
|
78 |
model_input = y[layer.source]
|