File size: 4,220 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node

# the rank of distributed node worker
# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU.
WORKER_RANK=${1:-'0'}
PLATFORM=${2:-'shef'} 
YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'}
TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'}
MASTER_PROC_ADD=${5:-'127.0.0.1'}
DIST_PORT=${6:-'39685'}
# echo $PATH
# export PATH=$PATH:./
echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}"

MAP_PROJ_DIR=$(pwd)
echo $MAP_PROJ_DIR

NNODS=1
MAX_TOKENS=1000000 # set for 80GB A100 batchsize
NUM_WOKERS=0

run_command_prefix=' '
# Loading folders
# 1. tsv files for audio paths
# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv
DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest
# 2. working folder for saving checkpoints and loading config files
CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain
# 3. clustering labels for training data
LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset

FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq;
SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/

# set 75 for the RVQ-VAE model
LABEL_RATE=75

case $YAML_NAME_WITHOUT_EXT in
    MERT_RVQ-VAE_CQT_95M)
        TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]'
        NNODS=1
        LABEL_RATE=75
        MAX_TOKENS=1800000
        ;;
    MERT_RVQ-VAE_CQT_95M_bestrq)
        TASK_LABELS_POSTFIX='["rq_0"]'
        NNODS=1
        LABEL_RATE=75
        MAX_TOKENS=1200000
        ;;
    MERT_RVQ-VAE_CQT_330M)
        TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]'
        NNODS=1
        LABEL_RATE=75
        NPROCES_PER_NODE=8
        MAX_TOKENS=720000
        ;;
    MERT_RVQ-VAE_CQT_330M_multinodes)
        TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]'
        NNODS=4
        LABEL_RATE=75
        NPROCES_PER_NODE=8
        MAX_TOKENS=600000
        ;;
    MERT_RVQ-VAE_CQT_330M_multinodes_debug2node)
        TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]'
        NNODS=2
        LABEL_RATE=75
        NPROCES_PER_NODE=8
        MAX_TOKENS=600000
        ;;
    MERT_RVQ-VAE_CQT_330M_multinodes_debug1node)
        TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]'
        NNODS=1
        LABEL_RATE=75
        NPROCES_PER_NODE=8
        MAX_TOKENS=600000
        ;;
    *)
        echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}"
        exit 1
        ;;
    esac

 echo running $YAML_NAME_WITHOUT_EXT ..

  mkdir -p ${SAVE_DIR}
  echo "checkpoint save at: ${SAVE_DIR}"
  cd ${SAVE_DIR}

  DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}`
  ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}`
  echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}"

  DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"`
  CKPT_SAVE_DIR="${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT}"
  
  OMP_NUM_THREADS=6 ${run_command_prefix} \
  python -u -m torch.distributed.launch --use_env \
    --nproc_per_node=8 --nnodes=${NNODS} --node_rank=${INDEX} \
    --master_addr=${CHIEF_IP} --master_port=25521 \
    ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py -m \
    --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT}\
    common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \
    common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \
    task.data=${DATA_DIR}\
    task.label_dir=${LABEL_DIR} \
    task.labels=${TASK_LABELS_POSTFIX} \
    dataset.num_workers=${NUM_WOKERS} \
    dataset.max_tokens=${MAX_TOKENS} \
    dataset.disable_validation=true \
    model.label_rate=${LABEL_RATE}\
    checkpoint.save_dir=${CKPT_SAVE_DIR} \
    checkpoint.restore_file="checkpoint_last.pt"