|
"""Sequential implementation of Recurrent Neural Network Language Model.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from espnet.nets.lm_interface import LMInterface |
|
|
|
|
|
class SequentialRNNLM(LMInterface, torch.nn.Module): |
|
"""Sequential RNNLM. |
|
|
|
See also: |
|
https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py |
|
|
|
""" |
|
|
|
@staticmethod |
|
def add_arguments(parser): |
|
"""Add arguments to command line argument parser.""" |
|
parser.add_argument( |
|
"--type", |
|
type=str, |
|
default="lstm", |
|
nargs="?", |
|
choices=["lstm", "gru"], |
|
help="Which type of RNN to use", |
|
) |
|
parser.add_argument( |
|
"--layer", "-l", type=int, default=2, help="Number of hidden layers" |
|
) |
|
parser.add_argument( |
|
"--unit", "-u", type=int, default=650, help="Number of hidden units" |
|
) |
|
parser.add_argument( |
|
"--dropout-rate", type=float, default=0.5, help="dropout probability" |
|
) |
|
return parser |
|
|
|
def __init__(self, n_vocab, args): |
|
"""Initialize class. |
|
|
|
Args: |
|
n_vocab (int): The size of the vocabulary |
|
args (argparse.Namespace): configurations. see py:method:`add_arguments` |
|
|
|
""" |
|
torch.nn.Module.__init__(self) |
|
self._setup( |
|
rnn_type=args.type.upper(), |
|
ntoken=n_vocab, |
|
ninp=args.unit, |
|
nhid=args.unit, |
|
nlayers=args.layer, |
|
dropout=args.dropout_rate, |
|
) |
|
|
|
def _setup( |
|
self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False |
|
): |
|
self.drop = nn.Dropout(dropout) |
|
self.encoder = nn.Embedding(ntoken, ninp) |
|
if rnn_type in ["LSTM", "GRU"]: |
|
self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) |
|
else: |
|
try: |
|
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type] |
|
except KeyError: |
|
raise ValueError( |
|
"An invalid option for `--model` was supplied, " |
|
"options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']" |
|
) |
|
self.rnn = nn.RNN( |
|
ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout |
|
) |
|
self.decoder = nn.Linear(nhid, ntoken) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if tie_weights: |
|
if nhid != ninp: |
|
raise ValueError( |
|
"When using the tied flag, nhid must be equal to emsize" |
|
) |
|
self.decoder.weight = self.encoder.weight |
|
|
|
self._init_weights() |
|
|
|
self.rnn_type = rnn_type |
|
self.nhid = nhid |
|
self.nlayers = nlayers |
|
|
|
def _init_weights(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
for param in self.parameters(): |
|
param.data.uniform_(-0.1, 0.1) |
|
|
|
def forward(self, x, t): |
|
"""Compute LM loss value from buffer sequences. |
|
|
|
Args: |
|
x (torch.Tensor): Input ids. (batch, len) |
|
t (torch.Tensor): Target ids. (batch, len) |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of |
|
loss to backward (scalar), |
|
negative log-likelihood of t: -log p(t) (scalar) and |
|
the number of elements in x (scalar) |
|
|
|
Notes: |
|
The last two return values are used |
|
in perplexity: p(t)^{-n} = exp(-log p(t) / n) |
|
|
|
""" |
|
y = self._before_loss(x, None)[0] |
|
mask = (x != 0).to(y.dtype) |
|
loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") |
|
logp = loss * mask.view(-1) |
|
logp = logp.sum() |
|
count = mask.sum() |
|
return logp / count, logp, count |
|
|
|
def _before_loss(self, input, hidden): |
|
emb = self.drop(self.encoder(input)) |
|
output, hidden = self.rnn(emb, hidden) |
|
output = self.drop(output) |
|
decoded = self.decoder( |
|
output.view(output.size(0) * output.size(1), output.size(2)) |
|
) |
|
return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden |
|
|
|
def init_state(self, x): |
|
"""Get an initial state for decoding. |
|
|
|
Args: |
|
x (torch.Tensor): The encoded feature tensor |
|
|
|
Returns: initial state |
|
|
|
""" |
|
bsz = 1 |
|
weight = next(self.parameters()) |
|
if self.rnn_type == "LSTM": |
|
return ( |
|
weight.new_zeros(self.nlayers, bsz, self.nhid), |
|
weight.new_zeros(self.nlayers, bsz, self.nhid), |
|
) |
|
else: |
|
return weight.new_zeros(self.nlayers, bsz, self.nhid) |
|
|
|
def score(self, y, state, x): |
|
"""Score new token. |
|
|
|
Args: |
|
y (torch.Tensor): 1D torch.int64 prefix tokens. |
|
state: Scorer state for prefix tokens |
|
x (torch.Tensor): 2D encoder feature that generates ys. |
|
|
|
Returns: |
|
tuple[torch.Tensor, Any]: Tuple of |
|
torch.float32 scores for next token (n_vocab) |
|
and next state for ys |
|
|
|
""" |
|
y, new_state = self._before_loss(y[-1].view(1, 1), state) |
|
logp = y.log_softmax(dim=-1).view(-1) |
|
return logp, new_state |
|
|