✨ [Add] indicate output layer to model
Browse files- config/model/v7-base.yaml +1 -0
- model/yolo.py +8 -3
config/model/v7-base.yaml
CHANGED
@@ -241,3 +241,4 @@ model:
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
|
|
|
241 |
- [36,75, 76,55, 72,146] # P4/16
|
242 |
- [142,110, 192,243, 459,401] # P5/32
|
243 |
source: [102, 103, 104]
|
244 |
+
output: True
|
model/yolo.py
CHANGED
@@ -48,6 +48,7 @@ class YOLO(nn.Module):
|
|
48 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
49 |
layer_args = layer_info.get("args", {})
|
50 |
source = layer_info.get("source", -1)
|
|
|
51 |
|
52 |
if isinstance(source, str):
|
53 |
source = layer_indices_by_tag[source]
|
@@ -57,7 +58,7 @@ class YOLO(nn.Module):
|
|
57 |
layer_args["nc"] = self.nc
|
58 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
59 |
|
60 |
-
layer = self.create_layer(layer_type, source, **layer_args)
|
61 |
model_list.append(layer)
|
62 |
|
63 |
if "tags" in layer_info:
|
@@ -71,6 +72,7 @@ class YOLO(nn.Module):
|
|
71 |
|
72 |
def forward(self, x):
|
73 |
y = [x]
|
|
|
74 |
for layer in self.model:
|
75 |
if isinstance(layer.source, list):
|
76 |
model_input = [y[idx] for idx in layer.source]
|
@@ -78,7 +80,9 @@ class YOLO(nn.Module):
|
|
78 |
model_input = y[layer.source]
|
79 |
x = layer(model_input)
|
80 |
y.append(x)
|
81 |
-
|
|
|
|
|
82 |
|
83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
84 |
if "Conv" in layer_type:
|
@@ -90,10 +94,11 @@ class YOLO(nn.Module):
|
|
90 |
if layer_type == "IDetect":
|
91 |
return None
|
92 |
|
93 |
-
def create_layer(self, layer_type: str, source: Union[int, list], **kwargs):
|
94 |
if layer_type in self.layer_map:
|
95 |
layer = self.layer_map[layer_type](**kwargs)
|
96 |
layer.source = source
|
|
|
97 |
return layer
|
98 |
else:
|
99 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
|
|
48 |
layer_type, layer_info = next(iter(layer_spec.items()))
|
49 |
layer_args = layer_info.get("args", {})
|
50 |
source = layer_info.get("source", -1)
|
51 |
+
output = layer_info.get("output", False)
|
52 |
|
53 |
if isinstance(source, str):
|
54 |
source = layer_indices_by_tag[source]
|
|
|
58 |
layer_args["nc"] = self.nc
|
59 |
layer_args["ch"] = [output_dim[idx] for idx in source]
|
60 |
|
61 |
+
layer = self.create_layer(layer_type, source, output, **layer_args)
|
62 |
model_list.append(layer)
|
63 |
|
64 |
if "tags" in layer_info:
|
|
|
72 |
|
73 |
def forward(self, x):
|
74 |
y = [x]
|
75 |
+
output = []
|
76 |
for layer in self.model:
|
77 |
if isinstance(layer.source, list):
|
78 |
model_input = [y[idx] for idx in layer.source]
|
|
|
80 |
model_input = y[layer.source]
|
81 |
x = layer(model_input)
|
82 |
y.append(x)
|
83 |
+
if layer.output:
|
84 |
+
output.append(x)
|
85 |
+
return output
|
86 |
|
87 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
88 |
if "Conv" in layer_type:
|
|
|
94 |
if layer_type == "IDetect":
|
95 |
return None
|
96 |
|
97 |
+
def create_layer(self, layer_type: str, source: Union[int, list], output=False, **kwargs):
|
98 |
if layer_type in self.layer_map:
|
99 |
layer = self.layer_map[layer_type](**kwargs)
|
100 |
layer.source = source
|
101 |
+
layer.output = output
|
102 |
return layer
|
103 |
else:
|
104 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|