🔨 [Add] wegith transform for v9seg model
Browse files
yolo/tools/format_converters.py
CHANGED
@@ -83,3 +83,55 @@ def convert_weight_v7(old_state_dict, new_state_dict):
|
|
83 |
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
84 |
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
85 |
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
84 |
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
85 |
return new_state_dict
|
86 |
+
|
87 |
+
|
88 |
+
replace_dict = {"cv": "conv", ".m.": ".bottleneck."}
|
89 |
+
|
90 |
+
|
91 |
+
def convert_weight_seg(old_state_dict, new_state_dict):
|
92 |
+
diff = -1
|
93 |
+
for old_weight_name in old_state_dict.keys():
|
94 |
+
old_idx = int(old_weight_name.split(".")[1])
|
95 |
+
if old_idx == 23:
|
96 |
+
diff = 3
|
97 |
+
elif old_idx == 41:
|
98 |
+
diff = -19
|
99 |
+
new_idx = old_idx + diff
|
100 |
+
new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
|
101 |
+
for key, val in replace_dict.items():
|
102 |
+
new_weight_name = new_weight_name.replace(key, val)
|
103 |
+
|
104 |
+
if new_weight_name not in new_state_dict.keys():
|
105 |
+
heads = "heads"
|
106 |
+
_, _, conv_name, conv_idx, *details = old_weight_name.split(".")
|
107 |
+
if "proto" in conv_name:
|
108 |
+
conv_idx = "3"
|
109 |
+
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details])
|
110 |
+
continue
|
111 |
+
if "dfl" in old_weight_name:
|
112 |
+
continue
|
113 |
+
if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6":
|
114 |
+
layer_idx = 44
|
115 |
+
heads = "detect.heads"
|
116 |
+
if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7":
|
117 |
+
layer_idx = 25
|
118 |
+
heads = "detect.heads"
|
119 |
+
|
120 |
+
if conv_name == "cv2" or conv_name == "cv4":
|
121 |
+
conv_task = "anchor_conv"
|
122 |
+
if conv_name == "cv3" or conv_name == "cv5":
|
123 |
+
conv_task = "class_conv"
|
124 |
+
if conv_name == "cv6" or conv_name == "cv7":
|
125 |
+
conv_task = "mask_conv"
|
126 |
+
heads = "heads"
|
127 |
+
|
128 |
+
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])
|
129 |
+
|
130 |
+
if (
|
131 |
+
new_weight_name not in new_state_dict.keys()
|
132 |
+
or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
|
133 |
+
):
|
134 |
+
print(f"new: {new_weight_name}, old: {old_weight_name}")
|
135 |
+
print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
|
136 |
+
new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
|
137 |
+
return new_state_dict
|