hainazhu
Add application file
258fd02
WORKER_RANK=${1:-$INDEX}
PLATFORM=${2:-'shef'}
YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'}
TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'}
MASTER_PROC_ADD=${5:-$CHIEF_IP}
DIST_PORT=${6:-'25520'}
# echo $PATH
# export PATH=$PATH:./
echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}"
MAP_PROJ_DIR=$(pwd)
echo $MAP_PROJ_DIR
NNODS=1
BATCH_SIZE=12
NUM_WOKERS=6
run_command_prefix=' '
# Loading folders
# 1. tsv files for audio paths
# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv
DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest
# 2. working folder for saving checkpoints and loading config files
CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain
# 3. clustering labels for training data
LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset
FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq;
SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/
case $YAML_NAME_WITHOUT_EXT in
EAT_pretraining_music_multinodes)
NNODS=4
NPROCES_PER_NODE=8
LABEL_RATE=25
BATCH_SIZE=12
;;
*)
echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}"
exit 1
;;
esac
echo running $YAML_NAME_WITHOUT_EXT ..
mkdir -p ${SAVE_DIR}
echo "checkpoint save at: ${SAVE_DIR}"
cd ${SAVE_DIR}
DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}`
ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}`
echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}"
DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"`
OMP_NUM_THREADS=6 ${run_command_prefix} \
python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
--config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \
common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \
common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \
checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \
distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \
distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \
distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \
distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \
distributed_training.distributed_init_method="tcp://${CHIEF_IP}:${DIST_PORT}" \
task.data=${DATA_DIR} \
dataset.num_workers=${NUM_WOKERS} \
dataset.batch_size=${BATCH_SIZE} \
dataset.disable_validation=true \
# pip install h5py timm -i https://mirrors.tencent.com/pypi/simple/