PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
fb0facd verified
raw
history blame
2.04 kB
#!/usr/bin/env bash
train_common () {
fairseq-train "$DATA" \
--combine-val \
--train-subset train \
--num-workers 2 \
--validate-interval-updates 1000 \
--save-interval-updates 1000 \
--no-epoch-checkpoints \
--ddp-backend fully_sharded \
--memory-efficient-fp16 \
--fp16-init-scale 4 \
--checkpoint-activations \
--arch transformer_lm_gpt \
--activation-fn gelu \
--share-decoder-input-output-embed \
--task language_modeling \
--sample-break-mode none \
--tokens-per-sample 2048 \
--optimizer adam --adam-betas "(0.9, 0.98)" \
--adam-eps 1e-08 \
--clip-norm 0.0 \
--lr-scheduler polynomial_decay \
--warmup-updates 750 \
--dropout 0.1 \
--attention-dropout 0.1 \
--weight-decay 0.01 \
--batch-size 16 \
--update-freq 2 \
--required-batch-size-multiple 1 \
--total-num-update 572204 \
--max-update 572204 \
--seed 1 \
--log-format json --log-interval 1 \
--distributed-world-size 8 --distributed-port 13177 \
"$@"
}
train_125M () {
train_common --decoder-layers 12 \
--decoder-embed-dim 768 \
--decoder-ffn-embed-dim 3072 \
--decoder-attention-heads 12 "$@"
}
train_355M () {
train_common --decoder-layers 24 \
--decoder-embed-dim 1024\
--decoder-ffn-embed-dim 4096 \
--decoder-attention-heads 16 \
--dropout 0.0 \
--attention-dropout 0.0 \
"$@"
}
train_1.3B () {
train_common --decoder-layers 24 \
--decoder-embed-dim 2048 \
--decoder-ffn-embed-dim 8192 \
--decoder-attention-heads 32 \
--batch-size 4 \
--update-freq 16 \
--total-num-update 286102 \
--max-update 286102 \
"$@"
}
train_2.7B () {
train_common --decoder-layers 32 \
--decoder-embed-dim 2560 \
--decoder-ffn-embed-dim 10240 \
--decoder-attention-heads 32 \
--batch-size 4 \
--update-freq 16 \
--total-num-update 286102 \
--max-update 286102 \
"$@"
}