|
|
|
import argparse |
|
import os.path as osp |
|
|
|
import mmengine |
|
import torch |
|
from mmengine.runner import CheckpointLoader |
|
|
|
|
|
def convert_stdc(ckpt, stdc_type): |
|
new_state_dict = {} |
|
if stdc_type == 'STDC1': |
|
stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] |
|
else: |
|
stage_lst = [ |
|
'0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', |
|
'3.4', '4.0', '4.1', '4.2' |
|
] |
|
for k, v in ckpt.items(): |
|
ori_k = k |
|
flag = False |
|
if 'cp.' in k: |
|
k = k.replace('cp.', '') |
|
if 'features.' in k: |
|
num_layer = int(k.split('.')[1]) |
|
feature_key_lst = 'features.' + str(num_layer) + '.' |
|
stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' |
|
k = k.replace(feature_key_lst, stages_key_lst) |
|
flag = True |
|
if 'conv_list' in k: |
|
k = k.replace('conv_list', 'layers') |
|
flag = True |
|
if 'avd_layer.' in k: |
|
if 'avd_layer.0' in k: |
|
k = k.replace('avd_layer.0', 'downsample.conv') |
|
elif 'avd_layer.1' in k: |
|
k = k.replace('avd_layer.1', 'downsample.bn') |
|
flag = True |
|
if flag: |
|
new_state_dict[k] = ckpt[ori_k] |
|
|
|
return new_state_dict |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser( |
|
description='Convert keys in official pretrained STDC1/2 to ' |
|
'MMSegmentation style.') |
|
parser.add_argument('src', help='src model path') |
|
|
|
parser.add_argument('dst', help='save path') |
|
parser.add_argument('type', help='model type: STDC1 or STDC2') |
|
args = parser.parse_args() |
|
|
|
checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') |
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
elif 'model' in checkpoint: |
|
state_dict = checkpoint['model'] |
|
else: |
|
state_dict = checkpoint |
|
|
|
assert args.type in ['STDC1', |
|
'STDC2'], 'STD type should be STDC1 or STDC2!' |
|
weight = convert_stdc(state_dict, args.type) |
|
mmengine.mkdir_or_exist(osp.dirname(args.dst)) |
|
torch.save(weight, args.dst) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|