henry000 commited on
Commit
f99f89b
·
1 Parent(s): 2d52a7f

🔧 [Add] model name into modelcfg, and move reg_max

Browse files
yolo/config/config.py CHANGED
@@ -24,6 +24,7 @@ class BlockConfig:
24
 
25
  @dataclass
26
  class ModelConfig:
 
27
  anchor: AnchorConfig
28
  model: Dict[str, BlockConfig]
29
 
 
24
 
25
  @dataclass
26
  class ModelConfig:
27
+ name: Optional[str]
28
  anchor: AnchorConfig
29
  model: Dict[str, BlockConfig]
30
 
yolo/config/model/v9-c.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  anchor:
2
  reg_max: 16
3
  strides: [8, 16, 32]
@@ -73,8 +75,6 @@ model:
73
  - MultiheadDetection:
74
  source: [P3, P4, P5]
75
  tags: Main
76
- args:
77
- reg_max: ${model.anchor.reg_max}
78
  output: True
79
 
80
  auxiliary:
@@ -129,6 +129,4 @@ model:
129
  - MultiheadDetection:
130
  source: [A3, A4, A5]
131
  tags: AUX
132
- args:
133
- reg_max: ${model.anchor.reg_max}
134
  output: True
 
1
+ name: v9-c
2
+
3
  anchor:
4
  reg_max: 16
5
  strides: [8, 16, 32]
 
75
  - MultiheadDetection:
76
  source: [P3, P4, P5]
77
  tags: Main
 
 
78
  output: True
79
 
80
  auxiliary:
 
129
  - MultiheadDetection:
130
  source: [A3, A4, A5]
131
  tags: AUX
 
 
132
  output: True
yolo/config/model/v9-m.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  anchor:
2
  reg_max: 16
3
 
@@ -72,8 +74,6 @@ model:
72
  - MultiheadDetection:
73
  source: [P3, P4, P5]
74
  tags: Main
75
- args:
76
- reg_max: ${model.anchor.reg_max}
77
  output: True
78
 
79
  auxiliary:
@@ -128,6 +128,4 @@ model:
128
  - MultiheadDetection:
129
  source: [A3, A4, A5]
130
  tags: AUX
131
- args:
132
- reg_max: ${model.anchor.reg_max}
133
  output: True
 
1
+ name: v9-m
2
+
3
  anchor:
4
  reg_max: 16
5
 
 
74
  - MultiheadDetection:
75
  source: [P3, P4, P5]
76
  tags: Main
 
 
77
  output: True
78
 
79
  auxiliary:
 
128
  - MultiheadDetection:
129
  source: [A3, A4, A5]
130
  tags: AUX
 
 
131
  output: True
yolo/config/model/v9-s.yaml CHANGED
@@ -1,3 +1,5 @@
 
 
1
  anchor:
2
  reg_max: 16
3
 
@@ -92,8 +94,6 @@ model:
92
  - MultiheadDetection:
93
  source: [P3, P4, P5]
94
  tags: Main
95
- args:
96
- reg_max: ${model.anchor.reg_max}
97
  output: True
98
 
99
  auxiliary:
@@ -129,6 +129,4 @@ model:
129
  - MultiheadDetection:
130
  source: [A3, A4, A5]
131
  tags: AUX
132
- args:
133
- reg_max: ${model.anchor.reg_max}
134
  output: True
 
1
+ name: v9-s
2
+
3
  anchor:
4
  reg_max: 16
5
 
 
94
  - MultiheadDetection:
95
  source: [P3, P4, P5]
96
  tags: Main
 
 
97
  output: True
98
 
99
  auxiliary:
 
129
  - MultiheadDetection:
130
  source: [A3, A4, A5]
131
  tags: AUX
 
 
132
  output: True
yolo/model/yolo.py CHANGED
@@ -25,8 +25,9 @@ class YOLO(nn.Module):
25
  self.num_classes = class_num
26
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
27
  self.model: List[YOLOLayer] = nn.ModuleList()
28
- self.build_model(model_cfg.model)
29
  self.strides = getattr(model_cfg.anchor, "strides", None)
 
30
 
31
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
32
  self.layer_index = {}
@@ -48,6 +49,7 @@ class YOLO(nn.Module):
48
  if "Detection" in layer_type:
49
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
50
  layer_args["num_classes"] = self.num_classes
 
51
 
52
  # create layers
53
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)
 
25
  self.num_classes = class_num
26
  self.layer_map = get_layer_map() # Get the map Dict[str: Module]
27
  self.model: List[YOLOLayer] = nn.ModuleList()
28
+ self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
29
  self.strides = getattr(model_cfg.anchor, "strides", None)
30
+ self.build_model(model_cfg.model)
31
 
32
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
33
  self.layer_index = {}
 
49
  if "Detection" in layer_type:
50
  layer_args["in_channels"] = [output_dim[idx] for idx in source]
51
  layer_args["num_classes"] = self.num_classes
52
+ layer_args["reg_max"] = self.reg_max
53
 
54
  # create layers
55
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)