import torch import numpy as np import os import re def average(checkpoints, lambdas=[0.5, 0.5], num_models=6, output_dir=None, filename=None, skip_keys=None, ema=False): ckpt = torch.load(checkpoints[0], map_location='cpu') if ema: key = 'extra_state' state = ckpt['extra_state']['ema'] else: key = 'model' state = ckpt['model'] print(lambdas) if num_models == 1: average_state = {k : v.clone() * lambdas[0] for k, v in state.items()} for i in range(1, len(checkpoints)): skip_keys_list = set() print(checkpoints[i], lambdas[i]) if ema: statei = torch.load(checkpoints[i], map_location='cpu')['extra_state']['ema'] else: statei = torch.load(checkpoints[i], map_location='cpu')['model'] for k, v in average_state.items(): if k in statei and (skip_keys is None or ((not any([re.match(sk, k) for sk in skip_keys])) and (not any([sk in k for sk in skip_keys])))): try: average_state[k] += (lambdas[i])*statei[k].clone() except: print(k, average_state[k].shape, statei[k].shape) average_state[k] += (lambdas[i])*average_state[k].clone() else: average_state[k] += (lambdas[i])*average_state[k].clone() skip_keys_list.add(k) state_dict = average_state print(skip_keys_list) if ema: save_obj = {key:{'ema': state_dict, 'epoch': 0}} for k, v in ckpt['extra_state'].items(): if k != 'ema': save_obj['extra_state']=v print(k) for k, v in ckpt.items(): if k != key: save_obj[k]=v print(k) else: save_obj = {key: state_dict,} for k, v in ckpt.items(): if k != key: save_obj[k]=v print(k) output_path = os.path.join(output_dir, '{}.pt'.format(filename)) print('saving', output_path) torch.save(save_obj, output_path) else: if ema: state_dict1 = ckpt['extra_state']['ema'] state_dict2 = torch.load(checkpoints[1], map_location='cpu')['extra_state']['ema'] else: state_dict1 = ckpt['model'] state_dict2 = torch.load(checkpoints[1], map_location='cpu')['model'] for l in lambdas: average_state = {k : v * l for k, v in state_dict1.items()} #{k : v * (1./NUM_MODELS) for k, v in state_dict1.items()} for k, v in average_state.items(): if k in state_dict2: average_state[k] += (1-l)*state_dict2[k] else: average_state[k] += (1-l)*state_dict1[k] state_dict = average_state if ema: save_obj = {key:{'ema': state_dict,}} for k, v in ckpt['extra_state'].items(): if k != 'ema': save_obj['extra_state'][k]=v print(k) for k, v in ckpt.items(): if k != key: save_obj[k]=v print(k) else: save_obj = {key: state_dict,} for k, v in ckpt.items(): if k != key: save_obj[k]=v print(k) output_path = os.path.join(output_dir, '{}_l{:.2f}.pt'.format(filename, l)) print('saving', output_path) torch.save(save_obj, output_path) # average of several models # lambdas = [1/4, 1/4, 1/4, 1/4] # num_models=1 # output_dir='/lus/scratch/NAT/gda2204/SHARED/logs/ofa/pretrained_models/average_models/' # filename='avg_caprefsnlivqa' # checkpoints = [ # '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/caption/caption_stage_1_ofaplus_base_pretrain_s2_hsep1_bs16_shuf/10_0.06_6000/checkpoint_best.pt', # '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/refcocoplus/refcocoplus_ofaplus_base_pretrain_s2_hsep1_fix_lr5e5_bs8_4_shuf/10_5e-5_512/checkpoint_best.pt', # '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/snli_ve/snli_ve_ofaplus_base_pretrain_s2_hsep1/10_5e-5/checkpoint_best.pt', # '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/vqa/vqa_ofaplus_base_pretrain_s2_bs16_lr1e4_shuf_hsep1/20_0.04_1e-4_480/checkpoint_best.pt', # ] # for weight interpolation num_models=6 output_dir='/lus/scratch/NAT/gda2204/SHARED/logs/ofa/pretrained_models/average_models/' filename='avg_capvqa' lambdas = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] checkpoints = ['/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/vqa/vqa_ofaplus_base_pretrain_s2_bs16_lr1e4_shuf_hsep1/20_0.04_1e-4_480/checkpoint_best.pt', '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/caption/caption_stage_1_ofaplus_base_pretrain_s2_hsep1_bs16_shuf/10_0.06_6000/checkpoint_best.pt', ] average(checkpoints, lambdas=lambdas, num_models=num_models, output_dir=output_dir, filename=filename)