|
import torch |
|
import numpy as np |
|
import os |
|
import re |
|
|
|
def average(checkpoints, lambdas=[0.5, 0.5], num_models=6, output_dir=None, filename=None, skip_keys=None, ema=False): |
|
|
|
ckpt = torch.load(checkpoints[0], map_location='cpu') |
|
|
|
if ema: |
|
key = 'extra_state' |
|
state = ckpt['extra_state']['ema'] |
|
else: |
|
key = 'model' |
|
state = ckpt['model'] |
|
|
|
print(lambdas) |
|
|
|
|
|
|
|
if num_models == 1: |
|
average_state = {k : v.clone() * lambdas[0] for k, v in state.items()} |
|
for i in range(1, len(checkpoints)): |
|
skip_keys_list = set() |
|
print(checkpoints[i], lambdas[i]) |
|
if ema: |
|
statei = torch.load(checkpoints[i], map_location='cpu')['extra_state']['ema'] |
|
else: |
|
statei = torch.load(checkpoints[i], map_location='cpu')['model'] |
|
for k, v in average_state.items(): |
|
if k in statei and (skip_keys is None or ((not any([re.match(sk, k) for sk in skip_keys])) and (not any([sk in k for sk in skip_keys])))): |
|
try: |
|
average_state[k] += (lambdas[i])*statei[k].clone() |
|
except: |
|
print(k, average_state[k].shape, statei[k].shape) |
|
average_state[k] += (lambdas[i])*average_state[k].clone() |
|
else: |
|
average_state[k] += (lambdas[i])*average_state[k].clone() |
|
skip_keys_list.add(k) |
|
|
|
|
|
state_dict = average_state |
|
print(skip_keys_list) |
|
if ema: |
|
save_obj = {key:{'ema': state_dict, 'epoch': 0}} |
|
for k, v in ckpt['extra_state'].items(): |
|
if k != 'ema': |
|
save_obj['extra_state']=v |
|
print(k) |
|
for k, v in ckpt.items(): |
|
if k != key: |
|
save_obj[k]=v |
|
print(k) |
|
else: |
|
save_obj = {key: state_dict,} |
|
for k, v in ckpt.items(): |
|
if k != key: |
|
save_obj[k]=v |
|
print(k) |
|
output_path = os.path.join(output_dir, '{}.pt'.format(filename)) |
|
print('saving', output_path) |
|
torch.save(save_obj, output_path) |
|
|
|
else: |
|
if ema: |
|
state_dict1 = ckpt['extra_state']['ema'] |
|
state_dict2 = torch.load(checkpoints[1], map_location='cpu')['extra_state']['ema'] |
|
else: |
|
state_dict1 = ckpt['model'] |
|
state_dict2 = torch.load(checkpoints[1], map_location='cpu')['model'] |
|
for l in lambdas: |
|
average_state = {k : v * l for k, v in state_dict1.items()} |
|
for k, v in average_state.items(): |
|
if k in state_dict2: |
|
average_state[k] += (1-l)*state_dict2[k] |
|
else: |
|
average_state[k] += (1-l)*state_dict1[k] |
|
|
|
state_dict = average_state |
|
|
|
if ema: |
|
save_obj = {key:{'ema': state_dict,}} |
|
for k, v in ckpt['extra_state'].items(): |
|
if k != 'ema': |
|
save_obj['extra_state'][k]=v |
|
print(k) |
|
for k, v in ckpt.items(): |
|
if k != key: |
|
save_obj[k]=v |
|
print(k) |
|
else: |
|
save_obj = {key: state_dict,} |
|
for k, v in ckpt.items(): |
|
if k != key: |
|
save_obj[k]=v |
|
print(k) |
|
output_path = os.path.join(output_dir, '{}_l{:.2f}.pt'.format(filename, l)) |
|
print('saving', output_path) |
|
torch.save(save_obj, output_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_models=6 |
|
output_dir='/lus/scratch/NAT/gda2204/SHARED/logs/ofa/pretrained_models/average_models/' |
|
filename='avg_capvqa' |
|
lambdas = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0] |
|
|
|
checkpoints = ['/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/vqa/vqa_ofaplus_base_pretrain_s2_bs16_lr1e4_shuf_hsep1/20_0.04_1e-4_480/checkpoint_best.pt', |
|
'/lus/scratch/NAT/gda2204/SHARED/logs/ofa/checkpoints/caption/caption_stage_1_ofaplus_base_pretrain_s2_hsep1_bs16_shuf/10_0.06_6000/checkpoint_best.pt', |
|
] |
|
|
|
|
|
|
|
average(checkpoints, lambdas=lambdas, num_models=num_models, output_dir=output_dir, filename=filename) |
|
|