# bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node # the rank of distributed node worker # If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. WORKER_RANK=${1:-'0'} PLATFORM=${2:-'shef'} YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} MASTER_PROC_ADD=${5:-'127.0.0.1'} DIST_PORT=${6:-'39685'} # echo $PATH # export PATH=$PATH:./ echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" MAP_PROJ_DIR=$(pwd) echo $MAP_PROJ_DIR NNODS=1 MAX_TOKENS=1000000 # set for 80GB A100 batchsize NUM_WOKERS=0 run_command_prefix=' ' # Loading folders # 1. tsv files for audio paths # DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest # 2. working folder for saving checkpoints and loading config files CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain # 3. clustering labels for training data LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ # set 75 for the RVQ-VAE model LABEL_RATE=75 case $YAML_NAME_WITHOUT_EXT in MERT_RVQ-VAE_CQT_95M) TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' NNODS=1 LABEL_RATE=75 MAX_TOKENS=1800000 ;; MERT_RVQ-VAE_CQT_95M_bestrq) TASK_LABELS_POSTFIX='["rq_0"]' NNODS=1 LABEL_RATE=75 MAX_TOKENS=1200000 ;; MERT_RVQ-VAE_CQT_330M) TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' NNODS=1 LABEL_RATE=75 NPROCES_PER_NODE=8 MAX_TOKENS=720000 ;; MERT_RVQ-VAE_CQT_330M_multinodes) TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' NNODS=4 LABEL_RATE=75 NPROCES_PER_NODE=8 MAX_TOKENS=600000 ;; MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' NNODS=2 LABEL_RATE=75 NPROCES_PER_NODE=8 MAX_TOKENS=600000 ;; MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' NNODS=1 LABEL_RATE=75 NPROCES_PER_NODE=8 MAX_TOKENS=600000 ;; *) echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" exit 1 ;; esac echo running $YAML_NAME_WITHOUT_EXT .. mkdir -p ${SAVE_DIR} echo "checkpoint save at: ${SAVE_DIR}" cd ${SAVE_DIR} DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` CKPT_SAVE_DIR="${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT}" OMP_NUM_THREADS=6 ${run_command_prefix} \ python -u -m torch.distributed.launch --use_env \ --nproc_per_node=8 --nnodes=${NNODS} --node_rank=${INDEX} \ --master_addr=${CHIEF_IP} --master_port=25521 \ ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py -m \ --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT}\ common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ task.data=${DATA_DIR}\ task.label_dir=${LABEL_DIR} \ task.labels=${TASK_LABELS_POSTFIX} \ dataset.num_workers=${NUM_WOKERS} \ dataset.max_tokens=${MAX_TOKENS} \ dataset.disable_validation=true \ model.label_rate=${LABEL_RATE}\ checkpoint.save_dir=${CKPT_SAVE_DIR} \ checkpoint.restore_file="checkpoint_last.pt"