File size: 4,516 Bytes
29d411b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import os
import argparse
import copy

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--source', type=str, required=True)
    parser.add_argument('--prefix', type=str, required=True)
    parser.add_argument('--target', type=str, default=None)
    args = parser.parse_args()
    return args

def main():

    args = parse_args()

    if args.target is None:
        args.target = '/'.join(args.source.split('/')[:-1])

    ckpt = torch.load(args.source, map_location='cpu', weights_only=True)
    
    experts = dict()

    new_ckpt = copy.deepcopy(ckpt)

    state_dict = new_ckpt['state_dict']

    for key, value in state_dict.items():
        if 'mlp.experts' in key:
            experts[key] = value

    keys = ckpt['state_dict'].keys()

    weight_names = ['keypoint_head.deconv_layers.0.weight', 
                    'keypoint_head.deconv_layers.1.weight', 
                    'keypoint_head.deconv_layers.1.bias', 
                    'keypoint_head.deconv_layers.1.running_mean', 
                    'keypoint_head.deconv_layers.1.running_var', 
                    'keypoint_head.deconv_layers.1.num_batches_tracked', 
                    'keypoint_head.deconv_layers.3.weight', 
                    'keypoint_head.deconv_layers.4.weight', 
                    'keypoint_head.deconv_layers.4.bias', 
                    'keypoint_head.deconv_layers.4.running_mean', 
                    'keypoint_head.deconv_layers.4.running_var', 
                    'keypoint_head.deconv_layers.4.num_batches_tracked', 
                    'keypoint_head.final_layer.weight', 
                    'keypoint_head.final_layer.bias']
    
    target_expert = 0
    new_ckpt = copy.deepcopy(ckpt)

    for key in keys:
        if 'mlp.fc2' in key:
            value = new_ckpt['state_dict'][key]
            value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0)
            new_ckpt['state_dict'][key] = value

    # remove unnecessary part in the state dict
    for j in range(5):
        # remove associate part
        for tensor_name in weight_names:
            new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head',
                                                           f'associate_keypoint_heads.{j}'))
    # remove expert part
    keys = new_ckpt['state_dict'].keys()
    for key in list(keys):
        if 'expert' in key:
            new_ckpt['state_dict'].pop(key)

    torch.save(new_ckpt, os.path.join(args.target, args.prefix + 'coco.pth'))

    names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody']
    num_keypoints = [14, 16, 17, 17, 133]
    exist_range = True

    for i in range(5):

        new_ckpt = copy.deepcopy(ckpt)

        target_expert = i + 1

        for key in keys:
            if 'mlp.fc2' in key:
                expert_key = key.replace('fc2.', f'experts.{target_expert}.')
                if expert_key in experts:
                    value = new_ckpt['state_dict'][key]
                    value = torch.cat([value, experts[expert_key]], dim=0)
                else:
                    exist_range = False

                new_ckpt['state_dict'][key] = value

        if not exist_range:
            break

        for tensor_name in weight_names:
            new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name.replace('keypoint_head',
                                                                                             f'associate_keypoint_heads.{i}')]

        for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']:
            new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]]
        
        # remove unnecessary part in the state dict
        for j in range(5):
            # remove associate part
            for tensor_name in weight_names:
                new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head',
                                                               f'associate_keypoint_heads.{j}'))
        # remove expert part
        keys = new_ckpt['state_dict'].keys()
        for key in list(keys):
            if 'expert' in key:
                new_ckpt['state_dict'].pop(key)
            
        torch.save(new_ckpt, os.path.join(args.target, f'{args.prefix}{names[i]}.pth'))

if __name__ == '__main__':
    main()