|
Tutorial: Simple LSTM |
|
===================== |
|
|
|
In this tutorial we will extend fairseq by adding a new |
|
:class:`~fairseq.models.FairseqEncoderDecoderModel` that encodes a source |
|
sentence with an LSTM and then passes the final hidden state to a second LSTM |
|
that decodes the target sentence (without attention). |
|
|
|
This tutorial covers: |
|
|
|
1. **Writing an Encoder and Decoder** to encode/decode the source/target |
|
sentence, respectively. |
|
2. **Registering a new Model** so that it can be used with the existing |
|
:ref:`Command-line tools`. |
|
3. **Training the Model** using the existing command-line tools. |
|
4. **Making generation faster** by modifying the Decoder to use |
|
:ref:`Incremental decoding`. |
|
|
|
|
|
1. Building an Encoder and Decoder |
|
---------------------------------- |
|
|
|
In this section we'll define a simple LSTM Encoder and Decoder. All Encoders |
|
should implement the :class:`~fairseq.models.FairseqEncoder` interface and |
|
Decoders should implement the :class:`~fairseq.models.FairseqDecoder` interface. |
|
These interfaces themselves extend :class:`torch.nn.Module`, so FairseqEncoders |
|
and FairseqDecoders can be written and used in the same ways as ordinary PyTorch |
|
Modules. |
|
|
|
|
|
Encoder |
|
~~~~~~~ |
|
|
|
Our Encoder will embed the tokens in the source sentence, feed them to a |
|
:class:`torch.nn.LSTM` and return the final hidden state. To create our encoder |
|
save the following in a new file named :file:`fairseq/models/simple_lstm.py`:: |
|
|
|
import torch.nn as nn |
|
from fairseq import utils |
|
from fairseq.models import FairseqEncoder |
|
|
|
class SimpleLSTMEncoder(FairseqEncoder): |
|
|
|
def __init__( |
|
self, args, dictionary, embed_dim=128, hidden_dim=128, dropout=0.1, |
|
): |
|
super().__init__(dictionary) |
|
self.args = args |
|
|
|
# Our encoder will embed the inputs before feeding them to the LSTM. |
|
self.embed_tokens = nn.Embedding( |
|
num_embeddings=len(dictionary), |
|
embedding_dim=embed_dim, |
|
padding_idx=dictionary.pad(), |
|
) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
# We'll use a single-layer, unidirectional LSTM for simplicity. |
|
self.lstm = nn.LSTM( |
|
input_size=embed_dim, |
|
hidden_size=hidden_dim, |
|
num_layers=1, |
|
bidirectional=False, |
|
batch_first=True, |
|
) |
|
|
|
def forward(self, src_tokens, src_lengths): |
|
# The inputs to the ``forward()`` function are determined by the |
|
# Task, and in particular the ``'net_input'`` key in each |
|
# mini-batch. We discuss Tasks in the next tutorial, but for now just |
|
# know that *src_tokens* has shape `(batch, src_len)` and *src_lengths* |
|
# has shape `(batch)`. |
|
|
|
# Note that the source is typically padded on the left. This can be |
|
# configured by adding the `--left-pad-source "False"` command-line |
|
# argument, but here we'll make the Encoder handle either kind of |
|
# padding by converting everything to be right-padded. |
|
if self.args.left_pad_source: |
|
# Convert left-padding to right-padding. |
|
src_tokens = utils.convert_padding_direction( |
|
src_tokens, |
|
padding_idx=self.dictionary.pad(), |
|
left_to_right=True |
|
) |
|
|
|
# Embed the source. |
|
x = self.embed_tokens(src_tokens) |
|
|
|
# Apply dropout. |
|
x = self.dropout(x) |
|
|
|
# Pack the sequence into a PackedSequence object to feed to the LSTM. |
|
x = nn.utils.rnn.pack_padded_sequence(x, src_lengths, batch_first=True) |
|
|
|
# Get the output from the LSTM. |
|
_outputs, (final_hidden, _final_cell) = self.lstm(x) |
|
|
|
# Return the Encoder's output. This can be any object and will be |
|
# passed directly to the Decoder. |
|
return { |
|
# this will have shape `(bsz, hidden_dim)` |
|
'final_hidden': final_hidden.squeeze(0), |
|
} |
|
|
|
# Encoders are required to implement this method so that we can rearrange |
|
# the order of the batch elements during inference (e.g., beam search). |
|
def reorder_encoder_out(self, encoder_out, new_order): |
|
""" |
|
Reorder encoder output according to `new_order`. |
|
|
|
Args: |
|
encoder_out: output from the ``forward()`` method |
|
new_order (LongTensor): desired order |
|
|
|
Returns: |
|
`encoder_out` rearranged according to `new_order` |
|
""" |
|
final_hidden = encoder_out['final_hidden'] |
|
return { |
|
'final_hidden': final_hidden.index_select(0, new_order), |
|
} |
|
|
|
|
|
Decoder |
|
~~~~~~~ |
|
|
|
Our Decoder will predict the next word, conditioned on the Encoder's final |
|
hidden state and an embedded representation of the previous target word -- which |
|
is sometimes called *teacher forcing*. More specifically, we'll use a |
|
:class:`torch.nn.LSTM` to produce a sequence of hidden states that we'll project |
|
to the size of the output vocabulary to predict each target word. |
|
|
|
:: |
|
|
|
import torch |
|
from fairseq.models import FairseqDecoder |
|
|
|
class SimpleLSTMDecoder(FairseqDecoder): |
|
|
|
def __init__( |
|
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, |
|
dropout=0.1, |
|
): |
|
super().__init__(dictionary) |
|
|
|
# Our decoder will embed the inputs before feeding them to the LSTM. |
|
self.embed_tokens = nn.Embedding( |
|
num_embeddings=len(dictionary), |
|
embedding_dim=embed_dim, |
|
padding_idx=dictionary.pad(), |
|
) |
|
self.dropout = nn.Dropout(p=dropout) |
|
|
|
# We'll use a single-layer, unidirectional LSTM for simplicity. |
|
self.lstm = nn.LSTM( |
|
# For the first layer we'll concatenate the Encoder's final hidden |
|
# state with the embedded target tokens. |
|
input_size=encoder_hidden_dim + embed_dim, |
|
hidden_size=hidden_dim, |
|
num_layers=1, |
|
bidirectional=False, |
|
) |
|
|
|
# Define the output projection. |
|
self.output_projection = nn.Linear(hidden_dim, len(dictionary)) |
|
|
|
# During training Decoders are expected to take the entire target sequence |
|
# (shifted right by one position) and produce logits over the vocabulary. |
|
# The *prev_output_tokens* tensor begins with the end-of-sentence symbol, |
|
# ``dictionary.eos()``, followed by the target sequence. |
|
def forward(self, prev_output_tokens, encoder_out): |
|
""" |
|
Args: |
|
prev_output_tokens (LongTensor): previous decoder outputs of shape |
|
`(batch, tgt_len)`, for teacher forcing |
|
encoder_out (Tensor, optional): output from the encoder, used for |
|
encoder-side attention |
|
|
|
Returns: |
|
tuple: |
|
- the last decoder layer's output of shape |
|
`(batch, tgt_len, vocab)` |
|
- the last decoder layer's attention weights of shape |
|
`(batch, tgt_len, src_len)` |
|
""" |
|
bsz, tgt_len = prev_output_tokens.size() |
|
|
|
# Extract the final hidden state from the Encoder. |
|
final_encoder_hidden = encoder_out['final_hidden'] |
|
|
|
# Embed the target sequence, which has been shifted right by one |
|
# position and now starts with the end-of-sentence symbol. |
|
x = self.embed_tokens(prev_output_tokens) |
|
|
|
# Apply dropout. |
|
x = self.dropout(x) |
|
|
|
# Concatenate the Encoder's final hidden state to *every* embedded |
|
# target token. |
|
x = torch.cat( |
|
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], |
|
dim=2, |
|
) |
|
|
|
# Using PackedSequence objects in the Decoder is harder than in the |
|
# Encoder, since the targets are not sorted in descending length order, |
|
# which is a requirement of ``pack_padded_sequence()``. Instead we'll |
|
# feed nn.LSTM directly. |
|
initial_state = ( |
|
final_encoder_hidden.unsqueeze(0), # hidden |
|
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell |
|
) |
|
output, _ = self.lstm( |
|
x.transpose(0, 1), # convert to shape `(tgt_len, bsz, dim)` |
|
initial_state, |
|
) |
|
x = output.transpose(0, 1) # convert to shape `(bsz, tgt_len, hidden)` |
|
|
|
# Project the outputs to the size of the vocabulary. |
|
x = self.output_projection(x) |
|
|
|
# Return the logits and ``None`` for the attention weights |
|
return x, None |
|
|
|
|
|
2. Registering the Model |
|
------------------------ |
|
|
|
Now that we've defined our Encoder and Decoder we must *register* our model with |
|
fairseq using the :func:`~fairseq.models.register_model` function decorator. |
|
Once the model is registered we'll be able to use it with the existing |
|
:ref:`Command-line Tools`. |
|
|
|
All registered models must implement the |
|
:class:`~fairseq.models.BaseFairseqModel` interface. For sequence-to-sequence |
|
models (i.e., any model with a single Encoder and Decoder), we can instead |
|
implement the :class:`~fairseq.models.FairseqEncoderDecoderModel` interface. |
|
|
|
Create a small wrapper class in the same file and register it in fairseq with |
|
the name ``'simple_lstm'``:: |
|
|
|
from fairseq.models import FairseqEncoderDecoderModel, register_model |
|
|
|
# Note: the register_model "decorator" should immediately precede the |
|
# definition of the Model class. |
|
|
|
@register_model('simple_lstm') |
|
class SimpleLSTMModel(FairseqEncoderDecoderModel): |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
# Models can override this method to add new command-line arguments. |
|
# Here we'll add some new command-line arguments to configure dropout |
|
# and the dimensionality of the embeddings and hidden states. |
|
parser.add_argument( |
|
'--encoder-embed-dim', type=int, metavar='N', |
|
help='dimensionality of the encoder embeddings', |
|
) |
|
parser.add_argument( |
|
'--encoder-hidden-dim', type=int, metavar='N', |
|
help='dimensionality of the encoder hidden state', |
|
) |
|
parser.add_argument( |
|
'--encoder-dropout', type=float, default=0.1, |
|
help='encoder dropout probability', |
|
) |
|
parser.add_argument( |
|
'--decoder-embed-dim', type=int, metavar='N', |
|
help='dimensionality of the decoder embeddings', |
|
) |
|
parser.add_argument( |
|
'--decoder-hidden-dim', type=int, metavar='N', |
|
help='dimensionality of the decoder hidden state', |
|
) |
|
parser.add_argument( |
|
'--decoder-dropout', type=float, default=0.1, |
|
help='decoder dropout probability', |
|
) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
# Fairseq initializes models by calling the ``build_model()`` |
|
# function. This provides more flexibility, since the returned model |
|
# instance can be of a different type than the one that was called. |
|
# In this case we'll just return a SimpleLSTMModel instance. |
|
|
|
# Initialize our Encoder and Decoder. |
|
encoder = SimpleLSTMEncoder( |
|
args=args, |
|
dictionary=task.source_dictionary, |
|
embed_dim=args.encoder_embed_dim, |
|
hidden_dim=args.encoder_hidden_dim, |
|
dropout=args.encoder_dropout, |
|
) |
|
decoder = SimpleLSTMDecoder( |
|
dictionary=task.target_dictionary, |
|
encoder_hidden_dim=args.encoder_hidden_dim, |
|
embed_dim=args.decoder_embed_dim, |
|
hidden_dim=args.decoder_hidden_dim, |
|
dropout=args.decoder_dropout, |
|
) |
|
model = SimpleLSTMModel(encoder, decoder) |
|
|
|
# Print the model architecture. |
|
print(model) |
|
|
|
return model |
|
|
|
# We could override the ``forward()`` if we wanted more control over how |
|
# the encoder and decoder interact, but it's not necessary for this |
|
# tutorial since we can inherit the default implementation provided by |
|
# the FairseqEncoderDecoderModel base class, which looks like: |
|
# |
|
# def forward(self, src_tokens, src_lengths, prev_output_tokens): |
|
# encoder_out = self.encoder(src_tokens, src_lengths) |
|
# decoder_out = self.decoder(prev_output_tokens, encoder_out) |
|
# return decoder_out |
|
|
|
Finally let's define a *named architecture* with the configuration for our |
|
model. This is done with the :func:`~fairseq.models.register_model_architecture` |
|
function decorator. Thereafter this named architecture can be used with the |
|
``--arch`` command-line argument, e.g., ``--arch tutorial_simple_lstm``:: |
|
|
|
from fairseq.models import register_model_architecture |
|
|
|
# The first argument to ``register_model_architecture()`` should be the name |
|
# of the model we registered above (i.e., 'simple_lstm'). The function we |
|
# register here should take a single argument *args* and modify it in-place |
|
# to match the desired architecture. |
|
|
|
@register_model_architecture('simple_lstm', 'tutorial_simple_lstm') |
|
def tutorial_simple_lstm(args): |
|
# We use ``getattr()`` to prioritize arguments that are explicitly given |
|
# on the command-line, so that the defaults defined below are only used |
|
# when no other value has been specified. |
|
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256) |
|
args.encoder_hidden_dim = getattr(args, 'encoder_hidden_dim', 256) |
|
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256) |
|
args.decoder_hidden_dim = getattr(args, 'decoder_hidden_dim', 256) |
|
|
|
|
|
3. Training the Model |
|
--------------------- |
|
|
|
Now we're ready to train the model. We can use the existing :ref:`fairseq-train` |
|
command-line tool for this, making sure to specify our new Model architecture |
|
(``--arch tutorial_simple_lstm``). |
|
|
|
.. note:: |
|
|
|
Make sure you've already preprocessed the data from the IWSLT example in the |
|
:file:`examples/translation/` directory. |
|
|
|
.. code-block:: console |
|
|
|
> fairseq-train data-bin/iwslt14.tokenized.de-en \ |
|
--arch tutorial_simple_lstm \ |
|
--encoder-dropout 0.2 --decoder-dropout 0.2 \ |
|
--optimizer adam --lr 0.005 --lr-shrink 0.5 \ |
|
--max-tokens 12000 |
|
(...) |
|
| epoch 052 | loss 4.027 | ppl 16.30 | wps 420805 | ups 39.7 | wpb 9841 | bsz 400 | num_updates 20852 | lr 1.95313e-05 | gnorm 0.218 | clip 0% | oom 0 | wall 529 | train_wall 396 |
|
| epoch 052 | valid on 'valid' subset | valid_loss 4.74989 | valid_ppl 26.91 | num_updates 20852 | best 4.74954 |
|
|
|
The model files should appear in the :file:`checkpoints/` directory. While this |
|
model architecture is not very good, we can use the :ref:`fairseq-generate` script to |
|
generate translations and compute our BLEU score over the test set: |
|
|
|
.. code-block:: console |
|
|
|
> fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
|
--path checkpoints/checkpoint_best.pt \ |
|
--beam 5 \ |
|
--remove-bpe |
|
(...) |
|
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) |
|
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
|
|
|
|
|
4. Making generation faster |
|
--------------------------- |
|
|
|
While autoregressive generation from sequence-to-sequence models is inherently |
|
slow, our implementation above is especially slow because it recomputes the |
|
entire sequence of Decoder hidden states for every output token (i.e., it is |
|
``O(n^2)``). We can make this significantly faster by instead caching the |
|
previous hidden states. |
|
|
|
In fairseq this is called :ref:`Incremental decoding`. Incremental decoding is a |
|
special mode at inference time where the Model only receives a single timestep |
|
of input corresponding to the immediately previous output token (for teacher |
|
forcing) and must produce the next output incrementally. Thus the model must |
|
cache any long-term state that is needed about the sequence, e.g., hidden |
|
states, convolutional states, etc. |
|
|
|
To implement incremental decoding we will modify our model to implement the |
|
:class:`~fairseq.models.FairseqIncrementalDecoder` interface. Compared to the |
|
standard :class:`~fairseq.models.FairseqDecoder` interface, the incremental |
|
decoder interface allows ``forward()`` methods to take an extra keyword argument |
|
(*incremental_state*) that can be used to cache state across time-steps. |
|
|
|
Let's replace our ``SimpleLSTMDecoder`` with an incremental one:: |
|
|
|
import torch |
|
from fairseq.models import FairseqIncrementalDecoder |
|
|
|
class SimpleLSTMDecoder(FairseqIncrementalDecoder): |
|
|
|
def __init__( |
|
self, dictionary, encoder_hidden_dim=128, embed_dim=128, hidden_dim=128, |
|
dropout=0.1, |
|
): |
|
# This remains the same as before. |
|
super().__init__(dictionary) |
|
self.embed_tokens = nn.Embedding( |
|
num_embeddings=len(dictionary), |
|
embedding_dim=embed_dim, |
|
padding_idx=dictionary.pad(), |
|
) |
|
self.dropout = nn.Dropout(p=dropout) |
|
self.lstm = nn.LSTM( |
|
input_size=encoder_hidden_dim + embed_dim, |
|
hidden_size=hidden_dim, |
|
num_layers=1, |
|
bidirectional=False, |
|
) |
|
self.output_projection = nn.Linear(hidden_dim, len(dictionary)) |
|
|
|
# We now take an additional kwarg (*incremental_state*) for caching the |
|
# previous hidden and cell states. |
|
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): |
|
if incremental_state is not None: |
|
# If the *incremental_state* argument is not ``None`` then we are |
|
# in incremental inference mode. While *prev_output_tokens* will |
|
# still contain the entire decoded prefix, we will only use the |
|
# last step and assume that the rest of the state is cached. |
|
prev_output_tokens = prev_output_tokens[:, -1:] |
|
|
|
# This remains the same as before. |
|
bsz, tgt_len = prev_output_tokens.size() |
|
final_encoder_hidden = encoder_out['final_hidden'] |
|
x = self.embed_tokens(prev_output_tokens) |
|
x = self.dropout(x) |
|
x = torch.cat( |
|
[x, final_encoder_hidden.unsqueeze(1).expand(bsz, tgt_len, -1)], |
|
dim=2, |
|
) |
|
|
|
# We will now check the cache and load the cached previous hidden and |
|
# cell states, if they exist, otherwise we will initialize them to |
|
# zeros (as before). We will use the ``utils.get_incremental_state()`` |
|
# and ``utils.set_incremental_state()`` helpers. |
|
initial_state = utils.get_incremental_state( |
|
self, incremental_state, 'prev_state', |
|
) |
|
if initial_state is None: |
|
# first time initialization, same as the original version |
|
initial_state = ( |
|
final_encoder_hidden.unsqueeze(0), # hidden |
|
torch.zeros_like(final_encoder_hidden).unsqueeze(0), # cell |
|
) |
|
|
|
# Run one step of our LSTM. |
|
output, latest_state = self.lstm(x.transpose(0, 1), initial_state) |
|
|
|
# Update the cache with the latest hidden and cell states. |
|
utils.set_incremental_state( |
|
self, incremental_state, 'prev_state', latest_state, |
|
) |
|
|
|
# This remains the same as before |
|
x = output.transpose(0, 1) |
|
x = self.output_projection(x) |
|
return x, None |
|
|
|
# The ``FairseqIncrementalDecoder`` interface also requires implementing a |
|
# ``reorder_incremental_state()`` method, which is used during beam search |
|
# to select and reorder the incremental state. |
|
def reorder_incremental_state(self, incremental_state, new_order): |
|
# Load the cached state. |
|
prev_state = utils.get_incremental_state( |
|
self, incremental_state, 'prev_state', |
|
) |
|
|
|
# Reorder batches according to *new_order*. |
|
reordered_state = ( |
|
prev_state[0].index_select(1, new_order), # hidden |
|
prev_state[1].index_select(1, new_order), # cell |
|
) |
|
|
|
# Update the cached state. |
|
utils.set_incremental_state( |
|
self, incremental_state, 'prev_state', reordered_state, |
|
) |
|
|
|
Finally, we can rerun generation and observe the speedup: |
|
|
|
.. code-block:: console |
|
|
|
# Before |
|
|
|
> fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
|
--path checkpoints/checkpoint_best.pt \ |
|
--beam 5 \ |
|
--remove-bpe |
|
(...) |
|
| Translated 6750 sentences (153132 tokens) in 17.3s (389.12 sentences/s, 8827.68 tokens/s) |
|
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
|
|
|
# After |
|
|
|
> fairseq-generate data-bin/iwslt14.tokenized.de-en \ |
|
--path checkpoints/checkpoint_best.pt \ |
|
--beam 5 \ |
|
--remove-bpe |
|
(...) |
|
| Translated 6750 sentences (153132 tokens) in 5.5s (1225.54 sentences/s, 27802.94 tokens/s) |
|
| Generate test with beam=5: BLEU4 = 8.18, 38.8/12.1/4.7/2.0 (BP=1.000, ratio=1.066, syslen=139865, reflen=131146) |
|
|