|
Tutorial: Classifying Names with a Character-Level RNN |
|
====================================================== |
|
|
|
In this tutorial we will extend fairseq to support *classification* tasks. In |
|
particular we will re-implement the PyTorch tutorial for `Classifying Names with |
|
a Character-Level RNN <https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html>`_ |
|
in fairseq. It is recommended to quickly skim that tutorial before beginning |
|
this one. |
|
|
|
This tutorial covers: |
|
|
|
1. **Preprocessing the data** to create dictionaries. |
|
2. **Registering a new Model** that encodes an input sentence with a simple RNN |
|
and predicts the output label. |
|
3. **Registering a new Task** that loads our dictionaries and dataset. |
|
4. **Training the Model** using the existing command-line tools. |
|
5. **Writing an evaluation script** that imports fairseq and allows us to |
|
interactively evaluate our model on new inputs. |
|
|
|
|
|
1. Preprocessing the data |
|
|
|
|
|
The original tutorial provides raw data, but we'll work with a modified version |
|
of the data that is already tokenized into characters and split into separate |
|
train, valid and test sets. |
|
|
|
Download and extract the data from here: |
|
`tutorial_names.tar.gz <https://dl.fbaipublicfiles.com/fairseq/data/tutorial_names.tar.gz>`_ |
|
|
|
Once extracted, let's preprocess the data using the :ref:`fairseq-preprocess` |
|
command-line tool to create the dictionaries. While this tool is primarily |
|
intended for sequence-to-sequence problems, we're able to reuse it here by |
|
treating the label as a "target" sequence of length 1. We'll also output the |
|
preprocessed files in "raw" format using the `` |
|
enhance readability: |
|
|
|
.. code-block:: console |
|
|
|
> fairseq-preprocess \ |
|
|
|
|
|
|
|
|
|
After running the above command you should see a new directory, |
|
:file:`names-bin/`, containing the dictionaries for *inputs* and *labels*. |
|
|
|
|
|
2. Registering a new Model |
|
|
|
|
|
Next we'll register a new model in fairseq that will encode an input sentence |
|
with a simple RNN and predict the output label. Compared to the original PyTorch |
|
tutorial, our version will also work with batches of data and GPU Tensors. |
|
|
|
First let's copy the simple RNN module implemented in the `PyTorch tutorial |
|
<https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html#creating-the-network>`_. |
|
Create a new file named :file:`fairseq/models/rnn_classifier.py` with the |
|
following contents:: |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
class RNN(nn.Module): |
|
|
|
def __init__(self, input_size, hidden_size, output_size): |
|
super(RNN, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
|
|
self.i2h = nn.Linear(input_size + hidden_size, hidden_size) |
|
self.i2o = nn.Linear(input_size + hidden_size, output_size) |
|
self.softmax = nn.LogSoftmax(dim=1) |
|
|
|
def forward(self, input, hidden): |
|
combined = torch.cat((input, hidden), 1) |
|
hidden = self.i2h(combined) |
|
output = self.i2o(combined) |
|
output = self.softmax(output) |
|
return output, hidden |
|
|
|
def initHidden(self): |
|
return torch.zeros(1, self.hidden_size) |
|
|
|
We must also *register* this model with fairseq using the |
|
:func:`~fairseq.models.register_model` function decorator. Once the model is |
|
registered we'll be able to use it with the existing :ref:`Command-line Tools`. |
|
|
|
All registered models must implement the :class:`~fairseq.models.BaseFairseqModel` |
|
interface, so we'll create a small wrapper class in the same file and register |
|
it in fairseq with the name ``'rnn_classifier'``:: |
|
|
|
from fairseq.models import BaseFairseqModel, register_model |
|
|
|
# Note: the register_model "decorator" should immediately precede the |
|
# definition of the Model class. |
|
|
|
@register_model('rnn_classifier') |
|
class FairseqRNNClassifier(BaseFairseqModel): |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
# Models can override this method to add new command-line arguments. |
|
# Here we'll add a new command-line argument to configure the |
|
# dimensionality of the hidden state. |
|
parser.add_argument( |
|
' |
|
help='dimensionality of the hidden state', |
|
) |
|
|
|
@classmethod |
|
def build_model(cls, args, task): |
|
# Fairseq initializes models by calling the ``build_model()`` |
|
# function. This provides more flexibility, since the returned model |
|
# instance can be of a different type than the one that was called. |
|
# In this case we'll just return a FairseqRNNClassifier instance. |
|
|
|
# Initialize our RNN module |
|
rnn = RNN( |
|
# We'll define the Task in the next section, but for now just |
|
# notice that the task holds the dictionaries for the "source" |
|
# (i.e., the input sentence) and "target" (i.e., the label). |
|
input_size=len(task.source_dictionary), |
|
hidden_size=args.hidden_dim, |
|
output_size=len(task.target_dictionary), |
|
) |
|
|
|
# Return the wrapped version of the module |
|
return FairseqRNNClassifier( |
|
rnn=rnn, |
|
input_vocab=task.source_dictionary, |
|
) |
|
|
|
def __init__(self, rnn, input_vocab): |
|
super(FairseqRNNClassifier, self).__init__() |
|
|
|
self.rnn = rnn |
|
self.input_vocab = input_vocab |
|
|
|
# The RNN module in the tutorial expects one-hot inputs, so we can |
|
# precompute the identity matrix to help convert from indices to |
|
# one-hot vectors. We register it as a buffer so that it is moved to |
|
# the GPU when ``cuda()`` is called. |
|
self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab))) |
|
|
|
def forward(self, src_tokens, src_lengths): |
|
# The inputs to the ``forward()`` function are determined by the |
|
# Task, and in particular the ``'net_input'`` key in each |
|
# mini-batch. We'll define the Task in the next section, but for |
|
# now just know that *src_tokens* has shape `(batch, src_len)` and |
|
# *src_lengths* has shape `(batch)`. |
|
bsz, max_src_len = src_tokens.size() |
|
|
|
# Initialize the RNN hidden state. Compared to the original PyTorch |
|
# tutorial we'll also handle batched inputs and work on the GPU. |
|
hidden = self.rnn.initHidden() |
|
hidden = hidden.repeat(bsz, 1) # expand for batched inputs |
|
hidden = hidden.to(src_tokens.device) # move to GPU |
|
|
|
for i in range(max_src_len): |
|
# WARNING: The inputs have padding, so we should mask those |
|
# elements here so that padding doesn't affect the results. |
|
# This is left as an exercise for the reader. The padding symbol |
|
# is given by ``self.input_vocab.pad()`` and the unpadded length |
|
# of each input is given by *src_lengths*. |
|
|
|
# One-hot encode a batch of input characters. |
|
input = self.one_hot_inputs[src_tokens[:, i].long()] |
|
|
|
# Feed the input to our RNN. |
|
output, hidden = self.rnn(input, hidden) |
|
|
|
# Return the final output state for making a prediction |
|
return output |
|
|
|
Finally let's define a *named architecture* with the configuration for our |
|
model. This is done with the :func:`~fairseq.models.register_model_architecture` |
|
function decorator. Thereafter this named architecture can be used with the |
|
`` |
|
|
|
from fairseq.models import register_model_architecture |
|
|
|
# The first argument to ``register_model_architecture()`` should be the name |
|
# of the model we registered above (i.e., 'rnn_classifier'). The function we |
|
# register here should take a single argument *args* and modify it in-place |
|
# to match the desired architecture. |
|
|
|
@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn') |
|
def pytorch_tutorial_rnn(args): |
|
# We use ``getattr()`` to prioritize arguments that are explicitly given |
|
# on the command-line, so that the defaults defined below are only used |
|
# when no other value has been specified. |
|
args.hidden_dim = getattr(args, 'hidden_dim', 128) |
|
|
|
|
|
3. Registering a new Task |
|
|
|
|
|
Now we'll register a new :class:`~fairseq.tasks.FairseqTask` that will load our |
|
dictionaries and dataset. Tasks can also control how the data is batched into |
|
mini-batches, but in this tutorial we'll reuse the batching provided by |
|
:class:`fairseq.data.LanguagePairDataset`. |
|
|
|
Create a new file named :file:`fairseq/tasks/simple_classification.py` with the |
|
following contents:: |
|
|
|
import os |
|
import torch |
|
|
|
from fairseq.data import Dictionary, LanguagePairDataset |
|
from fairseq.tasks import FairseqTask, register_task |
|
|
|
|
|
@register_task('simple_classification') |
|
class SimpleClassificationTask(LegacyFairseqTask): |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
# Add some command-line arguments for specifying where the data is |
|
# located and the maximum supported input length. |
|
parser.add_argument('data', metavar='FILE', |
|
help='file prefix for data') |
|
parser.add_argument(' |
|
help='max input length') |
|
|
|
@classmethod |
|
def setup_task(cls, args, **kwargs): |
|
# Here we can perform any setup required for the task. This may include |
|
# loading Dictionaries, initializing shared Embedding layers, etc. |
|
# In this case we'll just load the Dictionaries. |
|
input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt')) |
|
label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt')) |
|
print('| [input] dictionary: {} types'.format(len(input_vocab))) |
|
print('| [label] dictionary: {} types'.format(len(label_vocab))) |
|
|
|
return SimpleClassificationTask(args, input_vocab, label_vocab) |
|
|
|
def __init__(self, args, input_vocab, label_vocab): |
|
super().__init__(args) |
|
self.input_vocab = input_vocab |
|
self.label_vocab = label_vocab |
|
|
|
def load_dataset(self, split, **kwargs): |
|
"""Load a given dataset split (e.g., train, valid, test).""" |
|
|
|
prefix = os.path.join(self.args.data, '{}.input-label'.format(split)) |
|
|
|
# Read input sentences. |
|
sentences, lengths = [], [] |
|
with open(prefix + '.input', encoding='utf-8') as file: |
|
for line in file: |
|
sentence = line.strip() |
|
|
|
# Tokenize the sentence, splitting on spaces |
|
tokens = self.input_vocab.encode_line( |
|
sentence, add_if_not_exist=False, |
|
) |
|
|
|
sentences.append(tokens) |
|
lengths.append(tokens.numel()) |
|
|
|
# Read labels. |
|
labels = [] |
|
with open(prefix + '.label', encoding='utf-8') as file: |
|
for line in file: |
|
label = line.strip() |
|
labels.append( |
|
# Convert label to a numeric ID. |
|
torch.LongTensor([self.label_vocab.add_symbol(label)]) |
|
) |
|
|
|
assert len(sentences) == len(labels) |
|
print('| {} {} {} examples'.format(self.args.data, split, len(sentences))) |
|
|
|
# We reuse LanguagePairDataset since classification can be modeled as a |
|
# sequence-to-sequence task where the target sequence has length 1. |
|
self.datasets[split] = LanguagePairDataset( |
|
src=sentences, |
|
src_sizes=lengths, |
|
src_dict=self.input_vocab, |
|
tgt=labels, |
|
tgt_sizes=torch.ones(len(labels)), # targets have length 1 |
|
tgt_dict=self.label_vocab, |
|
left_pad_source=False, |
|
# Since our target is a single class label, there's no need for |
|
# teacher forcing. If we set this to ``True`` then our Model's |
|
# ``forward()`` method would receive an additional argument called |
|
# *prev_output_tokens* that would contain a shifted version of the |
|
# target sequence. |
|
input_feeding=False, |
|
) |
|
|
|
def max_positions(self): |
|
"""Return the max input length allowed by the task.""" |
|
# The source should be less than *args.max_positions* and the "target" |
|
# has max length 1. |
|
return (self.args.max_positions, 1) |
|
|
|
@property |
|
def source_dictionary(self): |
|
"""Return the source :class:`~fairseq.data.Dictionary`.""" |
|
return self.input_vocab |
|
|
|
@property |
|
def target_dictionary(self): |
|
"""Return the target :class:`~fairseq.data.Dictionary`.""" |
|
return self.label_vocab |
|
|
|
# We could override this method if we wanted more control over how batches |
|
# are constructed, but it's not necessary for this tutorial since we can |
|
# reuse the batching provided by LanguagePairDataset. |
|
# |
|
# def get_batch_iterator( |
|
# self, dataset, max_tokens=None, max_sentences=None, max_positions=None, |
|
# ignore_invalid_inputs=False, required_batch_size_multiple=1, |
|
# seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1, |
|
# data_buffer_size=0, disable_iterator_cache=False, |
|
# ): |
|
# (...) |
|
|
|
|
|
4. Training the Model |
|
|
|
|
|
Now we're ready to train the model. We can use the existing :ref:`fairseq-train` |
|
command-line tool for this, making sure to specify our new Task (`` |
|
simple_classification``) and Model architecture (`` |
|
pytorch_tutorial_rnn``): |
|
|
|
.. note:: |
|
|
|
You can also configure the dimensionality of the hidden state by passing the |
|
`` |
|
|
|
.. code-block:: console |
|
|
|
> fairseq-train names-bin \ |
|
|
|
|
|
|
|
|
|
(...) |
|
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21 |
|
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208 |
|
| done training in 31.6 seconds |
|
|
|
The model files should appear in the :file:`checkpoints/` directory. |
|
|
|
|
|
5. Writing an evaluation script |
|
|
|
|
|
Finally we can write a short script to evaluate our model on new inputs. Create |
|
a new file named :file:`eval_classifier.py` with the following contents:: |
|
|
|
from fairseq import checkpoint_utils, data, options, tasks |
|
|
|
# Parse command-line arguments for generation |
|
parser = options.get_generation_parser(default_task='simple_classification') |
|
args = options.parse_args_and_arch(parser) |
|
|
|
# Setup task |
|
task = tasks.setup_task(args) |
|
|
|
# Load model |
|
print('| loading model from {}'.format(args.path)) |
|
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task) |
|
model = models[0] |
|
|
|
while True: |
|
sentence = input('\nInput: ') |
|
|
|
# Tokenize into characters |
|
chars = ' '.join(list(sentence.strip())) |
|
tokens = task.source_dictionary.encode_line( |
|
chars, add_if_not_exist=False, |
|
) |
|
|
|
# Build mini-batch to feed to the model |
|
batch = data.language_pair_dataset.collate( |
|
samples=[{'id': -1, 'source': tokens}], # bsz = 1 |
|
pad_idx=task.source_dictionary.pad(), |
|
eos_idx=task.source_dictionary.eos(), |
|
left_pad_source=False, |
|
input_feeding=False, |
|
) |
|
|
|
# Feed batch to the model and get predictions |
|
preds = model(**batch['net_input']) |
|
|
|
# Print top 3 predictions and their log-probabilities |
|
top_scores, top_labels = preds[0].topk(k=3) |
|
for score, label_idx in zip(top_scores, top_labels): |
|
label_name = task.target_dictionary.string([label_idx]) |
|
print('({:.2f})\t{}'.format(score, label_name)) |
|
|
|
Now we can evaluate our model interactively. Note that we have included the |
|
original data path (:file:`names-bin/`) so that the dictionaries can be loaded: |
|
|
|
.. code-block:: console |
|
|
|
> python eval_classifier.py names-bin |
|
| [input] dictionary: 64 types |
|
| [label] dictionary: 24 types |
|
| loading model from checkpoints/checkpoint_best.pt |
|
|
|
Input: Satoshi |
|
(-0.61) Japanese |
|
(-1.20) Arabic |
|
(-2.86) Italian |
|
|
|
Input: Sinbad |
|
(-0.30) Arabic |
|
(-1.76) English |
|
(-4.08) Russian |
|
|