♻️ [Refactor] the finding source input code in yolo
Browse files- yolo/model/yolo.py +19 -16
yolo/model/yolo.py
CHANGED
@@ -25,10 +25,8 @@ class YOLO(nn.Module):
|
|
25 |
self.build_model(model_cfg.model)
|
26 |
|
27 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
28 |
-
|
29 |
-
output_dim = [3]
|
30 |
-
layer_indices_by_tag = {}
|
31 |
-
layer_idx = 1
|
32 |
logger.info(f"🚜 Building YOLO")
|
33 |
for arch_name in model_arch:
|
34 |
logger.info(f" 🏗️ Building {arch_name}")
|
@@ -37,11 +35,7 @@ class YOLO(nn.Module):
|
|
37 |
layer_args = layer_info.get("args", {})
|
38 |
|
39 |
# Get input source
|
40 |
-
source = layer_info.get("source", -1)
|
41 |
-
if isinstance(source, str):
|
42 |
-
source = layer_indices_by_tag[source]
|
43 |
-
elif isinstance(source, ListConfig):
|
44 |
-
source = [layer_indices_by_tag[idx] if isinstance(idx, str) else idx for idx in source]
|
45 |
|
46 |
# Find in channels
|
47 |
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
@@ -52,19 +46,17 @@ class YOLO(nn.Module):
|
|
52 |
|
53 |
# create layers
|
54 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
55 |
-
|
56 |
|
57 |
-
if
|
58 |
-
if
|
59 |
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
|
60 |
-
|
61 |
|
62 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
63 |
output_dim.append(out_channels)
|
64 |
layer_idx += 1
|
65 |
|
66 |
-
self.model = model_list
|
67 |
-
|
68 |
def forward(self, x):
|
69 |
y = [x]
|
70 |
output = []
|
@@ -91,7 +83,18 @@ class YOLO(nn.Module):
|
|
91 |
if layer_type == "IDetect":
|
92 |
return None
|
93 |
|
94 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if layer_type in self.layer_map:
|
96 |
layer = self.layer_map[layer_type](**kwargs)
|
97 |
setattr(layer, "source", source)
|
|
|
25 |
self.build_model(model_cfg.model)
|
26 |
|
27 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
28 |
+
self.layer_index = {}
|
29 |
+
output_dim, layer_idx = [3], 1
|
|
|
|
|
30 |
logger.info(f"🚜 Building YOLO")
|
31 |
for arch_name in model_arch:
|
32 |
logger.info(f" 🏗️ Building {arch_name}")
|
|
|
35 |
layer_args = layer_info.get("args", {})
|
36 |
|
37 |
# Get input source
|
38 |
+
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Find in channels
|
41 |
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
|
|
46 |
|
47 |
# create layers
|
48 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
49 |
+
self.model.append(layer)
|
50 |
|
51 |
+
if layer.tags:
|
52 |
+
if layer.tags in self.layer_index:
|
53 |
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
|
54 |
+
self.layer_index[layer.tags] = layer_idx
|
55 |
|
56 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
57 |
output_dim.append(out_channels)
|
58 |
layer_idx += 1
|
59 |
|
|
|
|
|
60 |
def forward(self, x):
|
61 |
y = [x]
|
62 |
output = []
|
|
|
83 |
if layer_type == "IDetect":
|
84 |
return None
|
85 |
|
86 |
+
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
|
87 |
+
if isinstance(source, ListConfig):
|
88 |
+
return [self.get_source_idx(index, layer_idx) for index in source]
|
89 |
+
if isinstance(source, str):
|
90 |
+
source = self.layer_index[source]
|
91 |
+
if source < 0:
|
92 |
+
source += layer_idx
|
93 |
+
if source > 0:
|
94 |
+
setattr(self.model[source - 1], "save", True)
|
95 |
+
return source
|
96 |
+
|
97 |
+
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
98 |
if layer_type in self.layer_map:
|
99 |
layer = self.layer_map[layer_type](**kwargs)
|
100 |
setattr(layer, "source", source)
|