ktda-models / tools /extract_weight.py
XavierJiezou's picture
Add files using upload-large-folder tool
918db92 verified
import torch
import os
import argparse
def get_args():
parse = argparse.ArgumentParser()
parse.add_argument("--weight_path", type=str)
parse.add_argument("--save_path", type=str)
args = parse.parse_args()
return args.weight_path,args.save_path,
def main():
weight_path, save_path = get_args()
weight = torch.load(weight_path, map_location="cpu")
state_dict = weight["state_dict"]
head_state_dict = {}
auxiliary_head_dict = {}
backbone_dict = {}
neck_dict = {}
student_adapter_dict = {}
for k, v in state_dict.items():
if "decode_head" in k:
head_state_dict[k] = v
elif "auxiliary_head" in k:
auxiliary_head_dict[k] = v
elif "backbone" in k:
backbone_dict[k] = v
elif "neck" in k:
neck_dict[k] = v
elif "student_adapter" in k:
student_adapter_dict[k] = v
else:
raise ValueError(f"unexpected keys:{k}")
torch.save(head_state_dict, os.path.join(save_path,"head.pth"))
torch.save(auxiliary_head_dict, os.path.join(save_path,"auxiliary_head.pth"))
torch.save(backbone_dict, os.path.join(save_path,"backbone.pth"))
torch.save(neck_dict, os.path.join(save_path,"neck.pth"))
torch.save(student_adapter_dict, os.path.join(save_path,"student_adapter.pth"))
if __name__ == "__main__":
main()