File size: 11,177 Bytes
3c8ff2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import os
import torch

sub_dir = os.path.join(os.getcwd(), 'model')
if os.path.isdir(sub_dir): os.chdir(sub_dir)
from src.backbones import base_model, utae, uncrtaints

S1_BANDS = 2
S2_BANDS = 13

def get_base_model(config):
    model = base_model.BaseModel(config)
    return model

# for running image reconstruction
def get_generator(config):
    if "unet" in config.model:
            model = utae.UNet(
                input_dim=S1_BANDS*config.use_sar+S2_BANDS,
                encoder_widths=config.encoder_widths,
                decoder_widths=config.decoder_widths,
                out_conv=config.out_conv,
                out_nonlin_mean=config.mean_nonLinearity,
                out_nonlin_var=config.var_nonLinearity,
                str_conv_k=4,
                str_conv_s=2,
                str_conv_p=1,
                encoder_norm=config.encoder_norm,
                norm_skip='batch',
                norm_up='batch',
                decoder_norm=config.decoder_norm,
                encoder=False,
                return_maps=False,
                pad_value=config.pad_value,
                padding_mode=config.padding_mode,
            )
    elif "utae" in config.model:
        if config.pretrain:
            # on monotemporal data, just use a simple U-Net
            model = utae.UNet(
                input_dim=S1_BANDS*config.use_sar+S2_BANDS, 
                encoder_widths=config.encoder_widths,
                decoder_widths=config.decoder_widths,
                out_conv=config.out_conv,
                out_nonlin_mean=config.mean_nonLinearity,
                out_nonlin_var=config.var_nonLinearity,
                str_conv_k=4,
                str_conv_s=2,
                str_conv_p=1,
                encoder_norm=config.encoder_norm,
                norm_skip='batch',
                norm_up='batch',
                decoder_norm=config.decoder_norm,
                encoder=False,
                return_maps=False,
                pad_value=config.pad_value,
                padding_mode=config.padding_mode,
            )
        else:
            model = utae.UTAE(
                input_dim=S1_BANDS*config.use_sar+S2_BANDS,
                encoder_widths=config.encoder_widths,
                decoder_widths=config.decoder_widths,
                out_conv=config.out_conv,
                out_nonlin_mean=config.mean_nonLinearity,
                out_nonlin_var=config.var_nonLinearity,
                str_conv_k=4,
                str_conv_s=2,
                str_conv_p=1,
                agg_mode=config.agg_mode,
                encoder_norm=config.encoder_norm,
                norm_skip='batch',
                norm_up='batch',
                decoder_norm=config.decoder_norm,
                n_head=config.n_head,
                d_model=config.d_model,
                d_k=config.d_k,
                encoder=False,
                return_maps=False,
                pad_value=config.pad_value,
                padding_mode=config.padding_mode,
                positional_encoding=config.positional_encoding,
                scale_by=config.scale_by
            )
    elif 'uncrtaints' == config.model:
        model = uncrtaints.UNCRTAINTS(
                input_dim=S1_BANDS*config.use_sar+S2_BANDS,
                encoder_widths=config.encoder_widths,
                decoder_widths=config.decoder_widths, 
                out_conv=config.out_conv,
                out_nonlin_mean=config.mean_nonLinearity,
                out_nonlin_var=config.var_nonLinearity,
                agg_mode=config.agg_mode,
                encoder_norm=config.encoder_norm,
                decoder_norm=config.decoder_norm,
                n_head=config.n_head,
                d_model=config.d_model,
                d_k=config.d_k,
                pad_value=config.pad_value,
                padding_mode=config.padding_mode,
                positional_encoding=config.positional_encoding,
                covmode=config.covmode,
                scale_by=config.scale_by,
                separate_out=config.separate_out,
                use_v=config.use_v,
                block_type=config.block_type,
                is_mono=config.pretrain
            )
    else: raise NotImplementedError
    return model


def get_model(config):
    return get_base_model(config)


def save_model(config, epoch, model, name):
    state_dict = {"epoch":          epoch,
                  "state_dict":     model.state_dict(),
                  "state_dict_G":   model.netG.state_dict(),
                  "optimizer_G":    model.optimizer_G.state_dict(),
                  "scheduler_G":    model.scheduler_G.state_dict()}
    torch.save(state_dict,
        os.path.join(config.res_dir, config.experiment_name, f"{name}.pth.tar"),
    )


def load_model(config, model, train_out_layer=True, load_out_partly=True):
    # load pre-trained checkpoints, but only of matching weigths
    
    pretrained_dict = torch.load(config.trained_checkp, map_location=config.device)["state_dict_G"]
    model_dict      = model.netG.state_dict() 

    not_str = "" if pretrained_dict.keys() == model_dict.keys() else "not "
    print(f'The new and the (pre-)trained model architectures are {not_str}identical.\n')

    try:# try loading checkpoint strictly, all weights must match
        # (this is satisfied e.g. when resuming training)

        if train_out_layer: raise NotImplementedError # move to 'except' case
        model.netG.load_state_dict(pretrained_dict, strict=True)
        freeze_layers(model.netG, grad=True)    # set all weights to trainable, no need to freeze
        model.frozen, freeze_these = False, []  # ... as all weights match appropriately
    except: # if some weights don't match (e.g. when loading from pre-trained U-Net), then only load the compatible subset ...
        #     ... freeze compatible weights and make the incompatibel weights trainable

        # load output layer partly, e.g. when pretrained net has 3 output channels but novel model has 13
        if load_out_partly:
            # overwrite output layer even when dimensions mismatch (this overwrites kernels individually)
            #""" # these lines were used for predicting the 13 mean bands when mean and var shared a single output layer
            temp_weights, temp_biases       = model_dict['out_conv.conv.conv.0.weight'], model_dict['out_conv.conv.conv.0.bias']
            temp_weights[:S2_BANDS,...]     = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
            temp_biases[:S2_BANDS,...]      = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
            pretrained_dict['out_conv.conv.conv.0.weight'] = temp_weights[:S2_BANDS,...]
            pretrained_dict['out_conv.conv.conv.0.bias']   = temp_biases[:S2_BANDS,...]
            """
            if 'out_conv.conv.conv.0.weight' in pretrained_dict: # if predicting from a model with a single output layer for both mean and var
                pretrained_dict['out_conv_mean.conv.conv.0.weight'] = pretrained_dict['out_conv.conv.conv.0.weight'][:S2_BANDS,...]
                pretrained_dict['out_conv_mean.conv.conv.0.bias']   = pretrained_dict['out_conv.conv.conv.0.bias'][:S2_BANDS,...]
            if 'out_conv_var.conv.conv.0.weight' in model_dict:
                pretrained_dict['out_conv_var.conv.conv.0.weight'] = model_dict['out_conv_var.conv.conv.0.weight']
                pretrained_dict['out_conv_var.conv.conv.0.bias']   = model_dict['out_conv_var.conv.conv.0.bias']
            """

        # check for size mismatch and exclude layers whose dimensions mismatch (they won't be loaded)
        pretrained_dict = {k:v for k,v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}
        model_dict.update(pretrained_dict) 
        model.netG.load_state_dict(model_dict, strict=False)
        
        # freeze pretrained weights 
        model.frozen = True
        freeze_layers(model.netG, grad=True) # set all weights to trainable, except final ...
        if train_out_layer:
            # freeze all but last layer
            all_but_last = {k:v for k, v in pretrained_dict.items() if 'out_conv.conv.conv.0' not in k}
            freeze_layers(model.netG, apply_to=all_but_last, grad=False)
            freeze_these = list(all_but_last.keys())
        else: # freeze all pre-trained layers, without exceptions
            freeze_layers(model.netG, apply_to=pretrained_dict, grad=False)
            freeze_these = list(pretrained_dict.keys())
    train_these = [train_layer for train_layer in list(model_dict.keys()) if train_layer not in freeze_these]
    print(f'\nFroze these layers: {freeze_these}')
    print(f'\nTrain these layers: {train_these}')

    if config.resume_from:
        resume_at = int(config.trained_checkp.split('.pth.tar')[0].split('_')[-1])
        print(f'\nResuming training at epoch {resume_at+1}/{config.epochs}, loading optimizers and schedulers')
        # if continuing training, then also load states of previous runs' optimizers and schedulers
        # ---else, we start optimizing from scratch but with the model parameters loaded above
        optimizer_G_dict = torch.load(config.trained_checkp, map_location=config.device)["optimizer_G"]
        model.optimizer_G.load_state_dict(optimizer_G_dict)

        scheduler_G_dict = torch.load(config.trained_checkp, map_location=config.device)["scheduler_G"]
        model.scheduler_G.load_state_dict(scheduler_G_dict)

    # no return value, models are passed by reference


# function to load checkpoints of individual and ensemble models
# (this is used for training and testing scripts)
def load_checkpoint(config, checkp_dir, model, name):
    print(checkp_dir)
    chckp_path = os.path.join(checkp_dir, config.experiment_name, f"{name}.pth.tar")
    print(f'Loading checkpoint {chckp_path}')
    checkpoint = torch.load(chckp_path, map_location=config.device)["state_dict"]

    try: # try loading checkpoint strictly, all weights & their names must match
        model.load_state_dict(checkpoint, strict=True)
    except:
        # rename keys
        #   in_block1 -> in_block0, out_block1 -> out_block0
        checkpoint_renamed = dict()
        for key, val in checkpoint.items():
            if 'in_block' in key or 'out_block' in key:
                strs    = key.split('.')
                strs[1] = strs[1][:-1] + str(int(strs[1][-1])-1)
                strs[1] = '.'.join([strs[1][:-1], strs[1][-1]])
                key     = '.'.join(strs)
            checkpoint_renamed[key] = val
        model.load_state_dict(checkpoint_renamed, strict=False)

def freeze_layers(net, apply_to=None, grad=False):
    if net is not None:
        for k, v in net.named_parameters():
            # check if layer is supposed to be frozen
            if hasattr(v, 'requires_grad') and v.dtype != torch.int64:
                if apply_to is not None:
                    # flip
                    if k in apply_to.keys() and v.size() == apply_to[k].size(): 
                        v.requires_grad_(grad)
                else: # otherwise apply indiscriminately to all layers
                    v.requires_grad_(grad)