henry000 commited on
Commit
f1585d3
·
1 Parent(s): 00c86de

🔨 [Add] wegith transform for v9seg model

Browse files
Files changed (1) hide show
  1. yolo/tools/format_converters.py +52 -0
yolo/tools/format_converters.py CHANGED
@@ -83,3 +83,55 @@ def convert_weight_v7(old_state_dict, new_state_dict):
83
  assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
84
  new_state_dict[new_key_name] = old_state_dict[old_key_name]
85
  return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
84
  new_state_dict[new_key_name] = old_state_dict[old_key_name]
85
  return new_state_dict
86
+
87
+
88
+ replace_dict = {"cv": "conv", ".m.": ".bottleneck."}
89
+
90
+
91
+ def convert_weight_seg(old_state_dict, new_state_dict):
92
+ diff = -1
93
+ for old_weight_name in old_state_dict.keys():
94
+ old_idx = int(old_weight_name.split(".")[1])
95
+ if old_idx == 23:
96
+ diff = 3
97
+ elif old_idx == 41:
98
+ diff = -19
99
+ new_idx = old_idx + diff
100
+ new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
101
+ for key, val in replace_dict.items():
102
+ new_weight_name = new_weight_name.replace(key, val)
103
+
104
+ if new_weight_name not in new_state_dict.keys():
105
+ heads = "heads"
106
+ _, _, conv_name, conv_idx, *details = old_weight_name.split(".")
107
+ if "proto" in conv_name:
108
+ conv_idx = "3"
109
+ new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details])
110
+ continue
111
+ if "dfl" in old_weight_name:
112
+ continue
113
+ if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6":
114
+ layer_idx = 44
115
+ heads = "detect.heads"
116
+ if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7":
117
+ layer_idx = 25
118
+ heads = "detect.heads"
119
+
120
+ if conv_name == "cv2" or conv_name == "cv4":
121
+ conv_task = "anchor_conv"
122
+ if conv_name == "cv3" or conv_name == "cv5":
123
+ conv_task = "class_conv"
124
+ if conv_name == "cv6" or conv_name == "cv7":
125
+ conv_task = "mask_conv"
126
+ heads = "heads"
127
+
128
+ new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])
129
+
130
+ if (
131
+ new_weight_name not in new_state_dict.keys()
132
+ or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
133
+ ):
134
+ print(f"new: {new_weight_name}, old: {old_weight_name}")
135
+ print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
136
+ new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
137
+ return new_state_dict