File size: 2,699 Bytes
295ff14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Pytorch-lightning module for causal language modeling.
"""

__all__ = ("GPT2LitModel",)

import pytorch_lightning as pl
import torch


class GPT2LitModel(pl.LightningModule):
    """Lightning module for autoregressive (causal) transformer language modeling.
    Successfully tested on HuggingFace `GPT2LMHeadModel`.
    """

    def __init__(self, transformer, batch_size: int, learning_rate: float,
                 final_learning_rate: float, weight_decay: float, adam_eps: float,
                 adam_betas: tuple, scheduler_T_max: int,
                 save_model_every: int = 10_000, checkpoint: str = ""):
        super().__init__()
        self.save_hyperparameters(ignore=("transformer", "save_model_every",
                                          "checkpoints"))
        self.transformer = transformer
        self.save_model_every = save_model_every
        self.checkpoint = checkpoint or "./gpt2litmodel-logs"

    def forward(self, *args, **kwargs):
        return self.transformer(*args, **kwargs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)

        if self.save_model_every > 0 and batch_idx % self.save_model_every == 0:
            self.transformer.save_pretrained(self.checkpoint)

        return {'loss': outputs['loss']}

    def training_epoch_end(self, outputs):
        if self.save_model_every > 0:
            self.transformer.save_pretrained(self.checkpoint)

        losses = [step_output["loss"] for step_output in outputs]
        mean_loss = torch.tensor(losses).mean()
        ppl = torch.exp(mean_loss)

        self.log("ppl", ppl, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        parameters = self.named_parameters()
        no_decay = ["bias", "LayerNorm.weight"]
        grouped_parameters = [
            {"params": [p for n, p in parameters
                        if not any(nd in n for nd in no_decay)],
             "weight_decay": self.hparams.weight_decay},
            {"params": [p for n, p in parameters
                        if any(nd in n for nd in no_decay)],
             "weight_decay": 0.0}]
        optimizer = torch.optim.Adam(
            grouped_parameters, lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
            eps=self.hparams.adam_eps, betas=self.hparams.adam_betas)

        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.hparams.scheduler_T_max,
            eta_min=self.hparams.final_learning_rate)

        return {'optimizer': optimizer,
                'lr_scheduler': {'scheduler': lr_scheduler,
                                 'interval': 'step', 'frequency': 1}}