XavierJiezou's picture
Upload folder using huggingface_hub
3c8ff2e verified
import os
import torch
sub_dir = os.path.join(os.getcwd(), 'model')
if os.path.isdir(sub_dir): os.chdir(sub_dir)
from src.backbones import base_model, utae, uncrtaints
S1_BANDS = 2
S2_BANDS = 13
def get_base_model(config):
model = base_model.BaseModel(config)
return model
# for running image reconstruction
def get_generator(config):
if "unet" in config.model:
model = utae.UNet(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
)
elif "utae" in config.model:
if config.pretrain:
# on monotemporal data, just use a simple U-Net
model = utae.UNet(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
)
else:
model = utae.UTAE(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
agg_mode=config.agg_mode,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
n_head=config.n_head,
d_model=config.d_model,
d_k=config.d_k,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
positional_encoding=config.positional_encoding,
scale_by=config.scale_by
)
elif 'uncrtaints' == config.model:
model = uncrtaints.UNCRTAINTS(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
agg_mode=config.agg_mode,
encoder_norm=config.encoder_norm,
decoder_norm=config.decoder_norm,
n_head=config.n_head,
d_model=config.d_model,
d_k=config.d_k,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
positional_encoding=config.positional_encoding,
covmode=config.covmode,
scale_by=config.scale_by,
separate_out=config.separate_out,
use_v=config.use_v,
block_type=config.block_type,
is_mono=config.pretrain
)
else: raise NotImplementedError
return model
def get_model(config):
return get_base_model(config)
def save_model(config, epoch, model, name):
state_dict = {"epoch": epoch,
"state_dict": model.state_dict(),
"state_dict_G": model.netG.state_dict(),
"optimizer_G": model.optimizer_G.state_dict(),
"scheduler_G": model.scheduler_G.state_dict()}
torch.save(state_dict,
os.path.join(config.res_dir, config.experiment_name, f"{name}.pth.tar"),
)
def load_model(config, model, train_out_layer=True, load_out_partly=True):
# load pre-trained checkpoints, but only of matching weigths
pretrained_dict = torch.load(config.trained_checkp, map_location=config.device)["state_dict_G"]
model_dict = model.netG.state_dict()
not_str = "" if pretrained_dict.keys() == model_dict.keys() else "not "
print(f'The new and the (pre-)trained model architectures are {not_str}identical.\n')
try:# try loading checkpoint strictly, all weights must match
# (this is satisfied e.g. when resuming training)
if train_out_layer: raise NotImplementedError # move to 'except' case
model.netG.load_state_dict(pretrained_dict, strict=True)
freeze_layers(model.netG, grad=True) # set all weights to trainable, no need to freeze
model.frozen, freeze_these = False, [] # ... as all weights match appropriately
except: # if some weights don't match (e.g. when loading from pre-trained U-Net), then only load the compatible subset ...
# ... freeze compatible weights and make the incompatibel weights trainable
# load output layer partly, e.g. when pretrained net has 3 output channels but novel model has 13
if load_out_partly:
# overwrite output layer even when dimensions mismatch (this overwrites kernels individually)
#""" # these lines were used for predicting the 13 mean bands when mean and var shared a single output layer
temp_weights, temp_biases = model_dict['out_conv.conv.conv.0.weight'], model_dict['out_conv.conv.conv.0.bias']
temp_weights[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
temp_biases[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
pretrained_dict['out_conv.conv.conv.0.weight'] = temp_weights[:S2_BANDS,...]
pretrained_dict['out_conv.conv.conv.0.bias'] = temp_biases[:S2_BANDS,...]
"""
if 'out_conv.conv.conv.0.weight' in pretrained_dict: # if predicting from a model with a single output layer for both mean and var
pretrained_dict['out_conv_mean.conv.conv.0.weight'] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
pretrained_dict['out_conv_mean.conv.conv.0.bias'] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
if 'out_conv_var.conv.conv.0.weight' in model_dict:
pretrained_dict['out_conv_var.conv.conv.0.weight'] = model_dict['out_conv_var.conv.conv.0.weight']
pretrained_dict['out_conv_var.conv.conv.0.bias'] = model_dict['out_conv_var.conv.conv.0.bias']
"""
# check for size mismatch and exclude layers whose dimensions mismatch (they won't be loaded)
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model.netG.load_state_dict(model_dict, strict=False)
# freeze pretrained weights
model.frozen = True
freeze_layers(model.netG, grad=True) # set all weights to trainable, except final ...
if train_out_layer:
# freeze all but last layer
all_but_last = {k:v for k, v in pretrained_dict.items() if 'out_conv.conv.conv.0' not in k}
freeze_layers(model.netG, apply_to=all_but_last, grad=False)
freeze_these = list(all_but_last.keys())
else: # freeze all pre-trained layers, without exceptions
freeze_layers(model.netG, apply_to=pretrained_dict, grad=False)
freeze_these = list(pretrained_dict.keys())
train_these = [train_layer for train_layer in list(model_dict.keys()) if train_layer not in freeze_these]
print(f'\nFroze these layers: {freeze_these}')
print(f'\nTrain these layers: {train_these}')
if config.resume_from:
resume_at = int(config.trained_checkp.split('.pth.tar')[0].split('_')[-1])
print(f'\nResuming training at epoch {resume_at+1}/{config.epochs}, loading optimizers and schedulers')
# if continuing training, then also load states of previous runs' optimizers and schedulers
# ---else, we start optimizing from scratch but with the model parameters loaded above
optimizer_G_dict = torch.load(config.trained_checkp, map_location=config.device)["optimizer_G"]
model.optimizer_G.load_state_dict(optimizer_G_dict)
scheduler_G_dict = torch.load(config.trained_checkp, map_location=config.device)["scheduler_G"]
model.scheduler_G.load_state_dict(scheduler_G_dict)
# no return value, models are passed by reference
# function to load checkpoints of individual and ensemble models
# (this is used for training and testing scripts)
def load_checkpoint(config, checkp_dir, model, name):
print(checkp_dir)
chckp_path = os.path.join(checkp_dir, config.experiment_name, f"{name}.pth.tar")
print(f'Loading checkpoint {chckp_path}')
checkpoint = torch.load(chckp_path, map_location=config.device)["state_dict"]
try: # try loading checkpoint strictly, all weights & their names must match
model.load_state_dict(checkpoint, strict=True)
except:
# rename keys
# in_block1 -> in_block0, out_block1 -> out_block0
checkpoint_renamed = dict()
for key, val in checkpoint.items():
if 'in_block' in key or 'out_block' in key:
strs = key.split('.')
strs[1] = strs[1][:-1] + str(int(strs[1][-1])-1)
strs[1] = '.'.join([strs[1][:-1], strs[1][-1]])
key = '.'.join(strs)
checkpoint_renamed[key] = val
model.load_state_dict(checkpoint_renamed, strict=False)
def freeze_layers(net, apply_to=None, grad=False):
if net is not None:
for k, v in net.named_parameters():
# check if layer is supposed to be frozen
if hasattr(v, 'requires_grad') and v.dtype != torch.int64:
if apply_to is not None:
# flip
if k in apply_to.keys() and v.size() == apply_to[k].size():
v.requires_grad_(grad)
else: # otherwise apply indiscriminately to all layers
v.requires_grad_(grad)