File size: 1,387 Bytes
918db92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
import torch
import os
import argparse
def get_args():
parse = argparse.ArgumentParser()
parse.add_argument("--weight_path", type=str)
parse.add_argument("--save_path", type=str)
args = parse.parse_args()
return args.weight_path,args.save_path,
def main():
weight_path, save_path = get_args()
weight = torch.load(weight_path, map_location="cpu")
state_dict = weight["state_dict"]
head_state_dict = {}
auxiliary_head_dict = {}
backbone_dict = {}
neck_dict = {}
student_adapter_dict = {}
for k, v in state_dict.items():
if "decode_head" in k:
head_state_dict[k] = v
elif "auxiliary_head" in k:
auxiliary_head_dict[k] = v
elif "backbone" in k:
backbone_dict[k] = v
elif "neck" in k:
neck_dict[k] = v
elif "student_adapter" in k:
student_adapter_dict[k] = v
else:
raise ValueError(f"unexpected keys:{k}")
torch.save(head_state_dict, os.path.join(save_path,"head.pth"))
torch.save(auxiliary_head_dict, os.path.join(save_path,"auxiliary_head.pth"))
torch.save(backbone_dict, os.path.join(save_path,"backbone.pth"))
torch.save(neck_dict, os.path.join(save_path,"neck.pth"))
torch.save(student_adapter_dict, os.path.join(save_path,"student_adapter.pth"))
if __name__ == "__main__":
main()
|