Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling (Gong et al., 2021)
https://arxiv.org/pdf/2106.10840.pdf
Introduction
We present attention head selection strategies in multilingual and multi-domain sequence modeling including text translation, speech recognition and speech translation tasks.
Below is an example of training multilingual/multi-domain speech recognition models.
Data Preparation
Prepare mTEDx data as in mTEDx example and CoVoST data as in CoVoST example. Similarly prepare EuroParl data.
Training a multilingual ASR model with attention head selection
data_dir=<path to mtedx data>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
valid_subset="valid_ar_ar_tedx,valid_de_de_tedx,valid_el_el_tedx,valid_es_es_tedx,valid_fr_fr_tedx,valid_it_it_tedx,valid_pt_pt_tedx,valid_ru_ru_tedx"
strateg=<subset or group>
fairseq-train ${data_dir} \
--user-dir examples/attention_head_selection/src \
--train-subset "${train_subset}" \
--valid-subset "${valid_subset}" \
--config-yaml 'config_asr.yaml' \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
--lr 5e-4 \
--clip-norm 10.0 \
--seed 1 \
--max-epoch 400 \
--max-tokens 32000 \
--ignore-prefix-size 1 \
--dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--skip-invalid-size-inputs-valid-test \
--encoder-attn-head-select \
--total-encoder-attention-heads 8 \
--decoder-self-attn-head-select \
--total-decoder-attention-heads 8 \
--attn-head-select-strategy ${strategy} \
--task-type lang \
Training a multi-domain ASR model with attention head selection
data_dir=<path to multi-domain data>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
valid_subset="dev_es_es_tedx,dev_fr_fr_tedx,dev_pt_pt_tedx,dev_it_it_tedx,dev_ru_ru_tedx,dev_el_el_tedx,dev_ar_ar_tedx,dev_de_de_tedx,dev_ar_ar_cv,dev_de_de_cv,dev_es_es_cv,dev_fr_fr_cv,dev_it_it_cv,dev_pt_pt_cv,dev_ru_ru_cv,dev_de_de_ep,dev_es_es_ep,dev_fr_fr_ep,dev_it_it_ep,dev_pt_pt_ep"
strateg=<subset or group>
fairseq-train ${data_dir} \
--user-dir examples/attention_head_selection/src \
--train-subset "${train_subset}" \
--valid-subset "${valid_subset}" \
--config-yaml 'config_asr.yaml' \
--arch head_selection_s2t_transformer_s \
--task speech_to_text_head_selection \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--lr-scheduler 'inverse_sqrt' --stop-min-lr -1.0 --warmup-updates 10000 \
--lr 5e-4 \
--clip-norm 10.0 \
--seed 1 \
--max-epoch 400 \
--max-tokens 32000 \
--ignore-prefix-size 1 \
--dropout 0.3 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--skip-invalid-size-inputs-valid-test \
--encoder-attn-head-select \
--total-encoder-attention-heads 8 \
--decoder-self-attn-head-select \
--total-decoder-attention-heads 8 \
--attn-head-select-strategy ${strategy} \
--task-type domain
Inference in multilingual setting
MODEL_DIR=<checkpoint directory>
data_dir=<path to mtedx data>
gen_subset=<data to test, e.g., test_ar_ar_tedx>
train_subset="train_ar_ar_tedx,train_de_de_tedx,train_el_el_tedx,train_es_es_tedx,train_fr_fr_tedx,train_it_it_tedx,train_pt_pt_tedx,train_ru_ru_tedx"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
mkdir -p $RESULTS
fi;
python scripts/average_checkpoints.py \
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${data_dir} \
--user-dir examples/attention_head_selection/src \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--train-subset ${train_subset} \
--gen-subset ${gen_subset} \
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
--config-yaml 'config_asr.yaml' \
--prefix-size 1 \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--results-path ${RESULTS} \
--scoring wer --wer-tokenizer 13a \
--wer-lowercase --wer-remove-punct --remove-bpe
Inference in multi-domain setting
MODEL_DIR=<checkpoint directory>
data_dir=<path to multi-domain data>
gen_subset=<data to test, e.g., test_pt_pt_cv>
train_subset="train_es_es_tedx,train_fr_fr_tedx,train_pt_pt_tedx,train_it_it_tedx,train_ru_ru_tedx,train_el_el_tedx,train_ar_ar_tedx,train_de_de_tedx,train_ar_ar_cv,train_de_de_cv,train_es_es_cv,train_fr_fr_cv,train_it_it_cv,train_pt_pt_cv,train_ru_ru_cv,train_de_de_ep,train_es_es_ep,train_fr_fr_ep,train_it_it_ep,train_pt_pt_ep"
last_n=10
CHECKPOINT_FILENAME="avg_last_${last_n}_checkpoint.pt"
CHECKPOINT="_avg"
RESULTS="${MODEL_DIR}/ckpt${CHECKPOINT}"
if [ ! -d $RESULTS ]; then
mkdir -p $RESULTS
fi;
python scripts/average_checkpoints.py \
--inputs ${MODEL_DIR} --num-epoch-checkpoints ${last_n} \
--output "${MODEL_DIR}/${CHECKPOINT_FILENAME}"
fairseq-generate ${data_dir} \
--user-dir examples/attention_head_selection/src \
--arch 'head_selection_s2t_transformer_s' \
--task 'speech_to_text_head_selection' \
--train-subset ${train_subset} \
--gen-subset ${gen_subset} \
--path "${MODEL_DIR}/${CHECKPOINT_FILENAME}" \
--config-yaml 'config_asr.yaml' \
--prefix-size 1 \
--max-tokens 40000 --beam 5 \
--skip-invalid-size-inputs-valid-test \
--results-path ${RESULTS} \
--scoring wer --wer-tokenizer 13a \
--wer-lowercase --wer-remove-punct --remove-bpe
Citation
@article{gong2021pay,
title={Pay Better Attention to Attention: Head Selection in Multilingual and Multi-Domain Sequence Modeling},
author={Gong, Hongyu and Tang, Yun and Pino, Juan and Li, Xian},
journal={arXiv preprint arXiv:2106.10840},
year={2021}
}
'''