UnIVAL / preprocess /average_save_models.py
mshukor
init
26fd00c
raw
history blame
4.39 kB
import torch
import numpy as np
import os
import re
def average(checkpoints, lambdas=[0.5, 0.5], num_models=6, output_dir=None, filename=None, skip_keys=None, ema=False):
ckpt = torch.load(checkpoints[0], map_location='cpu')
if ema:
key = 'extra_state'
state = ckpt['extra_state']['ema']
else:
key = 'model'
state = ckpt['model']
print(lambdas)
if num_models == 1:
average_state = {k : v.clone() * lambdas[0] for k, v in state.items()}
for i in range(1, len(checkpoints)):
skip_keys_list = set()
print(checkpoints[i], lambdas[i])
if ema:
statei = torch.load(checkpoints[i], map_location='cpu')['extra_state']['ema']
else:
statei = torch.load(checkpoints[i], map_location='cpu')['model']
for k, v in average_state.items():
if k in statei and (skip_keys is None or ((not any([re.match(sk, k) for sk in skip_keys])) and (not any([sk in k for sk in skip_keys])))):
try:
average_state[k] += (lambdas[i])*statei[k].clone()
except:
print(k, average_state[k].shape, statei[k].shape)
average_state[k] += (lambdas[i])*average_state[k].clone()
else:
average_state[k] += (lambdas[i])*average_state[k].clone()
skip_keys_list.add(k)
state_dict = average_state
print(skip_keys_list)
if ema:
save_obj = {key:{'ema': state_dict, 'epoch': 0}}
for k, v in ckpt['extra_state'].items():
if k != 'ema':
save_obj['extra_state']=v
print(k)
for k, v in ckpt.items():
if k != key:
save_obj[k]=v
print(k)
else:
save_obj = {key: state_dict,}
for k, v in ckpt.items():
if k != key:
save_obj[k]=v
print(k)
output_path = os.path.join(output_dir, '{}.pt'.format(filename))
print('saving', output_path)
torch.save(save_obj, output_path)
else:
if ema:
state_dict1 = ckpt['extra_state']['ema']
state_dict2 = torch.load(checkpoints[1], map_location='cpu')['extra_state']['ema']
else:
state_dict1 = ckpt['model']
state_dict2 = torch.load(checkpoints[1], map_location='cpu')['model']
for l in lambdas:
average_state = {k : v * l for k, v in state_dict1.items()} #{k : v * (1./NUM_MODELS) for k, v in state_dict1.items()}
for k, v in average_state.items():
if k in state_dict2:
average_state[k] += (1-l)*state_dict2[k]
else:
average_state[k] += (1-l)*state_dict1[k]
state_dict = average_state
if ema:
save_obj = {key:{'ema': state_dict,}}
for k, v in ckpt['extra_state'].items():
if k != 'ema':
save_obj['extra_state'][k]=v
print(k)
for k, v in ckpt.items():
if k != key:
save_obj[k]=v
print(k)
else:
save_obj = {key: state_dict,}
for k, v in ckpt.items():
if k != key:
save_obj[k]=v
print(k)
output_path = os.path.join(output_dir, '{}_l{:.2f}.pt'.format(filename, l))
print('saving', output_path)
torch.save(save_obj, output_path)
# average of several models
# lambdas = [1/4, 1/4, 1/4, 1/4]
# num_models=1
# output_dir='/lus/scratch/NAT/gda2204/SHARED/logs/ofa/pretrained_models/average_models/'
# filename='avg_caprefsnlivqa'
# checkpoints = [
# '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/caption/caption_stage_1_ofaplus_base_pretrain_s2_hsep1_bs16_shuf/10_0.06_6000/checkpoint_best.pt',
# '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/refcocoplus/refcocoplus_ofaplus_base_pretrain_s2_hsep1_fix_lr5e5_bs8_4_shuf/10_5e-5_512/checkpoint_best.pt',
# '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/snli_ve/snli_ve_ofaplus_base_pretrain_s2_hsep1/10_5e-5/checkpoint_best.pt',
# '/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/vqa/vqa_ofaplus_base_pretrain_s2_bs16_lr1e4_shuf_hsep1/20_0.04_1e-4_480/checkpoint_best.pt',
# ]
# for weight interpolation
num_models=6
output_dir='/lus/scratch/NAT/gda2204/SHARED/logs/ofa/pretrained_models/average_models/'
filename='avg_capvqa'
lambdas = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
checkpoints = ['/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/vqa/vqa_ofaplus_base_pretrain_s2_bs16_lr1e4_shuf_hsep1/20_0.04_1e-4_480/checkpoint_best.pt',
'/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/caption/caption_stage_1_ofaplus_base_pretrain_s2_hsep1_bs16_shuf/10_0.06_6000/checkpoint_best.pt',
]
average(checkpoints, lambdas=lambdas, num_models=num_models, output_dir=output_dir, filename=filename)