⚡️ [Update] usable tags, reduce training ram cost
Browse files- yolo/config/config.py +1 -0
- yolo/model/yolo.py +6 -4
yolo/config/config.py
CHANGED
@@ -108,6 +108,7 @@ class YOLOLayer(nn.Module):
|
|
108 |
output: bool
|
109 |
tags: str
|
110 |
layer_type: str
|
|
|
111 |
|
112 |
def __post_init__(self):
|
113 |
super().__init__()
|
|
|
108 |
output: bool
|
109 |
tags: str
|
110 |
layer_type: str
|
111 |
+
usable: bool
|
112 |
|
113 |
def __post_init__(self):
|
114 |
super().__init__()
|
yolo/model/yolo.py
CHANGED
@@ -69,7 +69,8 @@ class YOLO(nn.Module):
|
|
69 |
else:
|
70 |
model_input = y[layer.source]
|
71 |
x = layer(model_input)
|
72 |
-
|
|
|
73 |
y[index] = x
|
74 |
if layer.output:
|
75 |
output.append(x)
|
@@ -92,10 +93,10 @@ class YOLO(nn.Module):
|
|
92 |
return [self.get_source_idx(index, layer_idx) for index in source]
|
93 |
if isinstance(source, str):
|
94 |
source = self.layer_index[source]
|
95 |
-
if source <
|
96 |
source += layer_idx
|
97 |
-
if source > 0:
|
98 |
-
|
99 |
return source
|
100 |
|
101 |
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
@@ -106,6 +107,7 @@ class YOLO(nn.Module):
|
|
106 |
setattr(layer, "in_c", kwargs.get("in_channels", None))
|
107 |
setattr(layer, "output", layer_info.get("output", False))
|
108 |
setattr(layer, "tags", layer_info.get("tags", None))
|
|
|
109 |
return layer
|
110 |
else:
|
111 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
|
|
69 |
else:
|
70 |
model_input = y[layer.source]
|
71 |
x = layer(model_input)
|
72 |
+
y[-1] = x
|
73 |
+
if layer.usable:
|
74 |
y[index] = x
|
75 |
if layer.output:
|
76 |
output.append(x)
|
|
|
93 |
return [self.get_source_idx(index, layer_idx) for index in source]
|
94 |
if isinstance(source, str):
|
95 |
source = self.layer_index[source]
|
96 |
+
if source < -1:
|
97 |
source += layer_idx
|
98 |
+
if source > 0: # Using Previous Layer's Output
|
99 |
+
self.model[source - 1].usable = True
|
100 |
return source
|
101 |
|
102 |
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
|
|
107 |
setattr(layer, "in_c", kwargs.get("in_channels", None))
|
108 |
setattr(layer, "output", layer_info.get("output", False))
|
109 |
setattr(layer, "tags", layer_info.get("tags", None))
|
110 |
+
setattr(layer, "usable", 0)
|
111 |
return layer
|
112 |
else:
|
113 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|