henry000 commited on
Commit
3186e72
·
1 Parent(s): f211bec

✨ [Add] indicate output layer to model

Browse files
Files changed (2) hide show
  1. config/model/v7-base.yaml +1 -0
  2. model/yolo.py +8 -3
config/model/v7-base.yaml CHANGED
@@ -241,3 +241,4 @@ model:
241
  - [36,75, 76,55, 72,146] # P4/16
242
  - [142,110, 192,243, 459,401] # P5/32
243
  source: [102, 103, 104]
 
 
241
  - [36,75, 76,55, 72,146] # P4/16
242
  - [142,110, 192,243, 459,401] # P5/32
243
  source: [102, 103, 104]
244
+ output: True
model/yolo.py CHANGED
@@ -48,6 +48,7 @@ class YOLO(nn.Module):
48
  layer_type, layer_info = next(iter(layer_spec.items()))
49
  layer_args = layer_info.get("args", {})
50
  source = layer_info.get("source", -1)
 
51
 
52
  if isinstance(source, str):
53
  source = layer_indices_by_tag[source]
@@ -57,7 +58,7 @@ class YOLO(nn.Module):
57
  layer_args["nc"] = self.nc
58
  layer_args["ch"] = [output_dim[idx] for idx in source]
59
 
60
- layer = self.create_layer(layer_type, source, **layer_args)
61
  model_list.append(layer)
62
 
63
  if "tags" in layer_info:
@@ -71,6 +72,7 @@ class YOLO(nn.Module):
71
 
72
  def forward(self, x):
73
  y = [x]
 
74
  for layer in self.model:
75
  if isinstance(layer.source, list):
76
  model_input = [y[idx] for idx in layer.source]
@@ -78,7 +80,9 @@ class YOLO(nn.Module):
78
  model_input = y[layer.source]
79
  x = layer(model_input)
80
  y.append(x)
81
- return x
 
 
82
 
83
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
84
  if "Conv" in layer_type:
@@ -90,10 +94,11 @@ class YOLO(nn.Module):
90
  if layer_type == "IDetect":
91
  return None
92
 
93
- def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
94
  if layer_type in self.layer_map:
95
  layer = self.layer_map[layer_type](**kwargs)
96
  layer.source = source
 
97
  return layer
98
  else:
99
  raise ValueError(f"Unsupported layer type: {layer_type}")
 
48
  layer_type, layer_info = next(iter(layer_spec.items()))
49
  layer_args = layer_info.get("args", {})
50
  source = layer_info.get("source", -1)
51
+ output = layer_info.get("output", False)
52
 
53
  if isinstance(source, str):
54
  source = layer_indices_by_tag[source]
 
58
  layer_args["nc"] = self.nc
59
  layer_args["ch"] = [output_dim[idx] for idx in source]
60
 
61
+ layer = self.create_layer(layer_type, source, output, **layer_args)
62
  model_list.append(layer)
63
 
64
  if "tags" in layer_info:
 
72
 
73
  def forward(self, x):
74
  y = [x]
75
+ output = []
76
  for layer in self.model:
77
  if isinstance(layer.source, list):
78
  model_input = [y[idx] for idx in layer.source]
 
80
  model_input = y[layer.source]
81
  x = layer(model_input)
82
  y.append(x)
83
+ if layer.output:
84
+ output.append(x)
85
+ return output
86
 
87
  def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
88
  if "Conv" in layer_type:
 
94
  if layer_type == "IDetect":
95
  return None
96
 
97
+ def create_layer(self, layer_type: str, source: Union[int, list], output=False, **kwargs):
98
  if layer_type in self.layer_map:
99
  layer = self.layer_map[layer_type](**kwargs)
100
  layer.source = source
101
+ layer.output = output
102
  return layer
103
  else:
104
  raise ValueError(f"Unsupported layer type: {layer_type}")