|
import torch
|
|
import os
|
|
import argparse
|
|
import copy
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--source', type=str, required=True)
|
|
parser.add_argument('--prefix', type=str, required=True)
|
|
parser.add_argument('--target', type=str, default=None)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
def main():
|
|
|
|
args = parse_args()
|
|
|
|
if args.target is None:
|
|
args.target = '/'.join(args.source.split('/')[:-1])
|
|
|
|
ckpt = torch.load(args.source, map_location='cpu', weights_only=True)
|
|
|
|
experts = dict()
|
|
|
|
new_ckpt = copy.deepcopy(ckpt)
|
|
|
|
state_dict = new_ckpt['state_dict']
|
|
|
|
for key, value in state_dict.items():
|
|
if 'mlp.experts' in key:
|
|
experts[key] = value
|
|
|
|
keys = ckpt['state_dict'].keys()
|
|
|
|
weight_names = ['keypoint_head.deconv_layers.0.weight',
|
|
'keypoint_head.deconv_layers.1.weight',
|
|
'keypoint_head.deconv_layers.1.bias',
|
|
'keypoint_head.deconv_layers.1.running_mean',
|
|
'keypoint_head.deconv_layers.1.running_var',
|
|
'keypoint_head.deconv_layers.1.num_batches_tracked',
|
|
'keypoint_head.deconv_layers.3.weight',
|
|
'keypoint_head.deconv_layers.4.weight',
|
|
'keypoint_head.deconv_layers.4.bias',
|
|
'keypoint_head.deconv_layers.4.running_mean',
|
|
'keypoint_head.deconv_layers.4.running_var',
|
|
'keypoint_head.deconv_layers.4.num_batches_tracked',
|
|
'keypoint_head.final_layer.weight',
|
|
'keypoint_head.final_layer.bias']
|
|
|
|
target_expert = 0
|
|
new_ckpt = copy.deepcopy(ckpt)
|
|
|
|
for key in keys:
|
|
if 'mlp.fc2' in key:
|
|
value = new_ckpt['state_dict'][key]
|
|
value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0)
|
|
new_ckpt['state_dict'][key] = value
|
|
|
|
|
|
for j in range(5):
|
|
|
|
for tensor_name in weight_names:
|
|
new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head',
|
|
f'associate_keypoint_heads.{j}'))
|
|
|
|
keys = new_ckpt['state_dict'].keys()
|
|
for key in list(keys):
|
|
if 'expert' in key:
|
|
new_ckpt['state_dict'].pop(key)
|
|
|
|
torch.save(new_ckpt, os.path.join(args.target, args.prefix + 'coco.pth'))
|
|
|
|
names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody']
|
|
num_keypoints = [14, 16, 17, 17, 133]
|
|
exist_range = True
|
|
|
|
for i in range(5):
|
|
|
|
new_ckpt = copy.deepcopy(ckpt)
|
|
|
|
target_expert = i + 1
|
|
|
|
for key in keys:
|
|
if 'mlp.fc2' in key:
|
|
expert_key = key.replace('fc2.', f'experts.{target_expert}.')
|
|
if expert_key in experts:
|
|
value = new_ckpt['state_dict'][key]
|
|
value = torch.cat([value, experts[expert_key]], dim=0)
|
|
else:
|
|
exist_range = False
|
|
|
|
new_ckpt['state_dict'][key] = value
|
|
|
|
if not exist_range:
|
|
break
|
|
|
|
for tensor_name in weight_names:
|
|
new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name.replace('keypoint_head',
|
|
f'associate_keypoint_heads.{i}')]
|
|
|
|
for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']:
|
|
new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]]
|
|
|
|
|
|
for j in range(5):
|
|
|
|
for tensor_name in weight_names:
|
|
new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head',
|
|
f'associate_keypoint_heads.{j}'))
|
|
|
|
keys = new_ckpt['state_dict'].keys()
|
|
for key in list(keys):
|
|
if 'expert' in key:
|
|
new_ckpt['state_dict'].pop(key)
|
|
|
|
torch.save(new_ckpt, os.path.join(args.target, f'{args.prefix}{names[i]}.pth'))
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|