File size: 15,009 Bytes
52da96f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import torch
import abc
import os
import copy

import pytorch_lightning as pl
from utils.lr_scheduler import *
from torch import distributed as dist


class AbstractModel(pl.LightningModule):
    def __init__(self,

                 lr_scheduler_kwargs: dict = None,

                 optimizer_kwargs: dict = None,

                 save_path: str = None,

                 from_checkpoint: str = None,

                 load_prev_scheduler: bool = False,

                 save_weights_only: bool = True,):
        """



        Args:

            lr_scheduler: Kwargs for lr_scheduler

            optimizer_kwargs: Kwargs for optimizer_kwargs

            save_path: Save trained model

            from_checkpoint: Load model from checkpoint

            load_prev_scheduler: Whether load previous scheduler from checkpoint

            load_strict: Whether load model strictly

            save_weights_only: Whether save only weights or also optimizer and lr_scheduler

            

        """
        super().__init__()
        self.initialize_model()
        
        self.metrics = {}
        for stage in ["train", "valid", "test"]:
            stage_metrics = self.initialize_metrics(stage)
            # Rigister metrics as attributes
            for metric_name, metric in stage_metrics.items():
                setattr(self, metric_name, metric)
                
            self.metrics[stage] = stage_metrics
        
        if lr_scheduler_kwargs is None:
            # Default lr_scheduler
            self.lr_scheduler_kwargs = {
                "class": "ConstantLRScheduler",
                "init_lr": 0,
            }
            print("No lr_scheduler_kwargs provided. The default learning rate is 0.")

        else:
            self.lr_scheduler_kwargs = lr_scheduler_kwargs
        
        if optimizer_kwargs is None:
            # Default optimizer
            self.optimizer_kwargs = {
                "class": "AdamW",
                "betas": (0.9, 0.98),
                "weight_decay": 0.01,
            }
            print("No optimizer_kwargs provided. The default optimizer is AdamW.")
        else:
            self.optimizer_kwargs = optimizer_kwargs
        self.init_optimizers()

        self.save_path = save_path
        self.save_weights_only = save_weights_only
        
        # temp_step is used for accumulating gradients
        self.temp_step = 0
        self.step = 0
        self.epoch = 0
        
        self.load_prev_scheduler = load_prev_scheduler
        self.from_checkpoint = from_checkpoint
        if from_checkpoint:
            self.load_checkpoint(from_checkpoint)

    @abc.abstractmethod
    def initialize_model(self) -> None:
        """

        All model initialization should be done here

        Note that the whole model must be named as "self.model" for model saving and loading

        """
        raise NotImplementedError
    
    @abc.abstractmethod
    def forward(self, *args, **kwargs):
        """

        Forward propagation

        """
        raise NotImplementedError
    
    @abc.abstractmethod
    def initialize_metrics(self, stage: str) -> dict:
        """

        Initialize metrics for each stage

        Args:

            stage: "train", "valid" or "test"

        

        Returns:

            A dictionary of metrics for the stage. Keys are metric names and values are metric objects

        """
        raise NotImplementedError

    @abc.abstractmethod
    def loss_func(self, stage: str, outputs, labels) -> torch.Tensor:
        """



        Args:

            stage: "train", "valid" or "test"

            outputs: model outputs for calculating loss

            labels: labels for calculating loss



        Returns:

            loss



        """
        raise NotImplementedError

    @staticmethod
    def load_weights(model, weights):
        model_dict = model.state_dict()

        unused_params = []
        missed_params = list(model_dict.keys())

        for k, v in weights.items():
            if k in model_dict.keys():
                model_dict[k] = v
                missed_params.remove(k)

            else:
                unused_params.append(k)

        if len(missed_params) > 0:
            print(f"\033[31mSome weights of {type(model).__name__} were not "
                  f"initialized from the model checkpoint: {missed_params}\033[0m")

        if len(unused_params) > 0:
            print(f"\033[31mSome weights of the model checkpoint were not used: {unused_params}\033[0m")

        model.load_state_dict(model_dict)

    def optimizer_step(

        self,

        epoch: int,

        batch_idx: int,

        optimizer,

        optimizer_closure=None,

    ) -> None:
        super().optimizer_step(epoch, batch_idx, optimizer, optimizer_closure)

        self.temp_step += 1
        if self.temp_step == self.trainer.accumulate_grad_batches:
            self.step += 1
            self.temp_step = 0
    
    # For pytorch-lightning 1.9.5
    # def optimizer_step(
    #     self,
    #     epoch: int,
    #     batch_idx: int,
    #     optimizer,
    #     optimizer_idx: int = 0,
    #     optimizer_closure=None,
    #     on_tpu: bool = False,
    #     using_native_amp: bool = False,
    #     using_lbfgs: bool = False,
    # ) -> None:
    #     super().optimizer_step(
    #         epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
    #     )
    #     self.temp_step += 1
    #     if self.temp_step == self.trainer.accumulate_grad_batches:
    #         self.step += 1
    #         self.temp_step = 0

    def on_train_epoch_end(self):
        self.epoch += 1
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        
        # optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-4, weight_decay=0.01, betas=(0.9, 0.98))
        # for _ in range(1000):
        #     outputs = self(**inputs)
        #     loss = self.loss_func('train', outputs, labels)
        #     loss.backward()
        #     optimizer.step()
        #     optimizer.zero_grad()
        #
        # raise
        
        outputs = self(**inputs)
        loss = self.loss_func('train', outputs, labels)
        
        self.log("loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(**inputs)
        loss = self.loss_func('valid', outputs, labels)
        self.valid_outputs.append(loss)
        return loss

    def test_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(**inputs)
        
        loss = self.loss_func('test', outputs, labels)
        self.test_outputs.append(loss)
        return loss
    
    def on_train_start(self) -> None:
        # Load previous scheduler
        if getattr(self, "prev_schechuler", None) is not None:
            try:
                self.step = self.prev_schechuler["global_step"]
                self.epoch = self.prev_schechuler["epoch"]
                self.best_value = self.prev_schechuler["best_value"]
                self.lr_scheduler.load_state_dict(self.prev_schechuler["lr_scheduler"])
                print(f"Previous training global step: {self.step}")
                print(f"Previous training epoch: {self.epoch}")
                print(f"Previous best value: {self.best_value}")
                print(f"Previous lr_scheduler: {self.prev_schechuler['lr_scheduler']}")
                
                # Load optimizer state
                if hasattr(self.trainer.strategy, "deepspeed_engine"):
                    # For DeepSpeed strategy
                    try:
                        self.trainer.strategy.deepspeed_engine.load_checkpoint(self.from_checkpoint)
                    except Exception as e:
                        print(e)

                else:
                    # For DDP strategy
                    self.optimizer.load_state_dict(self.prev_schechuler["optimizer"])

            except Exception as e:
                print(e)
                raise Exception("Error in loading previous scheduler. Please set load_prev_scheduler=False")
    
    def on_validation_epoch_start(self) -> None:
        setattr(self, "valid_outputs", [])
    
    def on_test_epoch_start(self) -> None:
        setattr(self, "test_outputs", [])
            
    def load_checkpoint(self, from_checkpoint: str) -> None:
        """

        Args:

            from_checkpoint:  Path to checkpoint.

        """
        
        # If ``from_checkpoint`` is a directory, load the checkpoint in it
        if os.path.isdir(from_checkpoint):
            basename = os.path.basename(from_checkpoint)
            from_checkpoint = os.path.join(from_checkpoint, f"{basename}.pt")

        state_dict = torch.load(from_checkpoint, map_location=self.device)
        self.load_weights(self.model, state_dict["model"])
        
        if self.load_prev_scheduler:
            state_dict.pop("model")
            self.prev_schechuler = state_dict
        
    def save_checkpoint(self, save_path: str, save_info: dict = None, save_weights_only: bool = True) -> None:
        """

        Save model to save_path

        Args:

            save_path: Path to save model

            save_info: Other info to save

            save_weights_only: Whether only save model weights

        """
        dir = os.path.dirname(save_path)
        os.makedirs(dir, exist_ok=True)
        
        state_dict = {} if save_info is None else save_info
        state_dict["model"] = self.model.state_dict()
        
        # Convert model weights to fp32
        for k, v in state_dict["model"].items():
            state_dict["model"][k] = v.float()
            
        if not save_weights_only:
            state_dict["global_step"] = self.step
            state_dict["epoch"] = self.epoch
            state_dict["best_value"] = getattr(self, f"best_value", None)
            state_dict["lr_scheduler"] = self.lr_schedulers().state_dict()
            
            # If not using DeepSpeed, save optimizer state
            if not hasattr(self.trainer.strategy, "deepspeed_engine"):
                state_dict["optimizer"] = self.optimizers().optimizer.state_dict()

        torch.save(state_dict, save_path)

    def check_save_condition(self, now_value: float, mode: str, save_info: dict = None) -> None:
        """

        Check whether to save model. If save_path is not None and now_value is the best, save model.

        Args:

            now_value: Current metric value

            mode: "min" or "max", meaning whether the lower the better or the higher the better

            save_info: Other info to save

        """

        assert mode in ["min", "max"], "mode should be 'min' or 'max'"

        if self.save_path is not None:
            # In case there are variables to be included in the save path
            save_path = eval(f"f'{self.save_path}'")
            
            dir = os.path.dirname(save_path)
            os.makedirs(dir, exist_ok=True)
            
            # Check whether to save model
            best_value = getattr(self, f"best_value", None)
            if best_value is not None:
                if mode == "min" and now_value >= best_value or mode == "max" and now_value <= best_value:
                    return
                
            setattr(self, "best_value", now_value)
                
            # For DeepSpeed strategy
            if hasattr(self.trainer.strategy, "deepspeed_engine"):
                if not self.save_weights_only:
                    self.trainer.strategy.deepspeed_engine.save_checkpoint(save_path, tag="deepspeed_ckpt")
                
                # Save a complete checkpoint
                if dist.get_rank() == 0:
                    basename = os.path.basename(save_path)
                    ckpt_path = os.path.join(save_path, f"{basename}.pt")
                    self.save_checkpoint(ckpt_path, save_info, self.save_weights_only)
            
            # For normal situation
            else:
                if dist.get_rank() == 0:
                    self.save_checkpoint(save_path, save_info, self.save_weights_only)
            
    def reset_metrics(self, stage) -> None:
        """

        Reset metrics for given stage

        Args:

            stage: "train", "valid" or "test"

        """
        for metric in self.metrics[stage].values():
            metric.reset()
    
    def get_log_dict(self, stage: str) -> dict:
        """

        Get log dict for the stage

        Args:

            stage: "train", "valid" or "test"



        Returns:

            A dictionary of metrics for the stage. Keys are metric names and values are metric values



        """
        return {name: metric.compute() for name, metric in self.metrics[stage].items()}
    
    def log_info(self, info: dict) -> None:
        """

        Record metrics during training and testing

        Args:

            info: dict of metrics

        """
        if getattr(self, "logger", None) is not None and dist.get_rank() == 0:
            info["learning_rate"] = self.lr_scheduler.get_last_lr()[0]
            info["epoch"] = self.epoch
            self.logger.log_metrics(info, step=self.step)

    def init_optimizers(self):
        copy_optimizer_kwargs = copy.deepcopy(self.optimizer_kwargs)
        
        # No decay for layer norm and bias
        no_decay = ['LayerNorm.weight', 'bias']
        weight_decay = copy_optimizer_kwargs.pop("weight_decay")

        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]

        optimizer_cls = eval(f"torch.optim.{copy_optimizer_kwargs.pop('class')}")
        self.optimizer = optimizer_cls(optimizer_grouped_parameters,
                                       lr=self.lr_scheduler_kwargs['init_lr'],
                                       **copy_optimizer_kwargs)

        tmp_kwargs = copy.deepcopy(self.lr_scheduler_kwargs)
        lr_scheduler = tmp_kwargs.pop("class")
        self.lr_scheduler = eval(lr_scheduler)(self.optimizer, **tmp_kwargs)
    
    def configure_optimizers(self):
        return {"optimizer": self.optimizer,
                "lr_scheduler": {"scheduler": self.lr_scheduler,
                                 "interval": "step",
                                 "frequency": 1}
                }