File size: 5,626 Bytes
b5fa3f1 f95a3d7 b5fa3f1 f95a3d7 b5fa3f1 f95a3d7 86ef0ef b5fa3f1 86ef0ef b5fa3f1 7a913f7 f1585d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
# TODO: need to refactor
shift = 1
for idx in range(model_size):
new_list, old_list = [], []
for weight_name, weight_value in new_state_dict.items():
if weight_name.split(".")[0] == str(idx):
new_list.append((weight_name, None))
for weight_name, weight_value in old_state_dict.items():
if f"model.{idx+shift}." in weight_name:
old_list.append((weight_name, weight_value))
if len(new_list) == len(old_list):
for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
new_state_dict[weight_name] = weight_value
else:
for weight_name, weight_value in old_list:
if "dfl" in weight_name:
continue
_, _, conv_name, conv_idx, *details = weight_name.split(".")
if conv_name == "cv4" or conv_name == "cv5":
layer_idx = 22
shift = 2
else:
layer_idx = 37
if conv_name == "cv2" or conv_name == "cv4":
conv_task = "anchor_conv"
if conv_name == "cv3" or conv_name == "cv5":
conv_task = "class_conv"
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
new_state_dict[weight_name] = weight_value
return new_state_dict
head_converter = {
"head_conv": "m",
"implicit_a": "ia",
"implicit_m": "im",
}
SPP_converter = {
"pre_conv.0": "cv1",
"pre_conv.1": "cv3",
"pre_conv.2": "cv4",
"post_conv.0": "cv5",
"post_conv.1": "cv6",
"short_conv": "cv2",
"merge_conv": "cv7",
}
REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"}
def convert_weight_v7(old_state_dict, new_state_dict):
map_weight = []
for key_name in new_state_dict.keys():
new_shape = new_state_dict[key_name].shape
old_key_name = "model." + key_name
new_key_name = key_name
if old_key_name not in old_state_dict.keys():
if "heads" in key_name:
layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
elif (
"pre_conv" in key_name
or "post_conv" in key_name
or "short_conv" in key_name
or "merge_conv" in key_name
):
for key, value in SPP_converter.items():
if key in key_name:
key_name = key_name.replace(key, value)
old_key_name = "model." + key_name
elif "conv1" in key_name or "conv2" in key_name:
for key, value in REP_converter.items():
if key in key_name:
key_name = key_name.replace(key, value)
old_key_name = "model." + key_name
map_weight.append(old_key_name)
assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
old_shape = old_state_dict[old_key_name].shape
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
new_state_dict[new_key_name] = old_state_dict[old_key_name]
return new_state_dict
replace_dict = {"cv": "conv", ".m.": ".bottleneck."}
def convert_weight_seg(old_state_dict, new_state_dict):
diff = -1
for old_weight_name in old_state_dict.keys():
old_idx = int(old_weight_name.split(".")[1])
if old_idx == 23:
diff = 3
elif old_idx == 41:
diff = -19
new_idx = old_idx + diff
new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
for key, val in replace_dict.items():
new_weight_name = new_weight_name.replace(key, val)
if new_weight_name not in new_state_dict.keys():
heads = "heads"
_, _, conv_name, conv_idx, *details = old_weight_name.split(".")
if "proto" in conv_name:
conv_idx = "3"
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details])
continue
if "dfl" in old_weight_name:
continue
if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6":
layer_idx = 44
heads = "detect.heads"
if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7":
layer_idx = 25
heads = "detect.heads"
if conv_name == "cv2" or conv_name == "cv4":
conv_task = "anchor_conv"
if conv_name == "cv3" or conv_name == "cv5":
conv_task = "class_conv"
if conv_name == "cv6" or conv_name == "cv7":
conv_task = "mask_conv"
heads = "heads"
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])
if (
new_weight_name not in new_state_dict.keys()
or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
):
print(f"new: {new_weight_name}, old: {old_weight_name}")
print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
return new_state_dict
|