Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq/examples/hubert/tests/test_feature_and_unit.sh +92 -0
- fairseq/examples/hubert/tests/test_finetuned_asr.sh +46 -0
- fairseq/examples/joint_alignment_translation/README.md +89 -0
- fairseq/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh +118 -0
- fairseq/examples/language_model/README.adaptive_inputs.md +39 -0
- fairseq/examples/language_model/README.conv.md +40 -0
- fairseq/examples/language_model/README.md +123 -0
- fairseq/examples/language_model/prepare-wikitext-103.sh +33 -0
- fairseq/examples/laser/README.md +144 -0
- fairseq/examples/laser/laser_src/__init__.py +8 -0
- fairseq/examples/laser/laser_src/laser_lstm.py +585 -0
- fairseq/examples/laser/laser_src/laser_task.py +334 -0
- fairseq/examples/laser/laser_src/laser_transformer.py +354 -0
- fairseq/examples/laser/laser_src/multitask_data_utils.py +143 -0
- fairseq/examples/latent_depth/README.md +77 -0
- fairseq/examples/latent_depth/latent_depth_src/__init__.py +9 -0
- fairseq/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py +76 -0
- fairseq/examples/latent_depth/latent_depth_src/modules/__init__.py +0 -0
- fairseq/examples/latent_depth/latent_depth_src/modules/latent_layers.py +75 -0
- fairseq/examples/linformer/README.md +22 -0
- fairseq/examples/linformer/linformer_src/__init__.py +6 -0
- fairseq/examples/linformer/linformer_src/models/linformer_roberta.py +120 -0
- fairseq/examples/linformer/linformer_src/modules/__init__.py +0 -0
- fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +54 -0
- fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +65 -0
- fairseq/examples/linformer/linformer_src/modules/multihead_linear_attention.py +481 -0
- fairseq/examples/m2m_100/README.md +241 -0
- fairseq/examples/m2m_100/install_dependecies.sh +78 -0
- fairseq/examples/m2m_100/process_data/clean_histogram.py +52 -0
- fairseq/examples/m2m_100/process_data/dedup_data.py +91 -0
- fairseq/examples/m2m_100/process_data/remove_too_much_punc.py +36 -0
- fairseq/examples/m2m_100/tok.sh +83 -0
- fairseq/examples/m2m_100/tokenizers/README.md +18 -0
- fairseq/examples/m2m_100/tokenizers/seg_ja.sh +11 -0
- fairseq/examples/m2m_100/tokenizers/seg_ko.sh +12 -0
- fairseq/examples/m2m_100/tokenizers/thirdparty/.gitignore +12 -0
- fairseq/examples/m2m_100/tokenizers/tokenize_indic.py +23 -0
- fairseq/examples/m2m_100/tokenizers/tokenize_thai.py +13 -0
- fairseq/examples/m2m_100/tokenizers/tokenize_zh.py +14 -0
- fairseq/examples/m2m_100/tokenizers/tokenizer_ar.sh +27 -0
- fairseq/examples/mbart/README.md +123 -0
- fairseq/examples/megatron_11b/README.md +161 -0
- fairseq/examples/megatron_11b/detok.py +32 -0
- fairseq/examples/mms/MODEL_CARD.md +63 -0
- fairseq/examples/mms/README.md +215 -0
- fairseq/examples/mms/asr/config/infer_common.yaml +32 -0
- fairseq/examples/mms/asr/infer/example_infer_adapter.sh +3 -0
- fairseq/examples/mms/asr/infer/mms_infer.py +63 -0
- fairseq/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb +0 -0
- fairseq/examples/mms/data_prep/README.md +47 -0
fairseq/examples/hubert/tests/test_feature_and_unit.sh
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
sizes="base large xlarge"
|
6 |
+
|
7 |
+
declare -A ckpt_urls
|
8 |
+
ckpt_urls[base]="https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt"
|
9 |
+
ckpt_urls[large]="https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k.pt"
|
10 |
+
ckpt_urls[xlarge]="https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k.pt"
|
11 |
+
|
12 |
+
declare -A km_layers
|
13 |
+
km_layers[base]=9
|
14 |
+
km_layers[large]=20
|
15 |
+
km_layers[xlarge]=30
|
16 |
+
|
17 |
+
declare -A km_urls
|
18 |
+
km_urls[base]="https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960_L9_km500.bin"
|
19 |
+
|
20 |
+
declare -A km_nunits
|
21 |
+
km_nunits[base]=500
|
22 |
+
|
23 |
+
test_dir=./examples/hubert/tests
|
24 |
+
split=sample
|
25 |
+
|
26 |
+
echo -e "${test_dir}\n6313-76958-0021.flac\t190800" > "${test_dir}/${split}.tsv"
|
27 |
+
|
28 |
+
check_feature () {
|
29 |
+
echo "checking features..."
|
30 |
+
|
31 |
+
size=$1
|
32 |
+
ckpt_url=$2
|
33 |
+
km_layer=$3
|
34 |
+
ckpt_path="$test_dir/$(basename "$ckpt_url")"
|
35 |
+
|
36 |
+
if [ ! -f "$ckpt_path" ]; then
|
37 |
+
echo "downloading $ckpt_url to $ckpt_path"
|
38 |
+
wget "$ckpt_url" -O "$ckpt_path"
|
39 |
+
fi
|
40 |
+
|
41 |
+
python ./examples/hubert/simple_kmeans/dump_hubert_feature.py \
|
42 |
+
"${test_dir}" "${split}" "${ckpt_path}" "${km_layer}" 1 0 "${test_dir}"
|
43 |
+
|
44 |
+
if diff -q "${test_dir}/${split}.${size}.L${km_layer}.npy" "${test_dir}/${split}_0_1.npy" &>/dev/null; then
|
45 |
+
echo "...passed npy check"
|
46 |
+
else
|
47 |
+
echo "...failed npy check"
|
48 |
+
fi
|
49 |
+
|
50 |
+
if diff -q "${test_dir}/${split}.${size}.L${km_layer}.len" "${test_dir}/${split}_0_1.len" &>/dev/null; then
|
51 |
+
echo "...passed len check"
|
52 |
+
else
|
53 |
+
echo "...failed len check"
|
54 |
+
fi
|
55 |
+
}
|
56 |
+
|
57 |
+
|
58 |
+
check_unit () {
|
59 |
+
echo "checking units..."
|
60 |
+
|
61 |
+
size=$1
|
62 |
+
km_url=$2
|
63 |
+
km_layer=$3
|
64 |
+
km_nunit=$4
|
65 |
+
km_path="$test_dir/$(basename "$km_url")"
|
66 |
+
|
67 |
+
if [ ! -f "$km_path" ]; then
|
68 |
+
echo "downloading $km_url to $km_path"
|
69 |
+
wget "$km_url" -O "$km_path"
|
70 |
+
fi
|
71 |
+
|
72 |
+
python ./examples/hubert/simple_kmeans/dump_km_label.py \
|
73 |
+
"${test_dir}" "${split}" "${km_path}" 1 0 "${test_dir}"
|
74 |
+
|
75 |
+
if diff -q "${test_dir}/${split}.${size}.L${km_layer}.km${km_nunit}.km" "${test_dir}/${split}_0_1.km" &>/dev/null; then
|
76 |
+
echo "...passed unit check"
|
77 |
+
else
|
78 |
+
echo "...failed unit check"
|
79 |
+
fi
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
for size in $sizes; do
|
84 |
+
echo "=== Running unit test for HuBERT $size ==="
|
85 |
+
check_feature "$size" "${ckpt_urls[$size]}" "${km_layers[$size]}"
|
86 |
+
|
87 |
+
if [ -n "${km_urls[$size]}" ]; then
|
88 |
+
check_unit "$size" "${km_urls[$size]}" "${km_layers[$size]}" "${km_nunits[$size]}"
|
89 |
+
fi
|
90 |
+
|
91 |
+
rm -f $test_dir/${split}_0_1.*
|
92 |
+
done
|
fairseq/examples/hubert/tests/test_finetuned_asr.sh
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
set -e
|
4 |
+
|
5 |
+
sizes="large xlarge"
|
6 |
+
|
7 |
+
declare -A ckpt_urls
|
8 |
+
ckpt_urls[large]="https://dl.fbaipublicfiles.com/hubert/hubert_large_ll60k_finetune_ls960.pt"
|
9 |
+
ckpt_urls[xlarge]="https://dl.fbaipublicfiles.com/hubert/hubert_xtralarge_ll60k_finetune_ls960.pt"
|
10 |
+
|
11 |
+
test_dir=$(pwd)/examples/hubert/tests
|
12 |
+
split=sample
|
13 |
+
|
14 |
+
echo -e "${test_dir}\n6313-76958-0021.flac\t190800" > "${test_dir}/${split}.tsv"
|
15 |
+
echo -e "K E E P | A | G O I N G | A N D | I F | Y O U ' R E | L U C K Y | Y O U ' L L | R U N | P L U M B | I N T O | T H E M | W A S | T H E | J E E R I N G | A N S W E R | A S | T H E | S L E E P Y | C O W M E N | S P U R R E D | T H E I R | P O N I E S | O N | T O W A R D | C A M P | M U T T E R I N G | T H E I R | D I S A P P R O V A L | O F | T A K I N G | A L O N G | A | B U N C H | O F | B O Y S | O N | A | C A T T L E | D R I V E |" > "${test_dir}/${split}.ltr"
|
16 |
+
|
17 |
+
check_asr () {
|
18 |
+
echo "checking asr outputs..."
|
19 |
+
|
20 |
+
size=$1
|
21 |
+
ckpt_url=$2
|
22 |
+
ckpt_path="$test_dir/$(basename "$ckpt_url")"
|
23 |
+
|
24 |
+
if [ ! -f "$ckpt_path" ]; then
|
25 |
+
echo "downloading $ckpt_url to $ckpt_path"
|
26 |
+
wget "$ckpt_url" -O "$ckpt_path"
|
27 |
+
fi
|
28 |
+
|
29 |
+
python examples/speech_recognition/new/infer.py \
|
30 |
+
--config-dir examples/hubert/config/decode --config-name infer_viterbi \
|
31 |
+
common_eval.path="${ckpt_path}" task.data="${test_dir}" task.normalize=true \
|
32 |
+
decoding.results_path="${test_dir}/pred" \
|
33 |
+
common_eval.results_path="${test_dir}/pred" \
|
34 |
+
common_eval.quiet=false dataset.gen_subset="${split}"
|
35 |
+
|
36 |
+
if diff -q "${test_dir}/pred/hypo.word" "${test_dir}/${split}.${size}.hypo.word" &>/dev/null; then
|
37 |
+
echo "...passed word check"
|
38 |
+
else
|
39 |
+
echo "...failed word check"
|
40 |
+
fi
|
41 |
+
rm -rf "${test_dir}/pred"
|
42 |
+
}
|
43 |
+
|
44 |
+
for size in $sizes; do
|
45 |
+
check_asr "$size" "${ckpt_urls[$size]}"
|
46 |
+
done
|
fairseq/examples/joint_alignment_translation/README.md
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)
|
2 |
+
|
3 |
+
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).
|
4 |
+
|
5 |
+
## Training a joint alignment-translation model on WMT'18 En-De
|
6 |
+
|
7 |
+
##### 1. Extract and preprocess the WMT'18 En-De data
|
8 |
+
```bash
|
9 |
+
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
|
13 |
+
In this example, we use FastAlign.
|
14 |
+
```bash
|
15 |
+
git clone [email protected]:clab/fast_align.git
|
16 |
+
pushd fast_align
|
17 |
+
mkdir build
|
18 |
+
cd build
|
19 |
+
cmake ..
|
20 |
+
make
|
21 |
+
popd
|
22 |
+
ALIGN=fast_align/build/fast_align
|
23 |
+
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
|
24 |
+
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
|
25 |
+
```
|
26 |
+
|
27 |
+
##### 3. Preprocess the dataset with the above generated alignments.
|
28 |
+
```bash
|
29 |
+
fairseq-preprocess \
|
30 |
+
--source-lang en --target-lang de \
|
31 |
+
--trainpref bpe.32k/train \
|
32 |
+
--validpref bpe.32k/valid \
|
33 |
+
--testpref bpe.32k/test \
|
34 |
+
--align-suffix align \
|
35 |
+
--destdir binarized/ \
|
36 |
+
--joined-dictionary \
|
37 |
+
--workers 32
|
38 |
+
```
|
39 |
+
|
40 |
+
##### 4. Train a model
|
41 |
+
```bash
|
42 |
+
fairseq-train \
|
43 |
+
binarized \
|
44 |
+
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
|
45 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
|
46 |
+
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
47 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
48 |
+
--max-tokens 3500 --label-smoothing 0.1 \
|
49 |
+
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
|
50 |
+
--keep-interval-updates -1 --save-interval-updates 0 \
|
51 |
+
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
|
52 |
+
--fp16
|
53 |
+
```
|
54 |
+
|
55 |
+
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
|
56 |
+
|
57 |
+
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
|
58 |
+
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
|
59 |
+
- increase the learning rate; 0.0007 works well for big batches
|
60 |
+
|
61 |
+
##### 5. Evaluate and generate the alignments (BPE level)
|
62 |
+
```bash
|
63 |
+
fairseq-generate \
|
64 |
+
binarized --gen-subset test --print-alignment \
|
65 |
+
--source-lang en --target-lang de \
|
66 |
+
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
|
67 |
+
```
|
68 |
+
|
69 |
+
##### 6. Other resources.
|
70 |
+
The code for:
|
71 |
+
1. preparing alignment test sets
|
72 |
+
2. converting BPE level alignments to token level alignments
|
73 |
+
3. symmetrizing bidirectional alignments
|
74 |
+
4. evaluating alignments using AER metric
|
75 |
+
can be found [here](https://github.com/lilt/alignment-scripts)
|
76 |
+
|
77 |
+
## Citation
|
78 |
+
|
79 |
+
```bibtex
|
80 |
+
@inproceedings{garg2019jointly,
|
81 |
+
title = {Jointly Learning to Align and Translate with Transformer Models},
|
82 |
+
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
|
83 |
+
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
|
84 |
+
address = {Hong Kong},
|
85 |
+
month = {November},
|
86 |
+
url = {https://arxiv.org/abs/1909.02074},
|
87 |
+
year = {2019},
|
88 |
+
}
|
89 |
+
```
|
fairseq/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the MIT license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
9 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
10 |
+
|
11 |
+
SCRIPTS=mosesdecoder/scripts
|
12 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
13 |
+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
|
14 |
+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
|
15 |
+
|
16 |
+
URLS=(
|
17 |
+
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
|
18 |
+
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
|
19 |
+
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
|
20 |
+
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
|
21 |
+
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
|
22 |
+
"http://statmt.org/wmt14/test-full.tgz"
|
23 |
+
)
|
24 |
+
CORPORA=(
|
25 |
+
"training/europarl-v7.de-en"
|
26 |
+
"commoncrawl.de-en"
|
27 |
+
"training-parallel-nc-v13/news-commentary-v13.de-en"
|
28 |
+
"rapid2016.de-en"
|
29 |
+
)
|
30 |
+
|
31 |
+
if [ ! -d "$SCRIPTS" ]; then
|
32 |
+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
|
33 |
+
exit
|
34 |
+
fi
|
35 |
+
|
36 |
+
src=en
|
37 |
+
tgt=de
|
38 |
+
lang=en-de
|
39 |
+
prep=wmt18_en_de
|
40 |
+
tmp=$prep/tmp
|
41 |
+
orig=orig
|
42 |
+
dev=dev/newstest2012
|
43 |
+
codes=32000
|
44 |
+
bpe=bpe.32k
|
45 |
+
|
46 |
+
mkdir -p $orig $tmp $prep $bpe
|
47 |
+
|
48 |
+
cd $orig
|
49 |
+
|
50 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
51 |
+
url=${URLS[i]}
|
52 |
+
file=$(basename $url)
|
53 |
+
if [ -f $file ]; then
|
54 |
+
echo "$file already exists, skipping download"
|
55 |
+
else
|
56 |
+
wget "$url"
|
57 |
+
if [ -f $file ]; then
|
58 |
+
echo "$url successfully downloaded."
|
59 |
+
else
|
60 |
+
echo "$url not successfully downloaded."
|
61 |
+
exit 1
|
62 |
+
fi
|
63 |
+
if [ ${file: -4} == ".tgz" ]; then
|
64 |
+
tar zxvf $file
|
65 |
+
elif [ ${file: -4} == ".tar" ]; then
|
66 |
+
tar xvf $file
|
67 |
+
fi
|
68 |
+
fi
|
69 |
+
done
|
70 |
+
cd ..
|
71 |
+
|
72 |
+
echo "pre-processing train data..."
|
73 |
+
for l in $src $tgt; do
|
74 |
+
rm -rf $tmp/train.tags.$lang.tok.$l
|
75 |
+
for f in "${CORPORA[@]}"; do
|
76 |
+
cat $orig/$f.$l | \
|
77 |
+
perl $REM_NON_PRINT_CHAR | \
|
78 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
|
79 |
+
done
|
80 |
+
done
|
81 |
+
|
82 |
+
echo "pre-processing test data..."
|
83 |
+
for l in $src $tgt; do
|
84 |
+
if [ "$l" == "$src" ]; then
|
85 |
+
t="src"
|
86 |
+
else
|
87 |
+
t="ref"
|
88 |
+
fi
|
89 |
+
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
|
90 |
+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
|
91 |
+
sed -e 's/\s*<\/seg>\s*//g' | \
|
92 |
+
sed -e "s/\’/\'/g" | \
|
93 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
|
94 |
+
echo ""
|
95 |
+
done
|
96 |
+
|
97 |
+
# apply length filtering before BPE
|
98 |
+
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100
|
99 |
+
|
100 |
+
# use newstest2012 for valid
|
101 |
+
echo "pre-processing valid data..."
|
102 |
+
for l in $src $tgt; do
|
103 |
+
rm -rf $tmp/valid.$l
|
104 |
+
cat $orig/$dev.$l | \
|
105 |
+
perl $REM_NON_PRINT_CHAR | \
|
106 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
|
107 |
+
done
|
108 |
+
|
109 |
+
mkdir output
|
110 |
+
mv $tmp/{train,valid,test}.{$src,$tgt} output
|
111 |
+
|
112 |
+
#BPE
|
113 |
+
git clone https://github.com/glample/fastBPE.git
|
114 |
+
pushd fastBPE
|
115 |
+
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
|
116 |
+
popd
|
117 |
+
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
|
118 |
+
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done
|
fairseq/examples/language_model/README.adaptive_inputs.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)
|
2 |
+
|
3 |
+
## Pre-trained models
|
4 |
+
|
5 |
+
Description | Parameters | Dataset | Model and Test set(s)
|
6 |
+
---|---:|---|---
|
7 |
+
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
8 |
+
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
9 |
+
|
10 |
+
## Training an LM with adaptive inputs
|
11 |
+
|
12 |
+
First, see the general [language modeling README](README.md) for instructions on
|
13 |
+
preprocessing the WikiText-103 data.
|
14 |
+
|
15 |
+
Then use the following training command to train a model with adaptive inputs
|
16 |
+
using the `transformer_lm_wiki103` model architecture:
|
17 |
+
```bash
|
18 |
+
fairseq-train --task language_modeling \
|
19 |
+
data-bin/wikitext-103 \
|
20 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
21 |
+
--arch transformer_lm_wiki103 \
|
22 |
+
--max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
|
23 |
+
--warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \
|
24 |
+
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
|
25 |
+
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=legacy_ddp
|
26 |
+
```
|
27 |
+
|
28 |
+
## Citation
|
29 |
+
|
30 |
+
```bibtex
|
31 |
+
@inproceedings{
|
32 |
+
baevski2018adaptive,
|
33 |
+
title={Adaptive Input Representations for Neural Language Modeling},
|
34 |
+
author={Alexei Baevski and Michael Auli},
|
35 |
+
booktitle={International Conference on Learning Representations},
|
36 |
+
year={2019},
|
37 |
+
url={https://openreview.net/forum?id=ByxZX20qFQ},
|
38 |
+
}
|
39 |
+
```
|
fairseq/examples/language_model/README.conv.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
|
2 |
+
|
3 |
+
## Example usage
|
4 |
+
|
5 |
+
First download and preprocess the data following the main [language modeling README](README.md).
|
6 |
+
|
7 |
+
Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
|
8 |
+
architecture:
|
9 |
+
```bash
|
10 |
+
fairseq-train --task language_modeling \
|
11 |
+
data-bin/wikitext-103 \
|
12 |
+
--save-dir checkpoints/fconv_wikitext-103 \
|
13 |
+
--arch fconv_lm_dauphin_wikitext103 \
|
14 |
+
--adaptive-softmax-cutoff 10000,20000,200000 \
|
15 |
+
--dropout 0.2 \
|
16 |
+
--criterion adaptive_loss \
|
17 |
+
--optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \
|
18 |
+
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
|
19 |
+
--max-tokens 1024 --tokens-per-sample 1024 \
|
20 |
+
--ddp-backend legacy_ddp \
|
21 |
+
--max-epoch 35
|
22 |
+
```
|
23 |
+
|
24 |
+
And evaluate with:
|
25 |
+
```bash
|
26 |
+
fairseq-eval-lm data-bin/wikitext-103 --path checkpoints/fconv_wiki103/checkpoint_best.pt
|
27 |
+
```
|
28 |
+
|
29 |
+
## Citation
|
30 |
+
|
31 |
+
```bibtex
|
32 |
+
@inproceedings{dauphin2017language,
|
33 |
+
title={Language Modeling with Gated Convolutional Networks},
|
34 |
+
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
|
35 |
+
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
|
36 |
+
pages={933--941},
|
37 |
+
year={2017},
|
38 |
+
organization={JMLR}
|
39 |
+
}
|
40 |
+
```
|
fairseq/examples/language_model/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neural Language Modeling
|
2 |
+
|
3 |
+
## Pre-trained models
|
4 |
+
|
5 |
+
Model | Description | Dataset | Download
|
6 |
+
---|---|---|---
|
7 |
+
`transformer_lm.gbw.adaptive_huge` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
8 |
+
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
9 |
+
`transformer_lm.wmt19.en` | English LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz)
|
10 |
+
`transformer_lm.wmt19.de` | German LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz)
|
11 |
+
`transformer_lm.wmt19.ru` | Russian LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz)
|
12 |
+
|
13 |
+
## Example usage
|
14 |
+
|
15 |
+
We require a few additional Python dependencies for preprocessing:
|
16 |
+
```bash
|
17 |
+
pip install fastBPE sacremoses
|
18 |
+
```
|
19 |
+
|
20 |
+
To sample from a language model using PyTorch Hub:
|
21 |
+
```python
|
22 |
+
import torch
|
23 |
+
|
24 |
+
# List available models
|
25 |
+
torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]
|
26 |
+
|
27 |
+
# Load an English LM trained on WMT'19 News Crawl data
|
28 |
+
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
|
29 |
+
en_lm.eval() # disable dropout
|
30 |
+
|
31 |
+
# Move model to GPU
|
32 |
+
en_lm.cuda()
|
33 |
+
|
34 |
+
# Sample from the language model
|
35 |
+
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
|
36 |
+
# "Barack Obama is coming to Sydney and New Zealand (...)"
|
37 |
+
|
38 |
+
# Compute perplexity for a sequence
|
39 |
+
en_lm.score('Barack Obama is coming to Sydney and New Zealand')['positional_scores'].mean().neg().exp()
|
40 |
+
# tensor(15.1474)
|
41 |
+
|
42 |
+
# The same interface can be used with custom models as well
|
43 |
+
from fairseq.models.transformer_lm import TransformerLanguageModel
|
44 |
+
custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe')
|
45 |
+
custom_lm.sample('Barack Obama', beam=5)
|
46 |
+
# "Barack Obama (...)"
|
47 |
+
```
|
48 |
+
|
49 |
+
## Training a transformer language model with the CLI tools
|
50 |
+
|
51 |
+
### 1) Preprocess the data
|
52 |
+
|
53 |
+
First download and prepare the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):
|
54 |
+
```bash
|
55 |
+
cd examples/language_model/
|
56 |
+
bash prepare-wikitext-103.sh
|
57 |
+
cd ../..
|
58 |
+
```
|
59 |
+
|
60 |
+
Next preprocess/binarize the data:
|
61 |
+
```bash
|
62 |
+
TEXT=examples/language_model/wikitext-103
|
63 |
+
fairseq-preprocess \
|
64 |
+
--only-source \
|
65 |
+
--trainpref $TEXT/wiki.train.tokens \
|
66 |
+
--validpref $TEXT/wiki.valid.tokens \
|
67 |
+
--testpref $TEXT/wiki.test.tokens \
|
68 |
+
--destdir data-bin/wikitext-103 \
|
69 |
+
--workers 20
|
70 |
+
```
|
71 |
+
|
72 |
+
### 2) Train a language model
|
73 |
+
|
74 |
+
Next we'll train a basic transformer language model on wikitext-103. For more
|
75 |
+
advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md).
|
76 |
+
|
77 |
+
To train a basic LM (assumes 2 GPUs):
|
78 |
+
```
|
79 |
+
$ fairseq-train --task language_modeling \
|
80 |
+
data-bin/wikitext-103 \
|
81 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
82 |
+
--arch transformer_lm --share-decoder-input-output-embed \
|
83 |
+
--dropout 0.1 \
|
84 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
|
85 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
86 |
+
--tokens-per-sample 512 --sample-break-mode none \
|
87 |
+
--max-tokens 2048 --update-freq 16 \
|
88 |
+
--fp16 \
|
89 |
+
--max-update 50000
|
90 |
+
```
|
91 |
+
|
92 |
+
If you run out of memory, try reducing `--max-tokens` (max number of tokens per
|
93 |
+
batch) or `--tokens-per-sample` (max sequence length). You can also adjust
|
94 |
+
`--update-freq` to accumulate gradients and simulate training on a different
|
95 |
+
number of GPUs.
|
96 |
+
|
97 |
+
### 3) Evaluate
|
98 |
+
|
99 |
+
```bash
|
100 |
+
fairseq-eval-lm data-bin/wikitext-103 \
|
101 |
+
--path checkpoints/transformer_wiki103/checkpoint_best.pt \
|
102 |
+
--batch-size 2 \
|
103 |
+
--tokens-per-sample 512 \
|
104 |
+
--context-window 400
|
105 |
+
# | Evaluated 245569 tokens in 56.1s (4379.02 tokens/s)
|
106 |
+
# | Loss: 3.4164, Perplexity: 30.46
|
107 |
+
```
|
108 |
+
|
109 |
+
*Note:* The `--context-window` option controls how much context is provided to
|
110 |
+
each token when computing perplexity. When the window size is 0, the dataset is
|
111 |
+
chunked into segments of length 512 and perplexity is computed over each segment
|
112 |
+
normally. However, this results in worse (higher) perplexity since tokens that
|
113 |
+
appear earlier in each segment have less conditioning. When the maximum window
|
114 |
+
size is used (511 in this case), then we compute perplexity for each token
|
115 |
+
fully conditioned on 511 tokens of context. This slows down evaluation
|
116 |
+
significantly, since we must run a separate forward pass for every token in the
|
117 |
+
dataset, but results in better (lower) perplexity.
|
118 |
+
|
119 |
+
|
120 |
+
## Convolutional language models
|
121 |
+
|
122 |
+
Please see the [convolutional LM README](README.conv.md) for instructions on
|
123 |
+
training convolutional language models.
|
fairseq/examples/language_model/prepare-wikitext-103.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
|
3 |
+
|
4 |
+
URLS=(
|
5 |
+
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip"
|
6 |
+
)
|
7 |
+
FILES=(
|
8 |
+
"wikitext-103-v1.zip"
|
9 |
+
)
|
10 |
+
|
11 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
12 |
+
file=${FILES[i]}
|
13 |
+
if [ -f $file ]; then
|
14 |
+
echo "$file already exists, skipping download"
|
15 |
+
else
|
16 |
+
url=${URLS[i]}
|
17 |
+
wget "$url"
|
18 |
+
if [ -f $file ]; then
|
19 |
+
echo "$url successfully downloaded."
|
20 |
+
else
|
21 |
+
echo "$url not successfully downloaded."
|
22 |
+
exit -1
|
23 |
+
fi
|
24 |
+
if [ ${file: -4} == ".tgz" ]; then
|
25 |
+
tar zxvf $file
|
26 |
+
elif [ ${file: -4} == ".tar" ]; then
|
27 |
+
tar xvf $file
|
28 |
+
elif [ ${file: -4} == ".zip" ]; then
|
29 |
+
unzip $file
|
30 |
+
fi
|
31 |
+
fi
|
32 |
+
done
|
33 |
+
cd ..
|
fairseq/examples/laser/README.md
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# LASER Language-Agnostic SEntence Representations
|
2 |
+
|
3 |
+
LASER is a library to calculate and use multilingual sentence embeddings.
|
4 |
+
|
5 |
+
You can find more information about LASER and how to use it on the official [LASER repository](https://github.com/facebookresearch/LASER).
|
6 |
+
|
7 |
+
This folder contains source code for training LASER embeddings.
|
8 |
+
|
9 |
+
|
10 |
+
## Prepare data and configuration file
|
11 |
+
|
12 |
+
Binarize your data with fairseq, as described [here](https://fairseq.readthedocs.io/en/latest/getting_started.html#data-pre-processing).
|
13 |
+
|
14 |
+
Create a json config file with this format:
|
15 |
+
```
|
16 |
+
{
|
17 |
+
"src_vocab": "/path/to/spm.src.cvocab",
|
18 |
+
"tgt_vocab": "/path/to/spm.tgt.cvocab",
|
19 |
+
"train": [
|
20 |
+
{
|
21 |
+
"type": "translation",
|
22 |
+
"id": 0,
|
23 |
+
"src": "/path/to/srclang1-tgtlang0/train.srclang1",
|
24 |
+
"tgt": "/path/to/srclang1-tgtlang0/train.tgtlang0"
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"type": "translation",
|
28 |
+
"id": 1,
|
29 |
+
"src": "/path/to/srclang1-tgtlang1/train.srclang1",
|
30 |
+
"tgt": "/path/to/srclang1-tgtlang1/train.tgtlang1"
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"type": "translation",
|
34 |
+
"id": 0,
|
35 |
+
"src": "/path/to/srclang2-tgtlang0/train.srclang2",
|
36 |
+
"tgt": "/path/to/srclang2-tgtlang0/train.tgtlang0"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"type": "translation",
|
40 |
+
"id": 1,
|
41 |
+
"src": "/path/to/srclang2-tgtlang1/train.srclang2",
|
42 |
+
"tgt": "/path/to/srclang2-tgtlang1/train.tgtlang1"
|
43 |
+
},
|
44 |
+
...
|
45 |
+
],
|
46 |
+
"valid": [
|
47 |
+
{
|
48 |
+
"type": "translation",
|
49 |
+
"id": 0,
|
50 |
+
"src": "/unused",
|
51 |
+
"tgt": "/unused"
|
52 |
+
}
|
53 |
+
]
|
54 |
+
}
|
55 |
+
```
|
56 |
+
where paths are paths to binarized indexed fairseq dataset files.
|
57 |
+
`id` represents the target language id.
|
58 |
+
|
59 |
+
|
60 |
+
## Training Command Line Example
|
61 |
+
|
62 |
+
```
|
63 |
+
fairseq-train \
|
64 |
+
/path/to/configfile_described_above.json \
|
65 |
+
--user-dir examples/laser/laser_src \
|
66 |
+
--log-interval 100 --log-format simple \
|
67 |
+
--task laser --arch laser_lstm \
|
68 |
+
--save-dir . \
|
69 |
+
--optimizer adam \
|
70 |
+
--lr 0.001 \
|
71 |
+
--lr-scheduler inverse_sqrt \
|
72 |
+
--clip-norm 5 \
|
73 |
+
--warmup-updates 90000 \
|
74 |
+
--update-freq 2 \
|
75 |
+
--dropout 0.0 \
|
76 |
+
--encoder-dropout-out 0.1 \
|
77 |
+
--max-tokens 2000 \
|
78 |
+
--max-epoch 50 \
|
79 |
+
--encoder-bidirectional \
|
80 |
+
--encoder-layers 5 \
|
81 |
+
--encoder-hidden-size 512 \
|
82 |
+
--decoder-layers 1 \
|
83 |
+
--decoder-hidden-size 2048 \
|
84 |
+
--encoder-embed-dim 320 \
|
85 |
+
--decoder-embed-dim 320 \
|
86 |
+
--decoder-lang-embed-dim 32 \
|
87 |
+
--warmup-init-lr 0.001 \
|
88 |
+
--disable-validation
|
89 |
+
```
|
90 |
+
|
91 |
+
|
92 |
+
## Applications
|
93 |
+
|
94 |
+
We showcase several applications of multilingual sentence embeddings
|
95 |
+
with code to reproduce our results (in the directory "tasks").
|
96 |
+
|
97 |
+
* [**Cross-lingual document classification**](https://github.com/facebookresearch/LASER/tree/master/tasks/mldoc) using the
|
98 |
+
[*MLDoc*](https://github.com/facebookresearch/MLDoc) corpus [2,6]
|
99 |
+
* [**WikiMatrix**](https://github.com/facebookresearch/LASER/tree/master/tasks/WikiMatrix)
|
100 |
+
Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia [7]
|
101 |
+
* [**Bitext mining**](https://github.com/facebookresearch/LASER/tree/master/tasks/bucc) using the
|
102 |
+
[*BUCC*](https://comparable.limsi.fr/bucc2018/bucc2018-task.html) corpus [3,5]
|
103 |
+
* [**Cross-lingual NLI**](https://github.com/facebookresearch/LASER/tree/master/tasks/xnli)
|
104 |
+
using the [*XNLI*](https://www.nyu.edu/projects/bowman/xnli/) corpus [4,5,6]
|
105 |
+
* [**Multilingual similarity search**](https://github.com/facebookresearch/LASER/tree/master/tasks/similarity) [1,6]
|
106 |
+
* [**Sentence embedding of text files**](https://github.com/facebookresearch/LASER/tree/master/tasks/embed)
|
107 |
+
example how to calculate sentence embeddings for arbitrary text files in any of the supported language.
|
108 |
+
|
109 |
+
**For all tasks, we use exactly the same multilingual encoder, without any task specific optimization or fine-tuning.**
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
## References
|
114 |
+
|
115 |
+
[1] Holger Schwenk and Matthijs Douze,
|
116 |
+
[*Learning Joint Multilingual Sentence Representations with Neural Machine Translation*](https://aclanthology.info/papers/W17-2619/w17-2619),
|
117 |
+
ACL workshop on Representation Learning for NLP, 2017
|
118 |
+
|
119 |
+
[2] Holger Schwenk and Xian Li,
|
120 |
+
[*A Corpus for Multilingual Document Classification in Eight Languages*](http://www.lrec-conf.org/proceedings/lrec2018/pdf/658.pdf),
|
121 |
+
LREC, pages 3548-3551, 2018.
|
122 |
+
|
123 |
+
[3] Holger Schwenk,
|
124 |
+
[*Filtering and Mining Parallel Data in a Joint Multilingual Space*](http://aclweb.org/anthology/P18-2037)
|
125 |
+
ACL, July 2018
|
126 |
+
|
127 |
+
[4] Alexis Conneau, Guillaume Lample, Ruty Rinott, Adina Williams, Samuel R. Bowman, Holger Schwenk and Veselin Stoyanov,
|
128 |
+
[*XNLI: Cross-lingual Sentence Understanding through Inference*](https://aclweb.org/anthology/D18-1269),
|
129 |
+
EMNLP, 2018.
|
130 |
+
|
131 |
+
[5] Mikel Artetxe and Holger Schwenk,
|
132 |
+
[*Margin-based Parallel Corpus Mining with Multilingual Sentence Embeddings*](https://arxiv.org/abs/1811.01136)
|
133 |
+
arXiv, Nov 3 2018.
|
134 |
+
|
135 |
+
[6] Mikel Artetxe and Holger Schwenk,
|
136 |
+
[*Massively Multilingual Sentence Embeddings for Zero-Shot Cross-Lingual Transfer and Beyond*](https://arxiv.org/abs/1812.10464)
|
137 |
+
arXiv, Dec 26 2018.
|
138 |
+
|
139 |
+
[7] Holger Schwenk, Vishrav Chaudhary, Shuo Sun, Hongyu Gong and Paco Guzman,
|
140 |
+
[*WikiMatrix: Mining 135M Parallel Sentences in 1620 Language Pairs from Wikipedia*](https://arxiv.org/abs/1907.05791)
|
141 |
+
arXiv, July 11 2019.
|
142 |
+
|
143 |
+
[8] Holger Schwenk, Guillaume Wenzek, Sergey Edunov, Edouard Grave and Armand Joulin
|
144 |
+
[*CCMatrix: Mining Billions of High-Quality Parallel Sentences on the WEB*](https://arxiv.org/abs/1911.04944)
|
fairseq/examples/laser/laser_src/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .laser_task import * # noqa
|
7 |
+
from .laser_lstm import * # noqa
|
8 |
+
from .laser_transformer import * # noqa
|
fairseq/examples/laser/laser_src/laser_lstm.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from fairseq import options, utils
|
11 |
+
|
12 |
+
from fairseq.models import (
|
13 |
+
FairseqEncoder,
|
14 |
+
FairseqIncrementalDecoder,
|
15 |
+
FairseqEncoderDecoderModel,
|
16 |
+
register_model,
|
17 |
+
register_model_architecture,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
@register_model("laser_lstm")
|
22 |
+
class LSTMModel(FairseqEncoderDecoderModel):
|
23 |
+
def __init__(self, encoder, decoder):
|
24 |
+
super().__init__(encoder, decoder)
|
25 |
+
|
26 |
+
def forward(
|
27 |
+
self,
|
28 |
+
src_tokens,
|
29 |
+
src_lengths,
|
30 |
+
prev_output_tokens=None,
|
31 |
+
tgt_tokens=None,
|
32 |
+
tgt_lengths=None,
|
33 |
+
target_language_id=None,
|
34 |
+
dataset_name="",
|
35 |
+
):
|
36 |
+
assert target_language_id is not None
|
37 |
+
|
38 |
+
src_encoder_out = self.encoder(src_tokens, src_lengths, dataset_name)
|
39 |
+
return self.decoder(
|
40 |
+
prev_output_tokens, src_encoder_out, lang_id=target_language_id
|
41 |
+
)
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def add_args(parser):
|
45 |
+
"""Add model-specific arguments to the parser."""
|
46 |
+
parser.add_argument(
|
47 |
+
"--dropout",
|
48 |
+
default=0.1,
|
49 |
+
type=float,
|
50 |
+
metavar="D",
|
51 |
+
help="dropout probability",
|
52 |
+
)
|
53 |
+
parser.add_argument(
|
54 |
+
"--encoder-embed-dim",
|
55 |
+
type=int,
|
56 |
+
metavar="N",
|
57 |
+
help="encoder embedding dimension",
|
58 |
+
)
|
59 |
+
parser.add_argument(
|
60 |
+
"--encoder-embed-path",
|
61 |
+
default=None,
|
62 |
+
type=str,
|
63 |
+
metavar="STR",
|
64 |
+
help="path to pre-trained encoder embedding",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--encoder-hidden-size", type=int, metavar="N", help="encoder hidden size"
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--encoder-layers", type=int, metavar="N", help="number of encoder layers"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--encoder-bidirectional",
|
74 |
+
action="store_true",
|
75 |
+
help="make all layers of encoder bidirectional",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--decoder-embed-dim",
|
79 |
+
type=int,
|
80 |
+
metavar="N",
|
81 |
+
help="decoder embedding dimension",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--decoder-embed-path",
|
85 |
+
default=None,
|
86 |
+
type=str,
|
87 |
+
metavar="STR",
|
88 |
+
help="path to pre-trained decoder embedding",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--decoder-hidden-size", type=int, metavar="N", help="decoder hidden size"
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--decoder-layers", type=int, metavar="N", help="number of decoder layers"
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--decoder-out-embed-dim",
|
98 |
+
type=int,
|
99 |
+
metavar="N",
|
100 |
+
help="decoder output embedding dimension",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--decoder-zero-init",
|
104 |
+
type=str,
|
105 |
+
metavar="BOOL",
|
106 |
+
help="initialize the decoder hidden/cell state to zero",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--decoder-lang-embed-dim",
|
110 |
+
type=int,
|
111 |
+
metavar="N",
|
112 |
+
help="decoder language embedding dimension",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--fixed-embeddings",
|
116 |
+
action="store_true",
|
117 |
+
help="keep embeddings fixed (ENCODER ONLY)",
|
118 |
+
) # TODO Also apply to decoder embeddings?
|
119 |
+
|
120 |
+
# Granular dropout settings (if not specified these default to --dropout)
|
121 |
+
parser.add_argument(
|
122 |
+
"--encoder-dropout-in",
|
123 |
+
type=float,
|
124 |
+
metavar="D",
|
125 |
+
help="dropout probability for encoder input embedding",
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--encoder-dropout-out",
|
129 |
+
type=float,
|
130 |
+
metavar="D",
|
131 |
+
help="dropout probability for encoder output",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--decoder-dropout-in",
|
135 |
+
type=float,
|
136 |
+
metavar="D",
|
137 |
+
help="dropout probability for decoder input embedding",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--decoder-dropout-out",
|
141 |
+
type=float,
|
142 |
+
metavar="D",
|
143 |
+
help="dropout probability for decoder output",
|
144 |
+
)
|
145 |
+
|
146 |
+
@classmethod
|
147 |
+
def build_model(cls, args, task):
|
148 |
+
"""Build a new model instance."""
|
149 |
+
# make sure that all args are properly defaulted (in case there are any new ones)
|
150 |
+
base_architecture(args)
|
151 |
+
|
152 |
+
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
|
153 |
+
num_embeddings = len(dictionary)
|
154 |
+
padding_idx = dictionary.pad()
|
155 |
+
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
|
156 |
+
embed_dict = utils.parse_embedding(embed_path)
|
157 |
+
utils.print_embed_overlap(embed_dict, dictionary)
|
158 |
+
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
|
159 |
+
|
160 |
+
pretrained_encoder_embed = None
|
161 |
+
if args.encoder_embed_path:
|
162 |
+
pretrained_encoder_embed = load_pretrained_embedding_from_file(
|
163 |
+
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim
|
164 |
+
)
|
165 |
+
pretrained_decoder_embed = None
|
166 |
+
if args.decoder_embed_path:
|
167 |
+
pretrained_decoder_embed = load_pretrained_embedding_from_file(
|
168 |
+
args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim
|
169 |
+
)
|
170 |
+
|
171 |
+
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
|
172 |
+
|
173 |
+
encoder = LSTMEncoder(
|
174 |
+
dictionary=task.source_dictionary,
|
175 |
+
embed_dim=args.encoder_embed_dim,
|
176 |
+
hidden_size=args.encoder_hidden_size,
|
177 |
+
num_layers=args.encoder_layers,
|
178 |
+
dropout_in=args.encoder_dropout_in,
|
179 |
+
dropout_out=args.encoder_dropout_out,
|
180 |
+
bidirectional=args.encoder_bidirectional,
|
181 |
+
pretrained_embed=pretrained_encoder_embed,
|
182 |
+
fixed_embeddings=args.fixed_embeddings,
|
183 |
+
)
|
184 |
+
decoder = LSTMDecoder(
|
185 |
+
dictionary=task.target_dictionary,
|
186 |
+
embed_dim=args.decoder_embed_dim,
|
187 |
+
hidden_size=args.decoder_hidden_size,
|
188 |
+
out_embed_dim=args.decoder_out_embed_dim,
|
189 |
+
num_layers=args.decoder_layers,
|
190 |
+
dropout_in=args.decoder_dropout_in,
|
191 |
+
dropout_out=args.decoder_dropout_out,
|
192 |
+
zero_init=options.eval_bool(args.decoder_zero_init),
|
193 |
+
encoder_embed_dim=args.encoder_embed_dim,
|
194 |
+
encoder_output_units=encoder.output_units,
|
195 |
+
pretrained_embed=pretrained_decoder_embed,
|
196 |
+
num_langs=num_langs,
|
197 |
+
lang_embed_dim=args.decoder_lang_embed_dim,
|
198 |
+
)
|
199 |
+
return cls(encoder, decoder)
|
200 |
+
|
201 |
+
|
202 |
+
class LSTMEncoder(FairseqEncoder):
|
203 |
+
"""LSTM encoder."""
|
204 |
+
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
dictionary,
|
208 |
+
embed_dim=512,
|
209 |
+
hidden_size=512,
|
210 |
+
num_layers=1,
|
211 |
+
dropout_in=0.1,
|
212 |
+
dropout_out=0.1,
|
213 |
+
bidirectional=False,
|
214 |
+
left_pad=True,
|
215 |
+
pretrained_embed=None,
|
216 |
+
padding_value=0.0,
|
217 |
+
fixed_embeddings=False,
|
218 |
+
):
|
219 |
+
super().__init__(dictionary)
|
220 |
+
self.num_layers = num_layers
|
221 |
+
self.dropout_in = dropout_in
|
222 |
+
self.dropout_out = dropout_out
|
223 |
+
self.bidirectional = bidirectional
|
224 |
+
self.hidden_size = hidden_size
|
225 |
+
|
226 |
+
num_embeddings = len(dictionary)
|
227 |
+
self.padding_idx = dictionary.pad()
|
228 |
+
if pretrained_embed is None:
|
229 |
+
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
|
230 |
+
else:
|
231 |
+
self.embed_tokens = pretrained_embed
|
232 |
+
if fixed_embeddings:
|
233 |
+
self.embed_tokens.weight.requires_grad = False
|
234 |
+
|
235 |
+
self.lstm = LSTM(
|
236 |
+
input_size=embed_dim,
|
237 |
+
hidden_size=hidden_size,
|
238 |
+
num_layers=num_layers,
|
239 |
+
dropout=self.dropout_out if num_layers > 1 else 0.0,
|
240 |
+
bidirectional=bidirectional,
|
241 |
+
)
|
242 |
+
self.left_pad = left_pad
|
243 |
+
self.padding_value = padding_value
|
244 |
+
|
245 |
+
self.output_units = hidden_size
|
246 |
+
if bidirectional:
|
247 |
+
self.output_units *= 2
|
248 |
+
|
249 |
+
def forward(self, src_tokens, src_lengths, dataset_name):
|
250 |
+
if self.left_pad:
|
251 |
+
# convert left-padding to right-padding
|
252 |
+
src_tokens = utils.convert_padding_direction(
|
253 |
+
src_tokens,
|
254 |
+
self.padding_idx,
|
255 |
+
left_to_right=True,
|
256 |
+
)
|
257 |
+
|
258 |
+
bsz, seqlen = src_tokens.size()
|
259 |
+
|
260 |
+
# embed tokens
|
261 |
+
x = self.embed_tokens(src_tokens)
|
262 |
+
x = F.dropout(x, p=self.dropout_in, training=self.training)
|
263 |
+
|
264 |
+
# B x T x C -> T x B x C
|
265 |
+
x = x.transpose(0, 1)
|
266 |
+
|
267 |
+
# pack embedded source tokens into a PackedSequence
|
268 |
+
try:
|
269 |
+
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
|
270 |
+
except BaseException:
|
271 |
+
raise Exception(f"Packing failed in dataset {dataset_name}")
|
272 |
+
|
273 |
+
# apply LSTM
|
274 |
+
if self.bidirectional:
|
275 |
+
state_size = 2 * self.num_layers, bsz, self.hidden_size
|
276 |
+
else:
|
277 |
+
state_size = self.num_layers, bsz, self.hidden_size
|
278 |
+
h0 = x.data.new(*state_size).zero_()
|
279 |
+
c0 = x.data.new(*state_size).zero_()
|
280 |
+
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
|
281 |
+
|
282 |
+
# unpack outputs and apply dropout
|
283 |
+
x, _ = nn.utils.rnn.pad_packed_sequence(
|
284 |
+
packed_outs, padding_value=self.padding_value
|
285 |
+
)
|
286 |
+
x = F.dropout(x, p=self.dropout_out, training=self.training)
|
287 |
+
assert list(x.size()) == [seqlen, bsz, self.output_units]
|
288 |
+
|
289 |
+
if self.bidirectional:
|
290 |
+
|
291 |
+
def combine_bidir(outs):
|
292 |
+
return torch.cat(
|
293 |
+
[
|
294 |
+
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(
|
295 |
+
1, bsz, self.output_units
|
296 |
+
)
|
297 |
+
for i in range(self.num_layers)
|
298 |
+
],
|
299 |
+
dim=0,
|
300 |
+
)
|
301 |
+
|
302 |
+
final_hiddens = combine_bidir(final_hiddens)
|
303 |
+
final_cells = combine_bidir(final_cells)
|
304 |
+
|
305 |
+
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
|
306 |
+
|
307 |
+
# Set padded outputs to -inf so they are not selected by max-pooling
|
308 |
+
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
|
309 |
+
if padding_mask.any():
|
310 |
+
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
|
311 |
+
|
312 |
+
# Build the sentence embedding by max-pooling over the encoder outputs
|
313 |
+
sentemb = x.max(dim=0)[0]
|
314 |
+
|
315 |
+
return {
|
316 |
+
"sentemb": sentemb,
|
317 |
+
"encoder_out": (x, final_hiddens, final_cells),
|
318 |
+
"encoder_padding_mask": encoder_padding_mask
|
319 |
+
if encoder_padding_mask.any()
|
320 |
+
else None,
|
321 |
+
}
|
322 |
+
|
323 |
+
def reorder_encoder_out(self, encoder_out_dict, new_order):
|
324 |
+
encoder_out_dict["sentemb"] = encoder_out_dict["sentemb"].index_select(
|
325 |
+
0, new_order
|
326 |
+
)
|
327 |
+
encoder_out_dict["encoder_out"] = tuple(
|
328 |
+
eo.index_select(1, new_order) for eo in encoder_out_dict["encoder_out"]
|
329 |
+
)
|
330 |
+
if encoder_out_dict["encoder_padding_mask"] is not None:
|
331 |
+
encoder_out_dict["encoder_padding_mask"] = encoder_out_dict[
|
332 |
+
"encoder_padding_mask"
|
333 |
+
].index_select(1, new_order)
|
334 |
+
return encoder_out_dict
|
335 |
+
|
336 |
+
def max_positions(self):
|
337 |
+
"""Maximum input length supported by the encoder."""
|
338 |
+
return int(1e5) # an arbitrary large number
|
339 |
+
|
340 |
+
|
341 |
+
class LSTMDecoder(FairseqIncrementalDecoder):
|
342 |
+
"""LSTM decoder."""
|
343 |
+
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
dictionary,
|
347 |
+
embed_dim=512,
|
348 |
+
hidden_size=512,
|
349 |
+
out_embed_dim=512,
|
350 |
+
num_layers=1,
|
351 |
+
dropout_in=0.1,
|
352 |
+
dropout_out=0.1,
|
353 |
+
zero_init=False,
|
354 |
+
encoder_embed_dim=512,
|
355 |
+
encoder_output_units=512,
|
356 |
+
pretrained_embed=None,
|
357 |
+
num_langs=1,
|
358 |
+
lang_embed_dim=0,
|
359 |
+
):
|
360 |
+
super().__init__(dictionary)
|
361 |
+
self.dropout_in = dropout_in
|
362 |
+
self.dropout_out = dropout_out
|
363 |
+
self.hidden_size = hidden_size
|
364 |
+
|
365 |
+
num_embeddings = len(dictionary)
|
366 |
+
padding_idx = dictionary.pad()
|
367 |
+
if pretrained_embed is None:
|
368 |
+
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
|
369 |
+
else:
|
370 |
+
self.embed_tokens = pretrained_embed
|
371 |
+
|
372 |
+
self.layers = nn.ModuleList(
|
373 |
+
[
|
374 |
+
LSTMCell(
|
375 |
+
input_size=encoder_output_units + embed_dim + lang_embed_dim
|
376 |
+
if layer == 0
|
377 |
+
else hidden_size,
|
378 |
+
hidden_size=hidden_size,
|
379 |
+
)
|
380 |
+
for layer in range(num_layers)
|
381 |
+
]
|
382 |
+
)
|
383 |
+
if hidden_size != out_embed_dim:
|
384 |
+
self.additional_fc = Linear(hidden_size, out_embed_dim)
|
385 |
+
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
|
386 |
+
|
387 |
+
if zero_init:
|
388 |
+
self.sentemb2init = None
|
389 |
+
else:
|
390 |
+
self.sentemb2init = Linear(
|
391 |
+
encoder_output_units, 2 * num_layers * hidden_size
|
392 |
+
)
|
393 |
+
|
394 |
+
if lang_embed_dim == 0:
|
395 |
+
self.embed_lang = None
|
396 |
+
else:
|
397 |
+
self.embed_lang = nn.Embedding(num_langs, lang_embed_dim)
|
398 |
+
nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
|
399 |
+
|
400 |
+
def forward(
|
401 |
+
self, prev_output_tokens, encoder_out_dict, incremental_state=None, lang_id=0
|
402 |
+
):
|
403 |
+
sentemb = encoder_out_dict["sentemb"]
|
404 |
+
encoder_out = encoder_out_dict["encoder_out"]
|
405 |
+
|
406 |
+
if incremental_state is not None:
|
407 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
408 |
+
bsz, seqlen = prev_output_tokens.size()
|
409 |
+
|
410 |
+
# get outputs from encoder
|
411 |
+
encoder_outs, _, _ = encoder_out[:3]
|
412 |
+
srclen = encoder_outs.size(0)
|
413 |
+
|
414 |
+
# embed tokens
|
415 |
+
x = self.embed_tokens(prev_output_tokens)
|
416 |
+
x = F.dropout(x, p=self.dropout_in, training=self.training)
|
417 |
+
|
418 |
+
# embed language identifier
|
419 |
+
if self.embed_lang is not None:
|
420 |
+
lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
|
421 |
+
langemb = self.embed_lang(lang_ids)
|
422 |
+
# TODO Should we dropout here???
|
423 |
+
|
424 |
+
# B x T x C -> T x B x C
|
425 |
+
x = x.transpose(0, 1)
|
426 |
+
|
427 |
+
# initialize previous states (or get from cache during incremental generation)
|
428 |
+
cached_state = utils.get_incremental_state(
|
429 |
+
self, incremental_state, "cached_state"
|
430 |
+
)
|
431 |
+
if cached_state is not None:
|
432 |
+
prev_hiddens, prev_cells, input_feed = cached_state
|
433 |
+
else:
|
434 |
+
num_layers = len(self.layers)
|
435 |
+
if self.sentemb2init is None:
|
436 |
+
prev_hiddens = [
|
437 |
+
x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
|
438 |
+
]
|
439 |
+
prev_cells = [
|
440 |
+
x.data.new(bsz, self.hidden_size).zero_() for i in range(num_layers)
|
441 |
+
]
|
442 |
+
else:
|
443 |
+
init = self.sentemb2init(sentemb)
|
444 |
+
prev_hiddens = [
|
445 |
+
init[:, (2 * i) * self.hidden_size : (2 * i + 1) * self.hidden_size]
|
446 |
+
for i in range(num_layers)
|
447 |
+
]
|
448 |
+
prev_cells = [
|
449 |
+
init[
|
450 |
+
:,
|
451 |
+
(2 * i + 1) * self.hidden_size : (2 * i + 2) * self.hidden_size,
|
452 |
+
]
|
453 |
+
for i in range(num_layers)
|
454 |
+
]
|
455 |
+
input_feed = x.data.new(bsz, self.hidden_size).zero_()
|
456 |
+
|
457 |
+
attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
|
458 |
+
outs = []
|
459 |
+
for j in range(seqlen):
|
460 |
+
if self.embed_lang is None:
|
461 |
+
input = torch.cat((x[j, :, :], sentemb), dim=1)
|
462 |
+
else:
|
463 |
+
input = torch.cat((x[j, :, :], sentemb, langemb), dim=1)
|
464 |
+
|
465 |
+
for i, rnn in enumerate(self.layers):
|
466 |
+
# recurrent cell
|
467 |
+
hidden, cell = rnn(input, (prev_hiddens[i], prev_cells[i]))
|
468 |
+
|
469 |
+
# hidden state becomes the input to the next layer
|
470 |
+
input = F.dropout(hidden, p=self.dropout_out, training=self.training)
|
471 |
+
|
472 |
+
# save state for next time step
|
473 |
+
prev_hiddens[i] = hidden
|
474 |
+
prev_cells[i] = cell
|
475 |
+
|
476 |
+
out = hidden
|
477 |
+
out = F.dropout(out, p=self.dropout_out, training=self.training)
|
478 |
+
|
479 |
+
# input feeding
|
480 |
+
input_feed = out
|
481 |
+
|
482 |
+
# save final output
|
483 |
+
outs.append(out)
|
484 |
+
|
485 |
+
# cache previous states (no-op except during incremental generation)
|
486 |
+
utils.set_incremental_state(
|
487 |
+
self,
|
488 |
+
incremental_state,
|
489 |
+
"cached_state",
|
490 |
+
(prev_hiddens, prev_cells, input_feed),
|
491 |
+
)
|
492 |
+
|
493 |
+
# collect outputs across time steps
|
494 |
+
x = torch.cat(outs, dim=0).view(seqlen, bsz, self.hidden_size)
|
495 |
+
|
496 |
+
# T x B x C -> B x T x C
|
497 |
+
x = x.transpose(1, 0)
|
498 |
+
|
499 |
+
# srclen x tgtlen x bsz -> bsz x tgtlen x srclen
|
500 |
+
attn_scores = attn_scores.transpose(0, 2)
|
501 |
+
|
502 |
+
# project back to size of vocabulary
|
503 |
+
if hasattr(self, "additional_fc"):
|
504 |
+
x = self.additional_fc(x)
|
505 |
+
x = F.dropout(x, p=self.dropout_out, training=self.training)
|
506 |
+
x = self.fc_out(x)
|
507 |
+
|
508 |
+
return x, attn_scores
|
509 |
+
|
510 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
511 |
+
super().reorder_incremental_state(incremental_state, new_order)
|
512 |
+
cached_state = utils.get_incremental_state(
|
513 |
+
self, incremental_state, "cached_state"
|
514 |
+
)
|
515 |
+
if cached_state is None:
|
516 |
+
return
|
517 |
+
|
518 |
+
def reorder_state(state):
|
519 |
+
if isinstance(state, list):
|
520 |
+
return [reorder_state(state_i) for state_i in state]
|
521 |
+
return state.index_select(0, new_order)
|
522 |
+
|
523 |
+
new_state = tuple(map(reorder_state, cached_state))
|
524 |
+
utils.set_incremental_state(self, incremental_state, "cached_state", new_state)
|
525 |
+
|
526 |
+
def max_positions(self):
|
527 |
+
"""Maximum output length supported by the decoder."""
|
528 |
+
return int(1e5) # an arbitrary large number
|
529 |
+
|
530 |
+
|
531 |
+
def Embedding(num_embeddings, embedding_dim, padding_idx):
|
532 |
+
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
533 |
+
nn.init.uniform_(m.weight, -0.1, 0.1)
|
534 |
+
nn.init.constant_(m.weight[padding_idx], 0)
|
535 |
+
return m
|
536 |
+
|
537 |
+
|
538 |
+
def LSTM(input_size, hidden_size, **kwargs):
|
539 |
+
m = nn.LSTM(input_size, hidden_size, **kwargs)
|
540 |
+
for name, param in m.named_parameters():
|
541 |
+
if "weight" in name or "bias" in name:
|
542 |
+
param.data.uniform_(-0.1, 0.1)
|
543 |
+
return m
|
544 |
+
|
545 |
+
|
546 |
+
def LSTMCell(input_size, hidden_size, **kwargs):
|
547 |
+
m = nn.LSTMCell(input_size, hidden_size, **kwargs)
|
548 |
+
for name, param in m.named_parameters():
|
549 |
+
if "weight" in name or "bias" in name:
|
550 |
+
param.data.uniform_(-0.1, 0.1)
|
551 |
+
return m
|
552 |
+
|
553 |
+
|
554 |
+
def Linear(in_features, out_features, bias=True, dropout=0):
|
555 |
+
"""Weight-normalized Linear layer (input: N x T x C)"""
|
556 |
+
m = nn.Linear(in_features, out_features, bias=bias)
|
557 |
+
m.weight.data.uniform_(-0.1, 0.1)
|
558 |
+
if bias:
|
559 |
+
m.bias.data.uniform_(-0.1, 0.1)
|
560 |
+
return m
|
561 |
+
|
562 |
+
|
563 |
+
@register_model_architecture("laser_lstm", "laser_lstm")
|
564 |
+
def base_architecture(args):
|
565 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
566 |
+
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
|
567 |
+
args.encoder_hidden_size = getattr(
|
568 |
+
args, "encoder_hidden_size", args.encoder_embed_dim
|
569 |
+
)
|
570 |
+
args.encoder_layers = getattr(args, "encoder_layers", 1)
|
571 |
+
args.encoder_bidirectional = getattr(args, "encoder_bidirectional", False)
|
572 |
+
args.encoder_dropout_in = getattr(args, "encoder_dropout_in", args.dropout)
|
573 |
+
args.encoder_dropout_out = getattr(args, "encoder_dropout_out", args.dropout)
|
574 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
575 |
+
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
|
576 |
+
args.decoder_hidden_size = getattr(
|
577 |
+
args, "decoder_hidden_size", args.decoder_embed_dim
|
578 |
+
)
|
579 |
+
args.decoder_layers = getattr(args, "decoder_layers", 1)
|
580 |
+
args.decoder_out_embed_dim = getattr(args, "decoder_out_embed_dim", 512)
|
581 |
+
args.decoder_dropout_in = getattr(args, "decoder_dropout_in", args.dropout)
|
582 |
+
args.decoder_dropout_out = getattr(args, "decoder_dropout_out", args.dropout)
|
583 |
+
args.decoder_zero_init = getattr(args, "decoder_zero_init", "0")
|
584 |
+
args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
|
585 |
+
args.fixed_embeddings = getattr(args, "fixed_embeddings", False)
|
fairseq/examples/laser/laser_src/laser_task.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
from collections import OrderedDict, defaultdict
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import logging
|
11 |
+
from argparse import ArgumentError
|
12 |
+
|
13 |
+
from fairseq import options, models
|
14 |
+
from fairseq.data import (
|
15 |
+
data_utils,
|
16 |
+
Dictionary,
|
17 |
+
LanguagePairDataset,
|
18 |
+
IndexedDataset,
|
19 |
+
FairseqDataset,
|
20 |
+
)
|
21 |
+
from .multitask_data_utils import (
|
22 |
+
MultitaskDatasetWrapper,
|
23 |
+
MultidatasetEpochBatchIterator,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@register_task("laser")
|
33 |
+
class LaserTask(LegacyFairseqTask):
|
34 |
+
@staticmethod
|
35 |
+
def add_args(parser):
|
36 |
+
"""Add task-specific arguments to the parser."""
|
37 |
+
parser.add_argument(
|
38 |
+
"configfile", metavar="PATH", help="dataset configuration file in json"
|
39 |
+
)
|
40 |
+
parser.add_argument(
|
41 |
+
"--weighting-alpha",
|
42 |
+
type=float,
|
43 |
+
default=None,
|
44 |
+
help="alpha for automatic weighting",
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--raw-text", action="store_true", help="load raw text dataset"
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--left-pad-source",
|
51 |
+
default="True",
|
52 |
+
type=str,
|
53 |
+
metavar="BOOL",
|
54 |
+
help="pad the source on the left (default: True)",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--left-pad-target",
|
58 |
+
default="False",
|
59 |
+
type=str,
|
60 |
+
metavar="BOOL",
|
61 |
+
help="pad the target on the left (default: False)",
|
62 |
+
)
|
63 |
+
try:
|
64 |
+
parser.add_argument(
|
65 |
+
"--max-source-positions",
|
66 |
+
default=1024,
|
67 |
+
type=int,
|
68 |
+
metavar="N",
|
69 |
+
help="max number of tokens in the source sequence",
|
70 |
+
)
|
71 |
+
parser.add_argument(
|
72 |
+
"--max-target-positions",
|
73 |
+
default=1024,
|
74 |
+
type=int,
|
75 |
+
metavar="N",
|
76 |
+
help="max number of tokens in the target sequence",
|
77 |
+
)
|
78 |
+
except ArgumentError:
|
79 |
+
# this might have already been defined. Once we transition this to hydra it should be fine to add it here.
|
80 |
+
pass
|
81 |
+
|
82 |
+
def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
|
83 |
+
super().__init__(args)
|
84 |
+
self.config = config
|
85 |
+
self.src_dictionary = src_dictionary
|
86 |
+
self.tgt_dictionary = tgt_dictionary
|
87 |
+
self.num_tasks = num_tasks
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def setup_task(cls, args, **kwargs):
|
91 |
+
with open(args.configfile, "r") as f:
|
92 |
+
config = json.load(f)
|
93 |
+
num_tasks = max(dataset["id"] for dataset in config["train"]) + 1
|
94 |
+
|
95 |
+
args.left_pad_source = options.eval_bool(args.left_pad_source)
|
96 |
+
args.left_pad_target = options.eval_bool(args.left_pad_target)
|
97 |
+
|
98 |
+
src_dictionary = Dictionary.load(config["src_vocab"])
|
99 |
+
tgt_dictionary = Dictionary.load(config["tgt_vocab"])
|
100 |
+
|
101 |
+
logger.info(
|
102 |
+
"| src Dictionary {} : {} types".format(
|
103 |
+
config["src_vocab"], len(src_dictionary)
|
104 |
+
)
|
105 |
+
)
|
106 |
+
logger.info(
|
107 |
+
"| tgt Dictionary {} : {} types".format(
|
108 |
+
config["tgt_vocab"], len(tgt_dictionary)
|
109 |
+
)
|
110 |
+
)
|
111 |
+
|
112 |
+
return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)
|
113 |
+
|
114 |
+
# Experimental overriding for backtranslation
|
115 |
+
def build_model(self, args, from_checkpoint=False):
|
116 |
+
model = models.build_model(args, self)
|
117 |
+
return model
|
118 |
+
|
119 |
+
def dataset(self, split):
|
120 |
+
if split not in self.datasets:
|
121 |
+
raise KeyError("Dataset not loaded: " + split)
|
122 |
+
return self.datasets[split]
|
123 |
+
|
124 |
+
def load_dataset(self, split, epoch=1, **kwargs):
|
125 |
+
"""Load a dataset split."""
|
126 |
+
|
127 |
+
def indexed_dataset(path, dictionary):
|
128 |
+
if self.args.raw_text:
|
129 |
+
raise Exception("Unable to handle raw text.")
|
130 |
+
dataset = IndexedDataset(path, fix_lua_indexing=True)
|
131 |
+
|
132 |
+
return dataset
|
133 |
+
|
134 |
+
pair_datasets = OrderedDict()
|
135 |
+
|
136 |
+
if split == "valid":
|
137 |
+
self.datasets[split] = pair_datasets
|
138 |
+
return
|
139 |
+
|
140 |
+
if split not in self.config:
|
141 |
+
raise FileNotFoundError(
|
142 |
+
"Dataset not found in config file: {}".format(split)
|
143 |
+
)
|
144 |
+
|
145 |
+
size_by_corpus = defaultdict(int)
|
146 |
+
size_sum = 0
|
147 |
+
size_sum_with_subsampling = 0
|
148 |
+
init_pair_datasets = {}
|
149 |
+
|
150 |
+
for dataset_config in self.config[split]:
|
151 |
+
src_path = os.path.dirname(dataset_config["src"])
|
152 |
+
corpus_name = src_path.split("/")[-2]
|
153 |
+
language_pair_name = src_path.split("/")[-1]
|
154 |
+
pair_datasets_key = corpus_name + "-" + language_pair_name
|
155 |
+
|
156 |
+
logger.info(f"loading... {pair_datasets_key}")
|
157 |
+
if "src" in dataset_config:
|
158 |
+
src_dataset = indexed_dataset(
|
159 |
+
dataset_config["src"], self.src_dictionary
|
160 |
+
)
|
161 |
+
else:
|
162 |
+
src_dataset = None
|
163 |
+
|
164 |
+
if "tgt" in dataset_config:
|
165 |
+
tgt_dataset = indexed_dataset(
|
166 |
+
dataset_config["tgt"], self.tgt_dictionary
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
tgt_dataset = None
|
170 |
+
|
171 |
+
dataset = LanguagePairDataset(
|
172 |
+
src_dataset,
|
173 |
+
src_dataset.sizes,
|
174 |
+
self.src_dictionary,
|
175 |
+
tgt_dataset,
|
176 |
+
tgt_dataset.sizes,
|
177 |
+
self.tgt_dictionary,
|
178 |
+
left_pad_source=self.args.left_pad_source,
|
179 |
+
left_pad_target=self.args.left_pad_target,
|
180 |
+
)
|
181 |
+
|
182 |
+
if pair_datasets_key in init_pair_datasets:
|
183 |
+
logger.warning(
|
184 |
+
f"Ignoring already added {pair_datasets_key}. "
|
185 |
+
f"Consider using `sample` key in order to upsample."
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
init_pair_datasets[pair_datasets_key] = {
|
189 |
+
"dataset": dataset,
|
190 |
+
"sample": dataset_config.get("sample", None),
|
191 |
+
"id": dataset_config.get("id", None),
|
192 |
+
"len": len(dataset),
|
193 |
+
}
|
194 |
+
|
195 |
+
length_sum = 0
|
196 |
+
weighted_freqs_sum = 0
|
197 |
+
freq_per_dataset = {}
|
198 |
+
vmax = 0
|
199 |
+
vmin = 1
|
200 |
+
weighted_freq_per_dataset = {}
|
201 |
+
|
202 |
+
if self.args.weighting_alpha:
|
203 |
+
for key in init_pair_datasets:
|
204 |
+
if init_pair_datasets[key]["sample"] is None:
|
205 |
+
length_sum += len(init_pair_datasets[key]["dataset"])
|
206 |
+
|
207 |
+
for key in init_pair_datasets:
|
208 |
+
if init_pair_datasets[key]["sample"] is None:
|
209 |
+
val = float(init_pair_datasets[key]["len"]) / length_sum
|
210 |
+
freq_per_dataset[key] = val
|
211 |
+
weighted_freqs_sum += val ** self.args.weighting_alpha
|
212 |
+
|
213 |
+
for key in freq_per_dataset:
|
214 |
+
val = (
|
215 |
+
freq_per_dataset[key] ** self.args.weighting_alpha
|
216 |
+
/ weighted_freqs_sum
|
217 |
+
)
|
218 |
+
vmin = min(vmin, val)
|
219 |
+
vmax = max(vmax, val)
|
220 |
+
weighted_freq_per_dataset[key] = val
|
221 |
+
|
222 |
+
for pair_datasets_key in init_pair_datasets:
|
223 |
+
dataset_config = init_pair_datasets[pair_datasets_key]
|
224 |
+
dataset = dataset_config["dataset"]
|
225 |
+
sample = dataset_config["sample"]
|
226 |
+
if sample is None:
|
227 |
+
sample = 1.0
|
228 |
+
|
229 |
+
if pair_datasets_key in weighted_freq_per_dataset:
|
230 |
+
w = vmax / weighted_freq_per_dataset[pair_datasets_key]
|
231 |
+
sample = w
|
232 |
+
|
233 |
+
sample = round(sample)
|
234 |
+
|
235 |
+
initial_sample = sample
|
236 |
+
initial_pair_datasets_key = pair_datasets_key
|
237 |
+
|
238 |
+
while sample >= 1.0:
|
239 |
+
assert (
|
240 |
+
pair_datasets_key not in pair_datasets
|
241 |
+
), f"{pair_datasets_key} already in"
|
242 |
+
size_sum_with_subsampling += len(dataset)
|
243 |
+
pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
|
244 |
+
dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
|
245 |
+
)
|
246 |
+
size_sum += len(dataset)
|
247 |
+
sample -= 1.0
|
248 |
+
pair_datasets_key += "-up"
|
249 |
+
|
250 |
+
assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"
|
251 |
+
|
252 |
+
logger.info(
|
253 |
+
f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
|
254 |
+
)
|
255 |
+
size_by_corpus[corpus_name] += len(dataset)
|
256 |
+
|
257 |
+
self.datasets[split] = pair_datasets
|
258 |
+
logger.info(
|
259 |
+
f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
|
260 |
+
)
|
261 |
+
|
262 |
+
@property
|
263 |
+
def source_dictionary(self):
|
264 |
+
return self.src_dictionary
|
265 |
+
|
266 |
+
@property
|
267 |
+
def target_dictionary(self):
|
268 |
+
return self.tgt_dictionary
|
269 |
+
|
270 |
+
def get_batch_iterator(
|
271 |
+
self,
|
272 |
+
dataset,
|
273 |
+
max_tokens=None,
|
274 |
+
max_sentences=None,
|
275 |
+
max_positions=None,
|
276 |
+
ignore_invalid_inputs=False,
|
277 |
+
required_batch_size_multiple=1,
|
278 |
+
seed=1,
|
279 |
+
num_shards=1,
|
280 |
+
shard_id=0,
|
281 |
+
num_workers=0,
|
282 |
+
epoch=1,
|
283 |
+
data_buffer_size=0,
|
284 |
+
disable_iterator_cache=False,
|
285 |
+
grouped_shuffling=False,
|
286 |
+
update_epoch_batch_itr=False,
|
287 |
+
**kwargs,
|
288 |
+
):
|
289 |
+
|
290 |
+
assert isinstance(dataset, OrderedDict)
|
291 |
+
assert len(dataset)
|
292 |
+
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
|
293 |
+
|
294 |
+
# initialize the dataset with the correct starting epoch
|
295 |
+
for _, dt in dataset.items():
|
296 |
+
dt.set_epoch(epoch)
|
297 |
+
|
298 |
+
indices = OrderedDict()
|
299 |
+
batch_sampler = OrderedDict()
|
300 |
+
|
301 |
+
with data_utils.numpy_seed(seed + epoch):
|
302 |
+
for key, dt in dataset.items():
|
303 |
+
logger.info(f"\t ordered_indices {key}")
|
304 |
+
indices[key] = dt.ordered_indices()
|
305 |
+
|
306 |
+
# filter examples that are too large
|
307 |
+
if max_positions is not None:
|
308 |
+
for key, dt in dataset.items():
|
309 |
+
logger.info(f"\t filter_by_size {key}")
|
310 |
+
indices[key], ignored = dt.filter_indices_by_size(
|
311 |
+
indices[key], max_positions
|
312 |
+
)
|
313 |
+
|
314 |
+
for key, dt in dataset.items():
|
315 |
+
logger.info(f"\t batch_by_size {key}")
|
316 |
+
batch_sampler[key] = data_utils.batch_by_size(
|
317 |
+
indices[key],
|
318 |
+
dt.num_tokens,
|
319 |
+
max_tokens=max_tokens,
|
320 |
+
max_sentences=max_sentences,
|
321 |
+
required_batch_size_multiple=required_batch_size_multiple,
|
322 |
+
)
|
323 |
+
|
324 |
+
epoch_iter = MultidatasetEpochBatchIterator(
|
325 |
+
dataset=dataset,
|
326 |
+
batch_sampler=batch_sampler,
|
327 |
+
seed=seed,
|
328 |
+
num_shards=num_shards,
|
329 |
+
shard_id=shard_id,
|
330 |
+
num_workers=num_workers,
|
331 |
+
epoch=epoch,
|
332 |
+
)
|
333 |
+
|
334 |
+
return epoch_iter
|
fairseq/examples/laser/laser_src/laser_transformer.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from typing import Any, Dict, List, Optional
|
9 |
+
from torch import Tensor
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
from fairseq.models import (
|
15 |
+
FairseqEncoderDecoderModel,
|
16 |
+
register_model,
|
17 |
+
register_model_architecture,
|
18 |
+
)
|
19 |
+
from fairseq.models.transformer import (
|
20 |
+
base_architecture,
|
21 |
+
Embedding,
|
22 |
+
TransformerModel,
|
23 |
+
TransformerEncoder,
|
24 |
+
TransformerDecoder,
|
25 |
+
)
|
26 |
+
from fairseq.modules import (
|
27 |
+
TransformerDecoderLayer,
|
28 |
+
)
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
@register_model("laser_transformer")
|
34 |
+
class LaserTransformerModel(FairseqEncoderDecoderModel):
|
35 |
+
"""Train Transformer for LASER task
|
36 |
+
|
37 |
+
Requires --task laser
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, encoder, decoder):
|
41 |
+
super().__init__(encoder, decoder)
|
42 |
+
|
43 |
+
def forward(
|
44 |
+
self,
|
45 |
+
src_tokens,
|
46 |
+
src_lengths,
|
47 |
+
prev_output_tokens=None,
|
48 |
+
tgt_tokens=None,
|
49 |
+
tgt_lengths=None,
|
50 |
+
target_language_id=-1,
|
51 |
+
dataset_name="",
|
52 |
+
):
|
53 |
+
laser_encoder_out = self.encoder(src_tokens, src_lengths)
|
54 |
+
return self.decoder(
|
55 |
+
prev_output_tokens, laser_encoder_out, lang_id=target_language_id
|
56 |
+
)
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def add_args(parser):
|
60 |
+
"""Add model-specific arguments to the parser."""
|
61 |
+
TransformerModel.add_args(parser)
|
62 |
+
parser.add_argument(
|
63 |
+
"--decoder-lang-embed-dim",
|
64 |
+
type=int,
|
65 |
+
metavar="N",
|
66 |
+
help="decoder language embedding dimension",
|
67 |
+
)
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def build_model(cls, args, task):
|
71 |
+
base_laser_transformer_architecture(args)
|
72 |
+
|
73 |
+
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
|
74 |
+
|
75 |
+
def load_embed_tokens(dictionary, embed_dim):
|
76 |
+
num_embeddings = len(dictionary)
|
77 |
+
padding_idx = dictionary.pad()
|
78 |
+
|
79 |
+
return Embedding(num_embeddings, embed_dim, padding_idx)
|
80 |
+
|
81 |
+
encoder_embed_tokens = load_embed_tokens(
|
82 |
+
task.source_dictionary, args.encoder_embed_dim
|
83 |
+
)
|
84 |
+
decoder_embed_tokens = load_embed_tokens(
|
85 |
+
task.target_dictionary, args.decoder_embed_dim
|
86 |
+
)
|
87 |
+
num_langs = task.num_tasks if hasattr(task, "num_tasks") else 0
|
88 |
+
|
89 |
+
encoder = LaserTransformerEncoder(
|
90 |
+
args, task.source_dictionary, encoder_embed_tokens
|
91 |
+
)
|
92 |
+
|
93 |
+
decoder = LaserTransformerDecoder(
|
94 |
+
args,
|
95 |
+
task.target_dictionary,
|
96 |
+
decoder_embed_tokens,
|
97 |
+
num_langs=num_langs,
|
98 |
+
lang_embed_dim=args.decoder_lang_embed_dim,
|
99 |
+
)
|
100 |
+
|
101 |
+
return cls(encoder, decoder)
|
102 |
+
|
103 |
+
|
104 |
+
class LaserTransformerEncoder(TransformerEncoder):
|
105 |
+
def __init__(self, *args, **kwargs):
|
106 |
+
super().__init__(*args, **kwargs)
|
107 |
+
|
108 |
+
def forward(self, src_tokens, *args, **kwargs):
|
109 |
+
encoder_out = super().forward(src_tokens, *args, **kwargs)
|
110 |
+
|
111 |
+
x = encoder_out["encoder_out"][0] # T x B x C
|
112 |
+
padding_mask = src_tokens.eq(self.padding_idx).t().unsqueeze(-1)
|
113 |
+
|
114 |
+
if padding_mask.any():
|
115 |
+
x = x.float().masked_fill_(padding_mask, float("-inf")).type_as(x)
|
116 |
+
|
117 |
+
# Build the sentence embedding by max-pooling over the encoder outputs
|
118 |
+
sentemb = x.max(dim=0)[0]
|
119 |
+
|
120 |
+
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
|
121 |
+
# `foward` so we use a dictionary instead.
|
122 |
+
# TorchScript does not support mixed values so the values are all lists.
|
123 |
+
# The empty list is equivalent to None.
|
124 |
+
return {"sentemb": [sentemb]} # B x C
|
125 |
+
|
126 |
+
@torch.jit.export
|
127 |
+
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order):
|
128 |
+
"""
|
129 |
+
Same as the one in transformer.py, with new_sentemb
|
130 |
+
"""
|
131 |
+
if len(encoder_out["sentemb"]) == 0:
|
132 |
+
new_sentemb = []
|
133 |
+
else:
|
134 |
+
new_sentemb = [encoder_out["sentemb"][0].index_select(0, new_order)]
|
135 |
+
|
136 |
+
return {
|
137 |
+
"sentemb": new_sentemb, # B x C
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
class LaserTransformerDecoder(TransformerDecoder):
|
142 |
+
def __init__(self, args, dictionary, *kargs, **kwargs):
|
143 |
+
self.num_langs = kwargs.get("num_langs", 1)
|
144 |
+
self.lang_embed_dim = kwargs.get("lang_embed_dim", 0)
|
145 |
+
kwargs.pop("num_langs", None)
|
146 |
+
kwargs.pop("lang_embed_dim", None)
|
147 |
+
|
148 |
+
super().__init__(args, dictionary, *kargs, **kwargs, no_encoder_attn=True)
|
149 |
+
|
150 |
+
if self.lang_embed_dim == 0:
|
151 |
+
self.embed_lang = None
|
152 |
+
else:
|
153 |
+
self.embed_lang = nn.Embedding(self.num_langs, self.lang_embed_dim)
|
154 |
+
nn.init.uniform_(self.embed_lang.weight, -0.1, 0.1)
|
155 |
+
|
156 |
+
if self.output_projection is not None:
|
157 |
+
laser_output_embed_dim = (
|
158 |
+
self.output_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
|
159 |
+
)
|
160 |
+
self.output_projection = nn.Linear(
|
161 |
+
laser_output_embed_dim, len(dictionary), bias=False
|
162 |
+
)
|
163 |
+
nn.init.normal_(
|
164 |
+
self.output_projection.weight,
|
165 |
+
mean=0,
|
166 |
+
std=laser_output_embed_dim ** -0.5,
|
167 |
+
)
|
168 |
+
|
169 |
+
def build_decoder_layer(self, args, no_encoder_attn=False):
|
170 |
+
decoder_embed_dim = args.decoder_embed_dim
|
171 |
+
args.decoder_embed_dim = (
|
172 |
+
decoder_embed_dim + self.lang_embed_dim + args.encoder_embed_dim
|
173 |
+
)
|
174 |
+
res = TransformerDecoderLayer(args, no_encoder_attn=True)
|
175 |
+
args.decoder_embed_dim = decoder_embed_dim
|
176 |
+
|
177 |
+
return res
|
178 |
+
|
179 |
+
def extract_features(
|
180 |
+
self,
|
181 |
+
prev_output_tokens,
|
182 |
+
encoder_out: Optional[Dict[str, List[Tensor]]],
|
183 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
184 |
+
full_context_alignment: bool = False,
|
185 |
+
alignment_layer: Optional[int] = None,
|
186 |
+
alignment_heads: Optional[int] = None,
|
187 |
+
lang_id: Optional[int] = None,
|
188 |
+
):
|
189 |
+
"""
|
190 |
+
Similar to *forward* but only return features.
|
191 |
+
|
192 |
+
Includes several features from "Jointly Learning to Align and
|
193 |
+
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
194 |
+
|
195 |
+
Args:
|
196 |
+
full_context_alignment (bool, optional): don't apply
|
197 |
+
auto-regressive mask to self-attention (default: False).
|
198 |
+
alignment_layer (int, optional): return mean alignment over
|
199 |
+
heads at this layer (default: last layer).
|
200 |
+
alignment_heads (int, optional): only average alignment over
|
201 |
+
this many heads (default: all heads).
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
tuple:
|
205 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
206 |
+
- a dictionary with any model-specific outputs
|
207 |
+
"""
|
208 |
+
if alignment_layer is None:
|
209 |
+
alignment_layer = self.num_layers - 1
|
210 |
+
|
211 |
+
# embed positions
|
212 |
+
positions = (
|
213 |
+
self.embed_positions(
|
214 |
+
prev_output_tokens, incremental_state=incremental_state
|
215 |
+
)
|
216 |
+
if self.embed_positions is not None
|
217 |
+
else None
|
218 |
+
)
|
219 |
+
|
220 |
+
if incremental_state is not None:
|
221 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
222 |
+
if positions is not None:
|
223 |
+
positions = positions[:, -1:]
|
224 |
+
|
225 |
+
bsz, seqlen = prev_output_tokens.size()
|
226 |
+
|
227 |
+
# embed tokens and positions
|
228 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
229 |
+
|
230 |
+
if self.quant_noise is not None:
|
231 |
+
x = self.quant_noise(x)
|
232 |
+
|
233 |
+
if self.project_in_dim is not None:
|
234 |
+
x = self.project_in_dim(x)
|
235 |
+
|
236 |
+
if positions is not None:
|
237 |
+
x += positions
|
238 |
+
|
239 |
+
if self.layernorm_embedding is not None:
|
240 |
+
x = self.layernorm_embedding(x)
|
241 |
+
|
242 |
+
x = self.dropout_module(x)
|
243 |
+
|
244 |
+
# B x T x C -> T x B x C
|
245 |
+
x = x.transpose(0, 1)
|
246 |
+
|
247 |
+
if self.embed_lang is not None:
|
248 |
+
lang_ids = prev_output_tokens.data.new_full((bsz,), lang_id)
|
249 |
+
langemb = self.embed_lang(lang_ids)
|
250 |
+
langemb = langemb.unsqueeze(0)
|
251 |
+
repeat_vals = [x.shape[0] // langemb.shape[0]] + [-1] * (
|
252 |
+
len(langemb.shape) - 1
|
253 |
+
)
|
254 |
+
x = torch.cat((x, langemb.expand(*repeat_vals)), dim=-1)
|
255 |
+
|
256 |
+
sentemb = encoder_out["sentemb"][0]
|
257 |
+
sentemb = sentemb.unsqueeze(0)
|
258 |
+
|
259 |
+
repeat_vals = [x.shape[0] // sentemb.shape[0]] + [-1] * (len(sentemb.shape) - 1)
|
260 |
+
x = torch.cat((x, sentemb.expand(*repeat_vals)), dim=-1)
|
261 |
+
|
262 |
+
self_attn_padding_mask: Optional[Tensor] = None
|
263 |
+
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
264 |
+
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
265 |
+
|
266 |
+
# decoder layers
|
267 |
+
attn: Optional[Tensor] = None
|
268 |
+
inner_states: List[Optional[Tensor]] = [x]
|
269 |
+
for idx, layer in enumerate(self.layers):
|
270 |
+
if incremental_state is None and not full_context_alignment:
|
271 |
+
self_attn_mask = self.buffered_future_mask(x)
|
272 |
+
else:
|
273 |
+
self_attn_mask = None
|
274 |
+
|
275 |
+
x, layer_attn, _ = layer(
|
276 |
+
x,
|
277 |
+
None,
|
278 |
+
None,
|
279 |
+
incremental_state,
|
280 |
+
self_attn_mask=self_attn_mask,
|
281 |
+
self_attn_padding_mask=self_attn_padding_mask,
|
282 |
+
need_attn=bool((idx == alignment_layer)),
|
283 |
+
need_head_weights=bool((idx == alignment_layer)),
|
284 |
+
)
|
285 |
+
inner_states.append(x)
|
286 |
+
if layer_attn is not None and idx == alignment_layer:
|
287 |
+
attn = layer_attn.float().to(x)
|
288 |
+
|
289 |
+
if attn is not None:
|
290 |
+
if alignment_heads is not None:
|
291 |
+
attn = attn[:alignment_heads]
|
292 |
+
|
293 |
+
# average probabilities over heads
|
294 |
+
attn = attn.mean(dim=0)
|
295 |
+
|
296 |
+
if self.layer_norm is not None:
|
297 |
+
x = self.layer_norm(x)
|
298 |
+
|
299 |
+
# T x B x C -> B x T x C
|
300 |
+
x = x.transpose(0, 1)
|
301 |
+
|
302 |
+
if self.project_out_dim is not None:
|
303 |
+
x = self.project_out_dim(x)
|
304 |
+
|
305 |
+
return x, {"attn": [attn], "inner_states": inner_states}
|
306 |
+
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
prev_output_tokens,
|
310 |
+
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
|
311 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
312 |
+
features_only: bool = False,
|
313 |
+
alignment_layer: Optional[int] = None,
|
314 |
+
alignment_heads: Optional[int] = None,
|
315 |
+
src_lengths: Optional[Any] = None,
|
316 |
+
return_all_hiddens: bool = False,
|
317 |
+
lang_id: Optional[int] = None,
|
318 |
+
):
|
319 |
+
"""
|
320 |
+
Args:
|
321 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
322 |
+
`(batch, tgt_len)`, for teacher forcing
|
323 |
+
encoder_out (optional): output from the encoder, used for
|
324 |
+
encoder-side attention
|
325 |
+
incremental_state (dict): dictionary used for storing state during
|
326 |
+
:ref:`Incremental decoding`
|
327 |
+
features_only (bool, optional): only return features without
|
328 |
+
applying output layer (default: False).
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
tuple:
|
332 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
333 |
+
- a dictionary with any model-specific outputs
|
334 |
+
"""
|
335 |
+
|
336 |
+
assert lang_id is not None
|
337 |
+
|
338 |
+
x, extra = self.extract_features(
|
339 |
+
prev_output_tokens,
|
340 |
+
encoder_out=encoder_out,
|
341 |
+
incremental_state=incremental_state,
|
342 |
+
alignment_layer=alignment_layer,
|
343 |
+
alignment_heads=alignment_heads,
|
344 |
+
lang_id=lang_id,
|
345 |
+
)
|
346 |
+
if not features_only:
|
347 |
+
x = self.output_layer(x)
|
348 |
+
return x, extra
|
349 |
+
|
350 |
+
|
351 |
+
@register_model_architecture("laser_transformer", "laser_transformer")
|
352 |
+
def base_laser_transformer_architecture(args):
|
353 |
+
base_architecture(args)
|
354 |
+
args.decoder_lang_embed_dim = getattr(args, "decoder_lang_embed_dim", 0)
|
fairseq/examples/laser/laser_src/multitask_data_utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from fairseq.data import BaseWrapperDataset, FairseqDataset, iterators
|
11 |
+
|
12 |
+
|
13 |
+
class MultiItr(object):
|
14 |
+
def __init__(self, itr):
|
15 |
+
self.itr = itr
|
16 |
+
self._counts = [0 for x in itr]
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return sum(len(itr) for itr in self.itr)
|
20 |
+
|
21 |
+
def __iter__(self):
|
22 |
+
return self
|
23 |
+
|
24 |
+
def __next__(self):
|
25 |
+
ratios = [count / len(itr) for count, itr in zip(self._counts, self.itr)]
|
26 |
+
idx = ratios.index(min(ratios))
|
27 |
+
self._counts[idx] += 1
|
28 |
+
return next(self.itr[idx])
|
29 |
+
|
30 |
+
|
31 |
+
class MultidatasetEpochBatchIterator(iterators.EpochBatchIterating):
|
32 |
+
"""A wrapper around multiple epoch batch iterators."""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
dataset,
|
37 |
+
batch_sampler,
|
38 |
+
seed=1,
|
39 |
+
num_shards=1,
|
40 |
+
shard_id=0,
|
41 |
+
num_workers=0,
|
42 |
+
epoch=1,
|
43 |
+
):
|
44 |
+
|
45 |
+
assert isinstance(dataset, OrderedDict)
|
46 |
+
assert len(dataset)
|
47 |
+
assert isinstance(dataset[next(iter(dataset))], FairseqDataset)
|
48 |
+
|
49 |
+
self.iterators = []
|
50 |
+
|
51 |
+
self.epoch = epoch
|
52 |
+
for key, dt in dataset.items():
|
53 |
+
epoch_iter = iterators.EpochBatchIterator(
|
54 |
+
dataset=dt,
|
55 |
+
collate_fn=dt.collater,
|
56 |
+
batch_sampler=batch_sampler[key],
|
57 |
+
seed=seed,
|
58 |
+
num_shards=num_shards,
|
59 |
+
shard_id=shard_id,
|
60 |
+
num_workers=0,
|
61 |
+
epoch=epoch,
|
62 |
+
)
|
63 |
+
self.iterators.append(epoch_iter)
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return sum(len(itr) for itr in self.iterators)
|
67 |
+
|
68 |
+
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
|
69 |
+
# `self.epoch += 1` should be handled by underlying `EpochBatchIterator`s.
|
70 |
+
return MultiItr(
|
71 |
+
[
|
72 |
+
itr.next_epoch_itr(
|
73 |
+
shuffle=shuffle, fix_batches_to_gpus=fix_batches_to_gpus
|
74 |
+
)
|
75 |
+
for itr in self.iterators
|
76 |
+
]
|
77 |
+
)
|
78 |
+
|
79 |
+
def end_of_epoch(self):
|
80 |
+
return all(itr.end_of_epoch() for itr in self.iterators)
|
81 |
+
|
82 |
+
@property
|
83 |
+
def next_epoch_idx(self):
|
84 |
+
"""Return the epoch index after *next_epoch_itr* is called."""
|
85 |
+
|
86 |
+
epochs = [itr.next_epoch_idx for itr in self.iterators]
|
87 |
+
self.epoch = epochs[0]
|
88 |
+
assert all(epoch == self.epoch for epoch in epochs)
|
89 |
+
|
90 |
+
return self.epoch
|
91 |
+
|
92 |
+
@property
|
93 |
+
def iterations_in_epoch(self):
|
94 |
+
return sum(itr.iterations_in_epoch for itr in self.iterators)
|
95 |
+
|
96 |
+
def state_dict(self):
|
97 |
+
return {
|
98 |
+
"iterators": [it.state_dict() for it in self.iterators],
|
99 |
+
"epoch": self.epoch,
|
100 |
+
}
|
101 |
+
|
102 |
+
def load_state_dict(self, state_dict):
|
103 |
+
self.epoch = state_dict["epoch"]
|
104 |
+
for it, d in zip(self.iterators, state_dict["iterators"]):
|
105 |
+
it.load_state_dict(d)
|
106 |
+
|
107 |
+
|
108 |
+
class MultitaskDatasetWrapper(BaseWrapperDataset):
|
109 |
+
"""A wrapper for a multitask dataset."""
|
110 |
+
|
111 |
+
def __init__(self, dataset, target_language_id, sample=1.0, name=""):
|
112 |
+
super().__init__(dataset)
|
113 |
+
self.target_language_id = target_language_id
|
114 |
+
self.sample = sample
|
115 |
+
self.name = name
|
116 |
+
|
117 |
+
def collater(self, *args, **kwargs):
|
118 |
+
ans = self.dataset.collater(*args, **kwargs)
|
119 |
+
if "net_input" in ans:
|
120 |
+
ans["net_input"]["target_language_id"] = self.target_language_id
|
121 |
+
ans["net_input"]["dataset_name"] = self.name
|
122 |
+
return ans
|
123 |
+
|
124 |
+
def num_tokens(self, *args, **kwargs):
|
125 |
+
return self.dataset.num_tokens(*args, **kwargs)
|
126 |
+
|
127 |
+
def ordered_indices(self, *args, **kwargs):
|
128 |
+
indices = self.dataset.ordered_indices(*args, **kwargs)
|
129 |
+
# Hacky solution for sampling
|
130 |
+
size = int(self.sample * indices.shape[0])
|
131 |
+
|
132 |
+
return indices.take(np.sort(np.random.permutation(indices.shape[0])[:size]))
|
133 |
+
|
134 |
+
def size(self, index: int):
|
135 |
+
return self.dataset.size(index)
|
136 |
+
|
137 |
+
@property
|
138 |
+
def supports_prefetch(self):
|
139 |
+
"""Whether this dataset supports prefetching."""
|
140 |
+
return getattr(self.dataset, "supports_prefetch", False)
|
141 |
+
|
142 |
+
def prefetch(self, indices):
|
143 |
+
return self.dataset.prefetch(indices)
|
fairseq/examples/latent_depth/README.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Deep Transformers with Latent Depth (Li et al., 2020)
|
2 |
+
|
3 |
+
[https://arxiv.org/abs/2009.13102](https://arxiv.org/abs/2009.13102).
|
4 |
+
|
5 |
+
## Introduction
|
6 |
+
|
7 |
+
We present a probabilistic framework to automatically learn which layer(s) to use by learning the posterior distributions of layer selection. As an extension of this framework, we propose a novel method to train one shared Transformer network for multilingual machine translation with different layer selection posteriors for each language pair.
|
8 |
+
|
9 |
+
## Training a multilingual model with latent depth
|
10 |
+
|
11 |
+
Below is an example of training with latent depth in decoder for one-to-many (O2M) related languages. We use the same preprocessed (numberized and binarized) TED8 dataset as in [Balancing Training for Multilingual Neural Machine Translation (Wang et al., 2020)](https://github.com/cindyxinyiwang/multiDDS), which could be generated by [the script](https://github.com/cindyxinyiwang/multiDDS/blob/multiDDS/util_scripts/prepare_multilingual_data.sh) the author provided.
|
12 |
+
```bash
|
13 |
+
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
|
14 |
+
databin_dir=<path to binarized data>
|
15 |
+
|
16 |
+
fairseq-train ${databin_dir} \
|
17 |
+
--user-dir examples/latent_depth/latent_depth_src \
|
18 |
+
--lang-pairs "${lang_pairs_str}" \
|
19 |
+
--arch multilingual_transformer_iwslt_de_en \
|
20 |
+
--task multilingual_translation_latent_depth \
|
21 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
22 |
+
--share-encoders \
|
23 |
+
--share-decoders \
|
24 |
+
--decoder-langtok \
|
25 |
+
--share-decoder-input-output-embed \
|
26 |
+
--dropout 0.3 --attention-dropout 0.3 \
|
27 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
28 |
+
--lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \
|
29 |
+
--max-tokens 4096 --update-freq 1 \
|
30 |
+
--lr 0.0015 \
|
31 |
+
--clip-norm 1.0 \
|
32 |
+
--seed 2 \
|
33 |
+
--ddp-backend=legacy_ddp \
|
34 |
+
--encoder-layers 12 \
|
35 |
+
--decoder-layers 24 \
|
36 |
+
--decoder-latent-layer \
|
37 |
+
--sparsity-weight 0.1 \
|
38 |
+
--anneal-updates 5000 \
|
39 |
+
--soft-update 500 \
|
40 |
+
--target-layers 12 \
|
41 |
+
--share-weight 0.1
|
42 |
+
```
|
43 |
+
## Inference command
|
44 |
+
|
45 |
+
```bash
|
46 |
+
lang_pairs_str="eng-aze,eng-bel,eng-ces,eng-glg,eng-por,eng-rus,eng-slk,eng-tur"
|
47 |
+
databin_dir=<path to binarized data>
|
48 |
+
model_path=<path to checkpoint>
|
49 |
+
src_lang=<source language to translate from>
|
50 |
+
tgt_lang=<target language to translate to>
|
51 |
+
gen_data=<name of data split, e.g. valid, test, etc>
|
52 |
+
|
53 |
+
fairseq-generate ${databin_dir} \
|
54 |
+
--path ${model_path} \
|
55 |
+
--task multilingual_translation_latent_depth \
|
56 |
+
--decoder-latent-layer \
|
57 |
+
--lang-pairs "${lang_pairs_str}" \
|
58 |
+
-s ${src_lang} -t ${tgt_lang} \
|
59 |
+
--gen-subset $gen_data \
|
60 |
+
--scoring sacrebleu \
|
61 |
+
--remove-bpe 'sentencepiece' \
|
62 |
+
--lenpen 1.0 \
|
63 |
+
--beam 5 \
|
64 |
+
--decoder-langtok \
|
65 |
+
--max-tokens 4096
|
66 |
+
```
|
67 |
+
|
68 |
+
|
69 |
+
## Citation
|
70 |
+
```bibtex
|
71 |
+
@article{li2020deep,
|
72 |
+
title={Deep Transformers with Latent Depth},
|
73 |
+
author={Li, Xian and Stickland, Asa Cooper and Tang, Yuqing and Kong, Xiang},
|
74 |
+
journal={arXiv preprint arXiv:2009.13102},
|
75 |
+
year={2020}
|
76 |
+
}
|
77 |
+
```
|
fairseq/examples/latent_depth/latent_depth_src/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from . import multilingual_translation_latent_depth # noqa
|
7 |
+
from .loss import latent_depth # noqa
|
8 |
+
from .models import latent_multilingual_transformer # noqa
|
9 |
+
from .modules import latent_layers # noqa
|
fairseq/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from fairseq.models import register_model, register_model_architecture
|
7 |
+
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
|
8 |
+
from fairseq.models.transformer import (
|
9 |
+
TransformerDecoder,
|
10 |
+
TransformerEncoder,
|
11 |
+
base_architecture,
|
12 |
+
)
|
13 |
+
from fairseq.utils import safe_hasattr
|
14 |
+
|
15 |
+
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
|
16 |
+
|
17 |
+
|
18 |
+
@register_model("latent_multilingual_transformer")
|
19 |
+
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
20 |
+
"""A variant of standard multilingual Transformer models which encoder and/or
|
21 |
+
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
|
22 |
+
(https://arxiv.org/abs/2009.13102).
|
23 |
+
"""
|
24 |
+
|
25 |
+
@staticmethod
|
26 |
+
def add_args(parser):
|
27 |
+
"""Add model-specific arguments to the parser."""
|
28 |
+
MultilingualTransformerModel.add_args(parser)
|
29 |
+
parser.add_argument(
|
30 |
+
'--soft-select',
|
31 |
+
action='store_true',
|
32 |
+
help='use soft samples in training an inference',
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
'--sampling-tau',
|
36 |
+
type=float,
|
37 |
+
default=5.,
|
38 |
+
help='sampling temperature',
|
39 |
+
)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
43 |
+
if is_encoder:
|
44 |
+
if safe_hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
|
45 |
+
return LatentTransformerEncoder(
|
46 |
+
args, lang_dict, embed_tokens, num_logits=len(langs)
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
return TransformerEncoder(args, lang_dict, embed_tokens)
|
50 |
+
else:
|
51 |
+
if safe_hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
|
52 |
+
return LatentTransformerDecoder(
|
53 |
+
args, lang_dict, embed_tokens, num_logits=len(langs)
|
54 |
+
)
|
55 |
+
else:
|
56 |
+
return TransformerDecoder(args, lang_dict, embed_tokens)
|
57 |
+
|
58 |
+
|
59 |
+
@register_model_architecture(
|
60 |
+
"latent_multilingual_transformer", "latent_multilingual_transformer"
|
61 |
+
)
|
62 |
+
def latent_multilingual_architecture(args):
|
63 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
64 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
65 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
66 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
67 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
68 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
69 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
70 |
+
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
71 |
+
args.share_encoders = getattr(args, "share_encoders", True)
|
72 |
+
args.share_decoders = getattr(args, "share_decoders", True)
|
73 |
+
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
|
74 |
+
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
|
75 |
+
|
76 |
+
base_architecture(args)
|
fairseq/examples/latent_depth/latent_depth_src/modules/__init__.py
ADDED
File without changes
|
fairseq/examples/latent_depth/latent_depth_src/modules/latent_layers.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class LayerSelect(nn.Module):
|
11 |
+
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
|
12 |
+
either (soft) weighting or (hard) selection of residual connection.
|
13 |
+
https://arxiv.org/abs/2009.13102
|
14 |
+
"""
|
15 |
+
def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.):
|
16 |
+
super(LayerSelect, self).__init__()
|
17 |
+
self.layer_logits = torch.nn.Parameter(
|
18 |
+
torch.Tensor(num_logits, num_layers),
|
19 |
+
requires_grad=True,
|
20 |
+
)
|
21 |
+
self.hard_select = not soft_select
|
22 |
+
self.tau = sampling_tau
|
23 |
+
self.detach_grad = False
|
24 |
+
self.layer_samples = [None] * num_logits
|
25 |
+
|
26 |
+
def sample(self, logit_idx):
|
27 |
+
"""To leverage the efficiency of distributed training, samples for all
|
28 |
+
layers are computed at once for each logit_idx. Logits are parameters
|
29 |
+
learnt independent of each other.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
logit_idx: The index of logit parameters used for sampling.
|
33 |
+
"""
|
34 |
+
assert logit_idx is not None
|
35 |
+
self.samples = self._gumbel_sigmoid(
|
36 |
+
self.layer_logits[logit_idx, :].detach()
|
37 |
+
if self.detach_grad
|
38 |
+
else self.layer_logits[logit_idx, :],
|
39 |
+
dim=-1,
|
40 |
+
tau=self.tau,
|
41 |
+
hard=self.hard_select,
|
42 |
+
)
|
43 |
+
self.layer_samples[logit_idx] = self.samples
|
44 |
+
|
45 |
+
def forward(self, i):
|
46 |
+
sample = self.samples[i]
|
47 |
+
return sample
|
48 |
+
|
49 |
+
def _gumbel_sigmoid(
|
50 |
+
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
|
51 |
+
):
|
52 |
+
# ~Gumbel(0,1)
|
53 |
+
gumbels1 = (
|
54 |
+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
55 |
+
.exponential_()
|
56 |
+
.log()
|
57 |
+
)
|
58 |
+
gumbels2 = (
|
59 |
+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
60 |
+
.exponential_()
|
61 |
+
.log()
|
62 |
+
)
|
63 |
+
# Difference of two gumbels because we apply a sigmoid
|
64 |
+
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
65 |
+
y_soft = gumbels1.sigmoid()
|
66 |
+
if hard:
|
67 |
+
# Straight through.
|
68 |
+
y_hard = torch.zeros_like(
|
69 |
+
logits, memory_format=torch.legacy_contiguous_format
|
70 |
+
).masked_fill(y_soft > threshold, 1.0)
|
71 |
+
ret = y_hard - y_soft.detach() + y_soft
|
72 |
+
else:
|
73 |
+
# Reparametrization trick.
|
74 |
+
ret = y_soft
|
75 |
+
return ret
|
fairseq/examples/linformer/README.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)
|
2 |
+
|
3 |
+
This example contains code to train Linformer models as described in our paper
|
4 |
+
[Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768).
|
5 |
+
|
6 |
+
## Training a new Linformer RoBERTa model
|
7 |
+
|
8 |
+
You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md),
|
9 |
+
updating your training command with `--user-dir examples/linformer/linformer_src --arch linformer_roberta_base`.
|
10 |
+
|
11 |
+
## Citation
|
12 |
+
|
13 |
+
If you use our work, please cite:
|
14 |
+
|
15 |
+
```bibtex
|
16 |
+
@article{wang2020linformer,
|
17 |
+
title={Linformer: Self-Attention with Linear Complexity},
|
18 |
+
author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao},
|
19 |
+
journal={arXiv preprint arXiv:2006.04768},
|
20 |
+
year={2020}
|
21 |
+
}
|
22 |
+
```
|
fairseq/examples/linformer/linformer_src/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .models import linformer_roberta # noqa
|
fairseq/examples/linformer/linformer_src/models/linformer_roberta.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
"""
|
6 |
+
Linformer: Self-Attention with Linear Complexity
|
7 |
+
"""
|
8 |
+
|
9 |
+
import logging
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from fairseq import utils
|
13 |
+
from fairseq.models import register_model, register_model_architecture
|
14 |
+
from fairseq.models.roberta import (
|
15 |
+
init_bert_params,
|
16 |
+
roberta_base_architecture,
|
17 |
+
roberta_large_architecture,
|
18 |
+
RobertaEncoder,
|
19 |
+
RobertaModel,
|
20 |
+
)
|
21 |
+
from fairseq.utils import safe_hasattr
|
22 |
+
|
23 |
+
from ..modules.linformer_sentence_encoder import LinformerTransformerEncoder
|
24 |
+
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
@register_model("linformer_roberta")
|
30 |
+
class LinformerModel(RobertaModel):
|
31 |
+
@staticmethod
|
32 |
+
def add_args(parser):
|
33 |
+
RobertaModel.add_args(parser)
|
34 |
+
|
35 |
+
# add args for Linformer
|
36 |
+
parser.add_argument(
|
37 |
+
"--compressed", type=int, help="compressed ratio of sequence length"
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--shared-kv-compressed",
|
41 |
+
type=int,
|
42 |
+
help="share compressed matrix between k and v, in each layer",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--shared-layer-kv-compressed",
|
46 |
+
type=int,
|
47 |
+
help="share compressed matrix between k and v and across all layers",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--freeze-compress",
|
51 |
+
type=int,
|
52 |
+
help="freeze the parameters in compressed layer",
|
53 |
+
)
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def build_model(cls, args, task):
|
57 |
+
"""Build a new model instance."""
|
58 |
+
|
59 |
+
# make sure all arguments are present
|
60 |
+
base_architecture(args)
|
61 |
+
|
62 |
+
if not safe_hasattr(args, "max_positions"):
|
63 |
+
args.max_positions = args.tokens_per_sample
|
64 |
+
|
65 |
+
encoder = LinformerEncoder(args, task.source_dictionary)
|
66 |
+
return cls(args, encoder)
|
67 |
+
|
68 |
+
|
69 |
+
class LinformerEncoder(RobertaEncoder):
|
70 |
+
"""Linformer encoder."""
|
71 |
+
|
72 |
+
def __init__(self, args, dictionary):
|
73 |
+
super().__init__(args, dictionary)
|
74 |
+
self.register_buffer("version", torch.tensor(2))
|
75 |
+
|
76 |
+
def build_encoder(self, args, dictionary, embed_tokens):
|
77 |
+
encoder = LinformerTransformerEncoder(args, dictionary, embed_tokens)
|
78 |
+
encoder.apply(init_bert_params)
|
79 |
+
return encoder
|
80 |
+
|
81 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
82 |
+
super().upgrade_state_dict_named(state_dict, name)
|
83 |
+
prefix = name + "." if name != "" else ""
|
84 |
+
|
85 |
+
# some old checkpoints had weight sharing implemented incorrectly
|
86 |
+
# (note: this was correct in the original paper code)
|
87 |
+
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2:
|
88 |
+
state_dict[f"{prefix}version"] = torch.tensor(1)
|
89 |
+
# check if input embeddings and output embeddings were tied
|
90 |
+
if not torch.allclose(
|
91 |
+
state_dict[f"{prefix}sentence_encoder.embed_tokens.weight"],
|
92 |
+
state_dict[f"{prefix}lm_head.weight"],
|
93 |
+
):
|
94 |
+
# they weren't tied, re-init the LM head without weight sharing
|
95 |
+
self.lm_head = self.build_lm_head(
|
96 |
+
embed_dim=self.args.encoder_embed_dim,
|
97 |
+
output_dim=len(self.dictionary),
|
98 |
+
activation_fn=self.args.activation_fn,
|
99 |
+
weight=None, # don't share weights
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta")
|
104 |
+
def base_architecture(args):
|
105 |
+
args.compressed = getattr(args, "compressed", 4)
|
106 |
+
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
107 |
+
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
108 |
+
args.freeze_compress = getattr(args, "freeze_compress", 0)
|
109 |
+
roberta_base_architecture(args)
|
110 |
+
|
111 |
+
|
112 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
|
113 |
+
def linformer_roberta_base_architecture(args):
|
114 |
+
base_architecture(args)
|
115 |
+
|
116 |
+
|
117 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
|
118 |
+
def linformer_roberta_large_architecture(args):
|
119 |
+
roberta_large_architecture(args)
|
120 |
+
base_architecture(args)
|
fairseq/examples/linformer/linformer_src/modules/__init__.py
ADDED
File without changes
|
fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
|
8 |
+
import torch.nn as nn
|
9 |
+
from fairseq.models.transformer import TransformerEncoder
|
10 |
+
|
11 |
+
from .linformer_sentence_encoder_layer import LinformerTransformerEncoderLayer
|
12 |
+
|
13 |
+
|
14 |
+
class LinformerTransformerEncoder(TransformerEncoder):
|
15 |
+
"""
|
16 |
+
Implementation for a Bi-directional Linformer based Sentence Encoder used
|
17 |
+
in BERT/XLM style pre-trained models.
|
18 |
+
|
19 |
+
This first computes the token embedding using the token embedding matrix,
|
20 |
+
position embeddings (if specified) and segment embeddings
|
21 |
+
(if specified). After applying the specified number of
|
22 |
+
LinformerEncoderLayers, it outputs all the internal states of the
|
23 |
+
encoder as well as the final representation associated with the first
|
24 |
+
token (usually CLS token).
|
25 |
+
|
26 |
+
Input:
|
27 |
+
- tokens: B x T matrix representing sentences
|
28 |
+
- segment_labels: B x T matrix representing segment label for tokens
|
29 |
+
|
30 |
+
Output:
|
31 |
+
- a tuple of the following:
|
32 |
+
- a list of internal model states used to compute the
|
33 |
+
predictions where each tensor has shape T x B x C
|
34 |
+
- sentence representation associated with first input token
|
35 |
+
in format B x C.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, args, dictionary, embed_tokens):
|
39 |
+
self.compress_layer = None
|
40 |
+
super().__init__(args, dictionary, embed_tokens)
|
41 |
+
|
42 |
+
def build_encoder_layer(self, args):
|
43 |
+
if self.args.shared_layer_kv_compressed == 1 and self.compress_layer is None:
|
44 |
+
compress_layer = nn.Linear(
|
45 |
+
self.args.max_positions,
|
46 |
+
self.args.max_positions // self.args.compressed,
|
47 |
+
)
|
48 |
+
# intialize parameters for compressed layer
|
49 |
+
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
|
50 |
+
if self.args.freeze_compress == 1:
|
51 |
+
compress_layer.weight.requires_grad = False
|
52 |
+
self.compress_layer = compress_layer
|
53 |
+
|
54 |
+
return LinformerTransformerEncoderLayer(args, self.compress_layer)
|
fairseq/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from fairseq import utils
|
8 |
+
from fairseq.modules import TransformerEncoderLayer
|
9 |
+
|
10 |
+
from .multihead_linear_attention import MultiheadLinearAttention
|
11 |
+
|
12 |
+
|
13 |
+
class LinformerTransformerEncoderLayer(TransformerEncoderLayer):
|
14 |
+
"""
|
15 |
+
Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained
|
16 |
+
models.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, args, shared_compress_layer):
|
20 |
+
# wrap in a list so it's not automatically registered by PyTorch
|
21 |
+
self.shared_compress_layer = [shared_compress_layer]
|
22 |
+
|
23 |
+
super().__init__(args)
|
24 |
+
|
25 |
+
self.register_buffer("version", torch.tensor(2))
|
26 |
+
|
27 |
+
def build_self_attention(self, embed_dim, args):
|
28 |
+
return MultiheadLinearAttention(
|
29 |
+
embed_dim,
|
30 |
+
args.encoder_attention_heads,
|
31 |
+
dropout=args.dropout,
|
32 |
+
self_attention=True,
|
33 |
+
q_noise=args.quant_noise_pq,
|
34 |
+
qn_block_size=args.quant_noise_pq_block_size,
|
35 |
+
compressed=args.compressed,
|
36 |
+
max_seq_len=args.max_positions,
|
37 |
+
shared_kv_compressed=args.shared_kv_compressed,
|
38 |
+
shared_compress_layer=self.shared_compress_layer[0],
|
39 |
+
freeze_compress=args.freeze_compress,
|
40 |
+
)
|
41 |
+
|
42 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
43 |
+
super().upgrade_state_dict_named(state_dict, name)
|
44 |
+
prefix = name + "." if name != "" else ""
|
45 |
+
|
46 |
+
# some old checkpoints had weight sharing implemented incorrectly
|
47 |
+
# (note: this was correct in the original paper code)
|
48 |
+
if utils.item(state_dict.get(f"{prefix}version", torch.tensor(1))) < 2:
|
49 |
+
state_dict[f"{prefix}version"] = torch.tensor(1)
|
50 |
+
# check compression layer sharing
|
51 |
+
if f"{prefix}shared_compress_layer.weight" in state_dict:
|
52 |
+
# reinitialize block without sharing compression layer to match
|
53 |
+
# old behavior
|
54 |
+
self.shared_compress_layer = [
|
55 |
+
torch.nn.Linear(
|
56 |
+
self.shared_compress_layer[0].weight.size(1),
|
57 |
+
self.shared_compress_layer[0].weight.size(0),
|
58 |
+
)
|
59 |
+
]
|
60 |
+
self.self_attn = self.build_self_attention(self.embed_dim, self.args)
|
61 |
+
# delete shared_compress_layer, since it's already copied to
|
62 |
+
# self_attn.compress_k.weight
|
63 |
+
del state_dict[f"{prefix}shared_compress_layer.weight"]
|
64 |
+
if f"{prefix}shared_compress_layer.bias" in state_dict:
|
65 |
+
del state_dict[f"{prefix}shared_compress_layer.bias"]
|
fairseq/examples/linformer/linformer_src/modules/multihead_linear_attention.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
from typing import Dict, Optional, Tuple
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from fairseq import utils
|
12 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
13 |
+
from fairseq.modules.quant_noise import quant_noise
|
14 |
+
from torch import Tensor, nn
|
15 |
+
from torch.nn import Parameter
|
16 |
+
|
17 |
+
|
18 |
+
@with_incremental_state
|
19 |
+
class MultiheadLinearAttention(nn.Module):
|
20 |
+
"""Multi-headed linformer attention.
|
21 |
+
|
22 |
+
Projects the key and values down to the compressed dimension, before computing self-attention.
|
23 |
+
|
24 |
+
See "Linformer: Self-Attention with Linear Complexity" for more details.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
embed_dim,
|
30 |
+
num_heads,
|
31 |
+
kdim=None,
|
32 |
+
vdim=None,
|
33 |
+
dropout=0.0,
|
34 |
+
bias=True,
|
35 |
+
add_bias_kv=False,
|
36 |
+
add_zero_attn=False,
|
37 |
+
self_attention=False,
|
38 |
+
encoder_decoder_attention=False,
|
39 |
+
q_noise=0.0,
|
40 |
+
qn_block_size=8,
|
41 |
+
compressed=1,
|
42 |
+
max_seq_len=256,
|
43 |
+
shared_kv_compressed=0,
|
44 |
+
shared_compress_layer=None,
|
45 |
+
freeze_compress=0,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.embed_dim = embed_dim
|
49 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
50 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
51 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
52 |
+
|
53 |
+
self.num_heads = num_heads
|
54 |
+
self.dropout = dropout
|
55 |
+
self.head_dim = embed_dim // num_heads
|
56 |
+
assert (
|
57 |
+
self.head_dim * num_heads == self.embed_dim
|
58 |
+
), "embed_dim must be divisible by num_heads"
|
59 |
+
self.scaling = self.head_dim ** -0.5
|
60 |
+
|
61 |
+
self.self_attention = self_attention
|
62 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
63 |
+
|
64 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
65 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
66 |
+
)
|
67 |
+
|
68 |
+
self.k_proj = quant_noise(
|
69 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
70 |
+
)
|
71 |
+
self.v_proj = quant_noise(
|
72 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
73 |
+
)
|
74 |
+
self.q_proj = quant_noise(
|
75 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
76 |
+
)
|
77 |
+
|
78 |
+
# used for compress sequence to subsequence
|
79 |
+
if shared_compress_layer is None:
|
80 |
+
self.compress_seq_len = max_seq_len // compressed
|
81 |
+
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
82 |
+
if shared_kv_compressed == 0:
|
83 |
+
self.compress_v = nn.Linear(
|
84 |
+
max_seq_len, self.compress_seq_len, bias=False
|
85 |
+
)
|
86 |
+
self.layerwise_sharing = False
|
87 |
+
else:
|
88 |
+
self.compress_k = shared_compress_layer
|
89 |
+
if shared_kv_compressed == 0:
|
90 |
+
self.compress_v = shared_compress_layer
|
91 |
+
self.layerwise_sharing = True
|
92 |
+
self.shared_kv_compressed = shared_kv_compressed
|
93 |
+
|
94 |
+
self.out_proj = quant_noise(
|
95 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
96 |
+
)
|
97 |
+
|
98 |
+
if add_bias_kv:
|
99 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
100 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
101 |
+
else:
|
102 |
+
self.bias_k = self.bias_v = None
|
103 |
+
|
104 |
+
self.add_zero_attn = add_zero_attn
|
105 |
+
|
106 |
+
self.reset_parameters()
|
107 |
+
|
108 |
+
if freeze_compress == 1:
|
109 |
+
self.compress_k.weight.requires_grad = False
|
110 |
+
if shared_kv_compressed == 0:
|
111 |
+
self.compress_v.weight.requires_grad = False
|
112 |
+
|
113 |
+
self.onnx_trace = False
|
114 |
+
|
115 |
+
def prepare_for_onnx_export_(self):
|
116 |
+
self.onnx_trace = True
|
117 |
+
|
118 |
+
def reset_parameters(self):
|
119 |
+
if self.qkv_same_dim:
|
120 |
+
# Empirically observed the convergence to be much better with
|
121 |
+
# the scaled initialization
|
122 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
123 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
124 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
125 |
+
if (
|
126 |
+
not self.layerwise_sharing
|
127 |
+
): # otherwise, we already initialize the parameters
|
128 |
+
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
|
129 |
+
if self.shared_kv_compressed == 0:
|
130 |
+
nn.init.xavier_uniform_(
|
131 |
+
self.compress_v.weight, gain=1 / math.sqrt(2)
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
135 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
136 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
137 |
+
if (
|
138 |
+
not self.layerwise_sharing
|
139 |
+
): # otherwise, we already initialize the parameters
|
140 |
+
nn.init.xavier_uniform_(self.compress_k.weight)
|
141 |
+
if self.shared_kv_compressed == 0:
|
142 |
+
nn.init.xavier_uniform_(self.compress_v.weight)
|
143 |
+
|
144 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
145 |
+
if self.out_proj.bias is not None:
|
146 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
147 |
+
if self.bias_k is not None:
|
148 |
+
nn.init.xavier_normal_(self.bias_k)
|
149 |
+
if self.bias_v is not None:
|
150 |
+
nn.init.xavier_normal_(self.bias_v)
|
151 |
+
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
query,
|
155 |
+
key: Optional[Tensor],
|
156 |
+
value: Optional[Tensor],
|
157 |
+
key_padding_mask: Optional[Tensor] = None,
|
158 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
159 |
+
need_weights: bool = True,
|
160 |
+
static_kv: bool = False,
|
161 |
+
attn_mask: Optional[Tensor] = None,
|
162 |
+
before_softmax: bool = False,
|
163 |
+
need_head_weights: bool = False,
|
164 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
165 |
+
"""Input shape: Time x Batch x Channel
|
166 |
+
|
167 |
+
Args:
|
168 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
169 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
170 |
+
padding elements are indicated by 1s.
|
171 |
+
need_weights (bool, optional): return the attention weights,
|
172 |
+
averaged over heads (default: False).
|
173 |
+
attn_mask (ByteTensor, optional): typically used to
|
174 |
+
implement causal attention, where the mask prevents the
|
175 |
+
attention from looking forward in time (default: None).
|
176 |
+
before_softmax (bool, optional): return the raw attention
|
177 |
+
weights and values before the attention softmax.
|
178 |
+
need_head_weights (bool, optional): return the attention
|
179 |
+
weights for each head. Implies *need_weights*. Default:
|
180 |
+
return the average attention weights over all heads.
|
181 |
+
"""
|
182 |
+
if need_head_weights:
|
183 |
+
need_weights = True
|
184 |
+
|
185 |
+
tgt_len, bsz, embed_dim = query.size()
|
186 |
+
assert embed_dim == self.embed_dim
|
187 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
188 |
+
|
189 |
+
if incremental_state is not None:
|
190 |
+
saved_state = self._get_input_buffer(incremental_state)
|
191 |
+
if saved_state is not None and "prev_key" in saved_state:
|
192 |
+
# previous time steps are cached - no need to recompute
|
193 |
+
# key and value if they are static
|
194 |
+
if static_kv:
|
195 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
196 |
+
key = value = None
|
197 |
+
else:
|
198 |
+
saved_state = None
|
199 |
+
|
200 |
+
if self.self_attention:
|
201 |
+
q = self.q_proj(query)
|
202 |
+
|
203 |
+
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
204 |
+
k_input = (
|
205 |
+
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
|
206 |
+
.permute(2, 0, 1)
|
207 |
+
.contiguous()
|
208 |
+
)
|
209 |
+
k = self.k_proj(k_input)
|
210 |
+
|
211 |
+
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
212 |
+
if self.shared_kv_compressed == 0:
|
213 |
+
v_input = (
|
214 |
+
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
|
215 |
+
.permute(2, 0, 1)
|
216 |
+
.contiguous()
|
217 |
+
)
|
218 |
+
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
|
219 |
+
v_input = (
|
220 |
+
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
|
221 |
+
.permute(2, 0, 1)
|
222 |
+
.contiguous()
|
223 |
+
)
|
224 |
+
v = self.v_proj(v_input)
|
225 |
+
elif self.encoder_decoder_attention:
|
226 |
+
# encoder-decoder attention
|
227 |
+
q = self.q_proj(query)
|
228 |
+
if key is None:
|
229 |
+
assert value is None
|
230 |
+
k = v = None
|
231 |
+
else:
|
232 |
+
k = self.k_proj(key)
|
233 |
+
v = self.v_proj(key)
|
234 |
+
|
235 |
+
else:
|
236 |
+
assert key is not None and value is not None
|
237 |
+
q = self.q_proj(query)
|
238 |
+
k = self.k_proj(key)
|
239 |
+
v = self.v_proj(value)
|
240 |
+
q *= self.scaling
|
241 |
+
|
242 |
+
if self.bias_k is not None:
|
243 |
+
assert self.bias_v is not None
|
244 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
245 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
246 |
+
if attn_mask is not None:
|
247 |
+
attn_mask = torch.cat(
|
248 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
249 |
+
)
|
250 |
+
if key_padding_mask is not None:
|
251 |
+
key_padding_mask = torch.cat(
|
252 |
+
[
|
253 |
+
key_padding_mask,
|
254 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
255 |
+
],
|
256 |
+
dim=1,
|
257 |
+
)
|
258 |
+
|
259 |
+
q = (
|
260 |
+
q.contiguous()
|
261 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
262 |
+
.transpose(0, 1)
|
263 |
+
)
|
264 |
+
if k is not None:
|
265 |
+
k = (
|
266 |
+
k.contiguous()
|
267 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
268 |
+
.transpose(0, 1)
|
269 |
+
)
|
270 |
+
if v is not None:
|
271 |
+
v = (
|
272 |
+
v.contiguous()
|
273 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
274 |
+
.transpose(0, 1)
|
275 |
+
)
|
276 |
+
|
277 |
+
if saved_state is not None:
|
278 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
279 |
+
if "prev_key" in saved_state:
|
280 |
+
_prev_key = saved_state["prev_key"]
|
281 |
+
assert _prev_key is not None
|
282 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
283 |
+
if static_kv:
|
284 |
+
k = prev_key
|
285 |
+
else:
|
286 |
+
assert k is not None
|
287 |
+
k = torch.cat([prev_key, k], dim=1)
|
288 |
+
if "prev_value" in saved_state:
|
289 |
+
_prev_value = saved_state["prev_value"]
|
290 |
+
assert _prev_value is not None
|
291 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
292 |
+
if static_kv:
|
293 |
+
v = prev_value
|
294 |
+
else:
|
295 |
+
assert v is not None
|
296 |
+
v = torch.cat([prev_value, v], dim=1)
|
297 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
298 |
+
if "prev_key_padding_mask" in saved_state:
|
299 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
300 |
+
assert k is not None and v is not None
|
301 |
+
key_padding_mask = MultiheadLinearAttention._append_prev_key_padding_mask(
|
302 |
+
key_padding_mask=key_padding_mask,
|
303 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
304 |
+
batch_size=bsz,
|
305 |
+
src_len=k.size(1),
|
306 |
+
static_kv=static_kv,
|
307 |
+
)
|
308 |
+
|
309 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
310 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
311 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
312 |
+
# In this branch incremental_state is never None
|
313 |
+
assert incremental_state is not None
|
314 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
315 |
+
assert k is not None
|
316 |
+
src_len = k.size(1)
|
317 |
+
|
318 |
+
if self.add_zero_attn:
|
319 |
+
assert v is not None
|
320 |
+
src_len += 1
|
321 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
322 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
323 |
+
if attn_mask is not None:
|
324 |
+
attn_mask = torch.cat(
|
325 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
326 |
+
)
|
327 |
+
|
328 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
329 |
+
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
|
330 |
+
attn_weights, tgt_len, src_len, bsz
|
331 |
+
)
|
332 |
+
|
333 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
334 |
+
|
335 |
+
if attn_mask is not None:
|
336 |
+
attn_mask = attn_mask.unsqueeze(0)
|
337 |
+
if self.onnx_trace:
|
338 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
339 |
+
attn_weights += attn_mask
|
340 |
+
|
341 |
+
if before_softmax:
|
342 |
+
return attn_weights, v
|
343 |
+
|
344 |
+
attn_weights_float = utils.softmax(
|
345 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
346 |
+
)
|
347 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
348 |
+
attn_probs = F.dropout(
|
349 |
+
attn_weights,
|
350 |
+
p=self.dropout,
|
351 |
+
training=self.training,
|
352 |
+
)
|
353 |
+
assert v is not None
|
354 |
+
attn = torch.bmm(attn_probs, v)
|
355 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
356 |
+
if self.onnx_trace and attn.size(1) == 1:
|
357 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
358 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
359 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
360 |
+
else:
|
361 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
362 |
+
attn = self.out_proj(attn)
|
363 |
+
attn_weights: Optional[Tensor] = None
|
364 |
+
if need_weights:
|
365 |
+
attn_weights = attn_weights_float.view(
|
366 |
+
bsz, self.num_heads, tgt_len, src_len
|
367 |
+
).transpose(1, 0)
|
368 |
+
if not need_head_weights:
|
369 |
+
# average attention weights over heads
|
370 |
+
attn_weights = attn_weights.mean(dim=0)
|
371 |
+
|
372 |
+
return attn, attn_weights
|
373 |
+
|
374 |
+
@staticmethod
|
375 |
+
def _append_prev_key_padding_mask(
|
376 |
+
key_padding_mask: Optional[Tensor],
|
377 |
+
prev_key_padding_mask: Optional[Tensor],
|
378 |
+
batch_size: int,
|
379 |
+
src_len: int,
|
380 |
+
static_kv: bool,
|
381 |
+
) -> Optional[Tensor]:
|
382 |
+
# saved key padding masks have shape (bsz, seq_len)
|
383 |
+
if prev_key_padding_mask is not None and static_kv:
|
384 |
+
new_key_padding_mask = prev_key_padding_mask
|
385 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
386 |
+
new_key_padding_mask = torch.cat(
|
387 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
388 |
+
)
|
389 |
+
# During incremental decoding, as the padding token enters and
|
390 |
+
# leaves the frame, there will be a time when prev or current
|
391 |
+
# is None
|
392 |
+
elif prev_key_padding_mask is not None:
|
393 |
+
filler = torch.zeros(
|
394 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
395 |
+
device=prev_key_padding_mask.device,
|
396 |
+
)
|
397 |
+
new_key_padding_mask = torch.cat(
|
398 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
399 |
+
)
|
400 |
+
elif key_padding_mask is not None:
|
401 |
+
filler = torch.zeros(
|
402 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
403 |
+
device=key_padding_mask.device,
|
404 |
+
)
|
405 |
+
new_key_padding_mask = torch.cat(
|
406 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
new_key_padding_mask = prev_key_padding_mask
|
410 |
+
return new_key_padding_mask
|
411 |
+
|
412 |
+
@torch.jit.export
|
413 |
+
def reorder_incremental_state(
|
414 |
+
self,
|
415 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
416 |
+
new_order: Tensor,
|
417 |
+
):
|
418 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
419 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
420 |
+
if input_buffer is not None:
|
421 |
+
for k in input_buffer.keys():
|
422 |
+
input_buffer_k = input_buffer[k]
|
423 |
+
if input_buffer_k is not None:
|
424 |
+
if self.encoder_decoder_attention and input_buffer_k.size(
|
425 |
+
0
|
426 |
+
) == new_order.size(0):
|
427 |
+
break
|
428 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
429 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
430 |
+
return incremental_state
|
431 |
+
|
432 |
+
def _get_input_buffer(
|
433 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
434 |
+
) -> Dict[str, Optional[Tensor]]:
|
435 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
436 |
+
if result is not None:
|
437 |
+
return result
|
438 |
+
else:
|
439 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
440 |
+
return empty_result
|
441 |
+
|
442 |
+
def _set_input_buffer(
|
443 |
+
self,
|
444 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
445 |
+
buffer: Dict[str, Optional[Tensor]],
|
446 |
+
):
|
447 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
448 |
+
|
449 |
+
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
|
450 |
+
return attn_weights
|
451 |
+
|
452 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
453 |
+
prefix = name + "." if name != "" else ""
|
454 |
+
items_to_add = {}
|
455 |
+
keys_to_remove = []
|
456 |
+
for k in state_dict.keys():
|
457 |
+
if k.endswith(prefix + "in_proj_weight"):
|
458 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
459 |
+
dim = int(state_dict[k].shape[0] / 3)
|
460 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
461 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
462 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
463 |
+
|
464 |
+
keys_to_remove.append(k)
|
465 |
+
|
466 |
+
k_bias = prefix + "in_proj_bias"
|
467 |
+
if k_bias in state_dict.keys():
|
468 |
+
dim = int(state_dict[k].shape[0] / 3)
|
469 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
470 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
471 |
+
dim : 2 * dim
|
472 |
+
]
|
473 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
474 |
+
|
475 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
476 |
+
|
477 |
+
for k in keys_to_remove:
|
478 |
+
del state_dict[k]
|
479 |
+
|
480 |
+
for key, value in items_to_add.items():
|
481 |
+
state_dict[key] = value
|
fairseq/examples/m2m_100/README.md
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Beyond English-Centric Multilingual Machine Translation
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT.
|
5 |
+
|
6 |
+
If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below.
|
7 |
+
|
8 |
+
0. **Generation Data**
|
9 |
+
|
10 |
+
To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers.
|
11 |
+
```bash
|
12 |
+
# WMT - use sacrebleu, example here:
|
13 |
+
sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr
|
14 |
+
sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en
|
15 |
+
|
16 |
+
# WAT
|
17 |
+
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip
|
18 |
+
unzip wat2020.my-en.zip
|
19 |
+
|
20 |
+
# FLORES
|
21 |
+
# download from: https://github.com/facebookresearch/flores
|
22 |
+
|
23 |
+
# TED - need to detokenize with Moses!
|
24 |
+
# from: https://github.com/neulab/word-embeddings-for-nmt
|
25 |
+
wget http://phontron.com/data/ted_talks.tar.gz
|
26 |
+
|
27 |
+
# Autshumato
|
28 |
+
# request to download: https://repo.sadilar.org/handle/20.500.12185/397
|
29 |
+
|
30 |
+
# Tatoeba Challenge
|
31 |
+
# available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge
|
32 |
+
```
|
33 |
+
|
34 |
+
1. **Training Data**
|
35 |
+
|
36 |
+
To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data.
|
37 |
+
|
38 |
+
2. **Preprocess Data**
|
39 |
+
|
40 |
+
After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data.
|
41 |
+
|
42 |
+
```bash
|
43 |
+
# preprocess data
|
44 |
+
|
45 |
+
# remove sentences with more than 50% punctuation
|
46 |
+
python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
|
47 |
+
|
48 |
+
# deduplicate training data
|
49 |
+
paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup
|
50 |
+
echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)"
|
51 |
+
cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src
|
52 |
+
cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt
|
53 |
+
|
54 |
+
# remove all instances of evaluation data from the training data
|
55 |
+
python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py
|
56 |
+
|
57 |
+
# frequency cleaning
|
58 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz
|
59 |
+
tar -xvzf histograms.tar.gz
|
60 |
+
python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms
|
61 |
+
|
62 |
+
# apply SPM
|
63 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
64 |
+
python /path/to/fairseq/scripts/spm_encode.py \
|
65 |
+
--model spm.128k.model \
|
66 |
+
--output_format=piece \
|
67 |
+
--inputs=/path/to/input/file/here \
|
68 |
+
--outputs=/path/to/output/file/here
|
69 |
+
|
70 |
+
# length ratio cleaning
|
71 |
+
perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250
|
72 |
+
|
73 |
+
# binarize data
|
74 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
|
75 |
+
fairseq-preprocess \
|
76 |
+
--source-lang $src --target-lang $tgt \
|
77 |
+
--testpref spm.$src.$tgt \
|
78 |
+
--thresholdsrc 0 --thresholdtgt 0 \
|
79 |
+
--destdir data_bin \
|
80 |
+
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
|
81 |
+
```
|
82 |
+
|
83 |
+
3. **Training Scripts**
|
84 |
+
|
85 |
+
To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/main/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale).
|
86 |
+
|
87 |
+
4. **Generation**
|
88 |
+
|
89 |
+
To generate from our models, follow the the commands in the generation section below.
|
90 |
+
|
91 |
+
|
92 |
+
If you use any of the resources listed here, please cite:
|
93 |
+
```bibtex
|
94 |
+
@article{fan2020beyond,
|
95 |
+
title={Beyond English-Centric Multilingual Machine Translation},
|
96 |
+
author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand},
|
97 |
+
journal={arXiv preprint},
|
98 |
+
year={2020}
|
99 |
+
}
|
100 |
+
|
101 |
+
@article{schwenk2019ccmatrix,
|
102 |
+
title={Ccmatrix: Mining billions of high-quality parallel sentences on the web},
|
103 |
+
author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand},
|
104 |
+
journal={arXiv preprint arXiv:1911.04944},
|
105 |
+
year={2019}
|
106 |
+
}
|
107 |
+
|
108 |
+
@article{el2019massive,
|
109 |
+
title={A Massive Collection of Cross-Lingual Web-Document Pairs},
|
110 |
+
author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp},
|
111 |
+
journal={arXiv preprint arXiv:1911.06154},
|
112 |
+
year={2019}
|
113 |
+
}
|
114 |
+
```
|
115 |
+
|
116 |
+
|
117 |
+
## Trained Models
|
118 |
+
|
119 |
+
### 418M and 1.2B Model
|
120 |
+
We include the last checkpoint for both of these models.
|
121 |
+
|
122 |
+
```bash
|
123 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
|
124 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs_small_models.txt
|
125 |
+
|
126 |
+
# 418M parameter model
|
127 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/418M_last_checkpoint.pt
|
128 |
+
|
129 |
+
# 1.2B parameter model
|
130 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/1.2B_last_checkpoint.pt
|
131 |
+
|
132 |
+
# Generation:
|
133 |
+
fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model --fixed-dictionary model_dict.128k.txt -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models.txt --decoder-langtok --encoder-langtok src --gen-subset test > gen_out
|
134 |
+
```
|
135 |
+
|
136 |
+
### 12B Model
|
137 |
+
12B parameter model trained on many-to-many training data for 100 languages. We include the last checkpoint, average of last 5 checkpoints, average of last 10 checkpoints. There isn't a universally best choice out of these three, but all three versions are pretty close in accuracy. You can either sweep over the 3 checkpoints on a dev test and use the best performing checkpoint for final testing. Or the last checkpoint can be a good default choice.
|
138 |
+
|
139 |
+
**Model Download Links**
|
140 |
+
Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs
|
141 |
+
:--|:--|:--|:--|:--
|
142 |
+
Last Checkpoint | [12b_last_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_2_gpus.pt) | [12b_last_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt) | [12b_last_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_6_gpus.pt) | [12b_last_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_8_gpus.pt)
|
143 |
+
Average of last 5 checkpoints | [12b_avg5_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_2_gpus.pt) | [12b_avg5_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_4_gpus.pt) | [12b_avg5_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_6_gpus.pt) | [12b_avg5_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg5_chk_8_gpus.pt)
|
144 |
+
Average of last 10 checkpoints | [12b_avg10_chk_2_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_2_gpus.pt) | [12b_avg10_chk_4_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_4_gpus.pt) | [12b_avg10_chk_6_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_6_gpus.pt) | [12b_avg10_chk_8_gpus.pt](https://dl.fbaipublicfiles.com/m2m_100/12b_avg10_chk_8_gpus.pt)
|
145 |
+
|
146 |
+
**Generation Arguments**
|
147 |
+
Configuration | 2 32GB GPUs | 4 16GB GPUs | 6 12GB GPUs | 8 8GB GPUs
|
148 |
+
:--|:--|:--|:--|:--
|
149 |
+
`--pipeline-encoder-balance` | `[26]` | `[1,15,10]` | `[1,9,9,7]` | `[1,6,6,6,7]`
|
150 |
+
`--pipeline-encoder-devices` | `[0]` | `[0,1,0]` | `[0,1,2,0]` | `[0,4,5,1,0]`
|
151 |
+
`--pipeline-decoder-balance` | `[3,22,1]` | `[3,11,11,1]` | `[3,7,7,8,1]` | `[1,6,6,6,6,1]`
|
152 |
+
`--pipeline-decoder-devices` | `[0,1,0]` | `[0,2,3,0]` | `[0,3,4,5,0]` | `[0,2,6,7,3,0]`
|
153 |
+
|
154 |
+
|
155 |
+
## SentencePiece Model
|
156 |
+
|
157 |
+
```bash
|
158 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
159 |
+
```
|
160 |
+
|
161 |
+
## Generation with M2M-100
|
162 |
+
|
163 |
+
### Encode using our SentencePiece Model
|
164 |
+
|
165 |
+
Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
|
166 |
+
|
167 |
+
```bash
|
168 |
+
fairseq=/path/to/fairseq
|
169 |
+
cd $fairseq
|
170 |
+
sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
|
171 |
+
sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
|
172 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
173 |
+
for lang in de fr ; do
|
174 |
+
python scripts/spm_encode.py \
|
175 |
+
--model spm.128k.model \
|
176 |
+
--output_format=piece \
|
177 |
+
--inputs=raw_input.de-fr.${lang} \
|
178 |
+
--outputs=spm.de-fr.${lang}
|
179 |
+
done
|
180 |
+
```
|
181 |
+
|
182 |
+
### Binarization
|
183 |
+
|
184 |
+
```bash
|
185 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
|
186 |
+
fairseq-preprocess \
|
187 |
+
--source-lang de --target-lang fr \
|
188 |
+
--testpref spm.de-fr \
|
189 |
+
--thresholdsrc 0 --thresholdtgt 0 \
|
190 |
+
--destdir data_bin \
|
191 |
+
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
|
192 |
+
```
|
193 |
+
|
194 |
+
### Generation for the 12B model
|
195 |
+
|
196 |
+
Note that generation can currently be run using 2 32GB / 4 16GB / 6 12GB / 8 8GB GPUs, and the corresponding model checkpoints and pipeline arguments can be found in the [12B Model Section](#12b-model).
|
197 |
+
Generation on CPUs will be added in the future.
|
198 |
+
|
199 |
+
```bash
|
200 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
|
201 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt
|
202 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_chk_4_gpus.pt
|
203 |
+
fairseq-generate \
|
204 |
+
data_bin \
|
205 |
+
--batch-size 1 \
|
206 |
+
--path 12b_last_chk_4_gpus.pt \
|
207 |
+
--fixed-dictionary model_dict.128k.txt \
|
208 |
+
-s de -t fr \
|
209 |
+
--remove-bpe 'sentencepiece' \
|
210 |
+
--beam 5 \
|
211 |
+
--task translation_multi_simple_epoch \
|
212 |
+
--lang-pairs language_pairs.txt \
|
213 |
+
--decoder-langtok --encoder-langtok src \
|
214 |
+
--gen-subset test \
|
215 |
+
--fp16 \
|
216 |
+
--dataset-impl mmap \
|
217 |
+
--distributed-world-size 1 --distributed-no-spawn \
|
218 |
+
--pipeline-model-parallel \
|
219 |
+
--pipeline-chunks 1 \
|
220 |
+
--pipeline-encoder-balance '[1,15,10]' \
|
221 |
+
--pipeline-encoder-devices '[0,1,0]' \
|
222 |
+
--pipeline-decoder-balance '[3,11,11,1]' \
|
223 |
+
--pipeline-decoder-devices '[0,2,3,0]' > gen_out
|
224 |
+
```
|
225 |
+
## Evaluation with M2M-100
|
226 |
+
|
227 |
+
### Tokenization
|
228 |
+
|
229 |
+
Note: Refer to tokenizers/README.md for more details on tokenization.
|
230 |
+
|
231 |
+
```bash
|
232 |
+
cd ${fairseq}/examples/m2m_100
|
233 |
+
cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp
|
234 |
+
cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref
|
235 |
+
```
|
236 |
+
|
237 |
+
### BLEU
|
238 |
+
|
239 |
+
```bash
|
240 |
+
sacrebleu -tok 'none' ref < hyp
|
241 |
+
```
|
fairseq/examples/m2m_100/install_dependecies.sh
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
CWD=`pwd`
|
9 |
+
INSTALL_PATH=$CWD/tokenizers/thirdparty
|
10 |
+
|
11 |
+
MOSES=$INSTALL_PATH/mosesdecoder
|
12 |
+
if [ ! -d $MOSES ]; then
|
13 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
14 |
+
git clone https://github.com/moses-smt/mosesdecoder.git $MOSES
|
15 |
+
cd $MOSES
|
16 |
+
# To deal with differences in handling ' vs "
|
17 |
+
git checkout 03578921cc1a03402
|
18 |
+
cd -
|
19 |
+
fi
|
20 |
+
|
21 |
+
WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
|
22 |
+
if [ ! -d $WMT16_SCRIPTS ]; then
|
23 |
+
echo 'Cloning Romanian tokenization scripts'
|
24 |
+
git clone https://github.com/rsennrich/wmt16-scripts.git $WMT16_SCRIPTS
|
25 |
+
fi
|
26 |
+
|
27 |
+
KYTEA=$INSTALL_PATH/kytea
|
28 |
+
if [ ! -f $KYTEA/bin/kytea ]; then
|
29 |
+
git clone https://github.com/neubig/kytea.git $KYTEA
|
30 |
+
cd $KYTEA
|
31 |
+
autoreconf -i
|
32 |
+
./configure --prefix=`pwd`
|
33 |
+
make
|
34 |
+
make install
|
35 |
+
cd ..
|
36 |
+
fi
|
37 |
+
|
38 |
+
export MECAB=$INSTALL_PATH/mecab-0.996-ko-0.9.2
|
39 |
+
if [ ! -f $MECAB/bin/mecab ]; then
|
40 |
+
cd $INSTALL_PATH
|
41 |
+
curl -LO https://bitbucket.org/eunjeon/mecab-ko/downloads/mecab-0.996-ko-0.9.2.tar.gz
|
42 |
+
tar zxfv mecab-0.996-ko-0.9.2.tar.gz
|
43 |
+
cd mecab-0.996-ko-0.9.2/
|
44 |
+
./configure --prefix=`pwd`
|
45 |
+
make
|
46 |
+
make install
|
47 |
+
|
48 |
+
cd ..
|
49 |
+
curl -LO https://bitbucket.org/eunjeon/mecab-ko-dic/downloads/mecab-ko-dic-2.1.1-20180720.tar.gz
|
50 |
+
tar zxfv mecab-ko-dic-2.1.1-20180720.tar.gz
|
51 |
+
cd mecab-ko-dic-2.1.1-20180720/
|
52 |
+
./autogen.sh
|
53 |
+
./configure --prefix=`pwd` --with-dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic --with-mecab-config=$MECAB/bin/mecab-config
|
54 |
+
make
|
55 |
+
sh -c 'echo "dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic" > $MECAB/etc/mecabrc'
|
56 |
+
make install
|
57 |
+
cd $CWD
|
58 |
+
fi
|
59 |
+
|
60 |
+
INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
|
61 |
+
if [ ! -d $INDIC_RESOURCES_PATH ]; then
|
62 |
+
echo 'Cloning indic_nlp_resources'
|
63 |
+
git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git $INDIC_RESOURCES_PATH
|
64 |
+
fi
|
65 |
+
|
66 |
+
|
67 |
+
if [ ! -f $INSTALL_PATH/seg_my.py ]; then
|
68 |
+
cd $INSTALL_PATH
|
69 |
+
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip
|
70 |
+
unzip wat2020.my-en.zip
|
71 |
+
# switch to python3
|
72 |
+
cat wat2020.my-en/myseg.py |sed 's/^sys.std/###sys.std/g' | sed 's/### sys/sys/g' | sed 's/unichr/chr/g' > seg_my.py
|
73 |
+
cd $CWD
|
74 |
+
fi
|
75 |
+
|
76 |
+
|
77 |
+
pip install pythainlp sacrebleu indic-nlp-library
|
78 |
+
|
fairseq/examples/m2m_100/process_data/clean_histogram.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
parser = argparse.ArgumentParser()
|
4 |
+
parser.add_argument('--src', type=str, help='Source language')
|
5 |
+
parser.add_argument('--tgt', type=str, help='Target language')
|
6 |
+
parser.add_argument('--src-file', type=str, help='Input source file')
|
7 |
+
parser.add_argument('--tgt-file', type=str, help='Input target file')
|
8 |
+
parser.add_argument('--src-output-file', type=str, help='Output source file')
|
9 |
+
parser.add_argument('--tgt-output-file', type=str, help='Output target file')
|
10 |
+
parser.add_argument('--threshold', type=float, default=0.5, help='Threshold')
|
11 |
+
parser.add_argument('--threshold-character', type=str, default=']', help='Threshold character')
|
12 |
+
parser.add_argument('--histograms', type=str, help='Path to histograms')
|
13 |
+
|
14 |
+
args = parser.parse_args()
|
15 |
+
|
16 |
+
|
17 |
+
def read_hist(f):
|
18 |
+
ch = []
|
19 |
+
for line in f:
|
20 |
+
c = line[0]
|
21 |
+
if c == args.threshold_character:
|
22 |
+
break
|
23 |
+
ch.append(c)
|
24 |
+
return ch
|
25 |
+
|
26 |
+
|
27 |
+
with(open("{}/{}".format(args.histograms, args.src), 'r', encoding='utf8')) as f:
|
28 |
+
ch1 = read_hist(f)
|
29 |
+
|
30 |
+
with(open("{}/{}".format(args.histograms, args.tgt), 'r', encoding='utf8')) as f:
|
31 |
+
ch2 = read_hist(f)
|
32 |
+
|
33 |
+
print("Accepted characters for {}: {}".format(args.src, ch1))
|
34 |
+
print("Accepted characters for {}: {}".format(args.tgt, ch2))
|
35 |
+
|
36 |
+
with open(args.src_file, 'r', encoding='utf8') as fs1, open(args.tgt_file, 'r', encoding='utf8') as fs2, open(args.src_output_file, 'w', encoding='utf8') as fos1, open(args.tgt_output_file, 'w', encoding='utf8') as fos2:
|
37 |
+
ls1 = fs1.readline()
|
38 |
+
ls2 = fs2.readline()
|
39 |
+
|
40 |
+
while ls1 or ls2:
|
41 |
+
cnt1 = len([c for c in ls1.strip() if c in ch1])
|
42 |
+
cnt2 = len([c for c in ls2.strip() if c in ch2])
|
43 |
+
|
44 |
+
if cnt1 / len(ls1) > args.threshold and cnt2 / len(ls2) > args.threshold:
|
45 |
+
fos1.write(ls1)
|
46 |
+
fos2.write(ls2)
|
47 |
+
else:
|
48 |
+
print("{} {} {} \n{} {} {}".format(args.src, cnt1 / len(ls1), ls1.strip(), args.tgt, cnt2 / len(ls2), ls2.strip()))
|
49 |
+
|
50 |
+
ls1 = fs1.readline()
|
51 |
+
ls2 = fs2.readline()
|
52 |
+
|
fairseq/examples/m2m_100/process_data/dedup_data.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from collections import namedtuple
|
3 |
+
import os
|
4 |
+
|
5 |
+
DATADIR = "/path/to/train_data"
|
6 |
+
DEDUP_FROM_DIR = "/path/to/eval/data"
|
7 |
+
OUTPUT_DIR = "/path/to/output/data"
|
8 |
+
|
9 |
+
|
10 |
+
def main(args):
|
11 |
+
languages = set()
|
12 |
+
for language_directory in os.listdir(DATADIR):
|
13 |
+
if "_" in language_directory:
|
14 |
+
src, tgt = language_directory.split("_")
|
15 |
+
languages.add(LanguagePair(src=src, tgt=tgt))
|
16 |
+
|
17 |
+
data = existing_data()
|
18 |
+
train_languages = sorted(languages)
|
19 |
+
for language_pair in train_languages[args.start_index:args.start_index + args.size]:
|
20 |
+
print(language_pair)
|
21 |
+
dedup(language_pair, data)
|
22 |
+
|
23 |
+
|
24 |
+
LanguagePair = namedtuple("LanguagePair", ["src", "tgt"])
|
25 |
+
|
26 |
+
|
27 |
+
def existing_data():
|
28 |
+
data = set()
|
29 |
+
for file in os.listdir(DEDUP_FROM_DIR):
|
30 |
+
with open(os.path.join(DEDUP_FROM_DIR, file)) as f:
|
31 |
+
data |= set(f.readlines())
|
32 |
+
return data
|
33 |
+
|
34 |
+
def dedup(language_pair, data, verbose=True, output=True):
|
35 |
+
train_filenames = LanguagePair(
|
36 |
+
src=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.src}",
|
37 |
+
tgt=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.tgt}",
|
38 |
+
)
|
39 |
+
|
40 |
+
output_filenames = LanguagePair(
|
41 |
+
src=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.src}",
|
42 |
+
tgt=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.tgt}"
|
43 |
+
)
|
44 |
+
|
45 |
+
# If output exists, skip this pair. It has already been done.
|
46 |
+
if (os.path.exists(output_filenames.src) and
|
47 |
+
os.path.exists(output_filenames.tgt)):
|
48 |
+
if verbose:
|
49 |
+
print(f"{language_pair.src}-{language_pair.tgt} already done.")
|
50 |
+
return
|
51 |
+
|
52 |
+
if verbose:
|
53 |
+
print(f"{language_pair.src}-{language_pair.tgt} ready, will check dups.")
|
54 |
+
|
55 |
+
# If there is no output, no need to actually do the loop.
|
56 |
+
if not output:
|
57 |
+
return
|
58 |
+
|
59 |
+
if os.path.exists(train_filenames.src) and os.path.exists(train_filenames.tgt):
|
60 |
+
with open(train_filenames.src) as f:
|
61 |
+
train_source = f.readlines()
|
62 |
+
|
63 |
+
with open(train_filenames.tgt) as f:
|
64 |
+
train_target = f.readlines()
|
65 |
+
|
66 |
+
# do dedup
|
67 |
+
new_train_source = []
|
68 |
+
new_train_target = []
|
69 |
+
for i, train_line in enumerate(train_source):
|
70 |
+
if train_line not in data and train_target[i] not in data:
|
71 |
+
new_train_source.append(train_line)
|
72 |
+
new_train_target.append(train_target[i])
|
73 |
+
|
74 |
+
assert len(train_source) == len(train_target)
|
75 |
+
assert len(new_train_source) == len(new_train_target)
|
76 |
+
assert len(new_train_source) <= len(train_source)
|
77 |
+
|
78 |
+
with open(output_filenames.src, "w") as o:
|
79 |
+
for line in new_train_source:
|
80 |
+
o.write(line)
|
81 |
+
|
82 |
+
with open(output_filenames.tgt, "w") as o:
|
83 |
+
for line in new_train_target:
|
84 |
+
o.write(line)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == '__main__':
|
88 |
+
parser = argparse.ArgumentParser()
|
89 |
+
parser.add_argument("-s", "--start-index", required=True, type=int)
|
90 |
+
parser.add_argument("-n", "--size", required=True, type=int)
|
91 |
+
main(parser.parse_args())
|
fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import argparse
|
3 |
+
from string import punctuation
|
4 |
+
|
5 |
+
def len_no_punc(s, punc):
|
6 |
+
return len([ch for ch in s if ch in punc])
|
7 |
+
|
8 |
+
def filter_overpunc(len_npunc, len_sen):
|
9 |
+
return len_npunc < 0.5*len_sen
|
10 |
+
|
11 |
+
def main(args):
|
12 |
+
punc = punctuation + "—|–"
|
13 |
+
print('Processing file {}'.format(args.input))
|
14 |
+
with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv:
|
15 |
+
with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc:
|
16 |
+
with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt:
|
17 |
+
line = tsv.readline()
|
18 |
+
fields = line.split('\t')
|
19 |
+
|
20 |
+
src, tgt = fields[1], fields[2]
|
21 |
+
|
22 |
+
nchar_npunc_src = len_no_punc(src, punc)
|
23 |
+
nchar_npunc_tgt = len_no_punc(tgt, punc)
|
24 |
+
|
25 |
+
if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)):
|
26 |
+
fsrc.write(src.strip() + '\n')
|
27 |
+
ftgt.write(tgt.strip() + '\n')
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
parser = argparse.ArgumentParser()
|
31 |
+
parser.add_argument("--input", required=True, type=str)
|
32 |
+
parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output')
|
33 |
+
parser.add_argument('--bitext', type=str, required=True, help='language direction')
|
34 |
+
parser.add_argument('--src-lang', type=str, required=True, help='Source language')
|
35 |
+
parser.add_argument('--tgt-lang', type=str, required=True, help='Target language')
|
36 |
+
main(parser.parse_args())
|
fairseq/examples/m2m_100/tok.sh
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright (c) 2019-present, Facebook, Inc.
|
3 |
+
# All rights reserved.
|
4 |
+
#
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
#
|
8 |
+
|
9 |
+
set -e
|
10 |
+
|
11 |
+
TOKENIZERS_SCRIPTS=tokenizers
|
12 |
+
INSTALL_PATH=$TOKENIZERS_SCRIPTS/thirdparty
|
13 |
+
|
14 |
+
N_THREADS=8
|
15 |
+
|
16 |
+
lg=$1
|
17 |
+
|
18 |
+
MOSES=$INSTALL_PATH/mosesdecoder
|
19 |
+
REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
|
20 |
+
NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
|
21 |
+
REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
|
22 |
+
TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
|
23 |
+
|
24 |
+
# special tokenization for Romanian
|
25 |
+
WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
|
26 |
+
|
27 |
+
NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py
|
28 |
+
REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py
|
29 |
+
|
30 |
+
# Burmese
|
31 |
+
MY_SEGMENT=$INSTALL_PATH/seg_my.py
|
32 |
+
|
33 |
+
# Arabic
|
34 |
+
AR_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenizer_ar.sh
|
35 |
+
|
36 |
+
# Korean
|
37 |
+
KO_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ko.sh
|
38 |
+
|
39 |
+
# Japanese
|
40 |
+
JA_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ja.sh
|
41 |
+
|
42 |
+
# Indic
|
43 |
+
IN_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_indic.py
|
44 |
+
INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
|
45 |
+
|
46 |
+
# Thai
|
47 |
+
THAI_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_thai.py
|
48 |
+
|
49 |
+
# Chinese
|
50 |
+
CHINESE_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_zh.py
|
51 |
+
|
52 |
+
# Chinese
|
53 |
+
if [ "$lg" = "zh" ]; then
|
54 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $CHINESE_TOKENIZER
|
55 |
+
# Thai
|
56 |
+
elif [ "$lg" = "th" ]; then
|
57 |
+
cat - | python $THAI_TOKENIZER
|
58 |
+
# Japanese
|
59 |
+
elif [ "$lg" = "ja" ]; then
|
60 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | ${JA_SEGMENT}
|
61 |
+
# Korean
|
62 |
+
elif [ "$lg" = "ko" ]; then
|
63 |
+
cat - | $REM_NON_PRINT_CHAR | ${KO_SEGMENT}
|
64 |
+
# Romanian
|
65 |
+
elif [ "$lg" = "ro" ]; then
|
66 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
|
67 |
+
# Burmese
|
68 |
+
elif [ "$lg" = "my" ]; then
|
69 |
+
cat - | python ${MY_SEGMENT}
|
70 |
+
# Arabic
|
71 |
+
elif [ "$lg" = "ar" ]; then
|
72 |
+
cat - | ${AR_TOKENIZER}
|
73 |
+
# Indic
|
74 |
+
elif [ "$lg" = "ne" ]; then
|
75 |
+
cat - | python ${IN_TOKENIZER} $lg
|
76 |
+
elif [ "$lg" = "si" ]; then
|
77 |
+
cat - | python ${IN_TOKENIZER} $lg
|
78 |
+
elif [ "$lg" = "hi" ]; then
|
79 |
+
cat - | python ${IN_TOKENIZER} $lg
|
80 |
+
# other languages
|
81 |
+
else
|
82 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
|
83 |
+
fi
|
fairseq/examples/m2m_100/tokenizers/README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# M2M-100 Tokenization
|
2 |
+
|
3 |
+
We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results.
|
4 |
+
|
5 |
+
To reproduce the results, follow these steps:
|
6 |
+
|
7 |
+
```
|
8 |
+
tgt_lang=...
|
9 |
+
reference_translation=...
|
10 |
+
cat generation_output | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh $tgt_lang > hyp
|
11 |
+
cat $reference_translation |sh tok.sh $tgt_lang > ref
|
12 |
+
sacrebleu -tok 'none' ref < hyp
|
13 |
+
```
|
14 |
+
|
15 |
+
## Installation
|
16 |
+
|
17 |
+
Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh
|
18 |
+
If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install
|
fairseq/examples/m2m_100/tokenizers/seg_ja.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
SCRIPT=`realpath $0`
|
7 |
+
KYTEA=`dirname $SCRIPT`/thirdparty/kytea
|
8 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$KYTEA/lib:/usr/local/lib
|
9 |
+
export PATH=$PATH:"$KYTEA/bin"
|
10 |
+
|
11 |
+
cat - | tr -d "[:blank:]" | kytea -notags
|
fairseq/examples/m2m_100/tokenizers/seg_ko.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
SCRIPT=`realpath $0`
|
7 |
+
MECAB=`dirname $SCRIPT`/thirdparty/mecab-0.996-ko-0.9.2
|
8 |
+
|
9 |
+
export PATH=$PATH:"$MECAB/bin":"$MECAB/lib"
|
10 |
+
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:"$MECAB/lib"
|
11 |
+
|
12 |
+
cat - | mecab -O wakati
|
fairseq/examples/m2m_100/tokenizers/thirdparty/.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seg_my.py
|
2 |
+
indic_nlp_library/
|
3 |
+
indic_nlp_resources/
|
4 |
+
kytea/
|
5 |
+
mecab-0.996-ko-0.9.2.tar.gz
|
6 |
+
mecab-0.996-ko-0.9.2/
|
7 |
+
mosesdecoder/
|
8 |
+
wat2020.my-en.zip
|
9 |
+
wat2020.my-en/
|
10 |
+
wmt16-scripts/
|
11 |
+
mecab-ko-dic-2.1.1-20180720/
|
12 |
+
mecab-ko-dic-2.1.1-20180720.tar.gz
|
fairseq/examples/m2m_100/tokenizers/tokenize_indic.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Use: echo {text} | python tokenize_indic.py {language}
|
8 |
+
|
9 |
+
import sys
|
10 |
+
|
11 |
+
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
|
12 |
+
from indicnlp.tokenize.indic_tokenize import trivial_tokenize
|
13 |
+
|
14 |
+
|
15 |
+
factory = IndicNormalizerFactory()
|
16 |
+
normalizer = factory.get_normalizer(
|
17 |
+
sys.argv[1], remove_nuktas=False, nasals_mode="do_nothing"
|
18 |
+
)
|
19 |
+
|
20 |
+
for line in sys.stdin:
|
21 |
+
normalized_line = normalizer.normalize(line.strip())
|
22 |
+
tokenized_line = " ".join(trivial_tokenize(normalized_line, sys.argv[1]))
|
23 |
+
print(tokenized_line)
|
fairseq/examples/m2m_100/tokenizers/tokenize_thai.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import sys
|
8 |
+
|
9 |
+
from pythainlp import word_tokenize
|
10 |
+
|
11 |
+
|
12 |
+
for line in sys.stdin:
|
13 |
+
print(" ".join(word_tokenize(line.strip())))
|
fairseq/examples/m2m_100/tokenizers/tokenize_zh.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import fileinput
|
9 |
+
|
10 |
+
import sacrebleu
|
11 |
+
|
12 |
+
|
13 |
+
for line in fileinput.input():
|
14 |
+
print(sacrebleu.tokenize_zh(line))
|
fairseq/examples/m2m_100/tokenizers/tokenizer_ar.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env sh
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
# Please follow the instructions here http://alt.qcri.org/tools/arabic-normalizer/
|
8 |
+
# to install tools needed for Arabic
|
9 |
+
|
10 |
+
echo "Please install Arabic tools: http://alt.qcri.org/tools/arabic-normalizer/"
|
11 |
+
echo "Then update environment variables in tokenizer_ar.sh"
|
12 |
+
exit 1
|
13 |
+
|
14 |
+
SVMTOOL=...
|
15 |
+
GOMOSESGO=...
|
16 |
+
QCRI_ARABIC_NORMALIZER=...
|
17 |
+
|
18 |
+
export PERL5LIB="$SVMTOOL/lib":"$GOMOSESGO/bin/MADA-3.2":$PERL5LIB
|
19 |
+
|
20 |
+
|
21 |
+
tempfile=$(mktemp)
|
22 |
+
cat - > $tempfile
|
23 |
+
|
24 |
+
cd $QCRI_ARABIC_NORMALIZER
|
25 |
+
|
26 |
+
bash qcri_normalizer_mada3.2_aramorph1.2.1.sh $tempfile
|
27 |
+
cat $tempfile.mada_norm-aramorph.europarl_tok
|
fairseq/examples/mbart/README.md
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MBART: Multilingual Denoising Pre-training for Neural Machine Translation
|
2 |
+
[https://arxiv.org/abs/2001.08210]
|
3 |
+
|
4 |
+
## Introduction
|
5 |
+
|
6 |
+
MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.
|
7 |
+
|
8 |
+
## Pre-trained models
|
9 |
+
|
10 |
+
Model | Description | # params | Download
|
11 |
+
---|---|---|---
|
12 |
+
`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz)
|
13 |
+
`mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz)
|
14 |
+
|
15 |
+
## Results
|
16 |
+
|
17 |
+
**[WMT16 EN-RO](https://www.statmt.org/wmt16/translation-task.html)**
|
18 |
+
|
19 |
+
_(test set, no additional data used)_
|
20 |
+
|
21 |
+
Model | en-ro | ro-en
|
22 |
+
---|---|---
|
23 |
+
`Random` | 34.3 | 34.0
|
24 |
+
`mbart.cc25` | 37.7 | 37.8
|
25 |
+
`mbart.enro.bilingual` | 38.5 | 38.5
|
26 |
+
|
27 |
+
## BPE data
|
28 |
+
# download model
|
29 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz
|
30 |
+
tar -xzvf mbart.CC25.tar.gz
|
31 |
+
# bpe data
|
32 |
+
install SPM [here](https://github.com/google/sentencepiece)
|
33 |
+
```bash
|
34 |
+
SPM=/path/to/sentencepiece/build/src/spm_encode
|
35 |
+
MODEL=sentence.bpe.model
|
36 |
+
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${SRC} > ${DATA}/${TRAIN}.spm.${SRC} &
|
37 |
+
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${TGT} > ${DATA}/${TRAIN}.spm.${TGT} &
|
38 |
+
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${SRC} > ${DATA}/${VALID}.spm.${SRC} &
|
39 |
+
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${TGT} > ${DATA}/${VALID}.spm.${TGT} &
|
40 |
+
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${SRC} > ${DATA}/${TEST}.spm.${SRC} &
|
41 |
+
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DATA}/${TEST}.spm.${TGT} &
|
42 |
+
```
|
43 |
+
|
44 |
+
## Preprocess data
|
45 |
+
|
46 |
+
```bash
|
47 |
+
DICT=dict.txt
|
48 |
+
fairseq-preprocess \
|
49 |
+
--source-lang ${SRC} \
|
50 |
+
--target-lang ${TGT} \
|
51 |
+
--trainpref ${DATA}/${TRAIN}.spm \
|
52 |
+
--validpref ${DATA}/${VALID}.spm \
|
53 |
+
--testpref ${DATA}/${TEST}.spm \
|
54 |
+
--destdir ${DEST}/${NAME} \
|
55 |
+
--thresholdtgt 0 \
|
56 |
+
--thresholdsrc 0 \
|
57 |
+
--srcdict ${DICT} \
|
58 |
+
--tgtdict ${DICT} \
|
59 |
+
--workers 70
|
60 |
+
```
|
61 |
+
|
62 |
+
## Finetune on EN-RO
|
63 |
+
Finetune on mbart CC25
|
64 |
+
|
65 |
+
```bash
|
66 |
+
PRETRAIN=mbart.cc25 # fix if you moved the downloaded checkpoint
|
67 |
+
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
|
68 |
+
|
69 |
+
fairseq-train path_2_data \
|
70 |
+
--encoder-normalize-before --decoder-normalize-before \
|
71 |
+
--arch mbart_large --layernorm-embedding \
|
72 |
+
--task translation_from_pretrained_bart \
|
73 |
+
--source-lang en_XX --target-lang ro_RO \
|
74 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
75 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
76 |
+
--lr-scheduler polynomial_decay --lr 3e-05 --warmup-updates 2500 --total-num-update 40000 \
|
77 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
78 |
+
--max-tokens 1024 --update-freq 2 \
|
79 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
80 |
+
--seed 222 --log-format simple --log-interval 2 \
|
81 |
+
--restore-file $PRETRAIN \
|
82 |
+
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
|
83 |
+
--langs $langs \
|
84 |
+
--ddp-backend legacy_ddp
|
85 |
+
```
|
86 |
+
## Generate on EN-RO
|
87 |
+
Get sacrebleu on finetuned en-ro model
|
88 |
+
|
89 |
+
get tokenizer [here](https://github.com/rsennrich/wmt16-scripts)
|
90 |
+
```bash
|
91 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz
|
92 |
+
tar -xzvf mbart.cc25.ft.enro.tar.gz
|
93 |
+
```
|
94 |
+
|
95 |
+
```bash
|
96 |
+
model_dir=MBART_finetuned_enro # fix if you moved the checkpoint
|
97 |
+
|
98 |
+
fairseq-generate path_2_data \
|
99 |
+
--path $model_dir/model.pt \
|
100 |
+
--task translation_from_pretrained_bart \
|
101 |
+
--gen-subset test \
|
102 |
+
-t ro_RO -s en_XX \
|
103 |
+
--bpe 'sentencepiece' --sentencepiece-model $model_dir/sentence.bpe.model \
|
104 |
+
--sacrebleu --remove-bpe 'sentencepiece' \
|
105 |
+
--batch-size 32 --langs $langs > en_ro
|
106 |
+
|
107 |
+
cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp
|
108 |
+
cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref
|
109 |
+
sacrebleu -tok 'none' -s 'none' en_ro.ref < en_ro.hyp
|
110 |
+
```
|
111 |
+
|
112 |
+
## Citation
|
113 |
+
|
114 |
+
```bibtex
|
115 |
+
@article{liu2020multilingual,
|
116 |
+
title={Multilingual Denoising Pre-training for Neural Machine Translation},
|
117 |
+
author={Yinhan Liu and Jiatao Gu and Naman Goyal and Xian Li and Sergey Edunov and Marjan Ghazvininejad and Mike Lewis and Luke Zettlemoyer},
|
118 |
+
year={2020},
|
119 |
+
eprint={2001.08210},
|
120 |
+
archivePrefix={arXiv},
|
121 |
+
primaryClass={cs.CL}
|
122 |
+
}
|
123 |
+
```
|
fairseq/examples/megatron_11b/README.md
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Megatron-11b
|
2 |
+
|
3 |
+
Megatron-11b is a unidirectional language model with `11B` parameters based on [Megatron-LM](https://arxiv.org/pdf/1909.08053.pdf). Following the original Megatron work, we trained the model using intra-layer model parallelism with each layer's parameters split across 8 GPUs.
|
4 |
+
|
5 |
+
Megatron-11b is trained on the same data and uses the same byte-pair encoding (BPE) as [RoBERTa](https://arxiv.org/pdf/1907.11692.pdf).
|
6 |
+
|
7 |
+
## Pre-trained models
|
8 |
+
|
9 |
+
Model | Description | # params | # filesize | Download
|
10 |
+
---|---|---|---|---
|
11 |
+
`megatron_11b` | megatron_11b unidirectional language model | 11B | 19Gb | [megatron_11b.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/model_parallel/megatron_11b.tar.gz)
|
12 |
+
|
13 |
+
#### Architecture:
|
14 |
+
|
15 |
+
Param | Value
|
16 |
+
---|---
|
17 |
+
embed_dim | 3072
|
18 |
+
ffn_dim | 3072 * 6
|
19 |
+
layers | 72
|
20 |
+
attention heads | 32
|
21 |
+
|
22 |
+
#### Training details:
|
23 |
+
|
24 |
+
Param | value
|
25 |
+
---|---
|
26 |
+
bsz | 512
|
27 |
+
num_updates | 300,000
|
28 |
+
peak_lr | 1.5e-04
|
29 |
+
lr scheduler | inverse_sqrt
|
30 |
+
clip norm | 0.0
|
31 |
+
|
32 |
+
|
33 |
+
## Example training command (model parallel)
|
34 |
+
|
35 |
+
Megatron-11b contains too many parameters to train on a single GPU. Following
|
36 |
+
the original Megatron work, we adopt an intra-layer model parallel training
|
37 |
+
approach in which each layer's parameters are split across multiple GPUs and
|
38 |
+
activations and gradients are communicated during the forward/backward pass,
|
39 |
+
respectively. We similarly split the loss computation using the
|
40 |
+
`vocab_parallel_cross_entropy` criterion.
|
41 |
+
|
42 |
+
The following training command illustrates how to do model parallel training in
|
43 |
+
fairseq. We assume that each machine (node) has 8 GPUs among which to split the
|
44 |
+
model parameters (`--model-parallel-size 8`). If you have access to multiple
|
45 |
+
nodes, you may combine this with data parallel training by increasing
|
46 |
+
`--distributed-world-size`.
|
47 |
+
|
48 |
+
To train Megatron-11b on a single node:
|
49 |
+
|
50 |
+
|
51 |
+
```bash
|
52 |
+
fairseq-train <DATA_PATH> \
|
53 |
+
--distributed-world-size 8 \
|
54 |
+
--memory-efficient-fp16 \
|
55 |
+
--num-workers 2 \
|
56 |
+
--model-parallel-size 8 \
|
57 |
+
--criterion vocab_parallel_cross_entropy \
|
58 |
+
--task language_modeling \
|
59 |
+
--sample-break-mode none \
|
60 |
+
--tokens-per-sample 1024 \
|
61 |
+
--arch transformer_lm_megatron_11b \
|
62 |
+
--share-decoder-input-output-embed \
|
63 |
+
--optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 --clip-norm 0.0 \
|
64 |
+
--lr-scheduler inverse_sqrt --lr 0.00015 \
|
65 |
+
--warmup-updates 3000 --weight-decay 0.01 \
|
66 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
67 |
+
--batch-size 2 \
|
68 |
+
--max-update 300000;
|
69 |
+
```
|
70 |
+
|
71 |
+
Note: Above was tested on `DGX-1` box, with `8xV100-32Gb` GPUs.
|
72 |
+
|
73 |
+
## Results
|
74 |
+
|
75 |
+
**[Wikitext103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/)**
|
76 |
+
|
77 |
+
Model | Valid perplexity | Test perplexity
|
78 |
+
---|---|---
|
79 |
+
`megatron_11b` | 10.64 | 10.54
|
80 |
+
|
81 |
+
|
82 |
+
## Evaluating `megatron_11b` on Wikitext-103
|
83 |
+
|
84 |
+
#### 1. Downloading Megatron-11b
|
85 |
+
```bash
|
86 |
+
# WARNING: this file is 19GB
|
87 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/model_parallel/megatron_11b.tar.gz
|
88 |
+
tar -xzvf megatron_11b.tar.gz
|
89 |
+
```
|
90 |
+
|
91 |
+
#### 2. Download Wikitext-103
|
92 |
+
```bash
|
93 |
+
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
|
94 |
+
unzip wikitext-103-raw-v1.zip
|
95 |
+
```
|
96 |
+
|
97 |
+
#### 3. Detokenize test tokens
|
98 |
+
Megatron-11b uses a byte-level BPE that expects raw (untokenized) input. Since
|
99 |
+
the wikitext-103 dataset comes tokenized, we apply a simple detokenization
|
100 |
+
process to restore the untokenized test set:
|
101 |
+
|
102 |
+
```bash
|
103 |
+
python -m examples.megatron_11b.detok wikitext-103-raw/wiki.test.raw > wikitext-103-raw/wiki.test.detok
|
104 |
+
```
|
105 |
+
|
106 |
+
#### 4. BPE encoding
|
107 |
+
```bash
|
108 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
109 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
110 |
+
|
111 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
112 |
+
--encoder-json encoder.json \
|
113 |
+
--vocab-bpe vocab.bpe \
|
114 |
+
--inputs "wikitext-103-raw/wiki.test.detok" \
|
115 |
+
--outputs "wikitext-103-raw/wiki.test.bpe" \
|
116 |
+
--workers 60;
|
117 |
+
```
|
118 |
+
|
119 |
+
#### 5. Fairseq binarize
|
120 |
+
```bash
|
121 |
+
fairseq-preprocess \
|
122 |
+
--only-source \
|
123 |
+
--testpref wikitext-103-raw/wiki.test.bpe \
|
124 |
+
--srcdict megatron_11b/dict.txt \
|
125 |
+
--destdir wikitext103-bin;
|
126 |
+
```
|
127 |
+
|
128 |
+
#### 6. Evaluating perplexity.
|
129 |
+
We can now evaluate perplexity on the test set. Note that because we've modified
|
130 |
+
the test set (via detokenization and BPE), the perplexity reported by
|
131 |
+
`fairseq-eval-lm` needs to be renormalized.
|
132 |
+
|
133 |
+
Compute unnormalized perplexity:
|
134 |
+
|
135 |
+
```bash
|
136 |
+
DATA_PATH=wikitext103-bin/
|
137 |
+
fairseq-eval-lm \
|
138 |
+
$DATA_PATH \
|
139 |
+
--path megatron_11b/model.pt \
|
140 |
+
--task language_modeling \
|
141 |
+
--gen-subset test \
|
142 |
+
--batch-size 8 \
|
143 |
+
--criterion cross_entropy \
|
144 |
+
--context-window 992 \
|
145 |
+
--distributed-world-size 8 \
|
146 |
+
--model-parallel-size 8;
|
147 |
+
# Expected PPL (unnormalized_ppl): [8.46]
|
148 |
+
# Note: the eval command needs to run on 8 GPUs for the released model
|
149 |
+
```
|
150 |
+
Renormalizing formula: `2 ^ ( log_2(unnormalized_PPL) * (270847 / 245566))`.
|
151 |
+
PPL After normalization: `10.54`
|
152 |
+
|
153 |
+
To renormalize the perplexity, we must account for the change in token count
|
154 |
+
after detokenizing and appling BPE. The formula for this is:
|
155 |
+
`2 ^ ( log_2(unnormalized_PPL) * (new_token_cnt / orig_token_cnt))`
|
156 |
+
|
157 |
+
For the wikitext-103 test set, the original token count is `245566` and the
|
158 |
+
token count after detokenization and applying BPE is `270847`.
|
159 |
+
|
160 |
+
The perplexity after renormalization is:
|
161 |
+
`2 ^ ( log_2(8.46) * (270847 / 245566)) = 10.54`
|
fairseq/examples/megatron_11b/detok.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import fileinput
|
9 |
+
|
10 |
+
import sacremoses
|
11 |
+
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser(description="")
|
15 |
+
parser.add_argument("files", nargs="*", help="input files")
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
detok = sacremoses.MosesDetokenizer()
|
19 |
+
|
20 |
+
for line in fileinput.input(args.files, openhook=fileinput.hook_compressed):
|
21 |
+
print(
|
22 |
+
detok.detokenize(line.strip().split(" "))
|
23 |
+
.replace(" @", "")
|
24 |
+
.replace("@ ", "")
|
25 |
+
.replace(" =", "=")
|
26 |
+
.replace("= ", "=")
|
27 |
+
.replace(" – ", "–")
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
if __name__ == "__main__":
|
32 |
+
main()
|
fairseq/examples/mms/MODEL_CARD.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MMS Model Card
|
2 |
+
|
3 |
+
## Model details
|
4 |
+
|
5 |
+
**Organization developing the model** The FAIR team
|
6 |
+
|
7 |
+
**Model version** This is version 1 of the model.
|
8 |
+
|
9 |
+
**Model type** MMS is speech model, based on the transformer architecture. The pre-trained model comes in two sizes: 300M and 1B parameters. We fine-tune the model for speech recognition and make it available in the 1B variant. We also fine-tune the 1B variant for language identification.
|
10 |
+
|
11 |
+
**License** CC BY-NC
|
12 |
+
|
13 |
+
**Where to send questions or comments about the model** Questions and comments about MMS can be sent via the [GitHub repository](https://github.com/pytorch/fairseq/tree/master/examples/mms) of the project , by opening an issue and tagging it as MMS.
|
14 |
+
|
15 |
+
## Uses
|
16 |
+
|
17 |
+
**Primary intended uses** The primary use of MMS is to perform speech processing research for many more languages and to perform tasks such as automatic speech recognition, language identification, and speech synthesis.
|
18 |
+
|
19 |
+
**Primary intended users** The primary intended users of the model are researchers in speech processing, machine learning and artificial intelligence.
|
20 |
+
|
21 |
+
**Out-of-scope use cases** Fine-tuning the pre-pretrained models on other labeled datasets or downstream tasks requires further risk evaluation and mitigation.
|
22 |
+
|
23 |
+
## Bias and Risks
|
24 |
+
|
25 |
+
The MMS models were pre-trained on a blend of data from different domains, including readings of the New Testament. In the paper, we describe two studies analyzing gender bias and the use of religious language which conclude that models perform equally well for both genders and that on average, there is little bias for religious language (section 8 of the paper).
|
26 |
+
|
27 |
+
# Training Details
|
28 |
+
|
29 |
+
## Training Data
|
30 |
+
|
31 |
+
MMS is pre-trained on VoxPopuli (parliamentary speech), MLS (read audiobooks), VoxLingua-107 (YouTube speech), CommonVoice (read Wikipedia text), BABEL (telephone conversations), and MMS-lab-U (New Testament readings), MMS-unlab (various read Christian texts).
|
32 |
+
Models are fine-tuned on FLEURS, VoxLingua-107, MLS, CommonVoice, and MMS-lab. We obtained the language information for MMS-lab, MMS-lab-U and MMS-unlab from our data soucrce and did not manually verify it for every language.
|
33 |
+
|
34 |
+
## Training Procedure
|
35 |
+
|
36 |
+
Please refer to the research paper for details on this.
|
37 |
+
|
38 |
+
# Evaluation
|
39 |
+
|
40 |
+
## Testing Data, Factors & Metrics
|
41 |
+
|
42 |
+
We evaluate the model on a different benchmarks for the downstream tasks. The evaluation details are presented in the paper. The models performance is measured using standard metrics such as character error rate, word error rate, and classification accuracy.
|
43 |
+
|
44 |
+
|
45 |
+
# Citation
|
46 |
+
|
47 |
+
**BibTeX:**
|
48 |
+
|
49 |
+
```
|
50 |
+
@article{pratap2023mms,
|
51 |
+
title={Scaling Speech Technology to 1,000+ Languages},
|
52 |
+
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
|
53 |
+
journal={arXiv},
|
54 |
+
year={2023}
|
55 |
+
}
|
56 |
+
|
57 |
+
```
|
58 |
+
|
59 |
+
# Model Card Contact
|
60 |
+
|
61 |
+
Please reach out to the authors at: [[email protected]](mailto:[email protected]) [[email protected]](mailto:[email protected]) [[email protected]](mailto:[email protected]) [[email protected]](mailto:[email protected])
|
62 |
+
|
63 |
+
|
fairseq/examples/mms/README.md
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MMS: Scaling Speech Technology to 1000+ languages
|
2 |
+
|
3 |
+
The Massively Multilingual Speech (MMS) project expands speech technology from about 100 languages to over 1,000 by building a single multilingual speech recognition model supporting over 1,100 languages (more than 10 times as many as before), language identification models able to identify over [4,000 languages](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html) (40 times more than before), pretrained models supporting over 1,400 languages, and text-to-speech models for over 1,100 languages. Our goal is to make it easier for people to access information and to use devices in their preferred language.
|
4 |
+
|
5 |
+
You can find details in the paper [Scaling Speech Technology to 1000+ languages](https://research.facebook.com/publications/scaling-speech-technology-to-1000-languages/) and the [blog post](https://ai.facebook.com/blog/multilingual-model-speech-recognition/).
|
6 |
+
|
7 |
+
An overview of the languages covered by MMS can be found [here](https://dl.fbaipublicfiles.com/mms/misc/language_coverage_mms.html).
|
8 |
+
|
9 |
+
## 🤗 Transformers
|
10 |
+
|
11 |
+
MMS has been added to Transformers. For more information, please refer to [Transformers' MMS docs](https://huggingface.co/docs/transformers/main/en/model_doc/mms).
|
12 |
+
|
13 |
+
[Click here](https://huggingface.co/models?other=mms) to find all MMS checkpoints on the Hub.
|
14 |
+
|
15 |
+
Checkout the demo here [](https://huggingface.co/spaces/facebook/MMS)
|
16 |
+
|
17 |
+
## Finetuned models
|
18 |
+
### ASR
|
19 |
+
|
20 |
+
| Model | Languages | Dataset | Model | Dictionary* | Supported languages | |
|
21 |
+
|---|---|---|---|---|---|---
|
22 |
+
MMS-1B:FL102 | 102 | FLEURS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/dict/mms1b_fl102/eng.txt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_fl102_langs.html) | [🤗 Hub](https://huggingface.co/facebook/mms-1b-fl102)
|
23 |
+
MMS-1B:L1107| 1107 | MMS-lab | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/dict/mms1b_l1107/eng.txt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_l1107_langs.html) | [🤗 Hub](https://huggingface.co/facebook/mms-1b-l1107)
|
24 |
+
MMS-1B-all| 1162 | MMS-lab + FLEURS <br>+ CV + VP + MLS | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all.pt) | [download](https://dl.fbaipublicfiles.com/mms/asr/dict/mms1b_all/eng.txt) | [download](https://dl.fbaipublicfiles.com/mms/asr/mms1b_all_langs.html) | [🤗 Hub](https://huggingface.co/facebook/mms-1b-all)
|
25 |
+
|
26 |
+
\* In the `Dictionary` column, we provide the download link for token dictionary in English language. To download token dictionary for a different language supported by the model, modify the language code in the URL appropriately. For example, to get token dictionary of FL102 model for Hindi language, use [this](https://dl.fbaipublicfiles.com/mms/asr/dict/mms1b_fl102/hin.txt) link.
|
27 |
+
|
28 |
+
### TTS
|
29 |
+
1. Download the list of [iso codes](https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html) of 1107 languages.
|
30 |
+
2. Find the iso code of the target language and download the checkpoint. Each folder contains 3 files: `G_100000.pth`, `config.json`, `vocab.txt`. The `G_100000.pth` is the generator trained for 100K updates, `config.json` is the training config, `vocab.txt` is the vocabulary for the TTS model.
|
31 |
+
```
|
32 |
+
# Examples:
|
33 |
+
wget https://dl.fbaipublicfiles.com/mms/tts/eng.tar.gz # English (eng)
|
34 |
+
wget https://dl.fbaipublicfiles.com/mms/tts/azj-script_latin.tar.gz # North Azerbaijani (azj-script_latin)
|
35 |
+
```
|
36 |
+
The above command downloads generator only, which is enough to run TTS inference. If you want the full model checkpoint which also includes the discriminator (`D_100000.pth`) and the optimizer states, download as follows.
|
37 |
+
```
|
38 |
+
# Example (full checkpoint: generator + discriminator + optimizer):
|
39 |
+
wget https://dl.fbaipublicfiles.com/mms/tts/full_model/eng.tar.gz # English (eng)
|
40 |
+
```
|
41 |
+
|
42 |
+
|
43 |
+
### LID
|
44 |
+
|
45 |
+
\# Languages | Dataset | Model | Dictionary | Supported languages | |
|
46 |
+
|---|---|---|---|---|---
|
47 |
+
126 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l126/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l126_langs.html) | [🤗 Hub](https://huggingface.co/facebook/mms-lid-126)
|
48 |
+
256 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l256/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l256_langs.html) | [🤗 Hub](https://huggingface.co/facebook/mms-lid-256)
|
49 |
+
512 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l512/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l512_langs.html)| [🤗 Hub](https://huggingface.co/facebook/mms-lid-512)
|
50 |
+
1024 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l1024/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l1024_langs.html)| [🤗 Hub](https://huggingface.co/facebook/mms-lid-1024)
|
51 |
+
2048 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l2048/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l2048_langs.html)| [🤗 Hub](https://huggingface.co/facebook/mms-lid-2048)
|
52 |
+
4017 | FLEURS + VL + MMS-lab-U + MMS-unlab | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017.pt) | [download](https://dl.fbaipublicfiles.com/mms/lid/dict/l4017/dict.lang.txt) | [download](https://dl.fbaipublicfiles.com/mms/lid/mms1b_l4017_langs.html)| [🤗 Hub](https://huggingface.co/facebook/mms-lid-4017)
|
53 |
+
|
54 |
+
## Commands to run inference
|
55 |
+
|
56 |
+
### ASR
|
57 |
+
Run this command to transcribe one or more audio files:
|
58 |
+
```shell command
|
59 |
+
cd /path/to/fairseq-py/
|
60 |
+
python examples/mms/asr/infer/mms_infer.py --model "/path/to/asr/model" --lang lang_code \
|
61 |
+
--audio "/path/to/audio_1.wav" "/path/to/audio_2.wav" "/path/to/audio_3.wav"
|
62 |
+
```
|
63 |
+
We also provide an Ipython notebook example inside `asr/tutorial` folder [ipynb](https://github.com/facebookresearch/fairseq/blob/main/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb) or [](https://colab.research.google.com/github/facebookresearch/fairseq/blob/main/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb)
|
64 |
+
|
65 |
+
|
66 |
+
For more advance configuration and calculate CER/WER, you could prepare manifest folder by creating a folder with this format:
|
67 |
+
```
|
68 |
+
$ ls /path/to/manifest
|
69 |
+
dev.tsv
|
70 |
+
dev.wrd
|
71 |
+
dev.ltr
|
72 |
+
dev.uid
|
73 |
+
|
74 |
+
# dev.tsv each line contains <audio> <number_of_sample>
|
75 |
+
# if user don't have this information, please run misc/get_sample_size.py
|
76 |
+
|
77 |
+
$ cat dev.tsv
|
78 |
+
/
|
79 |
+
/path/to/audio_1.wav 180000
|
80 |
+
/path/to/audio_2.wav 200000
|
81 |
+
|
82 |
+
$ cat dev.ltr
|
83 |
+
t h i s | i s | o n e |
|
84 |
+
t h i s | i s | t w o |
|
85 |
+
|
86 |
+
$ cat dev.wrd
|
87 |
+
this is one
|
88 |
+
this is two
|
89 |
+
|
90 |
+
$ cat dev.uid
|
91 |
+
audio_1
|
92 |
+
audio_2
|
93 |
+
```
|
94 |
+
|
95 |
+
Followed by command below:
|
96 |
+
```
|
97 |
+
lang_code=<iso_code>
|
98 |
+
|
99 |
+
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/path/to/asr/model'" task.data='/path/to/manifest' dataset.gen_subset="${lang_code}:dev" common_eval.post_process=letter
|
100 |
+
|
101 |
+
```
|
102 |
+
Available options:
|
103 |
+
* To get the raw character-based output, user can change to `common_eval.post_process=none`
|
104 |
+
|
105 |
+
* To maximize GPU efficiency or avoid out-of-memory (OOM), user can tune `dataset.max_tokens=???` size
|
106 |
+
|
107 |
+
* To run language model decoding, install flashlight python bindings using
|
108 |
+
```
|
109 |
+
git clone --recursive [email protected]:flashlight/flashlight.git
|
110 |
+
cd flashlight;
|
111 |
+
git checkout 035ead6efefb82b47c8c2e643603e87d38850076
|
112 |
+
cd bindings/python
|
113 |
+
python3 setup.py install
|
114 |
+
```
|
115 |
+
Train a [KenLM language model](https://github.com/flashlight/wav2letter/tree/main/recipes/rasr#language-model) and prepare a lexicon file in [this](https://dl.fbaipublicfiles.com/wav2letter/rasr/tutorial/lexicon.txt) format. Pretrained languages models from our paper can be found in [🤗 Hub](https://huggingface.co/facebook/mms-cclms/).
|
116 |
+
|
117 |
+
```
|
118 |
+
LANG=<iso> # for example - 'eng', 'azj-script_latin'
|
119 |
+
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py --config-dir=examples/mms/asr/config \
|
120 |
+
--config-name=infer_common decoding.type=kenlm distributed_training.distributed_world_size=1 \
|
121 |
+
decoding.unique_wer_file=true decoding.beam=500 decoding.beamsizetoken=50 \
|
122 |
+
task.data=<MANIFEST_FOLDER_PATH> common_eval.path='<MODEL_PATH.pt>' decoding.lexicon=<LEXICON_FILE> decoding.lmpath=<LM_FILE> \
|
123 |
+
decoding.results_path=<OUTPUT_DIR> dataset.gen_subset=${LANG}:dev decoding.lmweight=??? decoding.wordscore=???
|
124 |
+
```
|
125 |
+
We typically sweep `lmweight` in the range of 0 to 5 and `wordscore` in the range of -3 to 3. The output directory will contain the reference and hypothesis outputs from decoder.
|
126 |
+
|
127 |
+
For decoding with character-based language models, use empty lexicon file (`decoding.lexicon=`), `decoding.unitlm=True` and sweep over `decoding.silweight` instead of `wordscore`.
|
128 |
+
|
129 |
+
### TTS
|
130 |
+
Note: clone and install [VITS](https://github.com/jaywalnut310/vits) before running inference.
|
131 |
+
```shell script
|
132 |
+
## English TTS
|
133 |
+
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/eng \
|
134 |
+
--wav ./example.wav --txt "Expanding the language coverage of speech technology \
|
135 |
+
has the potential to improve access to information for many more people"
|
136 |
+
|
137 |
+
## Maithili TTS
|
138 |
+
$ PYTHONPATH=$PYTHONPATH:/path/to/vits python examples/mms/tts/infer.py --model-dir /path/to/model/mai \
|
139 |
+
--wav ./example.wav --txt "मुदा आइ धरि ई तकनीक सौ सं किछु बेसी भाषा तक सीमित छल जे सात हजार \
|
140 |
+
सं बेसी ज्ञात भाषाक एकटा अंश अछी"
|
141 |
+
```
|
142 |
+
`example.wav` contains synthesized audio for the language.
|
143 |
+
|
144 |
+
We also provide an Ipython notebook example inside `tts/tutorial` folder [ipynb](https://github.com/facebookresearch/fairseq/blob/main/examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb) or [](https://colab.research.google.com/github/facebookresearch/fairseq/blob/main/examples/mms/tts/tutorial/MMS_TTS_Inference_Colab.ipynb)
|
145 |
+
|
146 |
+
|
147 |
+
### LID
|
148 |
+
|
149 |
+
|
150 |
+
Prepare two files in this format. Each manifest line contains <audio> and <number_of_sample>
|
151 |
+
```
|
152 |
+
#/path/to/manifest.tsv
|
153 |
+
/
|
154 |
+
/path/to/audio1.wav 180000
|
155 |
+
/path/to/audio2.wav 240000
|
156 |
+
/path/to/audio3.wav 160000
|
157 |
+
|
158 |
+
# /path/to/manifest.lang
|
159 |
+
eng 1
|
160 |
+
eng 1
|
161 |
+
eng 1
|
162 |
+
```
|
163 |
+
|
164 |
+
Download model and the corresponding dictionary file for the LID model.
|
165 |
+
Use the following command to run inference -
|
166 |
+
```shell script
|
167 |
+
$ PYTHONPATH='.' python3 examples/mms/lid/infer.py /path/to/dict/l126/ --path /path/to/models/mms1b_l126.pt \
|
168 |
+
--task audio_classification --infer-manifest /path/to/manifest.tsv --output-path <OUTDIR>
|
169 |
+
```
|
170 |
+
The above command assumes there is a file named `dict.lang.txt` in `/path/to/dict/l126/`. `<OUTDIR>/predictions.txt` will contain the predictions from the model for the audio files in `manifest.tsv`.
|
171 |
+
|
172 |
+
We also provide an Ipython notebook example inside `lid/tutorial` folder [ipynb](https://github.com/facebookresearch/fairseq/blob/main/examples/mms/lid/tutorial/MMS_LID_Inference_Colab.ipynb) or [](https://colab.research.google.com/github/facebookresearch/fairseq/blob/main/examples/mms/lid/tutorial/MMS_LID_Inference_Colab.ipynb)
|
173 |
+
|
174 |
+
## Fine-tuning
|
175 |
+
|
176 |
+
### ASR
|
177 |
+
|
178 |
+
MMS Adapter fine-tuning has been added to the official 🤗 Transformers examples [here](https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#connectionist-temporal-classification-with-adapters).
|
179 |
+
For a more step-by-step explanation of how to fine-tune MMS, please have a look at the blog [**Fine-tuning MMS Adapter Models for Multi-Lingual ASR**](https://huggingface.co/blog/mms_adapters) on 🤗 blogs.
|
180 |
+
|
181 |
+
### TTS
|
182 |
+
|
183 |
+
For a guide on how to fine-tune MMS TTS checkpoints using the 🤗 Transformer implementation, please have a look at this [repository](https://github.com/ylacombe/finetune-hf-vits).
|
184 |
+
|
185 |
+
## Pretrained models
|
186 |
+
|
187 |
+
| Model | Link | |
|
188 |
+
|---|---|---
|
189 |
+
MMS-300M | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_300m.pt) | [🤗 Hub](https://huggingface.co/facebook/mms-300m)
|
190 |
+
MMS-1B | [download](https://dl.fbaipublicfiles.com/mms/pretraining/base_1b.pt) | [🤗 Hub](https://huggingface.co/facebook/mms-1b)
|
191 |
+
|
192 |
+
Example commands to finetune the pretrained models can be found [here](https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec#fine-tune-a-pre-trained-model-with-ctc).
|
193 |
+
|
194 |
+
## Forced Alignment Tooling
|
195 |
+
|
196 |
+
We also developed an efficient forced alignment algorithm implemented on GPU which is able to process very long audio files. This algorithm is open sourced and we provide instructions on how to use it [here](data_prep). We also open source a multilingual alignment model trained on 31K hours of data in 1,130 languages, as well as text normalization scripts.
|
197 |
+
|
198 |
+
|
199 |
+
# License
|
200 |
+
|
201 |
+
The MMS code and model weights are released under the CC-BY-NC 4.0 license.
|
202 |
+
|
203 |
+
# Citation
|
204 |
+
|
205 |
+
**BibTeX:**
|
206 |
+
|
207 |
+
```
|
208 |
+
@article{pratap2023mms,
|
209 |
+
title={Scaling Speech Technology to 1,000+ Languages},
|
210 |
+
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
|
211 |
+
journal={arXiv},
|
212 |
+
year={2023}
|
213 |
+
}
|
214 |
+
|
215 |
+
```
|
fairseq/examples/mms/asr/config/infer_common.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
# defaults:
|
3 |
+
# - hydra/launcher: submitit_slurm
|
4 |
+
|
5 |
+
# @package _group_
|
6 |
+
|
7 |
+
task:
|
8 |
+
_name: audio_finetuning
|
9 |
+
data: null
|
10 |
+
labels: ltr
|
11 |
+
common_eval:
|
12 |
+
path: null
|
13 |
+
post_process: letter
|
14 |
+
# model_overrides: "{'task':{'multi_corpus_keys':None}}"
|
15 |
+
decoding:
|
16 |
+
type: viterbi
|
17 |
+
lexicon: null
|
18 |
+
unique_wer_file: false
|
19 |
+
results_path: null
|
20 |
+
distributed_training:
|
21 |
+
ddp_backend: legacy_ddp
|
22 |
+
distributed_world_size: 1
|
23 |
+
hydra:
|
24 |
+
run:
|
25 |
+
dir: ${common_eval.results_path}/${dataset.gen_subset}
|
26 |
+
sweep:
|
27 |
+
dir: /checkpoint/${env:USER}/${env:PREFIX}/${common_eval.results_path}
|
28 |
+
subdir: ${dataset.gen_subset}
|
29 |
+
dataset:
|
30 |
+
max_tokens: 2_000_000
|
31 |
+
gen_subset: dev
|
32 |
+
required_batch_size_multiple: 1
|
fairseq/examples/mms/asr/infer/example_infer_adapter.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
lang="$1"
|
3 |
+
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=4000000 distributed_training.distributed_world_size=1 "common_eval.path='/fsx-wav2vec/androstj/exps/wav2vec/mms/v4/finetune/xl1b_d5_dfls_0_0.3_u300k__ft_on_d5_127_dbeta1/ft_smax_adp_common.seed:1__dataset.max_tokens:2880000__optimization.lr:[0.001]__optimization.max_update:4000__merged_ckpt/checkpoints/checkpoint_last.pt'" task.data=/fsx-wav2vec/androstj/dataset/v4/fl/fseq dataset.gen_subset="${lang}:${lang}/dev" common_eval.post_process=none
|
fairseq/examples/mms/asr/infer/mms_infer.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import soundfile as sf
|
9 |
+
import tempfile
|
10 |
+
from pathlib import Path
|
11 |
+
import os
|
12 |
+
import subprocess
|
13 |
+
import sys
|
14 |
+
import re
|
15 |
+
|
16 |
+
def parser():
|
17 |
+
parser = argparse.ArgumentParser(description="ASR inference script for MMS model")
|
18 |
+
parser.add_argument("--model", type=str, help="path to ASR model", required=True)
|
19 |
+
parser.add_argument("--audio", type=str, help="path to audio file", required=True, nargs='+')
|
20 |
+
parser.add_argument("--lang", type=str, help="audio language", required=True)
|
21 |
+
parser.add_argument("--format", type=str, choices=["none", "letter"], default="letter")
|
22 |
+
parser.add_argument("--extra-infer-args", type=str, default="")
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
def reorder_decode(hypos):
|
26 |
+
outputs = []
|
27 |
+
for hypo in hypos:
|
28 |
+
idx = int(re.findall("\(None-(\d+)\)$", hypo)[0])
|
29 |
+
hypo = re.sub("\(\S+\)$", "", hypo).strip()
|
30 |
+
outputs.append((idx, hypo))
|
31 |
+
outputs = sorted(outputs)
|
32 |
+
return outputs
|
33 |
+
|
34 |
+
def process(args):
|
35 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
36 |
+
print(">>> preparing tmp manifest dir ...", file=sys.stderr)
|
37 |
+
tmpdir = Path(tmpdir)
|
38 |
+
with open(tmpdir / "dev.tsv", "w") as fw, open(tmpdir / "dev.uid", "w") as fu:
|
39 |
+
fw.write("/\n")
|
40 |
+
for audio in args.audio:
|
41 |
+
nsample = sf.SoundFile(audio).frames
|
42 |
+
fw.write(f"{audio}\t{nsample}\n")
|
43 |
+
fu.write(f"{audio}\n")
|
44 |
+
with open(tmpdir / "dev.ltr", "w") as fw:
|
45 |
+
fw.write("d u m m y | d u m m y |\n"*len(args.audio))
|
46 |
+
with open(tmpdir / "dev.wrd", "w") as fw:
|
47 |
+
fw.write("dummy dummy\n"*len(args.audio))
|
48 |
+
cmd = f"""
|
49 |
+
PYTHONPATH=. PREFIX=INFER HYDRA_FULL_ERROR=1 python examples/speech_recognition/new/infer.py -m --config-dir examples/mms/asr/config/ --config-name infer_common decoding.type=viterbi dataset.max_tokens=1440000 distributed_training.distributed_world_size=1 "common_eval.path='{args.model}'" task.data={tmpdir} dataset.gen_subset="{args.lang}:dev" common_eval.post_process={args.format} decoding.results_path={tmpdir} {args.extra_infer_args}
|
50 |
+
"""
|
51 |
+
print(">>> loading model & running inference ...", file=sys.stderr)
|
52 |
+
subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL,)
|
53 |
+
with open(tmpdir/"hypo.word") as fr:
|
54 |
+
hypos = fr.readlines()
|
55 |
+
outputs = reorder_decode(hypos)
|
56 |
+
for ii, hypo in outputs:
|
57 |
+
hypo = re.sub("\(\S+\)$", "", hypo).strip()
|
58 |
+
print(f'===============\nInput: {args.audio[ii]}\nOutput: {hypo}')
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
args = parser()
|
63 |
+
process(args)
|
fairseq/examples/mms/asr/tutorial/MMS_ASR_Inference_Colab.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
fairseq/examples/mms/data_prep/README.md
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Preparation
|
2 |
+
|
3 |
+
We describe the process of aligning long audio files with their transcripts and generating shorter audio segments below.
|
4 |
+
|
5 |
+
- Step 1: Download and install torchaudio using the nightly version. We have open sourced the CTC forced alignment algorithm described in our paper via [torchaudio](https://github.com/pytorch/audio/pull/3348).
|
6 |
+
```
|
7 |
+
pip install --pre torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
|
8 |
+
```
|
9 |
+
|
10 |
+
- Step 2: Download [uroman](https://github.com/isi-nlp/uroman) from Github. It is a universal romanizer which converts text in any script to the Latin alphabet. Use [this link](https://www.isi.edu/~ulf/uroman.html) to try their web interface.
|
11 |
+
```
|
12 |
+
git clone [email protected]:isi-nlp/uroman.git
|
13 |
+
```
|
14 |
+
|
15 |
+
- Step 3: Install a few other dependencies
|
16 |
+
```
|
17 |
+
apt install sox
|
18 |
+
pip install sox dataclasses
|
19 |
+
```
|
20 |
+
|
21 |
+
- Step 4: Create a text file containing the transcript for a (long) audio file. Each line in the text file will correspond to a separate audio segment that will be generated upon alignment.
|
22 |
+
|
23 |
+
Example content of the input text file :
|
24 |
+
```
|
25 |
+
Text of the desired first segment
|
26 |
+
Text of the desired second segment
|
27 |
+
Text of the desired third segment
|
28 |
+
```
|
29 |
+
|
30 |
+
- Step 5: Run forced alignment and segment the audio file into shorter segments.
|
31 |
+
```
|
32 |
+
python align_and_segment.py --audio /path/to/audio.wav --text_filepath /path/to/textfile --lang <iso> --outdir /path/to/output --uroman /path/to/uroman/bin
|
33 |
+
```
|
34 |
+
|
35 |
+
The above code will generated the audio segments under output directory based on the content of each line in the input text file. The `manifest.json` file consisting of the of segmented audio filepaths and their corresponding transcripts.
|
36 |
+
|
37 |
+
```
|
38 |
+
> head /path/to/output/manifest.json
|
39 |
+
|
40 |
+
{"audio_start_sec": 0.0, "audio_filepath": "/path/to/output/segment1.flac", "duration": 6.8, "text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "normalized_text": "she wondered afterwards how she could have spoken with that hard serenity how she could have", "uroman_tokens": "s h e w o n d e r e d a f t e r w a r d s h o w s h e c o u l d h a v e s p o k e n w i t h t h a t h a r d s e r e n i t y h o w s h e c o u l d h a v e"}
|
41 |
+
{"audio_start_sec": 6.8, "audio_filepath": "/path/to/output/segment2.flac", "duration": 5.3, "text": "gone steadily on with story after story poem after poem till", "normalized_text": "gone steadily on with story after story poem after poem till", "uroman_tokens": "g o n e s t e a d i l y o n w i t h s t o r y a f t e r s t o r y p o e m a f t e r p o e m t i l l"}
|
42 |
+
{"audio_start_sec": 12.1, "audio_filepath": "/path/to/output/segment3.flac", "duration": 5.9, "text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "normalized_text": "allan's grip on her hands relaxed and he fell into a heavy tired sleep", "uroman_tokens": "a l l a n ' s g r i p o n h e r h a n d s r e l a x e d a n d h e f e l l i n t o a h e a v y t i r e d s l e e p"}
|
43 |
+
```
|
44 |
+
|
45 |
+
To visualize the segmented audio files, [Speech Data Explorer](https://github.com/NVIDIA/NeMo/tree/main/tools/speech_data_explorer) tool from NeMo toolkit can be used.
|
46 |
+
|
47 |
+
As our alignment model outputs uroman tokens for input audio in any language, it also works with non-english audio and their corresponding transcripts.
|