amupd's picture
SpeechT5 upload
62e9ca6
#####################################
# Hubert ED model #
#####################################
[ $# -gt 2 ] && echo "Usage: $0 <world_size> <update_freq> [w2v_path] [mbart_path]" && exit 0
world_size=$1
update_freq=$2
w2v_path=$3
mbart_path=$4
[ -z $world_size ] && world_size=8
[ -z $update_freq ] && update_freq=2
[ -z $w2v_path ] && w2v_path=${HOME}/dataset/iwslt_mustc/pretrain_ed_model_cfg.pt
[ -z $mbart_path ] && mbart_path="/mnt/default/v-junyiao/released_exsp/mbart50.pretrained/model.pt"
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN,af_ZA,az_AZ,bn_IN,fa_IR,he_IL,hr_HR,id_ID,ka_GE,km_KH,mk_MK,ml_IN,mn_MN,mr_IN,pl_PL,ps_AF,pt_XX,sv_SE,sw_KE,ta_IN,te_IN,th_TH,tl_XX,uk_UA,ur_PK,xh_ZA,gl_ES,sl_SI
DATA_DIR=/mnt/default/lozhou/speechdata/mt_data/en-de/com-filter-ende/bin-idx
exp_name=tune_mbart_com_filter_le-4
SAVE_DIR="${HOME}/data/iwslt/mt_stage1_en-de/$exp_name"
[ -d $SAVE_DIR ] || mkdir -p $SAVE_DIR
CODE_ROOT=${HOME}/code/SpeechT5/YiTrans
python $CODE_ROOT/fairseq/fairseq_cli/hydra_train.py \
--config-dir $CODE_ROOT/yitrans_iwslt22/config/finetune_mt \
--config-name mt_translation \
common.user_dir=$CODE_ROOT/yitrans_iwslt22 \
distributed_training.distributed_world_size=${world_size} \
optimization.update_freq=[$update_freq] \
\
+task.data=$DATA_DIR \
+task.source_lang="en_XX" +task.target_lang="de_DE" \
+task.langs=\"$langs\" \
+task.normalize=false \
+task.append_source_id=true \
\
+model.dropout=0.2 \
+model.attention_dropout=0.1 \
model.activation_dropout=0.1 \
model.decoder_layerdrop=0 \
model.layerdrop=0 \
model.freeze_finetune_updates=0 \
\
model.w2v_path=$w2v_path \
+model.no_pretrained_weights=true \
+model.load_pretrained_mbart_from=$mbart_path \
+model.share_enc_dec_embeddings=true \
+model.share_s2t_t2t_embeddings=false \
+model.use_rel_pos_enc=false \
\
dataset.train_subset="train" \
dataset.valid_subset="valid" \
dataset.num_workers=4 \
dataset.max_tokens=2000 \
\
optimization.max_epoch=50 \
optimization.clip_norm=5 \
optimization.max_update=200000 \
lr_scheduler.total_num_update=200000 \
\
checkpoint.save_interval=1 \
checkpoint.save_interval_updates=5000 \
checkpoint.keep_last_epochs=5 \
checkpoint.keep_best_checkpoints=5 \
\
common.seed=222 \
common.log_interval=100 \
common.log_format="json" \
\
checkpoint.best_checkpoint_metric="accuracy" \
checkpoint.maximize_best_checkpoint_metric=true \
common.tensorboard_logdir=$SAVE_DIR \
checkpoint.save_dir=$SAVE_DIR \
hydra.run.dir=$SAVE_DIR \
hydra.job.name=$exp_name