henry000 commited on
Commit
7a913f7
·
1 Parent(s): d5a73bd

🔨 [Add] weight converter for YOLOv7

Browse files
Files changed (1) hide show
  1. yolo/tools/format_converters.py +52 -0
yolo/tools/format_converters.py CHANGED
@@ -31,3 +31,55 @@ def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
31
  weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
32
  new_state_dict[weight_name] = weight_value
33
  return new_state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
32
  new_state_dict[weight_name] = weight_value
33
  return new_state_dict
34
+
35
+
36
+ head_converter = {
37
+ "head_conv": "m",
38
+ "implicit_a": "ia",
39
+ "implicit_m": "im",
40
+ }
41
+
42
+ SPP_converter = {
43
+ "pre_conv.0": "cv1",
44
+ "pre_conv.1": "cv3",
45
+ "pre_conv.2": "cv4",
46
+ "post_conv.0": "cv5",
47
+ "post_conv.1": "cv6",
48
+ "short_conv": "cv2",
49
+ "merge_conv": "cv7",
50
+ }
51
+
52
+ REP_converter = {"conv1": "rbr_dense", "conv2": "rbr_1x1", "conv": "0", "bn": "1"}
53
+
54
+
55
+ def convert_weight_v7(old_state_dict, new_state_dict):
56
+ map_weight = []
57
+ for key_name in new_state_dict.keys():
58
+ new_shape = new_state_dict[key_name].shape
59
+ old_key_name = "model." + key_name
60
+ new_key_name = key_name
61
+ if old_key_name not in old_state_dict.keys():
62
+ if "heads" in key_name:
63
+ layer_idx, _, conv_idx, conv_name, *details = key_name.split(".")
64
+ old_key_name = ".".join(["model", str(layer_idx), head_converter[conv_name], conv_idx, *details])
65
+ elif (
66
+ "pre_conv" in key_name
67
+ or "post_conv" in key_name
68
+ or "short_conv" in key_name
69
+ or "merge_conv" in key_name
70
+ ):
71
+ for key, value in SPP_converter.items():
72
+ if key in key_name:
73
+ key_name = key_name.replace(key, value)
74
+ old_key_name = "model." + key_name
75
+ elif "conv1" in key_name or "conv2" in key_name:
76
+ for key, value in REP_converter.items():
77
+ if key in key_name:
78
+ key_name = key_name.replace(key, value)
79
+ old_key_name = "model." + key_name
80
+ map_weight.append(old_key_name)
81
+ assert old_key_name in old_state_dict.keys(), f"Weight Name Mismatch!! {old_key_name}"
82
+ old_shape = old_state_dict[old_key_name].shape
83
+ assert new_shape == old_shape, "Weight Shape Mismatch!! {old_key_name}"
84
+ new_state_dict[new_key_name] = old_state_dict[old_key_name]
85
+ return new_state_dict