File size: 11,177 Bytes
3c8ff2e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
import os
import torch
sub_dir = os.path.join(os.getcwd(), 'model')
if os.path.isdir(sub_dir): os.chdir(sub_dir)
from src.backbones import base_model, utae, uncrtaints
S1_BANDS = 2
S2_BANDS = 13
def get_base_model(config):
model = base_model.BaseModel(config)
return model
# for running image reconstruction
def get_generator(config):
if "unet" in config.model:
model = utae.UNet(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
)
elif "utae" in config.model:
if config.pretrain:
# on monotemporal data, just use a simple U-Net
model = utae.UNet(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
)
else:
model = utae.UTAE(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
str_conv_k=4,
str_conv_s=2,
str_conv_p=1,
agg_mode=config.agg_mode,
encoder_norm=config.encoder_norm,
norm_skip='batch',
norm_up='batch',
decoder_norm=config.decoder_norm,
n_head=config.n_head,
d_model=config.d_model,
d_k=config.d_k,
encoder=False,
return_maps=False,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
positional_encoding=config.positional_encoding,
scale_by=config.scale_by
)
elif 'uncrtaints' == config.model:
model = uncrtaints.UNCRTAINTS(
input_dim=S1_BANDS*config.use_sar+S2_BANDS,
encoder_widths=config.encoder_widths,
decoder_widths=config.decoder_widths,
out_conv=config.out_conv,
out_nonlin_mean=config.mean_nonLinearity,
out_nonlin_var=config.var_nonLinearity,
agg_mode=config.agg_mode,
encoder_norm=config.encoder_norm,
decoder_norm=config.decoder_norm,
n_head=config.n_head,
d_model=config.d_model,
d_k=config.d_k,
pad_value=config.pad_value,
padding_mode=config.padding_mode,
positional_encoding=config.positional_encoding,
covmode=config.covmode,
scale_by=config.scale_by,
separate_out=config.separate_out,
use_v=config.use_v,
block_type=config.block_type,
is_mono=config.pretrain
)
else: raise NotImplementedError
return model
def get_model(config):
return get_base_model(config)
def save_model(config, epoch, model, name):
state_dict = {"epoch": epoch,
"state_dict": model.state_dict(),
"state_dict_G": model.netG.state_dict(),
"optimizer_G": model.optimizer_G.state_dict(),
"scheduler_G": model.scheduler_G.state_dict()}
torch.save(state_dict,
os.path.join(config.res_dir, config.experiment_name, f"{name}.pth.tar"),
)
def load_model(config, model, train_out_layer=True, load_out_partly=True):
# load pre-trained checkpoints, but only of matching weigths
pretrained_dict = torch.load(config.trained_checkp, map_location=config.device)["state_dict_G"]
model_dict = model.netG.state_dict()
not_str = "" if pretrained_dict.keys() == model_dict.keys() else "not "
print(f'The new and the (pre-)trained model architectures are {not_str}identical.\n')
try:# try loading checkpoint strictly, all weights must match
# (this is satisfied e.g. when resuming training)
if train_out_layer: raise NotImplementedError # move to 'except' case
model.netG.load_state_dict(pretrained_dict, strict=True)
freeze_layers(model.netG, grad=True) # set all weights to trainable, no need to freeze
model.frozen, freeze_these = False, [] # ... as all weights match appropriately
except: # if some weights don't match (e.g. when loading from pre-trained U-Net), then only load the compatible subset ...
# ... freeze compatible weights and make the incompatibel weights trainable
# load output layer partly, e.g. when pretrained net has 3 output channels but novel model has 13
if load_out_partly:
# overwrite output layer even when dimensions mismatch (this overwrites kernels individually)
#""" # these lines were used for predicting the 13 mean bands when mean and var shared a single output layer
temp_weights, temp_biases = model_dict['out_conv.conv.conv.0.weight'], model_dict['out_conv.conv.conv.0.bias']
temp_weights[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
temp_biases[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
pretrained_dict['out_conv.conv.conv.0.weight'] = temp_weights[:S2_BANDS,...]
pretrained_dict['out_conv.conv.conv.0.bias'] = temp_biases[:S2_BANDS,...]
"""
if 'out_conv.conv.conv.0.weight' in pretrained_dict: # if predicting from a model with a single output layer for both mean and var
pretrained_dict['out_conv_mean.conv.conv.0.weight'] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
pretrained_dict['out_conv_mean.conv.conv.0.bias'] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
if 'out_conv_var.conv.conv.0.weight' in model_dict:
pretrained_dict['out_conv_var.conv.conv.0.weight'] = model_dict['out_conv_var.conv.conv.0.weight']
pretrained_dict['out_conv_var.conv.conv.0.bias'] = model_dict['out_conv_var.conv.conv.0.bias']
"""
# check for size mismatch and exclude layers whose dimensions mismatch (they won't be loaded)
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(pretrained_dict)
model.netG.load_state_dict(model_dict, strict=False)
# freeze pretrained weights
model.frozen = True
freeze_layers(model.netG, grad=True) # set all weights to trainable, except final ...
if train_out_layer:
# freeze all but last layer
all_but_last = {k:v for k, v in pretrained_dict.items() if 'out_conv.conv.conv.0' not in k}
freeze_layers(model.netG, apply_to=all_but_last, grad=False)
freeze_these = list(all_but_last.keys())
else: # freeze all pre-trained layers, without exceptions
freeze_layers(model.netG, apply_to=pretrained_dict, grad=False)
freeze_these = list(pretrained_dict.keys())
train_these = [train_layer for train_layer in list(model_dict.keys()) if train_layer not in freeze_these]
print(f'\nFroze these layers: {freeze_these}')
print(f'\nTrain these layers: {train_these}')
if config.resume_from:
resume_at = int(config.trained_checkp.split('.pth.tar')[0].split('_')[-1])
print(f'\nResuming training at epoch {resume_at+1}/{config.epochs}, loading optimizers and schedulers')
# if continuing training, then also load states of previous runs' optimizers and schedulers
# ---else, we start optimizing from scratch but with the model parameters loaded above
optimizer_G_dict = torch.load(config.trained_checkp, map_location=config.device)["optimizer_G"]
model.optimizer_G.load_state_dict(optimizer_G_dict)
scheduler_G_dict = torch.load(config.trained_checkp, map_location=config.device)["scheduler_G"]
model.scheduler_G.load_state_dict(scheduler_G_dict)
# no return value, models are passed by reference
# function to load checkpoints of individual and ensemble models
# (this is used for training and testing scripts)
def load_checkpoint(config, checkp_dir, model, name):
print(checkp_dir)
chckp_path = os.path.join(checkp_dir, config.experiment_name, f"{name}.pth.tar")
print(f'Loading checkpoint {chckp_path}')
checkpoint = torch.load(chckp_path, map_location=config.device)["state_dict"]
try: # try loading checkpoint strictly, all weights & their names must match
model.load_state_dict(checkpoint, strict=True)
except:
# rename keys
# in_block1 -> in_block0, out_block1 -> out_block0
checkpoint_renamed = dict()
for key, val in checkpoint.items():
if 'in_block' in key or 'out_block' in key:
strs = key.split('.')
strs[1] = strs[1][:-1] + str(int(strs[1][-1])-1)
strs[1] = '.'.join([strs[1][:-1], strs[1][-1]])
key = '.'.join(strs)
checkpoint_renamed[key] = val
model.load_state_dict(checkpoint_renamed, strict=False)
def freeze_layers(net, apply_to=None, grad=False):
if net is not None:
for k, v in net.named_parameters():
# check if layer is supposed to be frozen
if hasattr(v, 'requires_grad') and v.dtype != torch.int64:
if apply_to is not None:
# flip
if k in apply_to.keys() and v.size() == apply_to[k].size():
v.requires_grad_(grad)
else: # otherwise apply indiscriminately to all layers
v.requires_grad_(grad) |