|
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension |
|
|
|
[https://arxiv.org/abs/1910.13461](https://arxiv.org/abs/1910.13461) |
|
|
|
## Introduction |
|
|
|
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details. |
|
|
|
## Pre-trained models |
|
|
|
Model | Description | # params | Download |
|
---|---|---|--- |
|
`bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz) |
|
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz) |
|
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz) |
|
`bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz) |
|
`bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz) |
|
|
|
## Results |
|
|
|
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)** |
|
_(dev set, single model, single-task finetuning)_ |
|
|
|
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B |
|
---|---|---|---|---|---|---|---|--- |
|
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4 |
|
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2 |
|
|
|
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)** |
|
_(dev set, no additional data used)_ |
|
|
|
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1 |
|
---|---|--- |
|
`roberta.large` | 88.9/94.6 | 86.5/89.4 |
|
`bart.large` | 88.8/94.6 | 86.1/89.2 |
|
|
|
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)** |
|
_(test set, no additional data used)_ |
|
|
|
Model | R1 | R2 | RL |
|
---|---|---|--- |
|
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18 |
|
`bart.large` | 44.16 | 21.28 | 40.90 |
|
|
|
## Example usage |
|
|
|
##### Load BART from torch.hub (PyTorch >= 1.1): |
|
```python |
|
import torch |
|
bart = torch.hub.load('pytorch/fairseq', 'bart.large') |
|
bart.eval() # disable dropout (or leave in train mode to finetune) |
|
``` |
|
|
|
##### Load BART (for PyTorch 1.0 or custom models): |
|
```python |
|
# Download bart.large model |
|
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz |
|
tar -xzvf bart.large.tar.gz |
|
|
|
# Load the model in fairseq |
|
from fairseq.models.bart import BARTModel |
|
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt') |
|
bart.eval() # disable dropout (or leave in train mode to finetune) |
|
``` |
|
|
|
##### Apply Byte-Pair Encoding (BPE) to input text: |
|
```python |
|
tokens = bart.encode('Hello world!') |
|
assert tokens.tolist() == [0, 31414, 232, 328, 2] |
|
bart.decode(tokens) # 'Hello world!' |
|
``` |
|
|
|
##### Extract features from BART: |
|
```python |
|
# Extract the last layer's features |
|
last_layer_features = bart.extract_features(tokens) |
|
assert last_layer_features.size() == torch.Size([1, 5, 1024]) |
|
|
|
# Extract all layer's features from decoder (layer 0 is the embedding layer) |
|
all_layers = bart.extract_features(tokens, return_all_hiddens=True) |
|
assert len(all_layers) == 13 |
|
assert torch.all(all_layers[-1] == last_layer_features) |
|
``` |
|
|
|
##### Use BART for sentence-pair classification tasks: |
|
```python |
|
# Download BART already finetuned for MNLI |
|
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') |
|
bart.eval() # disable dropout for evaluation |
|
|
|
# Encode a pair of sentences and make a prediction |
|
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.') |
|
bart.predict('mnli', tokens).argmax() # 0: contradiction |
|
|
|
# Encode another pair of sentences |
|
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.') |
|
bart.predict('mnli', tokens).argmax() # 2: entailment |
|
``` |
|
|
|
##### Register a new (randomly initialized) classification head: |
|
```python |
|
bart.register_classification_head('new_task', num_classes=3) |
|
logprobs = bart.predict('new_task', tokens) |
|
``` |
|
|
|
##### Batched prediction: |
|
```python |
|
import torch |
|
from fairseq.data.data_utils import collate_tokens |
|
|
|
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli') |
|
bart.eval() |
|
|
|
batch_of_pairs = [ |
|
['BART is a seq2seq model.', 'BART is not sequence to sequence.'], |
|
['BART is denoising autoencoder.', 'BART is version of autoencoder.'], |
|
] |
|
|
|
batch = collate_tokens( |
|
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1 |
|
) |
|
|
|
logprobs = bart.predict('mnli', batch) |
|
print(logprobs.argmax(dim=1)) |
|
# tensor([0, 2]) |
|
``` |
|
|
|
##### Using the GPU: |
|
```python |
|
bart.cuda() |
|
bart.predict('new_task', tokens) |
|
``` |
|
|
|
#### Filling masks: |
|
|
|
BART can be used to fill multiple `<mask>` tokens in the input. |
|
```python |
|
bart = torch.hub.load('pytorch/fairseq', 'bart.base') |
|
bart.eval() |
|
bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10) |
|
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))]] |
|
``` |
|
|
|
Note that by default we enforce the output length to match the input length. |
|
This can be disabled by setting ``match_source_len=False``: |
|
``` |
|
bart.fill_mask(['The cat <mask> on the <mask>.'], topk=3, beam=10, match_source_len=False) |
|
# [[('The cat was on the ground.', tensor(-0.6185)), ('The cat was asleep on the couch.', tensor(-0.6276)), ('The cat was on the floor.', tensor(-0.6800))]] |
|
``` |
|
|
|
Example code to fill masks for a batch of sentences using GPU |
|
``` |
|
bart.cuda() |
|
bart.fill_mask(['The cat <mask> on the <mask>.', 'The dog <mask> on the <mask>.'], topk=3, beam=10) |
|
# [[('The cat was on the ground.', tensor(-0.6183)), ('The cat was on the floor.', tensor(-0.6798)), ('The cat sleeps on the couch.', tensor(-0.6830))], [('The dog was on the ground.', tensor(-0.6190)), ('The dog lay on the ground.', tensor(-0.6711)), |
|
('The dog was asleep on the couch', tensor(-0.6796))]] |
|
``` |
|
|
|
#### Evaluating the `bart.large.mnli` model: |
|
|
|
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set. |
|
```python |
|
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'} |
|
ncorrect, nsamples = 0, 0 |
|
bart.cuda() |
|
bart.eval() |
|
with open('glue_data/MNLI/dev_matched.tsv') as fin: |
|
fin.readline() |
|
for index, line in enumerate(fin): |
|
tokens = line.strip().split('\t') |
|
sent1, sent2, target = tokens[8], tokens[9], tokens[-1] |
|
tokens = bart.encode(sent1, sent2) |
|
prediction = bart.predict('mnli', tokens).argmax().item() |
|
prediction_label = label_map[prediction] |
|
ncorrect += int(prediction_label == target) |
|
nsamples += 1 |
|
print('| Accuracy: ', float(ncorrect)/float(nsamples)) |
|
# Expected output: 0.9010 |
|
``` |
|
|
|
#### Evaluating the `bart.large.cnn` model: |
|
- Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample. |
|
- For simpler preprocessing, you can also `wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz`, although there is no guarantee of identical scores |
|
- `huggingface/transformers` has a simpler interface that supports [single-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_eval.py) and [multi-gpu](https://github.com/huggingface/transformers/blob/master/examples/legacy/seq2seq/run_distributed_eval.py) beam search. |
|
In `huggingface/transformers`, the BART models' paths are `facebook/bart-large-cnn` and `facebook/bart-large-xsum`. |
|
|
|
In `fairseq`, summaries can be generated using: |
|
|
|
```bash |
|
cp data-bin/cnn_dm/dict.source.txt checkpoints/ |
|
python examples/bart/summarize.py \ |
|
--model-dir pytorch/fairseq \ |
|
--model-file bart.large.cnn \ |
|
--src cnn_dm/test.source \ |
|
--out cnn_dm/test.hypo |
|
``` |
|
|
|
For calculating rouge, install `files2rouge` from [here](https://github.com/pltrdy/files2rouge). |
|
|
|
```bash |
|
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar |
|
|
|
# Tokenize hypothesis and target files. |
|
cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized |
|
cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target |
|
files2rouge test.hypo.tokenized test.hypo.target |
|
# Expected output: (ROUGE-2 Average_F: 0.21238) |
|
``` |
|
|
|
|
|
## Finetuning |
|
|
|
- [Finetuning on GLUE](README.glue.md) |
|
- [Finetuning on CNN-DM](README.summarization.md) |
|
|
|
## Citation |
|
|
|
```bibtex |
|
@article{lewis2019bart, |
|
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural |
|
Language Generation, Translation, and Comprehension}, |
|
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and |
|
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov |
|
and Luke Zettlemoyer }, |
|
journal={arXiv preprint arXiv:1910.13461}, |
|
year = {2019}, |
|
} |
|
``` |
|
|