File size: 2,699 Bytes
295ff14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
"""Pytorch-lightning module for causal language modeling.
"""
__all__ = ("GPT2LitModel",)
import pytorch_lightning as pl
import torch
class GPT2LitModel(pl.LightningModule):
"""Lightning module for autoregressive (causal) transformer language modeling.
Successfully tested on HuggingFace `GPT2LMHeadModel`.
"""
def __init__(self, transformer, batch_size: int, learning_rate: float,
final_learning_rate: float, weight_decay: float, adam_eps: float,
adam_betas: tuple, scheduler_T_max: int,
save_model_every: int = 10_000, checkpoint: str = ""):
super().__init__()
self.save_hyperparameters(ignore=("transformer", "save_model_every",
"checkpoints"))
self.transformer = transformer
self.save_model_every = save_model_every
self.checkpoint = checkpoint or "./gpt2litmodel-logs"
def forward(self, *args, **kwargs):
return self.transformer(*args, **kwargs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
if self.save_model_every > 0 and batch_idx % self.save_model_every == 0:
self.transformer.save_pretrained(self.checkpoint)
return {'loss': outputs['loss']}
def training_epoch_end(self, outputs):
if self.save_model_every > 0:
self.transformer.save_pretrained(self.checkpoint)
losses = [step_output["loss"] for step_output in outputs]
mean_loss = torch.tensor(losses).mean()
ppl = torch.exp(mean_loss)
self.log("ppl", ppl, on_step=False, on_epoch=True, prog_bar=True)
def configure_optimizers(self):
parameters = self.named_parameters()
no_decay = ["bias", "LayerNorm.weight"]
grouped_parameters = [
{"params": [p for n, p in parameters
if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay},
{"params": [p for n, p in parameters
if any(nd in n for nd in no_decay)],
"weight_decay": 0.0}]
optimizer = torch.optim.Adam(
grouped_parameters, lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
eps=self.hparams.adam_eps, betas=self.hparams.adam_betas)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, self.hparams.scheduler_T_max,
eta_min=self.hparams.final_learning_rate)
return {'optimizer': optimizer,
'lr_scheduler': {'scheduler': lr_scheduler,
'interval': 'step', 'frequency': 1}}
|