import torch import os import argparse def get_args(): parse = argparse.ArgumentParser() parse.add_argument("--weight_path", type=str) parse.add_argument("--save_path", type=str) args = parse.parse_args() return args.weight_path,args.save_path, def main(): weight_path, save_path = get_args() weight = torch.load(weight_path, map_location="cpu") state_dict = weight["state_dict"] head_state_dict = {} auxiliary_head_dict = {} backbone_dict = {} neck_dict = {} student_adapter_dict = {} for k, v in state_dict.items(): if "decode_head" in k: head_state_dict[k] = v elif "auxiliary_head" in k: auxiliary_head_dict[k] = v elif "backbone" in k: backbone_dict[k] = v elif "neck" in k: neck_dict[k] = v elif "student_adapter" in k: student_adapter_dict[k] = v else: raise ValueError(f"unexpected keys:{k}") torch.save(head_state_dict, os.path.join(save_path,"head.pth")) torch.save(auxiliary_head_dict, os.path.join(save_path,"auxiliary_head.pth")) torch.save(backbone_dict, os.path.join(save_path,"backbone.pth")) torch.save(neck_dict, os.path.join(save_path,"neck.pth")) torch.save(student_adapter_dict, os.path.join(save_path,"student_adapter.pth")) if __name__ == "__main__": main()