File size: 1,516 Bytes
b5fa3f1
 
f95a3d7
b5fa3f1
 
 
 
 
 
f95a3d7
b5fa3f1
 
 
 
 
 
 
 
 
 
f95a3d7
 
86ef0ef
 
b5fa3f1
 
 
 
 
 
86ef0ef
b5fa3f1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def convert_weight(old_state_dict, new_state_dict, model_size: int = 38):
    # TODO: need to refactor
    shift = 1
    for idx in range(model_size):
        new_list, old_list = [], []
        for weight_name, weight_value in new_state_dict.items():
            if weight_name.split(".")[0] == str(idx):
                new_list.append((weight_name, None))
        for weight_name, weight_value in old_state_dict.items():
            if f"model.{idx+shift}." in weight_name:
                old_list.append((weight_name, weight_value))
        if len(new_list) == len(old_list):
            for (weight_name, _), (_, weight_value) in zip(new_list, old_list):
                new_state_dict[weight_name] = weight_value
        else:
            for weight_name, weight_value in old_list:
                if "dfl" in weight_name:
                    continue
                _, _, conv_name, conv_idx, *details = weight_name.split(".")
                if conv_name == "cv4" or conv_name == "cv5":
                    layer_idx = 22
                    shift = 2
                else:
                    layer_idx = 37

                if conv_name == "cv2" or conv_name == "cv4":
                    conv_task = "anchor_conv"
                if conv_name == "cv3" or conv_name == "cv5":
                    conv_task = "class_conv"

                weight_name = ".".join([str(layer_idx), "heads", conv_idx, conv_task, *details])
                new_state_dict[weight_name] = weight_value
    return new_state_dict