henry000 commited on
Commit
92614f8
·
1 Parent(s): 856cce6

♻️ [Refactor] the finding source input code in yolo

Browse files
Files changed (1) hide show
  1. yolo/model/yolo.py +19 -16
yolo/model/yolo.py CHANGED
@@ -25,10 +25,8 @@ class YOLO(nn.Module):
25
  self.build_model(model_cfg.model)
26
 
27
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
28
- model_list = nn.ModuleList()
29
- output_dim = [3]
30
- layer_indices_by_tag = {}
31
- layer_idx = 1
32
  logger.info(f"🚜 Building YOLO")
33
  for arch_name in model_arch:
34
  logger.info(f" 🏗️ Building {arch_name}")
@@ -37,11 +35,7 @@ class YOLO(nn.Module):
37
  layer_args = layer_info.get("args", {})
38
 
39
  # Get input source
40
- source = layer_info.get("source", -1)
41
- if isinstance(source, str):
42
- source = layer_indices_by_tag[source]
43
- elif isinstance(source, ListConfig):
44
- source = [layer_indices_by_tag[idx] if isinstance(idx, str) else idx for idx in source]
45
 
46
  # Find in channels
47
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
@@ -52,19 +46,17 @@ class YOLO(nn.Module):
52
 
53
  # create layers
54
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)
55
- model_list.append(layer)
56
 
57
- if "tags" in layer_info:
58
- if layer_info["tags"] in layer_indices_by_tag:
59
  raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
60
- layer_indices_by_tag[layer_info["tags"]] = layer_idx
61
 
62
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
63
  output_dim.append(out_channels)
64
  layer_idx += 1
65
 
66
- self.model = model_list
67
-
68
  def forward(self, x):
69
  y = [x]
70
  output = []
@@ -91,7 +83,18 @@ class YOLO(nn.Module):
91
  if layer_type == "IDetect":
92
  return None
93
 
94
- def create_layer(self, layer_type: str, source: Union[int, list], layer_info, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
95
  if layer_type in self.layer_map:
96
  layer = self.layer_map[layer_type](**kwargs)
97
  setattr(layer, "source", source)
 
25
  self.build_model(model_cfg.model)
26
 
27
  def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
28
+ self.layer_index = {}
29
+ output_dim, layer_idx = [3], 1
 
 
30
  logger.info(f"🚜 Building YOLO")
31
  for arch_name in model_arch:
32
  logger.info(f" 🏗️ Building {arch_name}")
 
35
  layer_args = layer_info.get("args", {})
36
 
37
  # Get input source
38
+ source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
 
 
 
 
39
 
40
  # Find in channels
41
  if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
 
46
 
47
  # create layers
48
  layer = self.create_layer(layer_type, source, layer_info, **layer_args)
49
+ self.model.append(layer)
50
 
51
+ if layer.tags:
52
+ if layer.tags in self.layer_index:
53
  raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
54
+ self.layer_index[layer.tags] = layer_idx
55
 
56
  out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
57
  output_dim.append(out_channels)
58
  layer_idx += 1
59
 
 
 
60
  def forward(self, x):
61
  y = [x]
62
  output = []
 
83
  if layer_type == "IDetect":
84
  return None
85
 
86
+ def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
87
+ if isinstance(source, ListConfig):
88
+ return [self.get_source_idx(index, layer_idx) for index in source]
89
+ if isinstance(source, str):
90
+ source = self.layer_index[source]
91
+ if source < 0:
92
+ source += layer_idx
93
+ if source > 0:
94
+ setattr(self.model[source - 1], "save", True)
95
+ return source
96
+
97
+ def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
98
  if layer_type in self.layer_map:
99
  layer = self.layer_map[layer_type](**kwargs)
100
  setattr(layer, "source", source)