|
"""Transformer language model.""" |
|
|
|
from typing import Any |
|
from typing import List |
|
from typing import Tuple |
|
|
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.lm_interface import LMInterface |
|
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding |
|
from espnet.nets.pytorch_backend.transformer.encoder import Encoder |
|
from espnet.nets.pytorch_backend.transformer.mask import subsequent_mask |
|
from espnet.nets.scorer_interface import BatchScorerInterface |
|
from espnet.utils.cli_utils import strtobool |
|
|
|
|
|
class TransformerLM(nn.Module, LMInterface, BatchScorerInterface): |
|
"""Transformer language model.""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments to command line argument parser.""" |
|
parser.add_argument( |
|
"--layer", type=int, default=4, help="Number of hidden layers" |
|
) |
|
parser.add_argument( |
|
"--unit", |
|
type=int, |
|
default=1024, |
|
help="Number of hidden units in feedforward layer", |
|
) |
|
parser.add_argument( |
|
"--att-unit", |
|
type=int, |
|
default=256, |
|
help="Number of hidden units in attention layer", |
|
) |
|
parser.add_argument( |
|
"--embed-unit", |
|
type=int, |
|
default=128, |
|
help="Number of hidden units in embedding layer", |
|
) |
|
parser.add_argument( |
|
"--head", type=int, default=2, help="Number of multi head attention" |
|
) |
|
parser.add_argument( |
|
"--dropout-rate", type=float, default=0.5, help="dropout probability" |
|
) |
|
parser.add_argument( |
|
"--att-dropout-rate", |
|
type=float, |
|
default=0.0, |
|
help="att dropout probability", |
|
) |
|
parser.add_argument( |
|
"--emb-dropout-rate", |
|
type=float, |
|
default=0.0, |
|
help="emb dropout probability", |
|
) |
|
parser.add_argument( |
|
"--tie-weights", |
|
type=strtobool, |
|
default=False, |
|
help="Tie input and output embeddings", |
|
) |
|
parser.add_argument( |
|
"--pos-enc", |
|
default="sinusoidal", |
|
choices=["sinusoidal", "none"], |
|
help="positional encoding", |
|
) |
|
return parser |
|
|
|
def __init__(self, n_vocab, args): |
|
"""Initialize class. |
|
|
|
Args: |
|
n_vocab (int): The size of the vocabulary |
|
args (argparse.Namespace): configurations. see py:method:`add_arguments` |
|
|
|
""" |
|
nn.Module.__init__(self) |
|
|
|
|
|
emb_dropout_rate = getattr(args, "emb_dropout_rate", 0.0) |
|
|
|
tie_weights = getattr(args, "tie_weights", False) |
|
|
|
att_dropout_rate = getattr(args, "att_dropout_rate", 0.0) |
|
|
|
if args.pos_enc == "sinusoidal": |
|
pos_enc_class = PositionalEncoding |
|
elif args.pos_enc == "none": |
|
|
|
def pos_enc_class(*args, **kwargs): |
|
return nn.Sequential() |
|
|
|
else: |
|
raise ValueError(f"unknown pos-enc option: {args.pos_enc}") |
|
|
|
self.embed = nn.Embedding(n_vocab, args.embed_unit) |
|
|
|
if emb_dropout_rate == 0.0: |
|
self.embed_drop = None |
|
else: |
|
self.embed_drop = nn.Dropout(emb_dropout_rate) |
|
|
|
self.encoder = Encoder( |
|
idim=args.embed_unit, |
|
attention_dim=args.att_unit, |
|
attention_heads=args.head, |
|
linear_units=args.unit, |
|
num_blocks=args.layer, |
|
dropout_rate=args.dropout_rate, |
|
attention_dropout_rate=att_dropout_rate, |
|
input_layer="linear", |
|
pos_enc_class=pos_enc_class, |
|
) |
|
self.decoder = nn.Linear(args.att_unit, n_vocab) |
|
|
|
logging.info("Tie weights set to {}".format(tie_weights)) |
|
logging.info("Dropout set to {}".format(args.dropout_rate)) |
|
logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) |
|
logging.info("Att Dropout set to {}".format(att_dropout_rate)) |
|
|
|
if tie_weights: |
|
assert ( |
|
args.att_unit == args.embed_unit |
|
), "Tie Weights: True need embedding and final dimensions to match" |
|
self.decoder.weight = self.embed.weight |
|
|
|
def _target_mask(self, ys_in_pad): |
|
ys_mask = ys_in_pad != 0 |
|
m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) |
|
return ys_mask.unsqueeze(-2) & m |
|
|
|
def forward( |
|
self, x: torch.Tensor, t: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
"""Compute LM loss value from buffer sequences. |
|
|
|
Args: |
|
x (torch.Tensor): Input ids. (batch, len) |
|
t (torch.Tensor): Target ids. (batch, len) |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of |
|
loss to backward (scalar), |
|
negative log-likelihood of t: -log p(t) (scalar) and |
|
the number of elements in x (scalar) |
|
|
|
Notes: |
|
The last two return values are used |
|
in perplexity: p(t)^{-n} = exp(-log p(t) / n) |
|
|
|
""" |
|
xm = x != 0 |
|
|
|
if self.embed_drop is not None: |
|
emb = self.embed_drop(self.embed(x)) |
|
else: |
|
emb = self.embed(x) |
|
|
|
h, _ = self.encoder(emb, self._target_mask(x)) |
|
y = self.decoder(h) |
|
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") |
|
mask = xm.to(dtype=loss.dtype) |
|
logp = loss * mask.view(-1) |
|
logp = logp.sum() |
|
count = mask.sum() |
|
return logp / count, logp, count |
|
|
|
def score( |
|
self, y: torch.Tensor, state: Any, x: torch.Tensor |
|
) -> Tuple[torch.Tensor, Any]: |
|
"""Score new token. |
|
|
|
Args: |
|
y (torch.Tensor): 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x (torch.Tensor): encoder feature that generates ys. |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: Tuple of |
|
torch.float32 scores for next token (n_vocab) |
|
and next state for ys |
|
|
|
""" |
|
y = y.unsqueeze(0) |
|
|
|
if self.embed_drop is not None: |
|
emb = self.embed_drop(self.embed(y)) |
|
else: |
|
emb = self.embed(y) |
|
|
|
h, _, cache = self.encoder.forward_one_step( |
|
emb, self._target_mask(y), cache=state |
|
) |
|
h = self.decoder(h[:, -1]) |
|
logp = h.log_softmax(dim=-1).squeeze(0) |
|
return logp, cache |
|
|
|
|
|
def batch_score( |
|
self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor |
|
) -> Tuple[torch.Tensor, List[Any]]: |
|
"""Score new token batch (required). |
|
|
|
Args: |
|
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). |
|
states (List[Any]): Scorer states for prefix tokens. |
|
xs (torch.Tensor): |
|
The encoder feature that generates ys (n_batch, xlen, n_feat). |
|
|
|
Returns: |
|
tuple[torch.Tensor, List[Any]]: Tuple of |
|
batchfied scores for next token with shape of `(n_batch, n_vocab)` |
|
and next state list for ys. |
|
|
|
""" |
|
|
|
n_batch = len(ys) |
|
n_layers = len(self.encoder.encoders) |
|
if states[0] is None: |
|
batch_state = None |
|
else: |
|
|
|
batch_state = [ |
|
torch.stack([states[b][i] for b in range(n_batch)]) |
|
for i in range(n_layers) |
|
] |
|
|
|
if self.embed_drop is not None: |
|
emb = self.embed_drop(self.embed(ys)) |
|
else: |
|
emb = self.embed(ys) |
|
|
|
|
|
h, _, states = self.encoder.forward_one_step( |
|
emb, self._target_mask(ys), cache=batch_state |
|
) |
|
h = self.decoder(h[:, -1]) |
|
logp = h.log_softmax(dim=-1) |
|
|
|
|
|
state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)] |
|
return logp, state_list |
|
|