YOLO / yolo /tools /format_converters.py
henry000's picture
🔨 [Add] wegith transform for v9seg model
f1585d3
raw
history blame
5.63 kB
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
# TODO: need to refactor
shift = 1
for idx in range(model_size):
new_list, old_list = [], []
for weight_name, weight_value in new_state_dict.items():
if weight_name.split(".")[0] == str(idx):
new_list.append((weight_name, None))
for weight_name, weight_value in old_state_dict.items():
if f"model.{idx+shift}." in weight_name:
old_list.append((weight_name, weight_value))
if len(new_list) == len(old_list):
for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
new_state_dict[weight_name] = weight_value
else:
for weight_name, weight_value in old_list:
if "dfl" in weight_name:
continue
_, _, conv_name, conv_idx, *details = weight_name.split(".")
if conv_name == "cv4" or conv_name == "cv5":
layer_idx = 22
shift = 2
else:
layer_idx = 37
if conv_name == "cv2" or conv_name == "cv4":
conv_task = "anchor_conv"
if conv_name == "cv3" or conv_name == "cv5":
conv_task = "class_conv"
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
new_state_dict[weight_name] = weight_value
return new_state_dict
head_converter = {
"head_conv": "m",
"implicit_a": "ia",
"implicit_m": "im",
}
SPP_converter = {
"pre_conv.0": "cv1",
"pre_conv.1": "cv3",
"pre_conv.2": "cv4",
"post_conv.0": "cv5",
"post_conv.1": "cv6",
"short_conv": "cv2",
"merge_conv": "cv7",
}
REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"}
def convert_weight_v7(old_state_dict, new_state_dict):
map_weight = []
for key_name in new_state_dict.keys():
new_shape = new_state_dict[key_name].shape
old_key_name = "model." + key_name
new_key_name = key_name
if old_key_name not in old_state_dict.keys():
if "heads" in key_name:
layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
elif (
"pre_conv" in key_name
or "post_conv" in key_name
or "short_conv" in key_name
or "merge_conv" in key_name
):
for key, value in SPP_converter.items():
if key in key_name:
key_name = key_name.replace(key, value)
old_key_name = "model." + key_name
elif "conv1" in key_name or "conv2" in key_name:
for key, value in REP_converter.items():
if key in key_name:
key_name = key_name.replace(key, value)
old_key_name = "model." + key_name
map_weight.append(old_key_name)
assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
old_shape = old_state_dict[old_key_name].shape
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
new_state_dict[new_key_name] = old_state_dict[old_key_name]
return new_state_dict
replace_dict = {"cv": "conv", ".m.": ".bottleneck."}
def convert_weight_seg(old_state_dict, new_state_dict):
diff = -1
for old_weight_name in old_state_dict.keys():
old_idx = int(old_weight_name.split(".")[1])
if old_idx == 23:
diff = 3
elif old_idx == 41:
diff = -19
new_idx = old_idx + diff
new_weight_name = old_weight_name.replace(f".{old_idx}.", f".{new_idx}.")
for key, val in replace_dict.items():
new_weight_name = new_weight_name.replace(key, val)
if new_weight_name not in new_state_dict.keys():
heads = "heads"
_, _, conv_name, conv_idx, *details = old_weight_name.split(".")
if "proto" in conv_name:
conv_idx = "3"
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_task, *details])
continue
if "dfl" in old_weight_name:
continue
if conv_name == "cv2" or conv_name == "cv3" or conv_name == "cv6":
layer_idx = 44
heads = "detect.heads"
if conv_name == "cv4" or conv_name == "cv5" or conv_name == "cv7":
layer_idx = 25
heads = "detect.heads"
if conv_name == "cv2" or conv_name == "cv4":
conv_task = "anchor_conv"
if conv_name == "cv3" or conv_name == "cv5":
conv_task = "class_conv"
if conv_name == "cv6" or conv_name == "cv7":
conv_task = "mask_conv"
heads = "heads"
new_weight_name = ".".join(["model", str(layer_idx), heads, conv_idx, conv_task, *details])
if (
new_weight_name not in new_state_dict.keys()
or new_state_dict[new_weight_name].shape != old_state_dict[old_weight_name].shape
):
print(f"new: {new_weight_name}, old: {old_weight_name}")
print(f"{new_state_dict[new_weight_name].shape} {old_state_dict[old_weight_name].shape}")
new_state_dict[new_weight_name] = old_state_dict[old_weight_name]
return new_state_dict