""" Transfer PuzzleTuning Pre-Training checkpoints Script ver: Oct 23rd 17:00 write a model based on the weight of a checkpoint file EG: create a vit-base based on PuzzleTuning SAE """ import argparse import sys sys.path.append('..') import os import torch import torch.nn as nn from Backbone import getmodel, GetPromptModel from SSL_structures import SAE # Transfer pretrained MSHT checkpoints to normal model state_dict def transfer_model_encoder(check_point_path, save_model_path, model_idx='ViT', prompt_mode=None, Prompt_Token_num=20, edge_size=384, given_name=None): if not os.path.exists(save_model_path): os.makedirs(save_model_path) if given_name is not None: given_path = os.path.join(save_model_path, given_name) else: given_path = None if prompt_mode == "Deep" or prompt_mode == "Shallow": model = GetPromptModel.build_promptmodel(edge_size=edge_size, model_idx=model_idx, patch_size=16, Prompt_Token_num=Prompt_Token_num, VPT_type=prompt_mode, base_state_dict=None) # elif prompt_mode == "Other" or prompt_mode == None: else: model = getmodel.get_model(model_idx=model_idx, pretrained_backbone=False, edge_size=edge_size) ''' state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} TempBest_state = {'model': best_model_wts, 'epoch': best_epoch_idx} ''' state = torch.load(check_point_path) transfer_name = os.path.splitext(os.path.split(check_point_path)[1])[0] + '_of_' try: model_state = state['model'] try: print("checkpoint epoch", state['epoch']) if prompt_mode is not None: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_E_' + str(state['epoch']) + '_promptstate' + '.pth') else: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_E_' + str(state['epoch']) + '_transfer' + '.pth') except: print("no 'epoch' in state") if prompt_mode is not None: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_promptstate' + '.pth') else: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_transfer' + '.pth') except: print("not a checkpoint state (no 'model' in state)") model_state = state if prompt_mode is not None: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_promptstate' + '.pth') else: save_model_path = os.path.join(save_model_path, transfer_name + model_idx + '_transfer' + '.pth') try: model.load_state_dict(model_state) print("model loaded") print("model :", model_idx) gpu_use = 0 except: try: model = nn.DataParallel(model) model.load_state_dict(model_state, False) print("DataParallel model loaded") print("model :", model_idx) gpu_use = -1 except: print("model loading erro!!") gpu_use = -2 if given_path is not None: save_model_path = given_path if gpu_use == -1: # print(model) if prompt_mode is not None: prompt_state_dict = model.module.obtain_prompt() # fixme maybe bug at DP module.obtain_prompt, just model.obtain_prompt is enough print('prompt obtained') torch.save(prompt_state_dict, save_model_path) else: torch.save(model.module.state_dict(), save_model_path) print('model trained by multi-GPUs has its single GPU copy saved at ', save_model_path) elif gpu_use == 0: if prompt_mode is not None: prompt_state_dict = model.obtain_prompt() print('prompt obtained') torch.save(prompt_state_dict, save_model_path) else: torch.save(model.state_dict(), save_model_path) print('model trained by a single GPU has been saved at ', save_model_path) else: print('erro') def transfer_model_decoder(check_point_path, save_model_path, model_idx='sae_vit_base_patch16_decoder', dec_idx='swin_unet', prompt_mode=None, Prompt_Token_num=20, edge_size=384): if not os.path.exists(save_model_path): os.makedirs(save_model_path) state = torch.load(check_point_path) transfer_name = os.path.splitext(os.path.split(check_point_path)[1])[0] + '_of_' model = SAE.__dict__[model_idx](img_size=edge_size, prompt_mode=prompt_mode, Prompt_Token_num=Prompt_Token_num, basic_state_dict=None, dec_idx=dec_idx) try: model_state = state['model'] try: print("checkpoint epoch", state['epoch']) save_model_path = os.path.join(save_model_path, transfer_name + 'Decoder_' + dec_idx + '_E_' + str(state['epoch']) + '.pth') except: print("no 'epoch' in state") save_model_path = os.path.join(save_model_path, transfer_name + 'Decoder_' + dec_idx + '.pth') except: print("not a checkpoint state (no 'model' in state)") model_state = state save_model_path = os.path.join(save_model_path, transfer_name + 'Decoder_' + dec_idx + '.pth') try: model.load_state_dict(model_state) print("model loaded") print("model :", model_idx) gpu_use = 0 except: try: model = nn.DataParallel(model) model.load_state_dict(model_state, False) print("DataParallel model loaded") print("model :", model_idx) gpu_use = -1 except: print("model loading erro!!") gpu_use = -2 else: model = model.decoder if gpu_use == -1: torch.save(model.module.decoder.state_dict(), save_model_path) print('model trained by multi-GPUs has its single GPU copy saved at ', save_model_path) elif gpu_use == 0: torch.save(model.state_dict(), save_model_path) print('model trained by a single GPU has been saved at ', save_model_path) else: print('erro') def get_args_parser(): parser = argparse.ArgumentParser('Take pre-trained model from PuzzleTuning', add_help=False) # Model Name or index parser.add_argument('--given_name', default=None, type=str, help='name of the weight-state-dict') parser.add_argument('--model_idx', default='ViT', type=str, help='taking the weight to the specified model') parser.add_argument('--edge_size', default=224, type=int, help='images input size for model') # PromptTuning parser.add_argument('--PromptTuning', default=None, type=str, help='Deep/Shallow to use Prompt Tuning model instead of Finetuning model, by default None') # Prompt_Token_num parser.add_argument('--Prompt_Token_num', default=20, type=int, help='Prompt_Token_num') # PATH settings parser.add_argument('--checkpoint_path', default=None, type=str, help='check_point_path') parser.add_argument('--save_model_path', default=None, type=str, help='out put weight path for pre-trained model') return parser def main(args): # fixme: now need a CUDA device as the model is save as a CUDA model! # PuzzleTuning Template """ # Prompt # transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20,given_name=given_name) # not prompt model # transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, given_name=given_name) # decoder # transfer_model_decoder(checkpoint_path, save_model_path, model_idx='sae_vit_base_patch16_decoder', dec_idx='swin_unet', edge_size=224, prompt_mode='Deep') # PuzzleTuning Experiments transfer records: # 1 周期puzzle自动减小ratio,自动loop变化size 迁移timm,PromptTuning:VPT-Deep,seg_decoder:None (核心方法) # ViT_b16_224_timm_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_Prompt_Deep_tokennum_20_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_50_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20,given_name=given_name) # PuzzleTuning Ablation studies:SAE+不同curriculum+不同VPT/ViT # 2 周期puzzle自动减小ratio,自动loop变化size 迁移timm,PromptTuning:None,seg_decoder:None # ViT_b16_224_timm_PuzzleTuning_SAE_CPIAm_E_199.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_CPIAm_E_199.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, given_name=given_name) # 3 固定puzzle ratio,固定patch size 迁移timm,PromptTuning:VPT-Deep,seg_decoder:None (服务器pt1) # ViT_b16_224_timm_PuzzleTuning_fixp16fixr25_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_fixp16fixr25_vit_base_Prompt_Deep_tokennum_20_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_fixp16fixr25_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20, given_name=given_name) # 4 固定puzzle ratio,固定patch size 迁移timm,PromptTuning:None,seg_decoder:None (服务器pt2) # ViT_b16_224_timm_PuzzleTuning_fixp16fixr25_SAE_CPIAm_E_199.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_fixp16fixr25_vit_base_patch16_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_fixp16fixr25_CPIAm_E_199.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, given_name=given_name) # 5 变化puzzle ratio,固定patch size 迁移timm,PromptTuning:VPT-Deep,seg_decoder:None, strategy: ratio-decay (服务器pt3) # ViT_b16_224_timm_PuzzleTuning_fixp16ratiodecay_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_fixp16ratiodecay_vit_base_Prompt_Deep_tokennum_20_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_fixp16ratiodecay_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20, given_name=given_name) # 6 变化puzzle ratio,固定patch size 迁移timm,PromptTuning:None,seg_decoder:None (服务器pt4) # ViT_b16_224_timm_PuzzleTuning_fixp16ratiodecay_SAE_CPIAm_E_199.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_fixp16ratiodecay_vit_base_patch16_tr_timm_CPIAm/PuzzleTuning_sae_vit_base_patch16_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_SAE_fixp16ratiodecay_CPIAm_E_199.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, given_name=given_name) # PuzzleTuning Ablation studies:上游不要puzzle 所以是 VPT+MAE # 7 MAE+VPT,迁移timm,PromptTuning:VPT-Deep,seg_decoder:None (A40*4服务器pt5) # ViT_b16_224_timm_PuzzleTuning_MAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_MAE_vit_base_Prompt_Deep_tokennum_20_tr_timm_CPIAm/PuzzleTuning_mae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_timm_PuzzleTuning_MAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20, given_name=given_name) # 8 周期puzzle自动减小ratio,自动loop变化size 迁移MAEImageNet,PromptTuning:VPT-Deep,seg_decoder:None (A100-PCIE*2 服务器pt6) # ViT_b16_224_MAEImageNet_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_Prompt_Deep_tokennum_20_tr_MAEImageNet_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_MAEImageNet_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20, given_name=given_name) # 9 周期puzzle自动减小ratio,自动loop变化size 迁移Random,PromptTuning:VPT-Deep,seg_decoder:None (A40*4服务器pt7) # ViT_b16_224_Random_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_Prompt_Deep_tokennum_20_tr_Random_CPIAm/PuzzleTuning_sae_vit_base_patch16_Prompt_Deep_tokennum_20_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_Random_PuzzleTuning_SAE_CPIAm_Prompt_Deep_tokennum_20_E_199_promptstate.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, prompt_mode='Deep', Prompt_Token_num=20, given_name=given_name) # 10 周期puzzle自动减小ratio,自动loop变化size 迁移Random,PromptTuning:None,seg_decoder:None (4090*6服务器pt8) # ViT_b16_224_MAEImageNet_PuzzleTuning_SAE_CPIAm_E_199.pth checkpoint_path = '/root/autodl-tmp/runs/PuzzleTuning_SAE_vit_base_patch16_tr_MAEImageNet_CPIAm/PuzzleTuning_sae_vit_base_patch16_checkpoint-199.pth' save_model_path = '/root/autodl-tmp/output_models' given_name = r'ViT_b16_224_MAEImageNet_PuzzleTuning_SAE_CPIAm_E_199.pth' transfer_model_encoder(checkpoint_path, save_model_path, model_idx='ViT', edge_size=224, given_name=given_name) """ transfer_model_encoder(args.checkpoint_path, args.save_model_path, model_idx=args.model_idx, edge_size=args.edge_size, prompt_mode=args.PromptTuning, Prompt_Token_num=args.Prompt_Token_num, given_name=args.given_name) if __name__ == '__main__': args = get_args_parser() args = args.parse_args() main(args)