henry000 commited on
Commit
5727efb
·
1 Parent(s): da24bd9

⚡️ [Update] usable tags, reduce training ram cost

Browse files
Files changed (2) hide show
  1. yolo/config/config.py +1 -0
  2. yolo/model/yolo.py +6 -4
yolo/config/config.py CHANGED
@@ -108,6 +108,7 @@ class YOLOLayer(nn.Module):
108
  output: bool
109
  tags: str
110
  layer_type: str
 
111
 
112
  def __post_init__(self):
113
  super().__init__()
 
108
  output: bool
109
  tags: str
110
  layer_type: str
111
+ usable: bool
112
 
113
  def __post_init__(self):
114
  super().__init__()
yolo/model/yolo.py CHANGED
@@ -69,7 +69,8 @@ class YOLO(nn.Module):
69
  else:
70
  model_input = y[layer.source]
71
  x = layer(model_input)
72
- if hasattr(layer, "save"):
 
73
  y[index] = x
74
  if layer.output:
75
  output.append(x)
@@ -92,10 +93,10 @@ class YOLO(nn.Module):
92
  return [self.get_source_idx(index, layer_idx) for index in source]
93
  if isinstance(source, str):
94
  source = self.layer_index[source]
95
- if source < 0:
96
  source += layer_idx
97
- if source > 0:
98
- setattr(self.model[source - 1], "save", True)
99
  return source
100
 
101
  def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
@@ -106,6 +107,7 @@ class YOLO(nn.Module):
106
  setattr(layer, "in_c", kwargs.get("in_channels", None))
107
  setattr(layer, "output", layer_info.get("output", False))
108
  setattr(layer, "tags", layer_info.get("tags", None))
 
109
  return layer
110
  else:
111
  raise ValueError(f"Unsupported layer type: {layer_type}")
 
69
  else:
70
  model_input = y[layer.source]
71
  x = layer(model_input)
72
+ y[-1] = x
73
+ if layer.usable:
74
  y[index] = x
75
  if layer.output:
76
  output.append(x)
 
93
  return [self.get_source_idx(index, layer_idx) for index in source]
94
  if isinstance(source, str):
95
  source = self.layer_index[source]
96
+ if source < -1:
97
  source += layer_idx
98
+ if source > 0: # Using Previous Layer's Output
99
+ self.model[source - 1].usable = True
100
  return source
101
 
102
  def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
 
107
  setattr(layer, "in_c", kwargs.get("in_channels", None))
108
  setattr(layer, "output", layer_info.get("output", False))
109
  setattr(layer, "tags", layer_info.get("tags", None))
110
+ setattr(layer, "usable", 0)
111
  return layer
112
  else:
113
  raise ValueError(f"Unsupported layer type: {layer_type}")