🔨 [Add] weight converter for YOLOv7
Browse files
yolo/tools/format_converters.py
CHANGED
@@ -31,3 +31,55 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
|
|
31 |
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
|
32 |
new_state_dict[weight_name] = weight_value
|
33 |
return new_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
|
32 |
new_state_dict[weight_name] = weight_value
|
33 |
return new_state_dict
|
34 |
+
|
35 |
+
|
36 |
+
head_converter = {
|
37 |
+
"head_conv": "m",
|
38 |
+
"implicit_a": "ia",
|
39 |
+
"implicit_m": "im",
|
40 |
+
}
|
41 |
+
|
42 |
+
SPP_converter = {
|
43 |
+
"pre_conv.0": "cv1",
|
44 |
+
"pre_conv.1": "cv3",
|
45 |
+
"pre_conv.2": "cv4",
|
46 |
+
"post_conv.0": "cv5",
|
47 |
+
"post_conv.1": "cv6",
|
48 |
+
"short_conv": "cv2",
|
49 |
+
"merge_conv": "cv7",
|
50 |
+
}
|
51 |
+
|
52 |
+
REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"}
|
53 |
+
|
54 |
+
|
55 |
+
def convert_weight_v7(old_state_dict, new_state_dict):
|
56 |
+
map_weight = []
|
57 |
+
for key_name in new_state_dict.keys():
|
58 |
+
new_shape = new_state_dict[key_name].shape
|
59 |
+
old_key_name = "model." + key_name
|
60 |
+
new_key_name = key_name
|
61 |
+
if old_key_name not in old_state_dict.keys():
|
62 |
+
if "heads" in key_name:
|
63 |
+
layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
|
64 |
+
old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
|
65 |
+
elif (
|
66 |
+
"pre_conv" in key_name
|
67 |
+
or "post_conv" in key_name
|
68 |
+
or "short_conv" in key_name
|
69 |
+
or "merge_conv" in key_name
|
70 |
+
):
|
71 |
+
for key, value in SPP_converter.items():
|
72 |
+
if key in key_name:
|
73 |
+
key_name = key_name.replace(key, value)
|
74 |
+
old_key_name = "model." + key_name
|
75 |
+
elif "conv1" in key_name or "conv2" in key_name:
|
76 |
+
for key, value in REP_converter.items():
|
77 |
+
if key in key_name:
|
78 |
+
key_name = key_name.replace(key, value)
|
79 |
+
old_key_name = "model." + key_name
|
80 |
+
map_weight.append(old_key_name)
|
81 |
+
assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
|
82 |
+
old_shape = old_state_dict[old_key_name].shape
|
83 |
+
assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
|
84 |
+
new_state_dict[new_key_name] = old_state_dict[old_key_name]
|
85 |
+
return new_state_dict
|