|
import os |
|
import torch |
|
|
|
sub_dir = os.path.join(os.getcwd(), 'model') |
|
if os.path.isdir(sub_dir): os.chdir(sub_dir) |
|
from src.backbones import base_model, utae, uncrtaints |
|
|
|
S1_BANDS = 2 |
|
S2_BANDS = 13 |
|
|
|
def get_base_model(config): |
|
model = base_model.BaseModel(config) |
|
return model |
|
|
|
|
|
def get_generator(config): |
|
if "unet" in config.model: |
|
model = utae.UNet( |
|
input_dim=S1_BANDS*config.use_sar+S2_BANDS, |
|
encoder_widths=config.encoder_widths, |
|
decoder_widths=config.decoder_widths, |
|
out_conv=config.out_conv, |
|
out_nonlin_mean=config.mean_nonLinearity, |
|
out_nonlin_var=config.var_nonLinearity, |
|
str_conv_k=4, |
|
str_conv_s=2, |
|
str_conv_p=1, |
|
encoder_norm=config.encoder_norm, |
|
norm_skip='batch', |
|
norm_up='batch', |
|
decoder_norm=config.decoder_norm, |
|
encoder=False, |
|
return_maps=False, |
|
pad_value=config.pad_value, |
|
padding_mode=config.padding_mode, |
|
) |
|
elif "utae" in config.model: |
|
if config.pretrain: |
|
|
|
model = utae.UNet( |
|
input_dim=S1_BANDS*config.use_sar+S2_BANDS, |
|
encoder_widths=config.encoder_widths, |
|
decoder_widths=config.decoder_widths, |
|
out_conv=config.out_conv, |
|
out_nonlin_mean=config.mean_nonLinearity, |
|
out_nonlin_var=config.var_nonLinearity, |
|
str_conv_k=4, |
|
str_conv_s=2, |
|
str_conv_p=1, |
|
encoder_norm=config.encoder_norm, |
|
norm_skip='batch', |
|
norm_up='batch', |
|
decoder_norm=config.decoder_norm, |
|
encoder=False, |
|
return_maps=False, |
|
pad_value=config.pad_value, |
|
padding_mode=config.padding_mode, |
|
) |
|
else: |
|
model = utae.UTAE( |
|
input_dim=S1_BANDS*config.use_sar+S2_BANDS, |
|
encoder_widths=config.encoder_widths, |
|
decoder_widths=config.decoder_widths, |
|
out_conv=config.out_conv, |
|
out_nonlin_mean=config.mean_nonLinearity, |
|
out_nonlin_var=config.var_nonLinearity, |
|
str_conv_k=4, |
|
str_conv_s=2, |
|
str_conv_p=1, |
|
agg_mode=config.agg_mode, |
|
encoder_norm=config.encoder_norm, |
|
norm_skip='batch', |
|
norm_up='batch', |
|
decoder_norm=config.decoder_norm, |
|
n_head=config.n_head, |
|
d_model=config.d_model, |
|
d_k=config.d_k, |
|
encoder=False, |
|
return_maps=False, |
|
pad_value=config.pad_value, |
|
padding_mode=config.padding_mode, |
|
positional_encoding=config.positional_encoding, |
|
scale_by=config.scale_by |
|
) |
|
elif 'uncrtaints' == config.model: |
|
model = uncrtaints.UNCRTAINTS( |
|
input_dim=S1_BANDS*config.use_sar+S2_BANDS, |
|
encoder_widths=config.encoder_widths, |
|
decoder_widths=config.decoder_widths, |
|
out_conv=config.out_conv, |
|
out_nonlin_mean=config.mean_nonLinearity, |
|
out_nonlin_var=config.var_nonLinearity, |
|
agg_mode=config.agg_mode, |
|
encoder_norm=config.encoder_norm, |
|
decoder_norm=config.decoder_norm, |
|
n_head=config.n_head, |
|
d_model=config.d_model, |
|
d_k=config.d_k, |
|
pad_value=config.pad_value, |
|
padding_mode=config.padding_mode, |
|
positional_encoding=config.positional_encoding, |
|
covmode=config.covmode, |
|
scale_by=config.scale_by, |
|
separate_out=config.separate_out, |
|
use_v=config.use_v, |
|
block_type=config.block_type, |
|
is_mono=config.pretrain |
|
) |
|
else: raise NotImplementedError |
|
return model |
|
|
|
|
|
def get_model(config): |
|
return get_base_model(config) |
|
|
|
|
|
def save_model(config, epoch, model, name): |
|
state_dict = {"epoch": epoch, |
|
"state_dict": model.state_dict(), |
|
"state_dict_G": model.netG.state_dict(), |
|
"optimizer_G": model.optimizer_G.state_dict(), |
|
"scheduler_G": model.scheduler_G.state_dict()} |
|
torch.save(state_dict, |
|
os.path.join(config.res_dir, config.experiment_name, f"{name}.pth.tar"), |
|
) |
|
|
|
|
|
def load_model(config, model, train_out_layer=True, load_out_partly=True): |
|
|
|
|
|
pretrained_dict = torch.load(config.trained_checkp, map_location=config.device)["state_dict_G"] |
|
model_dict = model.netG.state_dict() |
|
|
|
not_str = "" if pretrained_dict.keys() == model_dict.keys() else "not " |
|
print(f'The new and the (pre-)trained model architectures are {not_str}identical.\n') |
|
|
|
try: |
|
|
|
|
|
if train_out_layer: raise NotImplementedError |
|
model.netG.load_state_dict(pretrained_dict, strict=True) |
|
freeze_layers(model.netG, grad=True) |
|
model.frozen, freeze_these = False, [] |
|
except: |
|
|
|
|
|
|
|
if load_out_partly: |
|
|
|
|
|
temp_weights, temp_biases = model_dict['out_conv.conv.conv.0.weight'], model_dict['out_conv.conv.conv.0.bias'] |
|
temp_weights[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...] |
|
temp_biases[:S2_BANDS,...] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...] |
|
pretrained_dict['out_conv.conv.conv.0.weight'] = temp_weights[:S2_BANDS,...] |
|
pretrained_dict['out_conv.conv.conv.0.bias'] = temp_biases[:S2_BANDS,...] |
|
""" |
|
if 'out_conv.conv.conv.0.weight' in pretrained_dict: # if predicting from a model with a single output layer for both mean and var |
|
pretrained_dict['out_conv_mean.conv.conv.0.weight'] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...] |
|
pretrained_dict['out_conv_mean.conv.conv.0.bias'] = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...] |
|
if 'out_conv_var.conv.conv.0.weight' in model_dict: |
|
pretrained_dict['out_conv_var.conv.conv.0.weight'] = model_dict['out_conv_var.conv.conv.0.weight'] |
|
pretrained_dict['out_conv_var.conv.conv.0.bias'] = model_dict['out_conv_var.conv.conv.0.bias'] |
|
""" |
|
|
|
|
|
pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()} |
|
model_dict.update(pretrained_dict) |
|
model.netG.load_state_dict(model_dict, strict=False) |
|
|
|
|
|
model.frozen = True |
|
freeze_layers(model.netG, grad=True) |
|
if train_out_layer: |
|
|
|
all_but_last = {k:v for k, v in pretrained_dict.items() if 'out_conv.conv.conv.0' not in k} |
|
freeze_layers(model.netG, apply_to=all_but_last, grad=False) |
|
freeze_these = list(all_but_last.keys()) |
|
else: |
|
freeze_layers(model.netG, apply_to=pretrained_dict, grad=False) |
|
freeze_these = list(pretrained_dict.keys()) |
|
train_these = [train_layer for train_layer in list(model_dict.keys()) if train_layer not in freeze_these] |
|
print(f'\nFroze these layers: {freeze_these}') |
|
print(f'\nTrain these layers: {train_these}') |
|
|
|
if config.resume_from: |
|
resume_at = int(config.trained_checkp.split('.pth.tar')[0].split('_')[-1]) |
|
print(f'\nResuming training at epoch {resume_at+1}/{config.epochs}, loading optimizers and schedulers') |
|
|
|
|
|
optimizer_G_dict = torch.load(config.trained_checkp, map_location=config.device)["optimizer_G"] |
|
model.optimizer_G.load_state_dict(optimizer_G_dict) |
|
|
|
scheduler_G_dict = torch.load(config.trained_checkp, map_location=config.device)["scheduler_G"] |
|
model.scheduler_G.load_state_dict(scheduler_G_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(config, checkp_dir, model, name): |
|
print(checkp_dir) |
|
chckp_path = os.path.join(checkp_dir, config.experiment_name, f"{name}.pth.tar") |
|
print(f'Loading checkpoint {chckp_path}') |
|
checkpoint = torch.load(chckp_path, map_location=config.device)["state_dict"] |
|
|
|
try: |
|
model.load_state_dict(checkpoint, strict=True) |
|
except: |
|
|
|
|
|
checkpoint_renamed = dict() |
|
for key, val in checkpoint.items(): |
|
if 'in_block' in key or 'out_block' in key: |
|
strs = key.split('.') |
|
strs[1] = strs[1][:-1] + str(int(strs[1][-1])-1) |
|
strs[1] = '.'.join([strs[1][:-1], strs[1][-1]]) |
|
key = '.'.join(strs) |
|
checkpoint_renamed[key] = val |
|
model.load_state_dict(checkpoint_renamed, strict=False) |
|
|
|
def freeze_layers(net, apply_to=None, grad=False): |
|
if net is not None: |
|
for k, v in net.named_parameters(): |
|
|
|
if hasattr(v, 'requires_grad') and v.dtype != torch.int64: |
|
if apply_to is not None: |
|
|
|
if k in apply_to.keys() and v.size() == apply_to[k].size(): |
|
v.requires_grad_(grad) |
|
else: |
|
v.requires_grad_(grad) |