🔧 [Add] model name into modelcfg, and move reg_max
Browse files- yolo/config/config.py +1 -0
- yolo/config/model/v9-c.yaml +2 -4
- yolo/config/model/v9-m.yaml +2 -4
- yolo/config/model/v9-s.yaml +2 -4
- yolo/model/yolo.py +3 -1
yolo/config/config.py
CHANGED
@@ -24,6 +24,7 @@ class BlockConfig:
|
|
24 |
|
25 |
@dataclass
|
26 |
class ModelConfig:
|
|
|
27 |
anchor: AnchorConfig
|
28 |
model: Dict[str, BlockConfig]
|
29 |
|
|
|
24 |
|
25 |
@dataclass
|
26 |
class ModelConfig:
|
27 |
+
name: Optional[str]
|
28 |
anchor: AnchorConfig
|
29 |
model: Dict[str, BlockConfig]
|
30 |
|
yolo/config/model/v9-c.yaml
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
3 |
strides: [8, 16, 32]
|
@@ -73,8 +75,6 @@ model:
|
|
73 |
- MultiheadDetection:
|
74 |
source: [P3, P4, P5]
|
75 |
tags: Main
|
76 |
-
args:
|
77 |
-
reg_max: ${model.anchor.reg_max}
|
78 |
output: True
|
79 |
|
80 |
auxiliary:
|
@@ -129,6 +129,4 @@ model:
|
|
129 |
- MultiheadDetection:
|
130 |
source: [A3, A4, A5]
|
131 |
tags: AUX
|
132 |
-
args:
|
133 |
-
reg_max: ${model.anchor.reg_max}
|
134 |
output: True
|
|
|
1 |
+
name: v9-c
|
2 |
+
|
3 |
anchor:
|
4 |
reg_max: 16
|
5 |
strides: [8, 16, 32]
|
|
|
75 |
- MultiheadDetection:
|
76 |
source: [P3, P4, P5]
|
77 |
tags: Main
|
|
|
|
|
78 |
output: True
|
79 |
|
80 |
auxiliary:
|
|
|
129 |
- MultiheadDetection:
|
130 |
source: [A3, A4, A5]
|
131 |
tags: AUX
|
|
|
|
|
132 |
output: True
|
yolo/config/model/v9-m.yaml
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
3 |
|
@@ -72,8 +74,6 @@ model:
|
|
72 |
- MultiheadDetection:
|
73 |
source: [P3, P4, P5]
|
74 |
tags: Main
|
75 |
-
args:
|
76 |
-
reg_max: ${model.anchor.reg_max}
|
77 |
output: True
|
78 |
|
79 |
auxiliary:
|
@@ -128,6 +128,4 @@ model:
|
|
128 |
- MultiheadDetection:
|
129 |
source: [A3, A4, A5]
|
130 |
tags: AUX
|
131 |
-
args:
|
132 |
-
reg_max: ${model.anchor.reg_max}
|
133 |
output: True
|
|
|
1 |
+
name: v9-m
|
2 |
+
|
3 |
anchor:
|
4 |
reg_max: 16
|
5 |
|
|
|
74 |
- MultiheadDetection:
|
75 |
source: [P3, P4, P5]
|
76 |
tags: Main
|
|
|
|
|
77 |
output: True
|
78 |
|
79 |
auxiliary:
|
|
|
128 |
- MultiheadDetection:
|
129 |
source: [A3, A4, A5]
|
130 |
tags: AUX
|
|
|
|
|
131 |
output: True
|
yolo/config/model/v9-s.yaml
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
anchor:
|
2 |
reg_max: 16
|
3 |
|
@@ -92,8 +94,6 @@ model:
|
|
92 |
- MultiheadDetection:
|
93 |
source: [P3, P4, P5]
|
94 |
tags: Main
|
95 |
-
args:
|
96 |
-
reg_max: ${model.anchor.reg_max}
|
97 |
output: True
|
98 |
|
99 |
auxiliary:
|
@@ -129,6 +129,4 @@ model:
|
|
129 |
- MultiheadDetection:
|
130 |
source: [A3, A4, A5]
|
131 |
tags: AUX
|
132 |
-
args:
|
133 |
-
reg_max: ${model.anchor.reg_max}
|
134 |
output: True
|
|
|
1 |
+
name: v9-s
|
2 |
+
|
3 |
anchor:
|
4 |
reg_max: 16
|
5 |
|
|
|
94 |
- MultiheadDetection:
|
95 |
source: [P3, P4, P5]
|
96 |
tags: Main
|
|
|
|
|
97 |
output: True
|
98 |
|
99 |
auxiliary:
|
|
|
129 |
- MultiheadDetection:
|
130 |
source: [A3, A4, A5]
|
131 |
tags: AUX
|
|
|
|
|
132 |
output: True
|
yolo/model/yolo.py
CHANGED
@@ -25,8 +25,9 @@ class YOLO(nn.Module):
|
|
25 |
self.num_classes = class_num
|
26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
28 |
-
self.
|
29 |
self.strides = getattr(model_cfg.anchor, "strides", None)
|
|
|
30 |
|
31 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
32 |
self.layer_index = {}
|
@@ -48,6 +49,7 @@ class YOLO(nn.Module):
|
|
48 |
if "Detection" in layer_type:
|
49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
50 |
layer_args["num_classes"] = self.num_classes
|
|
|
51 |
|
52 |
# create layers
|
53 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
|
|
25 |
self.num_classes = class_num
|
26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
28 |
+
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
|
29 |
self.strides = getattr(model_cfg.anchor, "strides", None)
|
30 |
+
self.build_model(model_cfg.model)
|
31 |
|
32 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
33 |
self.layer_index = {}
|
|
|
49 |
if "Detection" in layer_type:
|
50 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
51 |
layer_args["num_classes"] = self.num_classes
|
52 |
+
layer_args["reg_max"] = self.reg_max
|
53 |
|
54 |
# create layers
|
55 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|