diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..39e7ae7fd0fdd2d8e5bc370225bb1f3eb8648ac8 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
 *.xz filter=lfs diff=lfs merge=lfs -text
 *.zip filter=lfs diff=lfs merge=lfs -text
 *.zst filter=lfs diff=lfs merge=lfs -text
-*tfevents* filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..b174072d5944c7044abfdde417bb3ec9eb33521e
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,18 @@
+__pycache__
+flagged
+result
+
+# Developing mode
+_*.sh
+_*.json
+*.lst
+yard*
+*.out
+evaluation/evalset_selection
+mfa
+egs/svc/*wavmark
+egs/svc/custom
+egs/svc/*/dev*
+egs/svc/dev_exp_config.json
+bins/svc/demo*
+bins/svc/preprocess_custom.py
\ No newline at end of file
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..49d1a413ec53285bb9a4ec813702c8fab99d9d1d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,78 @@
+import gradio as gr
+
+
+SUPPORTED_TARGET_SINGERS = {
+    "Adele": "vocalist_l1_Adele",
+    "Beyonce": "vocalist_l1_Beyonce",
+    "Bruno Mars": "vocalist_l1_BrunoMars",
+    "John Mayer": "vocalist_l1_JohnMayer",
+    "Michael Jackson": "vocalist_l1_MichaelJackson",
+    "Taylor Swift": "vocalist_l1_TaylorSwift",
+    "Jacky Cheung 张学友": "vocalist_l1_张学友",
+    "Jian Li 李健": "vocalist_l1_李健",
+    "Feng Wang 汪峰": "vocalist_l1_汪峰",
+    "Faye Wong 王菲": "vocalist_l1_王菲",
+    "Yijie Shi 石倚洁": "vocalist_l1_石倚洁",
+    "Tsai Chin 蔡琴": "vocalist_l1_蔡琴",
+    "Ying Na 那英": "vocalist_l1_那英",
+    "Eason Chan 陈奕迅": "vocalist_l1_陈奕迅",
+    "David Tao 陶喆": "vocalist_l1_陶喆",
+}
+
+
+def svc_inference(
+    source_audio,
+    target_singer,
+    diffusion_steps=1000,
+    key_shift_mode="auto",
+    key_shift_num=0,
+):
+    pass
+
+
+demo_inputs = [
+    gr.Audio(
+        sources=["upload", "microphone"],
+        label="Upload (or record) a song you want to listen",
+    ),
+    gr.Radio(
+        choices=list(SUPPORTED_TARGET_SINGERS.keys()),
+        label="Target Singer",
+        value="Jian Li 李健",
+    ),
+    gr.Slider(
+        1,
+        1000,
+        value=1000,
+        step=1,
+        label="Diffusion Inference Steps",
+        info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
+    ),
+    gr.Radio(
+        choices=["Auto Shift", "Key Shift"],
+        value="Auto Shift",
+        label="Pitch Shift Control",
+        info='If you want to control the specific pitch shift value, you need to choose "Key Shift"',
+    ),
+    gr.Slider(
+        -6,
+        6,
+        value=0,
+        step=1,
+        label="Key Shift Values",
+        info='How many semitones you want to transpose.	This parameter will work only if you choose "Key Shift"',
+    ),
+]
+
+demo_outputs = gr.Audio(label="")
+
+
+demo = gr.Interface(
+    fn=svc_inference,
+    inputs=demo_inputs,
+    outputs=demo_outputs,
+    title="Amphion Singing Voice Conversion",
+)
+
+if __name__ == "__main__":
+    demo.launch(show_api=False)
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/args.json b/ckpts/svc/vocalist_l1_contentvec+whisper/args.json
new file mode 100755
index 0000000000000000000000000000000000000000..836d5e81420921d4ec096d3445c0ff5964e13b73
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/args.json
@@ -0,0 +1,256 @@
+{
+    "base_config": "config/diffusion.json",
+    "dataset": [
+        "vocalist_l1",
+    ],
+    "exp_name": "vocalist_l1_contentvec+whisper",
+    "inference": {
+        "diffusion": {
+            "scheduler": "pndm",
+            "scheduler_settings": {
+                "num_inference_timesteps": 1000,
+            },
+        },
+    },
+    "model": {
+        "condition_encoder": {
+            "content_encoder_dim": 384,
+            "contentvec_dim": 256,
+            "f0_max": 1100,
+            "f0_min": 50,
+            "input_loudness_dim": 1,
+            "input_melody_dim": 1,
+            "merge_mode": "add",
+            "mert_dim": 256,
+            "n_bins_loudness": 256,
+            "n_bins_melody": 256,
+            "output_content_dim": 384,
+            "output_loudness_dim": 384,
+            "output_melody_dim": 384,
+            "output_singer_dim": 384,
+            "pitch_max": 1100,
+            "pitch_min": 50,
+            "singer_table_size": 512,
+            "use_conformer_for_content_features": false,
+            "use_contentvec": true,
+            "use_log_f0": true,
+            "use_log_loudness": true,
+            "use_mert": false,
+            "use_singer_encoder": true,
+            "use_spkid": true,
+            "use_wenet": false,
+            "use_whisper": true,
+            "wenet_dim": 512,
+            "whisper_dim": 1024,
+        },
+        "diffusion": {
+            "bidilconv": {
+                "base_channel": 384,
+                "conditioner_size": 384,
+                "conv_kernel_size": 3,
+                "dilation_cycle_length": 4,
+                "n_res_block": 20,
+            },
+            "model_type": "bidilconv",
+            "scheduler": "ddpm",
+            "scheduler_settings": {
+                "beta_end": 0.02,
+                "beta_schedule": "linear",
+                "beta_start": 0.0001,
+                "num_train_timesteps": 1000,
+            },
+            "step_encoder": {
+                "activation": "SiLU",
+                "dim_hidden_layer": 512,
+                "dim_raw_embedding": 128,
+                "max_period": 10000,
+                "num_layer": 2,
+            },
+            "unet2d": {
+                "down_block_types": [
+                    "CrossAttnDownBlock2D",
+                    "CrossAttnDownBlock2D",
+                    "CrossAttnDownBlock2D",
+                    "DownBlock2D",
+                ],
+                "in_channels": 1,
+                "mid_block_type": "UNetMidBlock2DCrossAttn",
+                "only_cross_attention": false,
+                "out_channels": 1,
+                "up_block_types": [
+                    "UpBlock2D",
+                    "CrossAttnUpBlock2D",
+                    "CrossAttnUpBlock2D",
+                    "CrossAttnUpBlock2D",
+                ],
+            },
+        },
+    },
+    "model_type": "DiffWaveNetSVC",
+    "preprocess": {
+        "audio_dir": "audios",
+        "bits": 8,
+        "content_feature_batch_size": 16,
+        "contentvec_batch_size": 1,
+        "contentvec_dir": "contentvec",
+        "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
+        "contentvec_frameshift": 0.02,
+        "contentvec_sample_rate": 16000,
+        "dur_dir": "durs",
+        "duration_dir": "duration",
+        "emo2id": "emo2id.json",
+        "energy_dir": "energys",
+        "extract_audio": false,
+        "extract_contentvec_feature": true,
+        "extract_energy": true,
+        "extract_label": false,
+        "extract_mcep": false,
+        "extract_mel": true,
+        "extract_mert_feature": false,
+        "extract_pitch": true,
+        "extract_uv": true,
+        "extract_wenet_feature": false,
+        "extract_whisper_feature": true,
+        "f0_max": 1100,
+        "f0_min": 50,
+        "file_lst": "file.lst",
+        "fmax": 12000,
+        "fmin": 0,
+        "hop_size": 256,
+        "is_label": true,
+        "is_mu_law": true,
+        "lab_dir": "labs",
+        "label_dir": "labels",
+        "mcep_dir": "mcep",
+        "mel_dir": "mels",
+        "mel_min_max_norm": true,
+        "mel_min_max_stats_dir": "mel_min_max_stats",
+        "mert_dir": "mert",
+        "mert_feature_layer": -1,
+        "mert_frameshit": 0.01333,
+        "mert_hop_size": 320,
+        "mert_model": "m-a-p/MERT-v1-330M",
+        "min_level_db": -115,
+        "mu_law_norm": false,
+        "n_fft": 1024,
+        "n_mel": 100,
+        "num_silent_frames": 8,
+        "num_workers": 8,
+        "phone_seq_file": "phone_seq_file",
+        "pin_memory": true,
+        "pitch_bin": 256,
+        "pitch_dir": "pitches",
+        "pitch_extractor": "parselmouth",
+        "pitch_max": 1100.0,
+        "pitch_min": 50.0,
+        "processed_dir": "ckpts/svc/vocalist_l1_contentvec+whisper/data",
+        "ref_level_db": 20,
+        "sample_rate": 24000,
+        "spk2id": "singers.json",
+        "train_file": "train.json",
+        "trim_fft_size": 512,
+        "trim_hop_size": 128,
+        "trim_silence": false,
+        "trim_top_db": 30,
+        "trimmed_wav_dir": "trimmed_wavs",
+        "use_audio": false,
+        "use_contentvec": true,
+        "use_dur": false,
+        "use_emoid": false,
+        "use_frame_duration": false,
+        "use_frame_energy": true,
+        "use_frame_pitch": true,
+        "use_lab": false,
+        "use_label": false,
+        "use_log_scale_energy": false,
+        "use_log_scale_pitch": false,
+        "use_mel": true,
+        "use_mert": false,
+        "use_min_max_norm_mel": true,
+        "use_one_hot": false,
+        "use_phn_seq": false,
+        "use_phone_duration": false,
+        "use_phone_energy": false,
+        "use_phone_pitch": false,
+        "use_spkid": true,
+        "use_uv": true,
+        "use_wav": false,
+        "use_wenet": false,
+        "use_whisper": true,
+        "utt2emo": "utt2emo",
+        "utt2spk": "utt2singer",
+        "uv_dir": "uvs",
+        "valid_file": "test.json",
+        "wav_dir": "wavs",
+        "wenet_batch_size": 1,
+        "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
+        "wenet_dir": "wenet",
+        "wenet_downsample_rate": 4,
+        "wenet_frameshift": 0.01,
+        "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
+        "wenet_sample_rate": 16000,
+        "whisper_batch_size": 30,
+        "whisper_dir": "whisper",
+        "whisper_downsample_rate": 2,
+        "whisper_frameshift": 0.01,
+        "whisper_model": "medium",
+        "whisper_model_path": "pretrained/whisper/medium.pt",
+        "win_size": 1024,
+    },
+    "supported_model_type": [
+        "Fastspeech2",
+        "DiffSVC",
+        "Transformer",
+        "EDM",
+        "CD",
+    ],
+    "train": {
+        "adamw": {
+            "lr": 0.0004,
+        },
+        "batch_size": 32,
+        "dataloader": {
+            "num_worker": 8,
+            "pin_memory": true,
+        },
+        "ddp": true,
+        "epochs": 50000,
+        "gradient_accumulation_step": 1,
+        "keep_checkpoint_max": 5,
+        "keep_last": [
+            5,
+            -1,
+        ],
+        "max_epoch": -1,
+        "max_steps": 1000000,
+        "multi_speaker_training": false,
+        "optimizer": "AdamW",
+        "random_seed": 10086,
+        "reducelronplateau": {
+            "factor": 0.8,
+            "min_lr": 0.0001,
+            "patience": 10,
+        },
+        "run_eval": [
+            false,
+            true,
+        ],
+        "sampler": {
+            "drop_last": true,
+            "holistic_shuffle": false,
+        },
+        "save_checkpoint_stride": [
+            3,
+            10,
+        ],
+        "save_checkpoints_steps": 10000,
+        "save_summary_steps": 500,
+        "scheduler": "ReduceLROnPlateau",
+        "total_training_steps": 50000,
+        "tracker": [
+            "tensorboard",
+        ],
+        "valid_interval": 10000,
+    },
+    "use_custom_dataset": true,
+}
\ No newline at end of file
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin
new file mode 100755
index 0000000000000000000000000000000000000000..6b5604e3770d0c8de693930332f32ef2e0b16fe0
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:836af10b834c7aec9209eb19ce43559e6ef1e3a59bd6468e90cadbc9a18749ef
+size 249512389
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin
new file mode 100755
index 0000000000000000000000000000000000000000..a11911352aa92208e5246cea59e52b3de1f0d704
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d54eed12bef331095fc367f196d07c5061d5cb72dd6fe0e1e4453b997bf1d68d
+size 124755137
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..be96aac1818d9f8fc4dedfcc530ee1e8ea9f78f7
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6798ddffadcd7d5405a77e667c674c474e4fef0cba817fdd300c7c985c1e82fe
+size 14599
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json
new file mode 100755
index 0000000000000000000000000000000000000000..cd56250fa8be439b4ac6d2afe15fed300a69c973
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json
@@ -0,0 +1,17 @@
+{
+    "vocalist_l1_Adele": 0,
+    "vocalist_l1_Beyonce": 1,
+    "vocalist_l1_BrunoMars": 2,
+    "vocalist_l1_JohnMayer": 3,
+    "vocalist_l1_MichaelJackson": 4,
+    "vocalist_l1_TaylorSwift": 5,
+    "vocalist_l1_张学友": 6,
+    "vocalist_l1_李健": 7,
+    "vocalist_l1_汪峰": 8,
+    "vocalist_l1_王菲": 9,
+    "vocalist_l1_石倚洁": 10,
+    "vocalist_l1_蔡琴": 11,
+    "vocalist_l1_那英": 12,
+    "vocalist_l1_陈奕迅": 13,
+    "vocalist_l1_陶喆": 14
+}
\ No newline at end of file
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy
new file mode 100755
index 0000000000000000000000000000000000000000..f74cf6fe3127f22eb07c931f1f9ece4c07ed00ed
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:04131849378aa4f525a701909f743c303f8d56571682572b888046ead9f3e2ab
+size 528
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy
new file mode 100755
index 0000000000000000000000000000000000000000..20326231f2c3925360e7b102eb98e22bb9a238f5
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef4895ebef0e9949a6e623315bdc8a68490ba95d2f81b2be9f5146f904203016
+size 528
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json
new file mode 100755
index 0000000000000000000000000000000000000000..5d4cf31177f9dd2bac9538df9eb649ac522fcd69
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json
@@ -0,0 +1,31 @@
+{
+    "dataset": "vocalist_l1",
+    "train": {
+        "size": 3180,
+        "hours": 6.1643
+    },
+    "test": {
+        "size": 114,
+        "hours": 0.2224
+    },
+    "singers": {
+        "size": 15,
+        "training_minutes": {
+            "vocalist_l1_陶喆": 45.51,
+            "vocalist_l1_陈奕迅": 43.36,
+            "vocalist_l1_汪峰": 41.08,
+            "vocalist_l1_李健": 38.9,
+            "vocalist_l1_JohnMayer": 30.83,
+            "vocalist_l1_Adele": 27.23,
+            "vocalist_l1_那英": 27.02,
+            "vocalist_l1_石倚洁": 24.93,
+            "vocalist_l1_张学友": 18.31,
+            "vocalist_l1_TaylorSwift": 18.31,
+            "vocalist_l1_王菲": 16.78,
+            "vocalist_l1_MichaelJackson": 15.13,
+            "vocalist_l1_蔡琴": 10.12,
+            "vocalist_l1_BrunoMars": 6.29,
+            "vocalist_l1_Beyonce": 6.06
+        }
+    }
+}
\ No newline at end of file
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json
new file mode 100755
index 0000000000000000000000000000000000000000..472551c609057a7eb7ba05b1eceba3ebc0461ed4
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json
@@ -0,0 +1,242 @@
+{
+    "vocalist_l1_Adele": {
+        "voiced_positions": {
+            "mean": 336.5038018286193,
+            "std": 100.2148774476881,
+            "median": 332.98363792619296,
+            "min": 59.99838412340723,
+            "max": 1099.849325287837
+        },
+        "total_positions": {
+            "mean": 231.79366581704338,
+            "std": 176.6042850107386,
+            "median": 273.2844263775394,
+            "min": 0.0,
+            "max": 1099.849325287837
+        }
+    },
+    "vocalist_l1_Beyonce": {
+        "voiced_positions": {
+            "mean": 357.5678927636881,
+            "std": 130.1132620135807,
+            "median": 318.2981879228934,
+            "min": 70.29719673914867,
+            "max": 1050.354470112099
+        },
+        "total_positions": {
+            "mean": 267.5248026267327,
+            "std": 191.71600807951046,
+            "median": 261.91981963774066,
+            "min": 0.0,
+            "max": 1050.354470112099
+        }
+    },
+    "vocalist_l1_BrunoMars": {
+        "voiced_positions": {
+            "mean": 330.92612740814315,
+            "std": 86.51034158515388,
+            "median": 324.65585832605217,
+            "min": 58.74277302450286,
+            "max": 999.2818302992808
+        },
+        "total_positions": {
+            "mean": 237.26076288057826,
+            "std": 166.09898203490803,
+            "median": 286.3097386522132,
+            "min": 0.0,
+            "max": 999.2818302992808
+        }
+    },
+    "vocalist_l1_JohnMayer": {
+        "voiced_positions": {
+            "mean": 218.3531239166661,
+            "std": 77.89887175223768,
+            "median": 200.19060542586652,
+            "min": 53.371912740674716,
+            "max": 1098.1986774161685
+        },
+        "total_positions": {
+            "mean": 112.95331907131244,
+            "std": 122.65534824070893,
+            "median": 124.71389285965317,
+            "min": 0.0,
+            "max": 1098.1986774161685
+        }
+    },
+    "vocalist_l1_MichaelJackson": {
+        "voiced_positions": {
+            "mean": 293.4663654519906,
+            "std": 89.02211325650234,
+            "median": 284.4323483619402,
+            "min": 61.14507754070825,
+            "max": 1096.4247902272325
+        },
+        "total_positions": {
+            "mean": 172.1013565770682,
+            "std": 159.79551912957191,
+            "median": 212.82938711725973,
+            "min": 0.0,
+            "max": 1096.4247902272325
+        }
+    },
+    "vocalist_l1_TaylorSwift": {
+        "voiced_positions": {
+            "mean": 302.5346928039029,
+            "std": 87.1724728626562,
+            "median": 286.91670244246586,
+            "min": 51.31173137207717,
+            "max": 1098.9374311806605
+        },
+        "total_positions": {
+            "mean": 169.90968097339214,
+            "std": 163.7133164876362,
+            "median": 220.90943653386546,
+            "min": 0.0,
+            "max": 1098.9374311806605
+        }
+    },
+    "vocalist_l1_张学友": {
+        "voiced_positions": {
+            "mean": 233.6845479691867,
+            "std": 66.47140810463938,
+            "median": 228.28695118043396,
+            "min": 51.65338480121057,
+            "max": 1094.4381927885959
+        },
+        "total_positions": {
+            "mean": 167.79543637603194,
+            "std": 119.28338415844308,
+            "median": 194.81504136428546,
+            "min": 0.0,
+            "max": 1094.4381927885959
+        }
+    },
+    "vocalist_l1_李健": {
+        "voiced_positions": {
+            "mean": 234.98401896504657,
+            "std": 71.3955175177514,
+            "median": 221.86415264367847,
+            "min": 54.070687769392585,
+            "max": 1096.3342286660531
+        },
+        "total_positions": {
+            "mean": 148.74760079412246,
+            "std": 126.70486473504008,
+            "median": 180.21374566147688,
+            "min": 0.0,
+            "max": 1096.3342286660531
+        }
+    },
+    "vocalist_l1_汪峰": {
+        "voiced_positions": {
+            "mean": 284.27752567207864,
+            "std": 78.51774150654873,
+            "median": 278.26186808969493,
+            "min": 54.30945929095861,
+            "max": 1053.6870553733015
+        },
+        "total_positions": {
+            "mean": 172.41584497486713,
+            "std": 151.74272125914902,
+            "median": 216.27534661524862,
+            "min": 0.0,
+            "max": 1053.6870553733015
+        }
+    },
+    "vocalist_l1_王菲": {
+        "voiced_positions": {
+            "mean": 339.1661679865587,
+            "std": 86.86768172635271,
+            "median": 327.4151031268507,
+            "min": 51.21299842481366,
+            "max": 1096.7044574066776
+        },
+        "total_positions": {
+            "mean": 217.726880186,
+            "std": 176.8748978138034,
+            "median": 277.8608050501477,
+            "min": 0.0,
+            "max": 1096.7044574066776
+        }
+    },
+    "vocalist_l1_石倚洁": {
+        "voiced_positions": {
+            "mean": 279.67710779262256,
+            "std": 87.82306577322389,
+            "median": 271.13024912248443,
+            "min": 59.604772357481075,
+            "max": 1098.0574674417153
+        },
+        "total_positions": {
+            "mean": 205.49634806008135,
+            "std": 144.6064344590865,
+            "median": 234.19454400899718,
+            "min": 0.0,
+            "max": 1098.0574674417153
+        }
+    },
+    "vocalist_l1_蔡琴": {
+        "voiced_positions": {
+            "mean": 258.9105806499278,
+            "std": 67.4079737418162,
+            "median": 250.29778287949176,
+            "min": 54.81875790199644,
+            "max": 930.3733192171918
+        },
+        "total_positions": {
+            "mean": 197.64675891035662,
+            "std": 124.80889987119957,
+            "median": 228.14775033720753,
+            "min": 0.0,
+            "max": 930.3733192171918
+        }
+    },
+    "vocalist_l1_那英": {
+        "voiced_positions": {
+            "mean": 358.98655838013195,
+            "std": 91.30591323348871,
+            "median": 346.95185476261275,
+            "min": 71.62879029165369,
+            "max": 1085.4349856526985
+        },
+        "total_positions": {
+            "mean": 243.83317702162077,
+            "std": 183.68660712060583,
+            "median": 294.9745603259994,
+            "min": 0.0,
+            "max": 1085.4349856526985
+        }
+    },
+    "vocalist_l1_陈奕迅": {
+        "voiced_positions": {
+            "mean": 222.0124146654594,
+            "std": 68.65002654904572,
+            "median": 218.9200565540147,
+            "min": 50.48503062529368,
+            "max": 1084.6336454006018
+        },
+        "total_positions": {
+            "mean": 154.2275169157727,
+            "std": 117.16740631313343,
+            "median": 176.89315636838086,
+            "min": 0.0,
+            "max": 1084.6336454006018
+        }
+    },
+    "vocalist_l1_陶喆": {
+        "voiced_positions": {
+            "mean": 242.58206762395713,
+            "std": 69.61805791083957,
+            "median": 227.5222796096177,
+            "min": 50.44809060945403,
+            "max": 1098.4942623171203
+        },
+        "total_positions": {
+            "mean": 171.59040988406485,
+            "std": 124.93911390018495,
+            "median": 204.4328861811408,
+            "min": 0.0,
+            "max": 1098.4942623171203
+        }
+    }
+}
\ No newline at end of file
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0
new file mode 100755
index 0000000000000000000000000000000000000000..df0c6ae73d3d4df3a0c2856e4ddd75bfc4cc520b
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d7f490fd0c97876e24bfc44413365ded7ff5d22c1c79f0dac0b754f3b32df76f
+size 88
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1
new file mode 100755
index 0000000000000000000000000000000000000000..4ee06f708e74717bc23b3130ddcb6f82e5cf84ee
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e01bcf2fa621ba563b70568c18fe0742d0f48cafae83a6e8beb0bb6d1f6d146d
+size 77413046
diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json b/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json
new file mode 100755
index 0000000000000000000000000000000000000000..cd56250fa8be439b4ac6d2afe15fed300a69c973
--- /dev/null
+++ b/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json
@@ -0,0 +1,17 @@
+{
+    "vocalist_l1_Adele": 0,
+    "vocalist_l1_Beyonce": 1,
+    "vocalist_l1_BrunoMars": 2,
+    "vocalist_l1_JohnMayer": 3,
+    "vocalist_l1_MichaelJackson": 4,
+    "vocalist_l1_TaylorSwift": 5,
+    "vocalist_l1_张学友": 6,
+    "vocalist_l1_李健": 7,
+    "vocalist_l1_汪峰": 8,
+    "vocalist_l1_王菲": 9,
+    "vocalist_l1_石倚洁": 10,
+    "vocalist_l1_蔡琴": 11,
+    "vocalist_l1_那英": 12,
+    "vocalist_l1_陈奕迅": 13,
+    "vocalist_l1_陶喆": 14
+}
\ No newline at end of file
diff --git a/egs/svc/MultipleContentsSVC/README.md b/egs/svc/MultipleContentsSVC/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..ac999e6253076f79ca59ed05fac168e5679feaea
--- /dev/null
+++ b/egs/svc/MultipleContentsSVC/README.md
@@ -0,0 +1,153 @@
+# Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion
+
+[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2310.11160)
+[![demo](https://img.shields.io/badge/SVC-Demo-red)](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html)
+
+<br>
+<div align="center">
+<img src="../../../imgs/svc/MultipleContentsSVC.png" width="85%">
+</div>
+<br>
+
+This is the official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Specially,
+
+- The muptile content features are from [Whipser](https://github.com/wenet-e2e/wenet) and [ContentVec](https://github.com/auspicious3000/contentvec).
+- The acoustic model is based on Bidirectional Non-Causal Dilated CNN (called `DiffWaveNetSVC` in Amphion), which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
+- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture and we fine-tuned it in over 120 hours singing voice data.
+
+There are four stages in total:
+
+1. Data preparation
+2. Features extraction
+3. Training
+4. Inference/conversion
+
+> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
+> ```bash
+> cd Amphion
+> ```
+
+## 1. Data Preparation
+
+### Dataset Download
+
+By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md).
+
+### Configuration
+
+Specify the dataset paths in  `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
+
+```json
+    "dataset": [
+        "m4singer",
+        "opencpop",
+        "opensinger",
+        "svcc",
+        "vctk"
+    ],
+    "dataset_path": {
+        // TODO: Fill in your dataset path
+        "m4singer": "[M4Singer dataset path]",
+        "opencpop": "[Opencpop dataset path]",
+        "opensinger": "[OpenSinger dataset path]",
+        "svcc": "[SVCC dataset path]",
+        "vctk": "[VCTK dataset path]"
+    },
+```
+
+## 2. Features Extraction
+
+### Content-based Pretrained Models Download
+
+By default, we utilize the Whisper and ContentVec to extract content features. How to download them is detailed [here](../../../pretrained/README.md).
+
+### Configuration
+
+Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`:
+
+```json
+    // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
+    "log_dir": "ckpts/svc",
+    "preprocess": {
+        // TODO: Fill in the output data path. The default value is "Amphion/data"
+        "processed_dir": "data",
+        ...
+    },
+```
+
+### Run
+
+Run the `run.sh` as the preproces stage (set  `--stage 1`).
+
+```bash
+sh egs/svc/MultipleContentsSVC/run.sh --stage 1
+```
+
+> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
+
+## 3. Training
+
+### Configuration
+
+We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
+
+```json
+"train": {
+        "batch_size": 32,
+        ...
+        "adamw": {
+            "lr": 2.0e-4
+        },
+        ...
+    }
+```
+
+### Run
+
+Run the `run.sh` as the training stage (set  `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`.
+
+```bash
+sh egs/svc/MultipleContentsSVC/run.sh --stage 2 --name [YourExptName]
+```
+
+> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
+
+## 4. Inference/Conversion
+
+### Pretrained Vocoder Download
+
+We fine-tune the official BigVGAN pretrained model with over 120 hours singing voice data. The benifits of fine-tuning has been investigated in our paper (see this [demo page](https://www.zhangxueyao.com/data/MultipleContentsSVC/vocoder.html)). The final pretrained singing voice vocoder is released [here](../../../pretrained/README.md#amphion-singing-bigvgan) (called `Amphion Singing BigVGAN`).
+
+### Run
+
+For inference/conversion, you need to specify the following configurations when running `run.sh`:
+
+| Parameters                                          | Description                                                                                                                                | Example                                                                                                                                                                            |
+| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `--infer_expt_dir`                                  | The experimental directory which contains `checkpoint`                                                                                     | `Amphion/ckpts/svc/[YourExptName]`                                                                                                                                                 |
+| `--infer_output_dir`                                | The output directory to save inferred audios.                                                                                              | `Amphion/ckpts/svc/[YourExptName]/result`                                                                                                                                          |
+| `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir).                                                                                        | The `infer_source_file` could be `Amphion/data/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). |
+| `--infer_target_speaker`                            | The target speaker you want to convert into. You can refer to `Amphion/ckpts/svc/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`.                                                                                                                |
+| `--infer_key_shift`                                 | How many semitones you want to transpose.                                                                                                  | `"autoshfit"` (by default), `3`, `-3`, etc.                                                                                                                                        |
+
+For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run:
+
+```bash
+sh egs/svc/MultipleContentsSVC/run.sh --stage 3 --gpu "0" \
+	--infer_expt_dir Amphion/ckpts/svc/[YourExptName] \
+	--infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \
+	--infer_source_audio_dir [Your Audios Folder] \
+	--infer_target_speaker "opencpop_female1" \
+	--infer_key_shift "autoshift"
+```
+
+## Citations
+
+```bibtex
+@article{zhang2023leveraging,
+  title={Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion},
+  author={Zhang, Xueyao and Gu, Yicheng and Chen, Haopeng and Fang, Zihao and Zou, Lexiao and Xue, Liumeng and Wu, Zhizheng},
+  journal={Machine Learning for Audio Worshop, NeurIPS 2023},
+  year={2023}
+}
+```
diff --git a/egs/svc/MultipleContentsSVC/exp_config.json b/egs/svc/MultipleContentsSVC/exp_config.json
new file mode 100755
index 0000000000000000000000000000000000000000..7047855abd18c25760fcdd46ec63da5c4b7ad8ba
--- /dev/null
+++ b/egs/svc/MultipleContentsSVC/exp_config.json
@@ -0,0 +1,126 @@
+{
+    "base_config": "config/diffusion.json",
+    "model_type": "DiffWaveNetSVC",
+    "dataset": [
+        "m4singer",
+        "opencpop",
+        "opensinger",
+        "svcc",
+        "vctk"
+    ],
+    "dataset_path": {
+        // TODO: Fill in your dataset path
+        "m4singer": "[M4Singer dataset path]",
+        "opencpop": "[Opencpop dataset path]",
+        "opensinger": "[OpenSinger dataset path]",
+        "svcc": "[SVCC dataset path]",
+        "vctk": "[VCTK dataset path]"
+    },
+    // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
+    "log_dir": "ckpts/svc",
+    "preprocess": {
+        // TODO: Fill in the output data path. The default value is "Amphion/data"
+        "processed_dir": "data",
+        // Config for features extraction
+        "extract_mel": true,
+        "extract_pitch": true,
+        "extract_energy": true,
+        "extract_whisper_feature": true,
+        "extract_contentvec_feature": true,
+        "extract_wenet_feature": false,
+        "whisper_batch_size": 30, // decrease it if your GPU is out of memory
+        "contentvec_batch_size": 1,
+        // Fill in the content-based pretrained model's path
+        "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
+        "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
+        "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
+        "whisper_model": "medium",
+        "whisper_model_path": "pretrained/whisper/medium.pt",
+        // Config for features usage
+        "use_mel": true,
+        "use_min_max_norm_mel": true,
+        "use_frame_pitch": true,
+        "use_frame_energy": true,
+        "use_spkid": true,
+        "use_whisper": true,
+        "use_contentvec": true,
+        "use_wenet": false,
+        "n_mel": 100,
+        "sample_rate": 24000
+    },
+    "model": {
+        "condition_encoder": {
+            // Config for features usage
+            "use_whisper": true,
+            "use_contentvec": true,
+            "use_wenet": false,
+            "whisper_dim": 1024,
+            "contentvec_dim": 256,
+            "wenet_dim": 512,
+            "use_singer_encoder": false,
+            "pitch_min": 50,
+            "pitch_max": 1100
+        },
+        "diffusion": {
+            "scheduler": "ddpm",
+            "scheduler_settings": {
+                "num_train_timesteps": 1000,
+                "beta_start": 1.0e-4,
+                "beta_end": 0.02,
+                "beta_schedule": "linear"
+            },
+            // Diffusion steps encoder
+            "step_encoder": {
+                "dim_raw_embedding": 128,
+                "dim_hidden_layer": 512,
+                "activation": "SiLU",
+                "num_layer": 2,
+                "max_period": 10000
+            },
+            // Diffusion decoder
+            "model_type": "bidilconv",
+            // bidilconv, unet2d, TODO: unet1d
+            "bidilconv": {
+                "base_channel": 512,
+                "n_res_block": 40,
+                "conv_kernel_size": 3,
+                "dilation_cycle_length": 4,
+                // specially, 1 means no dilation
+                "conditioner_size": 384
+            }
+        }
+    },
+    "train": {
+        "batch_size": 32,
+        "gradient_accumulation_step": 1,
+        "max_epoch": -1, // -1 means no limit
+        "save_checkpoint_stride": [
+            3,
+            50
+        ],
+        "keep_last": [
+            3,
+            2
+        ],
+        "run_eval": [
+            true,
+            true
+        ],
+        "adamw": {
+            "lr": 2.0e-4
+        },
+        "reducelronplateau": {
+            "factor": 0.8,
+            "patience": 30,
+            "min_lr": 1.0e-4
+        },
+        "dataloader": {
+            "num_worker": 8,
+            "pin_memory": true
+        },
+        "sampler": {
+            "holistic_shuffle": false,
+            "drop_last": true
+        }
+    }
+}
\ No newline at end of file
diff --git a/egs/svc/MultipleContentsSVC/run.sh b/egs/svc/MultipleContentsSVC/run.sh
new file mode 120000
index 0000000000000000000000000000000000000000..f8daac3da463c177e36cdf041342566cc4243257
--- /dev/null
+++ b/egs/svc/MultipleContentsSVC/run.sh
@@ -0,0 +1 @@
+../_template/run.sh
\ No newline at end of file
diff --git a/egs/svc/README.md b/egs/svc/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..5961eaab3782ff96ddbb65a246527ab768498fa5
--- /dev/null
+++ b/egs/svc/README.md
@@ -0,0 +1,34 @@
+# Amphion Singing Voice Conversion (SVC) Recipe
+
+## Quick Start
+
+We provide a **[beginner recipe](MultipleContentsSVC)** to demonstrate how to train a cutting edge SVC model. Specifically, it is also an official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Some demos can be seen [here](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html).
+
+## Supported Model Architectures
+
+The main idea of SVC is to first disentangle the speaker-agnostic representations from the source audio, and then inject the desired speaker information to synthesize the target, which usually utilizes an acoustic decoder and a subsequent waveform synthesizer (vocoder):
+
+<br>
+<div align="center">
+  <img src="../../imgs/svc/pipeline.png" width="70%">
+</div>
+<br>
+
+Until now, Amphion SVC has supported the following features and models:
+
+- **Speaker-agnostic Representations**:
+  - Content Features: Sourcing from [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec).
+  - Prosody Features: F0 and energy.
+- **Speaker Embeddings**:
+  - Speaker Look-Up Table.
+  - Reference Encoder (👨‍💻 developing): It can be used for zero-shot SVC.
+- **Acoustic Decoders**:
+  - Diffusion-based models:
+    - **[DiffWaveNetSVC](MultipleContentsSVC)**: The encoder is based on Bidirectional Non-Causal Dilated CNN, which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
+    - **[DiffComoSVC](DiffComoSVC)** (👨‍💻 developing): The diffusion framework is based on [Consistency Model](https://proceedings.mlr.press/v202/song23a.html). It can significantly accelerate the inference process of the diffusion model.
+  - Transformer-based models:
+    - **[TransformerSVC](TransformerSVC)**: Encoder-only and Non-autoregressive Transformer Architecture.
+  - VAE- and Flow-based models:
+    - **[VitsSVC]()** (👨‍💻 developing): It is designed as a [VITS](https://arxiv.org/abs/2106.06103)-like model whose textual input is replaced by the content features, which is similar to [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc).
+- **Waveform Synthesizers (Vocoders)**:
+  - The supported vocoders can be seen in [Amphion Vocoder Recipe](../vocoder/README.md).
diff --git a/egs/svc/_template/run.sh b/egs/svc/_template/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8dc870fdef8b1464000021def5627f91d1676bbe
--- /dev/null
+++ b/egs/svc/_template/run.sh
@@ -0,0 +1,150 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+######## Build Experiment Environment ###########
+exp_dir=$(cd `dirname $0`; pwd)
+work_dir=$(dirname $(dirname $(dirname $exp_dir)))
+
+export WORK_DIR=$work_dir
+export PYTHONPATH=$work_dir
+export PYTHONIOENCODING=UTF-8
+
+######## Parse the Given Parameters from the Commond ###########
+options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir: -- "$@")
+eval set -- "$options"
+
+while true; do
+  case $1 in
+    # Experimental Configuration File
+    -c | --config) shift; exp_config=$1 ; shift ;;
+    # Experimental Name
+    -n | --name) shift; exp_name=$1 ; shift ;;
+    # Running Stage
+    -s | --stage) shift; running_stage=$1 ; shift ;;
+    # Visible GPU machines. The default value is "0".
+    --gpu) shift; gpu=$1 ; shift ;;
+
+    # [Only for Training] Resume configuration
+    --resume) shift; resume=$1 ; shift ;;
+    # [Only for Training] The specific checkpoint path that you want to resume from.
+    --resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;;
+    # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
+    --resume_type) shift; resume_type=$1 ; shift ;;
+
+    # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
+    --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
+    # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
+    --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
+    # [Only for Inference] The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir can be "$work_dir/source_audio" which includes several audio files (*.wav, *.mp3 or *.flac).
+    --infer_source_file) shift; infer_source_file=$1 ; shift ;;
+    --infer_source_audio_dir) shift; infer_source_audio_dir=$1 ; shift ;;
+    # [Only for Inference] Specify the target speaker you want to convert into. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1".
+    --infer_target_speaker) shift; infer_target_speaker=$1 ; shift ;;
+    # [Only for Inference] For advanced users, you can modify the trans_key parameters into an integer (which means the semitones you want to transpose). Its default value is "autoshift".
+    --infer_key_shift) shift; infer_key_shift=$1 ; shift ;;
+    # [Only for Inference] The vocoder dir. Its default value is Amphion/pretrained/bigvgan. See Amphion/pretrained/README.md to download the pretrained BigVGAN vocoders.
+    --infer_vocoder_dir) shift; infer_vocoder_dir=$1 ; shift ;;
+
+    --) shift ; break ;;
+    *) echo "Invalid option: $1" exit 1 ;;
+  esac
+done
+
+
+### Value check ###
+if [ -z "$running_stage" ]; then
+    echo "[Error] Please specify the running stage"
+    exit 1
+fi
+
+if [ -z "$exp_config" ]; then
+    exp_config="${exp_dir}"/exp_config.json
+fi
+echo "Exprimental Configuration File: $exp_config"
+
+if [ -z "$gpu" ]; then
+    gpu="0"
+fi
+
+######## Features Extraction ###########
+if [ $running_stage -eq 1 ]; then
+    CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/svc/preprocess.py \
+        --config $exp_config \
+        --num_workers 4
+fi
+
+######## Training ###########
+if [ $running_stage -eq 2 ]; then
+    if [ -z "$exp_name" ]; then
+        echo "[Error] Please specify the experiments name"
+        exit 1
+    fi
+    echo "Exprimental Name: $exp_name"
+
+    if [ "$resume" = true ]; then
+        echo "Automatically resume from the experimental dir..."
+        CUDA_VISIBLE_DEVICES="$gpu" accelerate launch "${work_dir}"/bins/svc/train.py \
+            --config "$exp_config" \
+            --exp_name "$exp_name" \
+            --log_level info \
+            --resume
+    else
+        CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/svc/train.py \
+            --config "$exp_config" \
+            --exp_name "$exp_name" \
+            --log_level info \
+            --resume_from_ckpt_path "$resume_from_ckpt_path" \
+            --resume_type "$resume_type"
+    fi
+fi
+
+######## Inference/Conversion ###########
+if [ $running_stage -eq 3 ]; then
+    if [ -z "$infer_expt_dir" ]; then
+        echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
+        exit 1
+    fi
+
+    if [ -z "$infer_output_dir" ]; then
+        infer_output_dir="$expt_dir/result"
+    fi
+
+    if [ -z "$infer_source_file" ] && [ -z "$infer_source_audio_dir" ]; then
+        echo "[Error] Please specify the source file/dir. The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir should include several audio files (*.wav, *.mp3 or *.flac)."
+        exit 1
+    fi
+
+    if [ -z "$infer_source_file" ]; then
+        infer_source=$infer_source_audio_dir
+    fi
+
+    if [ -z "$infer_source_audio_dir" ]; then
+        infer_source=$infer_source_file
+    fi
+
+    if [ -z "$infer_target_speaker" ]; then
+        echo "[Error] Please specify the target speaker. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1""
+        exit 1
+    fi
+
+    if [ -z "$infer_key_shift" ]; then
+        infer_key_shift="autoshift"
+    fi
+
+    if [ -z "$infer_vocoder_dir" ]; then
+        infer_vocoder_dir="$work_dir"/pretrained/bigvgan
+        echo "[Warning] You don't specify the infer_vocoder_dir. It is set $infer_vocoder_dir by default. Make sure that you have followed Amphoion/pretrained/README.md to download the pretrained BigVGAN vocoder checkpoint."
+    fi
+
+    CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/svc/inference.py \
+        --config $exp_config \
+        --acoustics_dir $infer_expt_dir \
+        --vocoder_dir $infer_vocoder_dir \
+        --target_singer $infer_target_speaker \
+        --trans_key $infer_key_shift \
+        --source $infer_source \
+        --output_dir $infer_output_dir  \
+        --log_level debug
+fi
\ No newline at end of file
diff --git a/inference.py b/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..729bbd0058b1ed399c793231ef645db106a071cf
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,258 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import glob
+from tqdm import tqdm
+import json
+import torch
+import time
+
+from models.svc.diffusion.diffusion_inference import DiffusionInference
+from models.svc.comosvc.comosvc_inference import ComoSVCInference
+from models.svc.transformer.transformer_inference import TransformerInference
+from utils.util import load_config
+from utils.audio_slicer import split_audio, merge_segments_encodec
+from processors import acoustic_extractor, content_extractor
+
+
+def build_inference(args, cfg, infer_type="from_dataset"):
+    supported_inference = {
+        "DiffWaveNetSVC": DiffusionInference,
+        "DiffComoSVC": ComoSVCInference,
+        "TransformerSVC": TransformerInference,
+    }
+
+    inference_class = supported_inference[cfg.model_type]
+    return inference_class(args, cfg, infer_type)
+
+
+def prepare_for_audio_file(args, cfg, num_workers=1):
+    preprocess_path = cfg.preprocess.processed_dir
+    audio_name = cfg.inference.source_audio_name
+    temp_audio_dir = os.path.join(preprocess_path, audio_name)
+
+    ### eval file
+    t = time.time()
+    eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name)
+    args.source = eval_file
+    with open(eval_file, "r") as f:
+        metadata = json.load(f)
+    print("Prepare for meta eval data: {:.1f}s".format(time.time() - t))
+
+    ### acoustic features
+    t = time.time()
+    acoustic_extractor.extract_utt_acoustic_features_serial(
+        metadata, temp_audio_dir, cfg
+    )
+    acoustic_extractor.cal_mel_min_max(
+        dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
+    )
+    acoustic_extractor.cal_pitch_statistics_svc(
+        dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
+    )
+    print("Prepare for acoustic features: {:.1f}s".format(time.time() - t))
+
+    ### content features
+    t = time.time()
+    content_extractor.extract_utt_content_features_dataloader(
+        cfg, metadata, num_workers
+    )
+    print("Prepare for content features: {:.1f}s".format(time.time() - t))
+    return args, cfg, temp_audio_dir
+
+
+def merge_for_audio_segments(audio_files, args, cfg):
+    audio_name = cfg.inference.source_audio_name
+    target_singer_name = args.target_singer
+
+    merge_segments_encodec(
+        wav_files=audio_files,
+        fs=cfg.preprocess.sample_rate,
+        output_path=os.path.join(
+            args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name)
+        ),
+        overlap_duration=cfg.inference.segments_overlap_duration,
+    )
+
+    for tmp_file in audio_files:
+        os.remove(tmp_file)
+
+
+def prepare_source_eval_file(cfg, temp_audio_dir, audio_name):
+    """
+    Prepare the eval file (json) for an audio
+    """
+
+    audio_chunks_results = split_audio(
+        wav_file=cfg.inference.source_audio_path,
+        target_sr=cfg.preprocess.sample_rate,
+        output_dir=os.path.join(temp_audio_dir, "wavs"),
+        max_duration_of_segment=cfg.inference.segments_max_duration,
+        overlap_duration=cfg.inference.segments_overlap_duration,
+    )
+
+    metadata = []
+    for i, res in enumerate(audio_chunks_results):
+        res["index"] = i
+        res["Dataset"] = audio_name
+        res["Singer"] = audio_name
+        res["Uid"] = "{}_{}".format(audio_name, res["Uid"])
+        metadata.append(res)
+
+    eval_file = os.path.join(temp_audio_dir, "eval.json")
+    with open(eval_file, "w") as f:
+        json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True)
+
+    return eval_file
+
+
+def cuda_relevant(deterministic=False):
+    torch.cuda.empty_cache()
+    # TF32 on Ampere and above
+    torch.backends.cuda.matmul.allow_tf32 = True
+    torch.backends.cudnn.enabled = True
+    torch.backends.cudnn.allow_tf32 = True
+    # Deterministic
+    torch.backends.cudnn.deterministic = deterministic
+    torch.backends.cudnn.benchmark = not deterministic
+    torch.use_deterministic_algorithms(deterministic)
+
+
+def infer(args, cfg, infer_type):
+    # Build inference
+    t = time.time()
+    trainer = build_inference(args, cfg, infer_type)
+    print("Model Init: {:.1f}s".format(time.time() - t))
+
+    # Run inference
+    t = time.time()
+    output_audio_files = trainer.inference()
+    print("Model inference: {:.1f}s".format(time.time() - t))
+    return output_audio_files
+
+
+def build_parser():
+    r"""Build argument parser for inference.py.
+    Anything else should be put in an extra config YAML file.
+    """
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="JSON/YAML file for configurations.",
+    )
+    parser.add_argument(
+        "--acoustics_dir",
+        type=str,
+        help="Acoustics model checkpoint directory. If a directory is given, "
+        "search for the latest checkpoint dir in the directory. If a specific "
+        "checkpoint dir is given, directly load the checkpoint.",
+    )
+    parser.add_argument(
+        "--vocoder_dir",
+        type=str,
+        required=True,
+        help="Vocoder checkpoint directory. Searching behavior is the same as "
+        "the acoustics one.",
+    )
+    parser.add_argument(
+        "--target_singer",
+        type=str,
+        required=True,
+        help="convert to a specific singer (e.g. --target_singers singer_id).",
+    )
+    parser.add_argument(
+        "--trans_key",
+        default=0,
+        help="0: no pitch shift; autoshift: pitch shift;  int: key shift.",
+    )
+    parser.add_argument(
+        "--source",
+        type=str,
+        default="source_audio",
+        help="Source audio file or directory. If a JSON file is given, "
+        "inference from dataset is applied. If a directory is given, "
+        "inference from all wav/flac/mp3 audio files in the directory is applied. "
+        "Default: inference from all wav/flac/mp3 audio files in ./source_audio",
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="conversion_results",
+        help="Output directory. Default: ./conversion_results",
+    )
+    parser.add_argument(
+        "--log_level",
+        type=str,
+        default="warning",
+        help="Logging level. Default: warning",
+    )
+    parser.add_argument(
+        "--keep_cache",
+        action="store_true",
+        default=True,
+        help="Keep cache files. Only applicable to inference from files.",
+    )
+    parser.add_argument(
+        "--diffusion_inference_steps",
+        type=int,
+        default=1000,
+        help="Number of inference steps. Only applicable to diffusion inference.",
+    )
+    return parser
+
+
+def main():
+    ### Parse arguments and config
+    args = build_parser().parse_args()
+    cfg = load_config(args.config)
+
+    # CUDA settings
+    cuda_relevant()
+
+    if os.path.isdir(args.source):
+        ### Infer from file
+
+        # Get all the source audio files (.wav, .flac, .mp3)
+        source_audio_dir = args.source
+        audio_list = []
+        for suffix in ["wav", "flac", "mp3"]:
+            audio_list += glob.glob(
+                os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True
+            )
+        print("There are {} source audios: ".format(len(audio_list)))
+
+        # Infer for every file as dataset
+        output_root_path = args.output_dir
+        for audio_path in tqdm(audio_list):
+            audio_name = audio_path.split("/")[-1].split(".")[0]
+            args.output_dir = os.path.join(output_root_path, audio_name)
+            print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name))
+
+            cfg.inference.source_audio_path = audio_path
+            cfg.inference.source_audio_name = audio_name
+            cfg.inference.segments_max_duration = 10.0
+            cfg.inference.segments_overlap_duration = 1.0
+
+            # Prepare metadata and features
+            args, cfg, cache_dir = prepare_for_audio_file(args, cfg)
+
+            # Infer from file
+            output_audio_files = infer(args, cfg, infer_type="from_file")
+
+            # Merge the split segments
+            merge_for_audio_segments(output_audio_files, args, cfg)
+
+            # Keep or remove caches
+            if not args.keep_cache:
+                os.removedirs(cache_dir)
+
+    else:
+        ### Infer from dataset
+        infer(args, cfg, infer_type="from_dataset")
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/base/__init__.py b/models/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe0221047a62e0b9b3ddd112c79a700c48834fd1
--- /dev/null
+++ b/models/base/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .new_trainer import BaseTrainer
+from .new_inference import BaseInference
diff --git a/models/base/base_dataset.py b/models/base/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7d9848bb08f29669a51f3dde200d31bafe1d8da
--- /dev/null
+++ b/models/base/base_dataset.py
@@ -0,0 +1,350 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+import torch.utils.data
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from processors.acoustic_extractor import cal_normalized_mel
+from text import text_to_sequence
+from text.text_token_collation import phoneIDCollation
+
+
+class BaseDataset(torch.utils.data.Dataset):
+    def __init__(self, cfg, dataset, is_valid=False):
+        """
+        Args:
+            cfg: config
+            dataset: dataset name
+            is_valid: whether to use train or valid dataset
+        """
+
+        assert isinstance(dataset, str)
+
+        # self.data_root = processed_data_dir
+        self.cfg = cfg
+        
+        processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+        meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+        self.metafile_path = os.path.join(processed_data_dir, meta_file)
+        self.metadata = self.get_metadata()
+
+        
+
+        '''
+        load spk2id and utt2spk from json file
+            spk2id: {spk1: 0, spk2: 1, ...}
+            utt2spk: {dataset_uid: spk1, ...}
+        '''
+        if cfg.preprocess.use_spkid:
+            spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
+            with open(spk2id_path, "r") as f:
+                self.spk2id = json.load(f)
+            
+            utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
+            self.utt2spk = dict()
+            with open(utt2spk_path, "r") as f:
+                for line in f.readlines():
+                    utt, spk = line.strip().split('\t')
+                    self.utt2spk[utt] = spk
+        
+
+        if cfg.preprocess.use_uv:
+            self.utt2uv_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+                self.utt2uv_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.uv_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_frame_pitch:
+            self.utt2frame_pitch_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2frame_pitch_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.pitch_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_frame_energy:
+            self.utt2frame_energy_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2frame_energy_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.energy_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_mel:
+            self.utt2mel_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2mel_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.mel_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_linear:
+            self.utt2linear_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2linear_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.linear_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_audio:
+            self.utt2audio_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2audio_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.audio_dir,
+                    uid + ".npy",
+                )
+        elif cfg.preprocess.use_label:
+            self.utt2label_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2label_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.label_dir,
+                    uid + ".npy",
+                )
+        elif cfg.preprocess.use_one_hot:
+            self.utt2one_hot_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2one_hot_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.one_hot_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_text or cfg.preprocess.use_phone:
+            self.utt2seq = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                if cfg.preprocess.use_text:
+                    text = utt_info["Text"]
+                    sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
+                elif cfg.preprocess.use_phone:
+                    # load phoneme squence from phone file
+                    phone_path = os.path.join(processed_data_dir, 
+                                            cfg.preprocess.phone_dir,
+                                            uid+'.phone'
+                                            )
+                    with open(phone_path, 'r') as fin:
+                        phones = fin.readlines()
+                        assert len(phones) == 1
+                        phones = phones[0].strip()
+                    phones_seq = phones.split(' ')
+
+                    phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
+                    sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
+
+                self.utt2seq[utt] = sequence
+
+        
+    def get_metadata(self):
+        with open(self.metafile_path, "r", encoding="utf-8") as f:
+            metadata = json.load(f)
+
+        return metadata
+
+    def get_dataset_name(self):
+        return self.metadata[0]["Dataset"]
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        single_feature = dict()
+
+        if self.cfg.preprocess.use_spkid:
+            single_feature["spk_id"] = np.array(
+                [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
+            )
+
+        if self.cfg.preprocess.use_mel:
+            mel = np.load(self.utt2mel_path[utt])
+            assert mel.shape[0] == self.cfg.preprocess.n_mel  # [n_mels, T]
+            if self.cfg.preprocess.use_min_max_norm_mel:
+                # do mel norm
+                mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = mel.shape[1]
+            single_feature["mel"] = mel.T  # [T, n_mels]
+
+        if self.cfg.preprocess.use_linear:
+            linear = np.load(self.utt2linear_path[utt])
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = linear.shape[1]
+            single_feature["linear"] = linear.T  # [T, n_linear]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            frame_pitch_path = self.utt2frame_pitch_path[utt]
+            frame_pitch = np.load(frame_pitch_path)
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_pitch)
+            aligned_frame_pitch = align_length(
+                frame_pitch, single_feature["target_len"]
+            )
+            single_feature["frame_pitch"] = aligned_frame_pitch
+
+            if self.cfg.preprocess.use_uv:
+                frame_uv_path = self.utt2uv_path[utt]
+                frame_uv = np.load(frame_uv_path)
+                aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
+                aligned_frame_uv = [
+                    0 if frame_uv else 1 for frame_uv in aligned_frame_uv
+                ]
+                aligned_frame_uv = np.array(aligned_frame_uv)
+                single_feature["frame_uv"] = aligned_frame_uv
+
+        if self.cfg.preprocess.use_frame_energy:
+            frame_energy_path = self.utt2frame_energy_path[utt]
+            frame_energy = np.load(frame_energy_path)
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_energy)
+            aligned_frame_energy = align_length(
+                frame_energy, single_feature["target_len"]
+            )
+            single_feature["frame_energy"] = aligned_frame_energy
+
+        if self.cfg.preprocess.use_audio:
+            audio = np.load(self.utt2audio_path[utt])
+            single_feature["audio"] = audio
+            single_feature["audio_len"] = audio.shape[0]
+
+        if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
+            single_feature["phone_seq"] = np.array(self.utt2seq[utt])
+            single_feature["phone_len"] = len(self.utt2seq[utt])
+
+        return single_feature
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+class BaseCollator(object):
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def __call__(self, batch):
+        packed_batch_features = dict()
+
+        # mel: [b, T, n_mels]
+        # frame_pitch, frame_energy: [1, T]
+        # target_len: [1]
+        # spk_id: [b, 1]
+        # mask: [b, T, 1]
+
+        for key in batch[0].keys():
+            if key == "target_len":
+                packed_batch_features["target_len"] = torch.LongTensor(
+                    [b["target_len"] for b in batch]
+                )
+                masks = [
+                    torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+                ]
+                packed_batch_features["mask"] = pad_sequence(
+                    masks, batch_first=True, padding_value=0
+                )
+            elif key == "phone_len":
+                packed_batch_features["phone_len"] = torch.LongTensor(
+                    [b["phone_len"] for b in batch]
+                )
+                masks = [
+                    torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
+                ]
+                packed_batch_features["phn_mask"] = pad_sequence(
+                    masks, batch_first=True, padding_value=0
+                )
+            elif key == "audio_len":
+                packed_batch_features["audio_len"] = torch.LongTensor(
+                    [b["audio_len"] for b in batch]
+                )
+                masks = [
+                    torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
+                ]
+            else:
+                values = [torch.from_numpy(b[key]) for b in batch]
+                packed_batch_features[key] = pad_sequence(
+                    values, batch_first=True, padding_value=0
+                )
+        return packed_batch_features
+
+
+class BaseTestDataset(torch.utils.data.Dataset):
+    def __init__(self, cfg, args):
+        raise NotImplementedError
+          
+
+    def get_metadata(self):
+        raise NotImplementedError
+
+    def __getitem__(self, index):
+        raise NotImplementedError
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+class BaseTestCollator(object):
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        raise NotImplementedError
+
+    def __call__(self, batch):
+        raise NotImplementedError
diff --git a/models/base/base_inference.py b/models/base/base_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2713f19a0d61f06bca1f01de5ccd8a3b4d2cc02f
--- /dev/null
+++ b/models/base/base_inference.py
@@ -0,0 +1,220 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import os
+import re
+import time
+from pathlib import Path
+
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from models.vocoders.vocoder_inference import synthesis
+from torch.utils.data import DataLoader
+from utils.util import set_all_random_seed
+from utils.util import load_config
+
+
+def parse_vocoder(vocoder_dir):
+    r"""Parse vocoder config"""
+    vocoder_dir = os.path.abspath(vocoder_dir)
+    ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+    ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
+    ckpt_path = str(ckpt_list[0])
+    vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
+    vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
+    return vocoder_cfg, ckpt_path
+
+
+class BaseInference(object):
+    def __init__(self, cfg, args):
+        self.cfg = cfg
+        self.args = args
+        self.model_type = cfg.model_type
+        self.avg_rtf = list()
+        set_all_random_seed(10086)
+        os.makedirs(args.output_dir, exist_ok=True)
+
+        if torch.cuda.is_available():
+            self.device = torch.device("cuda")
+        else:
+            self.device = torch.device("cpu")
+            torch.set_num_threads(10)  # inference on 1 core cpu.
+
+        # Load acoustic model
+        self.model = self.create_model().to(self.device)
+        state_dict = self.load_state_dict()
+        self.load_model(state_dict)
+        self.model.eval()
+
+        # Load vocoder model if necessary
+        if self.args.checkpoint_dir_vocoder is not None:
+            self.get_vocoder_info()
+
+    def create_model(self):
+        raise NotImplementedError
+
+    def load_state_dict(self):
+        self.checkpoint_file = self.args.checkpoint_file
+        if self.checkpoint_file is None:
+            assert self.args.checkpoint_dir is not None
+            checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
+            checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
+            self.checkpoint_file = os.path.join(
+                self.args.checkpoint_dir, checkpoint_filename
+            )
+
+        self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
+
+        print("Restore acoustic model from {}".format(self.checkpoint_file))
+        raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
+        self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
+
+        return raw_state_dict
+
+    def load_model(self, model):
+        raise NotImplementedError
+
+    def get_vocoder_info(self):
+        self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
+        self.vocoder_cfg = os.path.join(
+            os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
+        )
+        self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
+        self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
+        self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
+
+    def build_test_utt_data(self):
+        raise NotImplementedError
+
+    def build_testdata_loader(self, args, target_speaker=None):
+        datasets, collate = self.build_test_dataset()
+        self.test_dataset = datasets(self.cfg, args, target_speaker)
+        self.test_collate = collate(self.cfg)
+        self.test_batch_size = min(
+            self.cfg.train.batch_size, len(self.test_dataset.metadata)
+        )
+        test_loader = DataLoader(
+            self.test_dataset,
+            collate_fn=self.test_collate,
+            num_workers=self.args.num_workers,
+            batch_size=self.test_batch_size,
+            shuffle=False,
+        )
+        return test_loader
+
+    def inference_each_batch(self, batch_data):
+        raise NotImplementedError
+
+    def inference_for_batches(self, args, target_speaker=None):
+        ###### Construct test_batch ######
+        loader = self.build_testdata_loader(args, target_speaker)
+
+        n_batch = len(loader)
+        now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
+        print(
+            "Model eval time: {}, batch_size = {}, n_batch = {}".format(
+                now, self.test_batch_size, n_batch
+            )
+        )
+        self.model.eval()
+
+        ###### Inference for each batch ######
+        pred_res = []
+        with torch.no_grad():
+            for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
+                # Put the data to device
+                for k, v in batch_data.items():
+                    batch_data[k] = batch_data[k].to(self.device)
+
+                y_pred, stats = self.inference_each_batch(batch_data)
+
+                pred_res += y_pred
+
+        return pred_res
+
+    def inference(self, feature):
+        raise NotImplementedError
+
+    def synthesis_by_vocoder(self, pred):
+        audios_pred = synthesis(
+            self.vocoder_cfg,
+            self.checkpoint_dir_vocoder,
+            len(pred),
+            pred,
+        )
+        return audios_pred
+
+    def __call__(self, utt):
+        feature = self.build_test_utt_data(utt)
+        start_time = time.time()
+        with torch.no_grad():
+            outputs = self.inference(feature)[0]
+        time_used = time.time() - start_time
+        rtf = time_used / (
+            outputs.shape[1]
+            * self.cfg.preprocess.hop_size
+            / self.cfg.preprocess.sample_rate
+        )
+        print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
+        self.avg_rtf.append(rtf)
+        audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
+        return audios
+
+
+def base_parser():
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--config", default="config.json", help="json files for configurations."
+    )
+    parser.add_argument("--use_ddp_inference", default=False)
+    parser.add_argument("--n_workers", default=1, type=int)
+    parser.add_argument("--local_rank", default=-1, type=int)
+    parser.add_argument(
+        "--batch_size", default=1, type=int, help="Batch size for inference"
+    )
+    parser.add_argument(
+        "--num_workers",
+        default=1,
+        type=int,
+        help="Worker number for inference dataloader",
+    )
+    parser.add_argument(
+        "--checkpoint_dir",
+        type=str,
+        default=None,
+        help="Checkpoint dir including model file and configuration",
+    )
+    parser.add_argument(
+        "--checkpoint_file", help="checkpoint file", type=str, default=None
+    )
+    parser.add_argument(
+        "--test_list", help="test utterance list for testing", type=str, default=None
+    )
+    parser.add_argument(
+        "--checkpoint_dir_vocoder",
+        help="Vocoder's checkpoint dir including model file and configuration",
+        type=str,
+        default=None,
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default=None,
+        help="Output dir for saving generated results",
+    )
+    return parser
+
+
+if __name__ == "__main__":
+    parser = base_parser()
+    args = parser.parse_args()
+    cfg = load_config(args.config)
+
+    # Build inference
+    inference = BaseInference(cfg, args)
+    inference()
diff --git a/models/base/base_sampler.py b/models/base/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..149d1437eb1d3a00ca8c9895b150b39b2a3635fa
--- /dev/null
+++ b/models/base/base_sampler.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+
+from torch.utils.data import ConcatDataset, Dataset
+from torch.utils.data.sampler import (
+    BatchSampler,
+    RandomSampler,
+    Sampler,
+    SequentialSampler,
+)
+
+
+class ScheduledSampler(Sampler):
+    """A sampler that samples data from a given concat-dataset.
+
+    Args:
+        concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
+        batch_size (int): batch size
+        holistic_shuffle (bool): whether to shuffle the whole dataset or not
+        logger (logging.Logger): logger to print warning message
+
+    Usage:
+        For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
+        >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
+        [3, 4, 5, 0, 1, 2, 6, 7, 8]
+    """
+
+    def __init__(
+        self,
+        concat_dataset,
+        batch_size,
+        holistic_shuffle,
+        logger=None,
+        loader_type="train",
+    ):
+        if not isinstance(concat_dataset, ConcatDataset):
+            raise ValueError(
+                "concat_dataset must be an instance of ConcatDataset, but got {}".format(
+                    type(concat_dataset)
+                )
+            )
+        if not isinstance(batch_size, int):
+            raise ValueError(
+                "batch_size must be an integer, but got {}".format(type(batch_size))
+            )
+        if not isinstance(holistic_shuffle, bool):
+            raise ValueError(
+                "holistic_shuffle must be a boolean, but got {}".format(
+                    type(holistic_shuffle)
+                )
+            )
+
+        self.concat_dataset = concat_dataset
+        self.batch_size = batch_size
+        self.holistic_shuffle = holistic_shuffle
+
+        affected_dataset_name = []
+        affected_dataset_len = []
+        for dataset in concat_dataset.datasets:
+            dataset_len = len(dataset)
+            dataset_name = dataset.get_dataset_name()
+            if dataset_len < batch_size:
+                affected_dataset_name.append(dataset_name)
+                affected_dataset_len.append(dataset_len)
+
+        self.type = loader_type
+        for dataset_name, dataset_len in zip(
+            affected_dataset_name, affected_dataset_len
+        ):
+            if not loader_type == "valid":
+                logger.warning(
+                    "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
+                        loader_type, dataset_name, dataset_len, batch_size
+                    )
+                )
+
+    def __len__(self):
+        # the number of batches with drop last
+        num_of_batches = sum(
+            [
+                math.floor(len(dataset) / self.batch_size)
+                for dataset in self.concat_dataset.datasets
+            ]
+        )
+        # if samples are not enough for one batch, we don't drop last
+        if self.type == "valid" and num_of_batches < 1:
+            return len(self.concat_dataset)
+        return num_of_batches * self.batch_size
+
+    def __iter__(self):
+        iters = []
+        for dataset in self.concat_dataset.datasets:
+            iters.append(
+                SequentialSampler(dataset).__iter__()
+                if not self.holistic_shuffle
+                else RandomSampler(dataset).__iter__()
+            )
+        # e.g. [0, 200, 400]
+        init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
+        output_batches = []
+        for dataset_idx in range(len(self.concat_dataset.datasets)):
+            cur_batch = []
+            for idx in iters[dataset_idx]:
+                cur_batch.append(idx + init_indices[dataset_idx])
+                if len(cur_batch) == self.batch_size:
+                    output_batches.append(cur_batch)
+                    cur_batch = []
+            # if loader_type is valid, we don't need to drop last
+            if self.type == "valid" and len(cur_batch) > 0:
+                output_batches.append(cur_batch)
+
+        # force drop last in training
+        random.shuffle(output_batches)
+        output_indices = [item for sublist in output_batches for item in sublist]
+        return iter(output_indices)
+
+
+def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
+    sampler = ScheduledSampler(
+        concat_dataset,
+        cfg.train.batch_size,
+        cfg.train.sampler.holistic_shuffle,
+        logger,
+        loader_type,
+    )
+    batch_sampler = BatchSampler(
+        sampler,
+        cfg.train.batch_size,
+        cfg.train.sampler.drop_last if not loader_type == "valid" else False,
+    )
+    return sampler, batch_sampler
diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8782216dc13ce5d9de05ae790faeb82cf7cfd501
--- /dev/null
+++ b/models/base/base_trainer.py
@@ -0,0 +1,348 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import collections
+import json
+import os
+import sys
+import time
+
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import ConcatDataset, DataLoader
+from torch.utils.tensorboard import SummaryWriter
+
+from models.base.base_sampler import BatchSampler
+from utils.util import (
+    Logger,
+    remove_older_ckpt,
+    save_config,
+    set_all_random_seed,
+    ValueWindow,
+)
+
+
+class BaseTrainer(object):
+    def __init__(self, args, cfg):
+        self.args = args
+        self.log_dir = args.log_dir
+        self.cfg = cfg
+
+        self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
+        os.makedirs(self.checkpoint_dir, exist_ok=True)
+        if not cfg.train.ddp or args.local_rank == 0:
+            self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
+            self.logger = self.build_logger()
+        self.time_window = ValueWindow(50)
+
+        self.step = 0
+        self.epoch = -1
+        self.max_epochs = self.cfg.train.epochs
+        self.max_steps = self.cfg.train.max_steps
+
+        # set random seed & init distributed training
+        set_all_random_seed(self.cfg.train.random_seed)
+        if cfg.train.ddp:
+            dist.init_process_group(backend="nccl")
+
+        if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
+            self.singers = self.build_singers_lut()
+
+        # setup data_loader
+        self.data_loader = self.build_data_loader()
+
+        # setup model & enable distributed training
+        self.model = self.build_model()
+        print(self.model)
+
+        if isinstance(self.model, dict):
+            for key, value in self.model.items():
+                value.cuda(self.args.local_rank)
+                if key == "PQMF":
+                    continue
+                if cfg.train.ddp:
+                    self.model[key] = DistributedDataParallel(
+                        value, device_ids=[self.args.local_rank]
+                    )
+        else:
+            self.model.cuda(self.args.local_rank)
+            if cfg.train.ddp:
+                self.model = DistributedDataParallel(
+                    self.model, device_ids=[self.args.local_rank]
+                )
+
+        # create criterion
+        self.criterion = self.build_criterion()
+        if isinstance(self.criterion, dict):
+            for key, value in self.criterion.items():
+                self.criterion[key].cuda(args.local_rank)
+        else:
+            self.criterion.cuda(self.args.local_rank)
+
+        # optimizer
+        self.optimizer = self.build_optimizer()
+        self.scheduler = self.build_scheduler()
+
+        # save config file
+        self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
+
+    def build_logger(self):
+        log_file = os.path.join(self.checkpoint_dir, "train.log")
+        logger = Logger(log_file, level=self.args.log_level).logger
+
+        return logger
+
+    def build_dataset(self):
+        raise NotImplementedError
+
+    def build_data_loader(self):
+        Dataset, Collator = self.build_dataset()
+        # build dataset instance for each dataset and combine them by ConcatDataset
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=False)
+            datasets_list.append(subdataset)
+        train_dataset = ConcatDataset(datasets_list)
+
+        train_collate = Collator(self.cfg)
+        # TODO: multi-GPU training
+        if self.cfg.train.ddp:
+            raise NotImplementedError("DDP is not supported yet.")
+
+        # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
+        batch_sampler = BatchSampler(
+            cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
+        )
+
+        # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
+        train_loader = DataLoader(
+            train_dataset,
+            collate_fn=train_collate,
+            num_workers=self.args.num_workers,
+            batch_sampler=batch_sampler,
+            pin_memory=False,
+        )
+        if not self.cfg.train.ddp or self.args.local_rank == 0:
+            datasets_list = []
+            for dataset in self.cfg.dataset:
+                subdataset = Dataset(self.cfg, dataset, is_valid=True)
+                datasets_list.append(subdataset)
+            valid_dataset = ConcatDataset(datasets_list)
+            valid_collate = Collator(self.cfg)
+            batch_sampler = BatchSampler(
+                cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
+            )
+            valid_loader = DataLoader(
+                valid_dataset,
+                collate_fn=valid_collate,
+                num_workers=1,
+                batch_sampler=batch_sampler,
+            )
+        else:
+            raise NotImplementedError("DDP is not supported yet.")
+            # valid_loader = None
+        data_loader = {"train": train_loader, "valid": valid_loader}
+        return data_loader
+
+    def build_singers_lut(self):
+        # combine singers
+        if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
+            singers = collections.OrderedDict()
+        else:
+            with open(
+                os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
+            ) as singer_file:
+                singers = json.load(singer_file)
+        singer_count = len(singers)
+        for dataset in self.cfg.dataset:
+            singer_lut_path = os.path.join(
+                self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
+            )
+            with open(singer_lut_path, "r") as singer_lut_path:
+                singer_lut = json.load(singer_lut_path)
+            for singer in singer_lut.keys():
+                if singer not in singers:
+                    singers[singer] = singer_count
+                    singer_count += 1
+        with open(
+            os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
+        ) as singer_file:
+            json.dump(singers, singer_file, indent=4, ensure_ascii=False)
+        print(
+            "singers have been dumped to {}".format(
+                os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
+            )
+        )
+        return singers
+
+    def build_model(self):
+        raise NotImplementedError()
+
+    def build_optimizer(self):
+        raise NotImplementedError
+
+    def build_scheduler(self):
+        raise NotImplementedError()
+
+    def build_criterion(self):
+        raise NotImplementedError
+
+    def get_state_dict(self):
+        raise NotImplementedError
+
+    def save_config_file(self):
+        save_config(self.config_save_path, self.cfg)
+
+    # TODO, save without module.
+    def save_checkpoint(self, state_dict, saved_model_path):
+        torch.save(state_dict, saved_model_path)
+
+    def load_checkpoint(self):
+        checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
+        assert os.path.exists(checkpoint_path)
+        checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
+        model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
+        assert os.path.exists(model_path)
+        if not self.cfg.train.ddp or self.args.local_rank == 0:
+            self.logger.info(f"Re(store) from {model_path}")
+        checkpoint = torch.load(model_path, map_location="cpu")
+        return checkpoint
+
+    def load_model(self, checkpoint):
+        raise NotImplementedError
+
+    def restore(self):
+        checkpoint = self.load_checkpoint()
+        self.load_model(checkpoint)
+
+    def train_step(self, data):
+        raise NotImplementedError(
+            f"Need to implement function {sys._getframe().f_code.co_name} in "
+            f"your sub-class of {self.__class__.__name__}. "
+        )
+
+    @torch.no_grad()
+    def eval_step(self):
+        raise NotImplementedError(
+            f"Need to implement function {sys._getframe().f_code.co_name} in "
+            f"your sub-class of {self.__class__.__name__}. "
+        )
+
+    def write_summary(self, losses, stats):
+        raise NotImplementedError(
+            f"Need to implement function {sys._getframe().f_code.co_name} in "
+            f"your sub-class of {self.__class__.__name__}. "
+        )
+
+    def write_valid_summary(self, losses, stats):
+        raise NotImplementedError(
+            f"Need to implement function {sys._getframe().f_code.co_name} in "
+            f"your sub-class of {self.__class__.__name__}. "
+        )
+
+    def echo_log(self, losses, mode="Training"):
+        message = [
+            "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
+                mode, self.epoch + 1, self.step, self.time_window.average
+            )
+        ]
+
+        for key in sorted(losses.keys()):
+            if isinstance(losses[key], dict):
+                for k, v in losses[key].items():
+                    message.append(
+                        str(k).split("/")[-1] + "=" + str(round(float(v), 5))
+                    )
+            else:
+                message.append(
+                    str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
+                )
+        self.logger.info(", ".join(message))
+
+    def eval_epoch(self):
+        self.logger.info("Validation...")
+        valid_losses = {}
+        for i, batch_data in enumerate(self.data_loader["valid"]):
+            for k, v in batch_data.items():
+                if isinstance(v, torch.Tensor):
+                    batch_data[k] = v.cuda()
+            valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
+            for key in valid_loss:
+                if key not in valid_losses:
+                    valid_losses[key] = 0
+                valid_losses[key] += valid_loss[key]
+
+        # Add mel and audio to the Tensorboard
+        # Average loss
+        for key in valid_losses:
+            valid_losses[key] /= i + 1
+        self.echo_log(valid_losses, "Valid")
+        return valid_losses, valid_stats
+
+    def train_epoch(self):
+        for i, batch_data in enumerate(self.data_loader["train"]):
+            start_time = time.time()
+            # Put the data to cuda device
+            for k, v in batch_data.items():
+                if isinstance(v, torch.Tensor):
+                    batch_data[k] = v.cuda(self.args.local_rank)
+
+            # Training step
+            train_losses, train_stats, total_loss = self.train_step(batch_data)
+            self.time_window.append(time.time() - start_time)
+
+            if self.args.local_rank == 0 or not self.cfg.train.ddp:
+                if self.step % self.args.stdout_interval == 0:
+                    self.echo_log(train_losses, "Training")
+
+                if self.step % self.cfg.train.save_summary_steps == 0:
+                    self.logger.info(f"Save summary as step {self.step}")
+                    self.write_summary(train_losses, train_stats)
+
+                if (
+                    self.step % self.cfg.train.save_checkpoints_steps == 0
+                    and self.step != 0
+                ):
+                    saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
+                        self.step, total_loss
+                    )
+                    saved_model_path = os.path.join(
+                        self.checkpoint_dir, saved_model_name
+                    )
+                    saved_state_dict = self.get_state_dict()
+                    self.save_checkpoint(saved_state_dict, saved_model_path)
+                    self.save_config_file()
+                    # keep max n models
+                    remove_older_ckpt(
+                        saved_model_name,
+                        self.checkpoint_dir,
+                        max_to_keep=self.cfg.train.keep_checkpoint_max,
+                    )
+
+                if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
+                    if isinstance(self.model, dict):
+                        for key in self.model.keys():
+                            self.model[key].eval()
+                    else:
+                        self.model.eval()
+                    # Evaluate one epoch and get average loss
+                    valid_losses, valid_stats = self.eval_epoch()
+                    if isinstance(self.model, dict):
+                        for key in self.model.keys():
+                            self.model[key].train()
+                    else:
+                        self.model.train()
+                    # Write validation losses to summary.
+                    self.write_valid_summary(valid_losses, valid_stats)
+            self.step += 1
+
+    def train(self):
+        for epoch in range(max(0, self.epoch), self.max_epochs):
+            self.train_epoch()
+            self.epoch += 1
+            if self.step > self.max_steps:
+                self.logger.info("Training finished!")
+                break
diff --git a/models/base/new_dataset.py b/models/base/new_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2201bb4132ab86d1110092d7ab9e509296367a22
--- /dev/null
+++ b/models/base/new_dataset.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+from abc import abstractmethod
+from pathlib import Path
+
+import json5
+import torch
+import yaml
+
+
+# TODO: for training and validating
+class BaseDataset(torch.utils.data.Dataset):
+    r"""Base dataset for training and validating."""
+
+    def __init__(self, args, cfg, is_valid=False):
+        pass
+
+
+class BaseTestDataset(torch.utils.data.Dataset):
+    r"""Test dataset for inference."""
+
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        assert infer_type in ["from_dataset", "from_file"]
+
+        self.args = args
+        self.cfg = cfg
+        self.infer_type = infer_type
+
+    @abstractmethod
+    def __getitem__(self, index):
+        pass
+
+    def __len__(self):
+        return len(self.metadata)
+
+    def get_metadata(self):
+        path = Path(self.args.source)
+        if path.suffix == ".json" or path.suffix == ".jsonc":
+            metadata = json5.load(open(self.args.source, "r"))
+        elif path.suffix == ".yaml" or path.suffix == ".yml":
+            metadata = yaml.full_load(open(self.args.source, "r"))
+        else:
+            raise ValueError(f"Unsupported file type: {path.suffix}")
+
+        return metadata
diff --git a/models/base/new_inference.py b/models/base/new_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..4813fca4aba192fb8737dd74f37f6d430e1909a4
--- /dev/null
+++ b/models/base/new_inference.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import random
+import re
+import time
+from abc import abstractmethod
+from pathlib import Path
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+
+from models.vocoders.vocoder_inference import synthesis
+from utils.io import save_audio
+from utils.util import load_config
+from utils.audio_slicer import is_silence
+
+EPS = 1.0e-12
+
+
+class BaseInference(object):
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        super().__init__()
+
+        start = time.monotonic_ns()
+        self.args = args
+        self.cfg = cfg
+
+        assert infer_type in ["from_dataset", "from_file"]
+        self.infer_type = infer_type
+
+        # init with accelerate
+        self.accelerator = accelerate.Accelerator()
+        self.accelerator.wait_for_everyone()
+
+        # Use accelerate logger for distributed inference
+        with self.accelerator.main_process_first():
+            self.logger = get_logger("inference", log_level=args.log_level)
+
+        # Log some info
+        self.logger.info("=" * 56)
+        self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+        self.logger.info("=" * 56)
+        self.logger.info("\n")
+        self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+
+        self.acoustics_dir = args.acoustics_dir
+        self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
+        self.vocoder_dir = args.vocoder_dir
+        self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+        # should be in svc inferencer
+        # self.target_singer = args.target_singer
+        # self.logger.info(f"Target singers: {args.target_singer}")
+        # self.trans_key = args.trans_key
+        # self.logger.info(f"Trans key: {args.trans_key}")
+
+        os.makedirs(args.output_dir, exist_ok=True)
+
+        # set random seed
+        with self.accelerator.main_process_first():
+            start = time.monotonic_ns()
+            self._set_random_seed(self.cfg.train.random_seed)
+            end = time.monotonic_ns()
+            self.logger.debug(
+                f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+            )
+            self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+        # setup data_loader
+        with self.accelerator.main_process_first():
+            self.logger.info("Building dataset...")
+            start = time.monotonic_ns()
+            self.test_dataloader = self._build_dataloader()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+        # setup model
+        with self.accelerator.main_process_first():
+            self.logger.info("Building model...")
+            start = time.monotonic_ns()
+            self.model = self._build_model()
+            end = time.monotonic_ns()
+            # self.logger.debug(self.model)
+            self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+        # init with accelerate
+        self.logger.info("Initializing accelerate...")
+        start = time.monotonic_ns()
+        self.accelerator = accelerate.Accelerator()
+        self.model = self.accelerator.prepare(self.model)
+        end = time.monotonic_ns()
+        self.accelerator.wait_for_everyone()
+        self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+        with self.accelerator.main_process_first():
+            self.logger.info("Loading checkpoint...")
+            start = time.monotonic_ns()
+            # TODO: Also, suppose only use latest one yet
+            self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
+            end = time.monotonic_ns()
+            self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+        self.model.eval()
+        self.accelerator.wait_for_everyone()
+
+    ### Abstract methods ###
+    @abstractmethod
+    def _build_test_dataset(self):
+        pass
+
+    @abstractmethod
+    def _build_model(self):
+        pass
+
+    @abstractmethod
+    @torch.inference_mode()
+    def _inference_each_batch(self, batch_data):
+        pass
+
+    ### Abstract methods end ###
+
+    @torch.inference_mode()
+    def inference(self):
+        for i, batch in enumerate(self.test_dataloader):
+            y_pred = self._inference_each_batch(batch).cpu()
+            mel_min, mel_max = self.test_dataset.target_mel_extrema
+            y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
+            y_ls = y_pred.chunk(self.test_batch_size)
+            tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
+            j = 0
+            for it, l in zip(y_ls, tgt_ls):
+                l = l.item()
+                it = it.squeeze(0)[:l]
+                uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+                torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
+                j += 1
+
+        vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
+
+        res = synthesis(
+            cfg=vocoder_cfg,
+            vocoder_weight_file=vocoder_ckpt,
+            n_samples=None,
+            pred=[
+                torch.load(
+                    os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
+                ).numpy(force=True)
+                for i in self.test_dataset.metadata
+            ],
+        )
+
+        output_audio_files = []
+        for it, wav in zip(self.test_dataset.metadata, res):
+            uid = it["Uid"]
+            file = os.path.join(self.args.output_dir, f"{uid}.wav")
+            output_audio_files.append(file)
+
+            wav = wav.numpy(force=True)
+            save_audio(
+                file,
+                wav,
+                self.cfg.preprocess.sample_rate,
+                add_silence=False,
+                turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
+            )
+            os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
+
+        return sorted(output_audio_files)
+
+    # TODO: LEGACY CODE
+    def _build_dataloader(self):
+        datasets, collate = self._build_test_dataset()
+        self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
+        self.test_collate = collate(self.cfg)
+        self.test_batch_size = min(
+            self.cfg.train.batch_size, len(self.test_dataset.metadata)
+        )
+        test_dataloader = DataLoader(
+            self.test_dataset,
+            collate_fn=self.test_collate,
+            num_workers=1,
+            batch_size=self.test_batch_size,
+            shuffle=False,
+        )
+        return test_dataloader
+
+    def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
+        r"""Load model from checkpoint. If checkpoint_path is None, it will
+        load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+        None, it will load the checkpoint specified by checkpoint_path. **Only use this
+        method after** ``accelerator.prepare()``.
+        """
+        if checkpoint_path is None:
+            ls = []
+            for i in Path(checkpoint_dir).iterdir():
+                if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
+                    ls.append(i)
+            ls.sort(
+                key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
+            )
+            checkpoint_path = ls[0]
+        else:
+            checkpoint_path = Path(checkpoint_path)
+        self.accelerator.load_state(str(checkpoint_path))
+        # set epoch and step
+        self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
+        self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
+        return str(checkpoint_path)
+
+    @staticmethod
+    def _set_random_seed(seed):
+        r"""Set random seed for all possible random modules."""
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    @staticmethod
+    def _parse_vocoder(vocoder_dir):
+        r"""Parse vocoder config"""
+        vocoder_dir = os.path.abspath(vocoder_dir)
+        ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
+        ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
+        ckpt_path = str(ckpt_list[0])
+        vocoder_cfg = load_config(
+            os.path.join(vocoder_dir, "args.json"), lowercase=True
+        )
+        return vocoder_cfg, ckpt_path
+
+    @staticmethod
+    def __count_parameters(model):
+        return sum(p.numel() for p in model.parameters())
+
+    def __dump_cfg(self, path):
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        json5.dump(
+            self.cfg,
+            open(path, "w"),
+            indent=4,
+            sort_keys=True,
+            ensure_ascii=False,
+            quote_keys=True,
+        )
diff --git a/models/base/new_trainer.py b/models/base/new_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d013d2bc2f2e47e5c7646cac5c63cc88c04486b
--- /dev/null
+++ b/models/base/new_trainer.py
@@ -0,0 +1,722 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+import random
+import shutil
+import time
+from abc import abstractmethod
+from pathlib import Path
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import ConcatDataset, DataLoader
+from tqdm import tqdm
+
+from models.base.base_sampler import build_samplers
+from optimizer.optimizers import NoamLR
+
+
+class BaseTrainer(object):
+    r"""The base trainer for all tasks. Any trainer should inherit from this class."""
+
+    def __init__(self, args=None, cfg=None):
+        super().__init__()
+
+        self.args = args
+        self.cfg = cfg
+
+        cfg.exp_name = args.exp_name
+
+        # init with accelerate
+        self._init_accelerator()
+        self.accelerator.wait_for_everyone()
+
+        # Use accelerate logger for distributed training
+        with self.accelerator.main_process_first():
+            self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+        # Log some info
+        self.logger.info("=" * 56)
+        self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+        self.logger.info("=" * 56)
+        self.logger.info("\n")
+        self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+        self.logger.info(f"Experiment name: {args.exp_name}")
+        self.logger.info(f"Experiment directory: {self.exp_dir}")
+        self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+        if self.accelerator.is_main_process:
+            os.makedirs(self.checkpoint_dir, exist_ok=True)
+        self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+        # init counts
+        self.batch_count: int = 0
+        self.step: int = 0
+        self.epoch: int = 0
+        self.max_epoch = (
+            self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+        )
+        self.logger.info(
+            "Max epoch: {}".format(
+                self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+            )
+        )
+
+        # Check values
+        if self.accelerator.is_main_process:
+            self.__check_basic_configs()
+            # Set runtime configs
+            self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+            self.checkpoints_path = [
+                [] for _ in range(len(self.save_checkpoint_stride))
+            ]
+            self.keep_last = [
+                i if i > 0 else float("inf") for i in self.cfg.train.keep_last
+            ]
+            self.run_eval = self.cfg.train.run_eval
+
+        # set random seed
+        with self.accelerator.main_process_first():
+            start = time.monotonic_ns()
+            self._set_random_seed(self.cfg.train.random_seed)
+            end = time.monotonic_ns()
+            self.logger.debug(
+                f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+            )
+            self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+        # setup data_loader
+        with self.accelerator.main_process_first():
+            self.logger.info("Building dataset...")
+            start = time.monotonic_ns()
+            self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+        # setup model
+        with self.accelerator.main_process_first():
+            self.logger.info("Building model...")
+            start = time.monotonic_ns()
+            self.model = self._build_model()
+            end = time.monotonic_ns()
+            self.logger.debug(self.model)
+            self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+            self.logger.info(
+                f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
+            )
+        # optimizer & scheduler
+        with self.accelerator.main_process_first():
+            self.logger.info("Building optimizer and scheduler...")
+            start = time.monotonic_ns()
+            self.optimizer = self.__build_optimizer()
+            self.scheduler = self.__build_scheduler()
+            end = time.monotonic_ns()
+            self.logger.info(
+                f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+            )
+
+        # accelerate prepare
+        self.logger.info("Initializing accelerate...")
+        start = time.monotonic_ns()
+        (
+            self.train_dataloader,
+            self.valid_dataloader,
+            self.model,
+            self.optimizer,
+            self.scheduler,
+        ) = self.accelerator.prepare(
+            self.train_dataloader,
+            self.valid_dataloader,
+            self.model,
+            self.optimizer,
+            self.scheduler,
+        )
+        end = time.monotonic_ns()
+        self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+        # create criterion
+        with self.accelerator.main_process_first():
+            self.logger.info("Building criterion...")
+            start = time.monotonic_ns()
+            self.criterion = self._build_criterion()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+        # Resume or Finetune
+        with self.accelerator.main_process_first():
+            if args.resume:
+                ## Automatically resume according to the current exprimental name
+                self.logger.info("Resuming from {}...".format(self.checkpoint_dir))
+                start = time.monotonic_ns()
+                ckpt_path = self.__load_model(
+                    checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
+                )
+                end = time.monotonic_ns()
+                self.logger.info(
+                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+                )
+                self.checkpoints_path = json.load(
+                    open(os.path.join(ckpt_path, "ckpts.json"), "r")
+                )
+            elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "":
+                ## Resume from the given checkpoint path
+                if not os.path.exists(args.resume_from_ckpt_path):
+                    raise ValueError(
+                        "[Error] The resumed checkpoint path {} don't exist.".format(
+                            args.resume_from_ckpt_path
+                        )
+                    )
+
+                self.logger.info(
+                    "Resuming from {}...".format(args.resume_from_ckpt_path)
+                )
+                start = time.monotonic_ns()
+                ckpt_path = self.__load_model(
+                    checkpoint_path=args.resume_from_ckpt_path,
+                    resume_type=args.resume_type,
+                )
+                end = time.monotonic_ns()
+                self.logger.info(
+                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+                )
+
+        # save config file path
+        self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+    ### Following are abstract methods that should be implemented in child classes ###
+    @abstractmethod
+    def _build_dataset(self):
+        r"""Build dataset for model training/validating/evaluating."""
+        pass
+
+    @staticmethod
+    @abstractmethod
+    def _build_criterion():
+        r"""Build criterion function for model loss calculation."""
+        pass
+
+    @abstractmethod
+    def _build_model(self):
+        r"""Build model for training/validating/evaluating."""
+        pass
+
+    @abstractmethod
+    def _forward_step(self, batch):
+        r"""One forward step of the neural network. This abstract method is trying to
+        unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
+        However, for special case that using different forward step pattern for
+        training and validating, you could just override this method with ``pass`` and
+        implement ``_train_step`` and ``_valid_step`` separately.
+        """
+        pass
+
+    @abstractmethod
+    def _save_auxiliary_states(self):
+        r"""To save some auxiliary states when saving model's ckpt"""
+        pass
+
+    ### Abstract methods end ###
+
+    ### THIS IS MAIN ENTRY ###
+    def train_loop(self):
+        r"""Training loop. The public entry of training process."""
+        # Wait everyone to prepare before we move on
+        self.accelerator.wait_for_everyone()
+        # dump config file
+        if self.accelerator.is_main_process:
+            self.__dump_cfg(self.config_save_path)
+        self.model.train()
+        self.optimizer.zero_grad()
+        # Wait to ensure good to go
+        self.accelerator.wait_for_everyone()
+        while self.epoch < self.max_epoch:
+            self.logger.info("\n")
+            self.logger.info("-" * 32)
+            self.logger.info("Epoch {}: ".format(self.epoch))
+
+            ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
+            ### It's inconvenient for the model with multiple losses
+            # Do training & validating epoch
+            train_loss = self._train_epoch()
+            self.logger.info("  |- Train/Loss: {:.6f}".format(train_loss))
+            valid_loss = self._valid_epoch()
+            self.logger.info("  |- Valid/Loss: {:.6f}".format(valid_loss))
+            self.accelerator.log(
+                {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
+                step=self.epoch,
+            )
+
+            self.accelerator.wait_for_everyone()
+            # TODO: what is scheduler?
+            self.scheduler.step(valid_loss)  # FIXME: use epoch track correct?
+
+            # Check if hit save_checkpoint_stride and run_eval
+            run_eval = False
+            if self.accelerator.is_main_process:
+                save_checkpoint = False
+                hit_dix = []
+                for i, num in enumerate(self.save_checkpoint_stride):
+                    if self.epoch % num == 0:
+                        save_checkpoint = True
+                        hit_dix.append(i)
+                        run_eval |= self.run_eval[i]
+
+            self.accelerator.wait_for_everyone()
+            if self.accelerator.is_main_process and save_checkpoint:
+                path = os.path.join(
+                    self.checkpoint_dir,
+                    "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                        self.epoch, self.step, train_loss
+                    ),
+                )
+                self.tmp_checkpoint_save_path = path
+                self.accelerator.save_state(path)
+                print(f"save checkpoint in {path}")
+                json.dump(
+                    self.checkpoints_path,
+                    open(os.path.join(path, "ckpts.json"), "w"),
+                    ensure_ascii=False,
+                    indent=4,
+                )
+                self._save_auxiliary_states()
+
+                # Remove old checkpoints
+                to_remove = []
+                for idx in hit_dix:
+                    self.checkpoints_path[idx].append(path)
+                    while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+                        to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+                # Search conflicts
+                total = set()
+                for i in self.checkpoints_path:
+                    total |= set(i)
+                do_remove = set()
+                for idx, path in to_remove[::-1]:
+                    if path in total:
+                        self.checkpoints_path[idx].insert(0, path)
+                    else:
+                        do_remove.add(path)
+
+                # Remove old checkpoints
+                for path in do_remove:
+                    shutil.rmtree(path, ignore_errors=True)
+                    self.logger.debug(f"Remove old checkpoint: {path}")
+
+            self.accelerator.wait_for_everyone()
+            if run_eval:
+                # TODO: run evaluation
+                pass
+
+            # Update info for each epoch
+            self.epoch += 1
+
+        # Finish training and save final checkpoint
+        self.accelerator.wait_for_everyone()
+        if self.accelerator.is_main_process:
+            self.accelerator.save_state(
+                os.path.join(
+                    self.checkpoint_dir,
+                    "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                        self.epoch, self.step, valid_loss
+                    ),
+                )
+            )
+            self._save_auxiliary_states()
+
+        self.accelerator.end_training()
+
+    ### Following are methods that can be used directly in child classes ###
+    def _train_epoch(self):
+        r"""Training epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.model.train()
+        epoch_sum_loss: float = 0.0
+        epoch_step: int = 0
+        for batch in tqdm(
+            self.train_dataloader,
+            desc=f"Training Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            # Do training step and BP
+            with self.accelerator.accumulate(self.model):
+                loss = self._train_step(batch)
+                self.accelerator.backward(loss)
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+            self.batch_count += 1
+
+            # Update info for each step
+            # TODO: step means BP counts or batch counts?
+            if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+                epoch_sum_loss += loss
+                self.accelerator.log(
+                    {
+                        "Step/Train Loss": loss,
+                        "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
+                    },
+                    step=self.step,
+                )
+                self.step += 1
+                epoch_step += 1
+
+        self.accelerator.wait_for_everyone()
+        return (
+            epoch_sum_loss
+            / len(self.train_dataloader)
+            * self.cfg.train.gradient_accumulation_step
+        )
+
+    @torch.inference_mode()
+    def _valid_epoch(self):
+        r"""Testing epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.model.eval()
+        epoch_sum_loss = 0.0
+        for batch in tqdm(
+            self.valid_dataloader,
+            desc=f"Validating Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            batch_loss = self._valid_step(batch)
+            epoch_sum_loss += batch_loss.item()
+
+        self.accelerator.wait_for_everyone()
+        return epoch_sum_loss / len(self.valid_dataloader)
+
+    def _train_step(self, batch):
+        r"""Training forward step. Should return average loss of a sample over
+        one batch. Provoke ``_forward_step`` is recommended except for special case.
+        See ``_train_epoch`` for usage.
+        """
+        return self._forward_step(batch)
+
+    @torch.inference_mode()
+    def _valid_step(self, batch):
+        r"""Testing forward step. Should return average loss of a sample over
+        one batch. Provoke ``_forward_step`` is recommended except for special case.
+        See ``_test_epoch`` for usage.
+        """
+        return self._forward_step(batch)
+
+    def __load_model(
+        self,
+        checkpoint_dir: str = None,
+        checkpoint_path: str = None,
+        resume_type: str = "",
+    ):
+        r"""Load model from checkpoint. If checkpoint_path is None, it will
+        load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+        None, it will load the checkpoint specified by checkpoint_path. **Only use this
+        method after** ``accelerator.prepare()``.
+        """
+        if checkpoint_path is None:
+            ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+            ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+            checkpoint_path = ls[0]
+            self.logger.info("Resume from {}...".format(checkpoint_path))
+
+        if resume_type in ["resume", ""]:
+            # Load all the things, including model weights, optimizer, scheduler, and random states.
+            self.accelerator.load_state(input_dir=checkpoint_path)
+
+            # set epoch and step
+            self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+            self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+
+        elif resume_type == "finetune":
+            # Load only the model weights
+            accelerate.load_checkpoint_and_dispatch(
+                self.accelerator.unwrap_model(self.model),
+                os.path.join(checkpoint_path, "pytorch_model.bin"),
+            )
+            self.logger.info("Load model weights for finetune...")
+
+        else:
+            raise ValueError("Resume_type must be `resume` or `finetune`.")
+
+        return checkpoint_path
+
+    # TODO: LEGACY CODE
+    def _build_dataloader(self):
+        Dataset, Collator = self._build_dataset()
+
+        # build dataset instance for each dataset and combine them by ConcatDataset
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=False)
+            datasets_list.append(subdataset)
+        train_dataset = ConcatDataset(datasets_list)
+        train_collate = Collator(self.cfg)
+        _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+        self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
+        self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
+        # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
+        train_loader = DataLoader(
+            train_dataset,
+            collate_fn=train_collate,
+            batch_sampler=batch_sampler,
+            num_workers=self.cfg.train.dataloader.num_worker,
+            pin_memory=self.cfg.train.dataloader.pin_memory,
+        )
+
+        # Build valid dataloader
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=True)
+            datasets_list.append(subdataset)
+        valid_dataset = ConcatDataset(datasets_list)
+        valid_collate = Collator(self.cfg)
+        _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
+        self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
+        self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
+        valid_loader = DataLoader(
+            valid_dataset,
+            collate_fn=valid_collate,
+            batch_sampler=batch_sampler,
+            num_workers=self.cfg.train.dataloader.num_worker,
+            pin_memory=self.cfg.train.dataloader.pin_memory,
+        )
+        return train_loader, valid_loader
+
+    @staticmethod
+    def _set_random_seed(seed):
+        r"""Set random seed for all possible random modules."""
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    def _check_nan(self, loss, y_pred, y_gt):
+        if torch.any(torch.isnan(loss)):
+            self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
+            self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+            if torch.any(torch.isnan(y_pred)):
+                self.logger.error(
+                    f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+                )
+            else:
+                self.logger.debug(
+                    f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
+                )
+            if torch.any(torch.isnan(y_gt)):
+                self.logger.error(
+                    f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+                )
+            else:
+                self.logger.debug(
+                    f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
+                )
+            if torch.any(torch.isnan(y_pred)):
+                self.logger.error(f"y_pred: {y_pred}", in_order=True)
+            else:
+                self.logger.debug(f"y_pred: {y_pred}", in_order=True)
+            if torch.any(torch.isnan(y_gt)):
+                self.logger.error(f"y_gt: {y_gt}", in_order=True)
+            else:
+                self.logger.debug(f"y_gt: {y_gt}", in_order=True)
+
+            # TODO: still OK to save tracking?
+            self.accelerator.end_training()
+            raise RuntimeError("Loss has Nan! See log for more info.")
+
+    ### Protected methods end ###
+
+    ## Following are private methods ##
+    ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
+    def __build_optimizer(self):
+        r"""Build optimizer for model."""
+        # Make case-insensitive matching
+        if self.cfg.train.optimizer.lower() == "adadelta":
+            optimizer = torch.optim.Adadelta(
+                self.model.parameters(), **self.cfg.train.adadelta
+            )
+            self.logger.info("Using Adadelta optimizer.")
+        elif self.cfg.train.optimizer.lower() == "adagrad":
+            optimizer = torch.optim.Adagrad(
+                self.model.parameters(), **self.cfg.train.adagrad
+            )
+            self.logger.info("Using Adagrad optimizer.")
+        elif self.cfg.train.optimizer.lower() == "adam":
+            optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
+            self.logger.info("Using Adam optimizer.")
+        elif self.cfg.train.optimizer.lower() == "adamw":
+            optimizer = torch.optim.AdamW(
+                self.model.parameters(), **self.cfg.train.adamw
+            )
+        elif self.cfg.train.optimizer.lower() == "sparseadam":
+            optimizer = torch.optim.SparseAdam(
+                self.model.parameters(), **self.cfg.train.sparseadam
+            )
+        elif self.cfg.train.optimizer.lower() == "adamax":
+            optimizer = torch.optim.Adamax(
+                self.model.parameters(), **self.cfg.train.adamax
+            )
+        elif self.cfg.train.optimizer.lower() == "asgd":
+            optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
+        elif self.cfg.train.optimizer.lower() == "lbfgs":
+            optimizer = torch.optim.LBFGS(
+                self.model.parameters(), **self.cfg.train.lbfgs
+            )
+        elif self.cfg.train.optimizer.lower() == "nadam":
+            optimizer = torch.optim.NAdam(
+                self.model.parameters(), **self.cfg.train.nadam
+            )
+        elif self.cfg.train.optimizer.lower() == "radam":
+            optimizer = torch.optim.RAdam(
+                self.model.parameters(), **self.cfg.train.radam
+            )
+        elif self.cfg.train.optimizer.lower() == "rmsprop":
+            optimizer = torch.optim.RMSprop(
+                self.model.parameters(), **self.cfg.train.rmsprop
+            )
+        elif self.cfg.train.optimizer.lower() == "rprop":
+            optimizer = torch.optim.Rprop(
+                self.model.parameters(), **self.cfg.train.rprop
+            )
+        elif self.cfg.train.optimizer.lower() == "sgd":
+            optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
+        else:
+            raise NotImplementedError(
+                f"Optimizer {self.cfg.train.optimizer} not supported yet!"
+            )
+        return optimizer
+
+    def __build_scheduler(self):
+        r"""Build scheduler for optimizer."""
+        # Make case-insensitive matching
+        if self.cfg.train.scheduler.lower() == "lambdalr":
+            scheduler = torch.optim.lr_scheduler.LambdaLR(
+                self.optimizer, **self.cfg.train.lambdalr
+            )
+        elif self.cfg.train.scheduler.lower() == "multiplicativelr":
+            scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
+                self.optimizer, **self.cfg.train.multiplicativelr
+            )
+        elif self.cfg.train.scheduler.lower() == "steplr":
+            scheduler = torch.optim.lr_scheduler.StepLR(
+                self.optimizer, **self.cfg.train.steplr
+            )
+        elif self.cfg.train.scheduler.lower() == "multisteplr":
+            scheduler = torch.optim.lr_scheduler.MultiStepLR(
+                self.optimizer, **self.cfg.train.multisteplr
+            )
+        elif self.cfg.train.scheduler.lower() == "constantlr":
+            scheduler = torch.optim.lr_scheduler.ConstantLR(
+                self.optimizer, **self.cfg.train.constantlr
+            )
+        elif self.cfg.train.scheduler.lower() == "linearlr":
+            scheduler = torch.optim.lr_scheduler.LinearLR(
+                self.optimizer, **self.cfg.train.linearlr
+            )
+        elif self.cfg.train.scheduler.lower() == "exponentiallr":
+            scheduler = torch.optim.lr_scheduler.ExponentialLR(
+                self.optimizer, **self.cfg.train.exponentiallr
+            )
+        elif self.cfg.train.scheduler.lower() == "polynomiallr":
+            scheduler = torch.optim.lr_scheduler.PolynomialLR(
+                self.optimizer, **self.cfg.train.polynomiallr
+            )
+        elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
+            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+                self.optimizer, **self.cfg.train.cosineannealinglr
+            )
+        elif self.cfg.train.scheduler.lower() == "sequentiallr":
+            scheduler = torch.optim.lr_scheduler.SequentialLR(
+                self.optimizer, **self.cfg.train.sequentiallr
+            )
+        elif self.cfg.train.scheduler.lower() == "reducelronplateau":
+            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
+                self.optimizer, **self.cfg.train.reducelronplateau
+            )
+        elif self.cfg.train.scheduler.lower() == "cycliclr":
+            scheduler = torch.optim.lr_scheduler.CyclicLR(
+                self.optimizer, **self.cfg.train.cycliclr
+            )
+        elif self.cfg.train.scheduler.lower() == "onecyclelr":
+            scheduler = torch.optim.lr_scheduler.OneCycleLR(
+                self.optimizer, **self.cfg.train.onecyclelr
+            )
+        elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
+            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+                self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
+            )
+        elif self.cfg.train.scheduler.lower() == "noamlr":
+            scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
+        else:
+            raise NotImplementedError(
+                f"Scheduler {self.cfg.train.scheduler} not supported yet!"
+            )
+        return scheduler
+
+    def _init_accelerator(self):
+        self.exp_dir = os.path.join(
+            os.path.abspath(self.cfg.log_dir), self.args.exp_name
+        )
+        project_config = ProjectConfiguration(
+            project_dir=self.exp_dir,
+            logging_dir=os.path.join(self.exp_dir, "log"),
+        )
+        self.accelerator = accelerate.Accelerator(
+            gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+            log_with=self.cfg.train.tracker,
+            project_config=project_config,
+        )
+        if self.accelerator.is_main_process:
+            os.makedirs(project_config.project_dir, exist_ok=True)
+            os.makedirs(project_config.logging_dir, exist_ok=True)
+        with self.accelerator.main_process_first():
+            self.accelerator.init_trackers(self.args.exp_name)
+
+    def __check_basic_configs(self):
+        if self.cfg.train.gradient_accumulation_step <= 0:
+            self.logger.fatal("Invalid gradient_accumulation_step value!")
+            self.logger.error(
+                f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+            )
+            self.accelerator.end_training()
+            raise ValueError(
+                f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+            )
+        # TODO: check other values
+
+    @staticmethod
+    def __count_parameters(model):
+        model_param = 0.0
+        if isinstance(model, dict):
+            for key, value in model.items():
+                model_param += sum(p.numel() for p in model[key].parameters())
+        else:
+            model_param = sum(p.numel() for p in model.parameters())
+        return model_param
+
+    def __dump_cfg(self, path):
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        json5.dump(
+            self.cfg,
+            open(path, "w"),
+            indent=4,
+            sort_keys=True,
+            ensure_ascii=False,
+            quote_keys=True,
+        )
+
+    ### Private methods end ###
diff --git a/models/svc/__init__.py b/models/svc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/base/__init__.py b/models/svc/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c2b1686db550b3b9892b8bc6e594cd847aafd1
--- /dev/null
+++ b/models/svc/base/__init__.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .svc_inference import SVCInference
+from .svc_trainer import SVCTrainer
diff --git a/models/svc/base/svc_dataset.py b/models/svc/base/svc_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..cad356bd97b92b9917db6c973b31a42552f5fa76
--- /dev/null
+++ b/models/svc/base/svc_dataset.py
@@ -0,0 +1,425 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import random
+import torch
+from torch.nn.utils.rnn import pad_sequence
+import json
+import os
+import numpy as np
+from utils.data_utils import *
+from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
+from processors.content_extractor import (
+    ContentvecExtractor,
+    WhisperExtractor,
+    WenetExtractor,
+)
+from models.base.base_dataset import (
+    BaseCollator,
+    BaseDataset,
+)
+from models.base.new_dataset import BaseTestDataset
+
+EPS = 1.0e-12
+
+
+class SVCDataset(BaseDataset):
+    def __init__(self, cfg, dataset, is_valid=False):
+        BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
+
+        cfg = self.cfg
+
+        if cfg.model.condition_encoder.use_whisper:
+            self.whisper_aligner = WhisperExtractor(self.cfg)
+            self.utt2whisper_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
+            )
+
+        if cfg.model.condition_encoder.use_contentvec:
+            self.contentvec_aligner = ContentvecExtractor(self.cfg)
+            self.utt2contentVec_path = load_content_feature_path(
+                self.metadata,
+                cfg.preprocess.processed_dir,
+                cfg.preprocess.contentvec_dir,
+            )
+
+        if cfg.model.condition_encoder.use_mert:
+            self.utt2mert_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
+            )
+        if cfg.model.condition_encoder.use_wenet:
+            self.wenet_aligner = WenetExtractor(self.cfg)
+            self.utt2wenet_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
+            )
+
+    def __getitem__(self, index):
+        single_feature = BaseDataset.__getitem__(self, index)
+
+        utt_info = self.metadata[index]
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        if self.cfg.model.condition_encoder.use_whisper:
+            assert "target_len" in single_feature.keys()
+            aligned_whisper_feat = self.whisper_aligner.offline_align(
+                np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
+            )
+            single_feature["whisper_feat"] = aligned_whisper_feat
+
+        if self.cfg.model.condition_encoder.use_contentvec:
+            assert "target_len" in single_feature.keys()
+            aligned_contentvec = self.contentvec_aligner.offline_align(
+                np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
+            )
+            single_feature["contentvec_feat"] = aligned_contentvec
+
+        if self.cfg.model.condition_encoder.use_mert:
+            assert "target_len" in single_feature.keys()
+            aligned_mert_feat = align_content_feature_length(
+                np.load(self.utt2mert_path[utt]),
+                single_feature["target_len"],
+                source_hop=self.cfg.preprocess.mert_hop_size,
+            )
+            single_feature["mert_feat"] = aligned_mert_feat
+
+        if self.cfg.model.condition_encoder.use_wenet:
+            assert "target_len" in single_feature.keys()
+            aligned_wenet_feat = self.wenet_aligner.offline_align(
+                np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
+            )
+            single_feature["wenet_feat"] = aligned_wenet_feat
+
+        # print(single_feature.keys())
+        # for k, v in single_feature.items():
+        #     if type(v) in [torch.Tensor, np.ndarray]:
+        #         print(k, v.shape)
+        #     else:
+        #         print(k, v)
+        # exit()
+
+        return self.clip_if_too_long(single_feature)
+
+    def __len__(self):
+        return len(self.metadata)
+
+    def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
+        """
+        ending_ts: to avoid invalid whisper features for over 30s audios
+            2812 = 30 * 24000 // 256
+        """
+        ts = max(feature_seq_len - max_seq_len, 0)
+        ts = min(ts, ending_ts - max_seq_len)
+
+        start = random.randint(0, ts)
+        end = start + max_seq_len
+        return start, end
+
+    def clip_if_too_long(self, sample, max_seq_len=512):
+        """
+        sample :
+            {
+                'spk_id': (1,),
+                'target_len': int
+                'mel': (seq_len, dim),
+                'frame_pitch': (seq_len,)
+                'frame_energy': (seq_len,)
+                'content_vector_feat': (seq_len, dim)
+            }
+        """
+        if sample["target_len"] <= max_seq_len:
+            return sample
+
+        start, end = self.random_select(sample["target_len"], max_seq_len)
+        sample["target_len"] = end - start
+
+        for k in sample.keys():
+            if k not in ["spk_id", "target_len"]:
+                sample[k] = sample[k][start:end]
+
+        return sample
+
+
+class SVCCollator(BaseCollator):
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        BaseCollator.__init__(self, cfg)
+
+    def __call__(self, batch):
+        parsed_batch_features = BaseCollator.__call__(self, batch)
+        return parsed_batch_features
+
+
+class SVCTestDataset(BaseTestDataset):
+    def __init__(self, args, cfg, infer_type):
+        BaseTestDataset.__init__(self, args, cfg, infer_type)
+        self.metadata = self.get_metadata()
+
+        target_singer = args.target_singer
+        self.cfg = cfg
+        self.trans_key = args.trans_key
+        assert type(target_singer) == str
+
+        self.target_singer = target_singer.split("_")[-1]
+        self.target_dataset = target_singer.replace(
+            "_{}".format(self.target_singer), ""
+        )
+
+        self.target_mel_extrema = load_mel_extrema(cfg.preprocess, self.target_dataset)
+        self.target_mel_extrema = torch.as_tensor(
+            self.target_mel_extrema[0]
+        ), torch.as_tensor(self.target_mel_extrema[1])
+
+        ######### Load source acoustic features #########
+        if cfg.preprocess.use_spkid:
+            spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
+            # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
+
+            with open(spk2id_path, "r") as f:
+                self.spk2id = json.load(f)
+            # print("self.spk2id", self.spk2id)
+
+        if cfg.preprocess.use_uv:
+            self.utt2uv_path = {
+                f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+                    cfg.preprocess.processed_dir,
+                    utt_info["Dataset"],
+                    cfg.preprocess.uv_dir,
+                    utt_info["Uid"] + ".npy",
+                )
+                for utt_info in self.metadata
+            }
+
+        if cfg.preprocess.use_frame_pitch:
+            self.utt2frame_pitch_path = {
+                f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+                    cfg.preprocess.processed_dir,
+                    utt_info["Dataset"],
+                    cfg.preprocess.pitch_dir,
+                    utt_info["Uid"] + ".npy",
+                )
+                for utt_info in self.metadata
+            }
+
+            # Target F0 median
+            target_f0_statistics_path = os.path.join(
+                cfg.preprocess.processed_dir,
+                self.target_dataset,
+                cfg.preprocess.pitch_dir,
+                "statistics.json",
+            )
+            self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[
+                f"{self.target_dataset}_{self.target_singer}"
+            ]["voiced_positions"]["median"]
+
+            # Source F0 median (if infer from file)
+            if infer_type == "from_file":
+                source_audio_name = cfg.inference.source_audio_name
+                source_f0_statistics_path = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    source_audio_name,
+                    cfg.preprocess.pitch_dir,
+                    "statistics.json",
+                )
+                self.source_pitch_median = json.load(
+                    open(source_f0_statistics_path, "r")
+                )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
+                    "median"
+                ]
+            else:
+                self.source_pitch_median = None
+
+        if cfg.preprocess.use_frame_energy:
+            self.utt2frame_energy_path = {
+                f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+                    cfg.preprocess.processed_dir,
+                    utt_info["Dataset"],
+                    cfg.preprocess.energy_dir,
+                    utt_info["Uid"] + ".npy",
+                )
+                for utt_info in self.metadata
+            }
+
+        if cfg.preprocess.use_mel:
+            self.utt2mel_path = {
+                f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
+                    cfg.preprocess.processed_dir,
+                    utt_info["Dataset"],
+                    cfg.preprocess.mel_dir,
+                    utt_info["Uid"] + ".npy",
+                )
+                for utt_info in self.metadata
+            }
+
+        ######### Load source content features' path #########
+        if cfg.model.condition_encoder.use_whisper:
+            self.whisper_aligner = WhisperExtractor(cfg)
+            self.utt2whisper_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
+            )
+
+        if cfg.model.condition_encoder.use_contentvec:
+            self.contentvec_aligner = ContentvecExtractor(cfg)
+            self.utt2contentVec_path = load_content_feature_path(
+                self.metadata,
+                cfg.preprocess.processed_dir,
+                cfg.preprocess.contentvec_dir,
+            )
+
+        if cfg.model.condition_encoder.use_mert:
+            self.utt2mert_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
+            )
+        if cfg.model.condition_encoder.use_wenet:
+            self.wenet_aligner = WenetExtractor(cfg)
+            self.utt2wenet_path = load_content_feature_path(
+                self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
+            )
+
+    def __getitem__(self, index):
+        single_feature = {}
+
+        utt_info = self.metadata[index]
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        source_dataset = self.metadata[index]["Dataset"]
+
+        if self.cfg.preprocess.use_spkid:
+            single_feature["spk_id"] = np.array(
+                [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
+                dtype=np.int32,
+            )
+
+        ######### Get Acoustic Features Item #########
+        if self.cfg.preprocess.use_mel:
+            mel = np.load(self.utt2mel_path[utt])
+            assert mel.shape[0] == self.cfg.preprocess.n_mel  # [n_mels, T]
+            if self.cfg.preprocess.use_min_max_norm_mel:
+                # mel norm
+                mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = mel.shape[1]
+            single_feature["mel"] = mel.T  # [T, n_mels]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            frame_pitch_path = self.utt2frame_pitch_path[utt]
+            frame_pitch = np.load(frame_pitch_path)
+
+            if self.trans_key:
+                try:
+                    self.trans_key = int(self.trans_key)
+                except:
+                    pass
+                if type(self.trans_key) == int:
+                    frame_pitch = transpose_key(frame_pitch, self.trans_key)
+                elif self.trans_key:
+                    assert self.target_singer
+
+                    frame_pitch = pitch_shift_to_target(
+                        frame_pitch, self.target_pitch_median, self.source_pitch_median
+                    )
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_pitch)
+            aligned_frame_pitch = align_length(
+                frame_pitch, single_feature["target_len"]
+            )
+            single_feature["frame_pitch"] = aligned_frame_pitch
+
+            if self.cfg.preprocess.use_uv:
+                frame_uv_path = self.utt2uv_path[utt]
+                frame_uv = np.load(frame_uv_path)
+                aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
+                aligned_frame_uv = [
+                    0 if frame_uv else 1 for frame_uv in aligned_frame_uv
+                ]
+                aligned_frame_uv = np.array(aligned_frame_uv)
+                single_feature["frame_uv"] = aligned_frame_uv
+
+        if self.cfg.preprocess.use_frame_energy:
+            frame_energy_path = self.utt2frame_energy_path[utt]
+            frame_energy = np.load(frame_energy_path)
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_energy)
+            aligned_frame_energy = align_length(
+                frame_energy, single_feature["target_len"]
+            )
+            single_feature["frame_energy"] = aligned_frame_energy
+
+        ######### Get Content Features Item #########
+        if self.cfg.model.condition_encoder.use_whisper:
+            assert "target_len" in single_feature.keys()
+            aligned_whisper_feat = self.whisper_aligner.offline_align(
+                np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
+            )
+            single_feature["whisper_feat"] = aligned_whisper_feat
+
+        if self.cfg.model.condition_encoder.use_contentvec:
+            assert "target_len" in single_feature.keys()
+            aligned_contentvec = self.contentvec_aligner.offline_align(
+                np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
+            )
+            single_feature["contentvec_feat"] = aligned_contentvec
+
+        if self.cfg.model.condition_encoder.use_mert:
+            assert "target_len" in single_feature.keys()
+            aligned_mert_feat = align_content_feature_length(
+                np.load(self.utt2mert_path[utt]),
+                single_feature["target_len"],
+                source_hop=self.cfg.preprocess.mert_hop_size,
+            )
+            single_feature["mert_feat"] = aligned_mert_feat
+
+        if self.cfg.model.condition_encoder.use_wenet:
+            assert "target_len" in single_feature.keys()
+            aligned_wenet_feat = self.wenet_aligner.offline_align(
+                np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
+            )
+            single_feature["wenet_feat"] = aligned_wenet_feat
+
+        return single_feature
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+class SVCTestCollator:
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def __call__(self, batch):
+        packed_batch_features = dict()
+
+        # mel: [b, T, n_mels]
+        # frame_pitch, frame_energy: [1, T]
+        # target_len: [1]
+        # spk_id: [b, 1]
+        # mask: [b, T, 1]
+
+        for key in batch[0].keys():
+            if key == "target_len":
+                packed_batch_features["target_len"] = torch.LongTensor(
+                    [b["target_len"] for b in batch]
+                )
+                masks = [
+                    torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+                ]
+                packed_batch_features["mask"] = pad_sequence(
+                    masks, batch_first=True, padding_value=0
+                )
+            else:
+                values = [torch.from_numpy(b[key]) for b in batch]
+                packed_batch_features[key] = pad_sequence(
+                    values, batch_first=True, padding_value=0
+                )
+
+        return packed_batch_features
diff --git a/models/svc/base/svc_inference.py b/models/svc/base/svc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..52f88d5d915e1616292c03927b4f51557351f58b
--- /dev/null
+++ b/models/svc/base/svc_inference.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from models.base.new_inference import BaseInference
+from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
+
+
+class SVCInference(BaseInference):
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        BaseInference.__init__(self, args, cfg, infer_type)
+
+    def _build_test_dataset(self):
+        return SVCTestDataset, SVCTestCollator
diff --git a/models/svc/base/svc_trainer.py b/models/svc/base/svc_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2a093a86712bb7ccfa786a6c18dd1683ffc013c
--- /dev/null
+++ b/models/svc/base/svc_trainer.py
@@ -0,0 +1,111 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+
+import torch
+import torch.nn as nn
+
+from models.base.new_trainer import BaseTrainer
+from models.svc.base.svc_dataset import SVCCollator, SVCDataset
+
+
+class SVCTrainer(BaseTrainer):
+    r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
+    ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
+    class, and implement ``_build_model``, ``_forward_step``.
+    """
+
+    def __init__(self, args=None, cfg=None):
+        self.args = args
+        self.cfg = cfg
+
+        self._init_accelerator()
+
+        # Only for SVC tasks
+        with self.accelerator.main_process_first():
+            self.singers = self._build_singer_lut()
+
+        # Super init
+        BaseTrainer.__init__(self, args, cfg)
+
+        # Only for SVC tasks
+        self.task_type = "SVC"
+        self.logger.info("Task type: {}".format(self.task_type))
+
+    ### Following are methods only for SVC tasks ###
+    # TODO: LEGACY CODE, NEED TO BE REFACTORED
+    def _build_dataset(self):
+        return SVCDataset, SVCCollator
+
+    @staticmethod
+    def _build_criterion():
+        criterion = nn.MSELoss(reduction="none")
+        return criterion
+
+    @staticmethod
+    def _compute_loss(criterion, y_pred, y_gt, loss_mask):
+        """
+        Args:
+            criterion: MSELoss(reduction='none')
+            y_pred, y_gt: (bs, seq_len, D)
+            loss_mask: (bs, seq_len, 1)
+        Returns:
+            loss: Tensor of shape []
+        """
+
+        # (bs, seq_len, D)
+        loss = criterion(y_pred, y_gt)
+        # expand loss_mask to (bs, seq_len, D)
+        loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
+
+        loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
+        return loss
+
+    def _save_auxiliary_states(self):
+        """
+        To save the singer's look-up table in the checkpoint saving path
+        """
+        with open(
+            os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
+        ) as f:
+            json.dump(self.singers, f, indent=4, ensure_ascii=False)
+
+    def _build_singer_lut(self):
+        resumed_singer_path = None
+        if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
+            resumed_singer_path = os.path.join(
+                self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
+            )
+        if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
+            resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+
+        if resumed_singer_path:
+            with open(resumed_singer_path, "r") as f:
+                singers = json.load(f)
+        else:
+            singers = dict()
+
+        for dataset in self.cfg.dataset:
+            singer_lut_path = os.path.join(
+                self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
+            )
+            with open(singer_lut_path, "r") as singer_lut_path:
+                singer_lut = json.load(singer_lut_path)
+            for singer in singer_lut.keys():
+                if singer not in singers:
+                    singers[singer] = len(singers)
+
+        with open(
+            os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
+        ) as singer_file:
+            json.dump(singers, singer_file, indent=4, ensure_ascii=False)
+        print(
+            "singers have been dumped to {}".format(
+                os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
+            )
+        )
+        return singers
diff --git a/models/svc/comosvc/__init__.py b/models/svc/comosvc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f1cb162e95d8a992002beaa0c0d8bada9cddd5
--- /dev/null
+++ b/models/svc/comosvc/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/models/svc/comosvc/comosvc.py b/models/svc/comosvc/comosvc.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cecd7a3f40f3a78f0df06ef2340159d321d6117
--- /dev/null
+++ b/models/svc/comosvc/comosvc.py
@@ -0,0 +1,377 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Adapted from https://github.com/zhenye234/CoMoSpeech"""
+
+import torch
+import torch.nn as nn
+import copy
+import numpy as np
+import math
+from tqdm.auto import tqdm
+
+from utils.ssim import SSIM
+
+from models.svc.transformer.conformer import Conformer, BaseModule
+from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
+from models.svc.comosvc.utils import slice_segments, rand_ids_segments
+
+
+class Consistency(nn.Module):
+    def __init__(self, cfg, distill=False):
+        super().__init__()
+        self.cfg = cfg
+        # self.denoise_fn = GradLogPEstimator2d(96)
+        self.denoise_fn = DiffusionWrapper(self.cfg)
+        self.cfg = cfg.model.comosvc
+        self.teacher = not distill
+        self.P_mean = self.cfg.P_mean
+        self.P_std = self.cfg.P_std
+        self.sigma_data = self.cfg.sigma_data
+        self.sigma_min = self.cfg.sigma_min
+        self.sigma_max = self.cfg.sigma_max
+        self.rho = self.cfg.rho
+        self.N = self.cfg.n_timesteps
+        self.ssim_loss = SSIM()
+
+        # Time step discretization
+        step_indices = torch.arange(self.N)
+        # karras boundaries formula
+        t_steps = (
+            self.sigma_min ** (1 / self.rho)
+            + step_indices
+            / (self.N - 1)
+            * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
+        ) ** self.rho
+        self.t_steps = torch.cat(
+            [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
+        )
+
+    def init_consistency_training(self):
+        self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
+        self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
+
+    def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
+        """
+        karras diffusion reverse process
+
+        Args:
+            x: noisy mel-spectrogram [B x n_mel x L]
+            sigma: noise level [B x 1 x 1]
+            cond: output of conformer encoder [B x n_mel x L]
+            denoise_fn: denoiser neural network e.g. DilatedCNN
+            mask: mask of padded frames [B x n_mel x L]
+
+        Returns:
+            denoised mel-spectrogram [B x n_mel x L]
+        """
+        sigma = sigma.reshape(-1, 1, 1)
+
+        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
+        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
+        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
+        c_noise = sigma.log() / 4
+
+        x_in = c_in * x
+        x_in = x_in.transpose(1, 2)
+        x = x.transpose(1, 2)
+        cond = cond.transpose(1, 2)
+        F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
+        # F_x =  denoise_fn((c_in * x), mask, cond, c_noise.flatten())
+        D_x = c_skip * x + c_out * (F_x)
+        D_x = D_x.transpose(1, 2)
+        return D_x
+
+    def EDMLoss(self, x_start, cond, mask):
+        """
+        compute loss for EDM model
+
+        Args:
+            x_start: ground truth mel-spectrogram [B x n_mel x L]
+            cond: output of conformer encoder [B x n_mel x L]
+            mask: mask of padded frames [B x n_mel x L]
+        """
+        rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
+        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
+        weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
+
+        # follow Grad-TTS, start from Gaussian noise with mean cond and std I
+        noise = (torch.randn_like(x_start) + cond) * sigma
+        D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
+        loss = weight * ((D_yn - x_start) ** 2)
+        loss = torch.sum(loss * mask) / torch.sum(mask)
+        return loss
+
+    def round_sigma(self, sigma):
+        return torch.as_tensor(sigma)
+
+    def edm_sampler(
+        self,
+        latents,
+        cond,
+        nonpadding,
+        num_steps=50,
+        sigma_min=0.002,
+        sigma_max=80,
+        rho=7,
+        S_churn=0,
+        S_min=0,
+        S_max=float("inf"),
+        S_noise=1,
+        # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
+        # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
+        # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
+        # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
+    ):
+        """
+        karras diffusion sampler
+
+        Args:
+            latents: noisy mel-spectrogram [B x n_mel x L]
+            cond: output of conformer encoder [B x n_mel x L]
+            nonpadding: mask of padded frames [B x n_mel x L]
+            num_steps: number of steps for diffusion inference
+
+        Returns:
+            denoised mel-spectrogram [B x n_mel x L]
+        """
+        # Time step discretization.
+        step_indices = torch.arange(num_steps, device=latents.device)
+
+        num_steps = num_steps + 1
+        t_steps = (
+            sigma_max ** (1 / rho)
+            + step_indices
+            / (num_steps - 1)
+            * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
+        ) ** rho
+        t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
+
+        # Main sampling loop.
+        x_next = latents * t_steps[0]
+        # wrap in tqdm for progress bar
+        bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
+        for i, (t_cur, t_next) in bar:
+            x_cur = x_next
+            # Increase noise temporarily.
+            gamma = (
+                min(S_churn / num_steps, np.sqrt(2) - 1)
+                if S_min <= t_cur <= S_max
+                else 0
+            )
+            t_hat = self.round_sigma(t_cur + gamma * t_cur)
+            t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
+            t[:, 0, 0] = t_hat
+            t_hat = t
+            x_hat = x_cur + (
+                t_hat**2 - t_cur**2
+            ).sqrt() * S_noise * torch.randn_like(x_cur)
+            # Euler step.
+            denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
+            d_cur = (x_hat - denoised) / t_hat
+            x_next = x_hat + (t_next - t_hat) * d_cur
+
+        return x_next
+
+    def CTLoss_D(self, y, cond, mask):
+        """
+        compute loss for consistency distillation
+
+        Args:
+            y: ground truth mel-spectrogram [B x n_mel x L]
+            cond: output of conformer encoder [B x n_mel x L]
+            mask: mask of padded frames [B x n_mel x L]
+        """
+        with torch.no_grad():
+            mu = 0.95
+            for p, ema_p in zip(
+                self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
+            ):
+                ema_p.mul_(mu).add_(p, alpha=1 - mu)
+
+        n = torch.randint(1, self.N, (y.shape[0],))
+        z = torch.randn_like(y) + cond
+
+        tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
+        f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
+
+        with torch.no_grad():
+            tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
+
+            # euler step
+            x_hat = y + tn_1 * z
+            denoised = self.EDMPrecond(
+                x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
+            )
+            d_cur = (x_hat - denoised) / tn_1
+            y_tn = x_hat + (tn - tn_1) * d_cur
+
+            f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
+
+        # loss = (f_theta - f_theta_ema.detach()) ** 2
+        # loss = torch.sum(loss * mask) / torch.sum(mask)
+        loss = self.ssim_loss(f_theta, f_theta_ema.detach())
+        loss = torch.sum(loss * mask) / torch.sum(mask)
+
+        return loss
+
+    def get_t_steps(self, N):
+        N = N + 1
+        step_indices = torch.arange(N)  # , device=latents.device)
+        t_steps = (
+            self.sigma_min ** (1 / self.rho)
+            + step_indices
+            / (N - 1)
+            * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
+        ) ** self.rho
+
+        return t_steps.flip(0)
+
+    def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
+        """
+        consistency distillation sampler
+
+        Args:
+            latents: noisy mel-spectrogram [B x n_mel x L]
+            cond: output of conformer encoder [B x n_mel x L]
+            nonpadding: mask of padded frames [B x n_mel x L]
+            t_steps: number of steps for diffusion inference
+
+        Returns:
+            denoised mel-spectrogram [B x n_mel x L]
+        """
+        # one-step
+        if t_steps == 1:
+            t_steps = [80]
+        # multi-step
+        else:
+            t_steps = self.get_t_steps(t_steps)
+
+        t_steps = torch.as_tensor(t_steps).to(latents.device)
+        latents = latents * t_steps[0]
+        _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
+        _t[:, 0, 0] = t_steps
+        x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
+
+        for t in t_steps[1:-1]:
+            z = torch.randn_like(x) + cond
+            x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
+            _t = torch.zeros((x.shape[0], 1, 1), device=x.device)
+            _t[:, 0, 0] = t
+            t = _t
+            print(t)
+            x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
+        return x
+
+    def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
+        """
+        calculate loss or sample mel-spectrogram
+
+        Args:
+            x:
+                training: ground truth mel-spectrogram [B x n_mel x L]
+                inference: output of encoder [B x n_mel x L]
+        """
+        if self.teacher:  # teacher model -- karras diffusion
+            if not infer:
+                loss = self.EDMLoss(x, cond, nonpadding)
+                return loss
+            else:
+                shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
+                x = torch.randn(shape, device=x.device) + cond
+                x = self.edm_sampler(x, cond, nonpadding, t_steps)
+
+            return x
+        else:  # Consistency distillation
+            if not infer:
+                loss = self.CTLoss_D(x, cond, nonpadding)
+                return loss
+
+            else:
+                shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
+                x = torch.randn(shape, device=x.device) + cond
+                x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
+
+            return x
+
+
+class ComoSVC(BaseModule):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+        self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
+        self.distill = self.cfg.model.comosvc.distill
+        self.encoder = Conformer(self.cfg.model.comosvc)
+        self.decoder = Consistency(self.cfg, distill=self.distill)
+        self.ssim_loss = SSIM()
+
+    @torch.no_grad()
+    def forward(self, x_mask, x, n_timesteps, temperature=1.0):
+        """
+        Generates mel-spectrogram from pitch, content vector, energy. Returns:
+            1. encoder outputs (from conformer)
+            2. decoder outputs (from diffusion-based decoder)
+
+        Args:
+            x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
+            x : output of encoder framework. [B x L x d_condition]
+            n_timesteps : number of steps to use for reverse diffusion in decoder.
+            temperature : controls variance of terminal distribution.
+        """
+
+        # Get encoder_outputs `mu_x`
+        mu_x = self.encoder(x, x_mask)
+        encoder_outputs = mu_x
+
+        mu_x = mu_x.transpose(1, 2)
+        x_mask = x_mask.transpose(1, 2)
+
+        # Generate sample by performing reverse dynamics
+        decoder_outputs = self.decoder(
+            mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
+        )
+        decoder_outputs = decoder_outputs.transpose(1, 2)
+        return encoder_outputs, decoder_outputs
+
+    def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
+        """
+        Computes 2 losses:
+            1. prior loss: loss between mel-spectrogram and encoder outputs.
+            2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
+
+        Args:
+            x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
+            x : output of encoder framework. [B x L x d_condition]
+            mel : ground truth mel-spectrogram. [B x L x n_mel]
+        """
+
+        mu_x = self.encoder(x, x_mask)
+        # prior loss
+        prior_loss = torch.sum(
+            0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
+        )
+        prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
+        # ssim loss
+        ssim_loss = self.ssim_loss(mu_x, mel)
+        ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
+
+        x_mask = x_mask.transpose(1, 2)
+        mu_x = mu_x.transpose(1, 2)
+        mel = mel.transpose(1, 2)
+        if not self.distill and skip_diff:
+            diff_loss = prior_loss.clone()
+            diff_loss.fill_(0)
+
+        # Cut a small segment of mel-spectrogram in order to increase batch size
+        else:
+            if self.distill:
+                mu_y = mu_x.detach()
+            else:
+                mu_y = mu_x
+            mask_y = x_mask
+
+            diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
+
+        return ssim_loss, prior_loss, diff_loss
diff --git a/models/svc/comosvc/comosvc_inference.py b/models/svc/comosvc/comosvc_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..2783ec7e468c367c7d2f5f8988ed1f7e272d4cb7
--- /dev/null
+++ b/models/svc/comosvc/comosvc_inference.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from models.svc.base import SVCInference
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.comosvc.comosvc import ComoSVC
+
+
+class ComoSVCInference(SVCInference):
+    def __init__(self, args, cfg, infer_type="from_dataset"):
+        SVCInference.__init__(self, args, cfg, infer_type)
+
+    def _build_model(self):
+        # TODO: sort out the config
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        self.acoustic_mapper = ComoSVC(self.cfg)
+        if self.cfg.model.comosvc.distill:
+            self.acoustic_mapper.decoder.init_consistency_training()
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+        return model
+
+    def _inference_each_batch(self, batch_data):
+        device = self.accelerator.device
+        for k, v in batch_data.items():
+            batch_data[k] = v.to(device)
+
+        cond = self.condition_encoder(batch_data)
+        mask = batch_data["mask"]
+        encoder_pred, decoder_pred = self.acoustic_mapper(
+            mask, cond, self.cfg.inference.comosvc.inference_steps
+        )
+
+        return decoder_pred
diff --git a/models/svc/comosvc/comosvc_trainer.py b/models/svc/comosvc/comosvc_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ba49fd4539b8ae351137a85595ff9cfba1f4677
--- /dev/null
+++ b/models/svc/comosvc/comosvc_trainer.py
@@ -0,0 +1,295 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import os
+import json5
+from collections import OrderedDict
+from tqdm import tqdm
+import json
+import shutil
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.comosvc.comosvc import ComoSVC
+
+
+class ComoSVCTrainer(SVCTrainer):
+    r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
+    implements ``_build_model`` and ``_forward_step`` methods.
+    """
+
+    def __init__(self, args=None, cfg=None):
+        SVCTrainer.__init__(self, args, cfg)
+        self.distill = cfg.model.comosvc.distill
+        self.skip_diff = True
+        if self.distill:  # and args.resume is None:
+            self.teacher_model_path = cfg.model.teacher_model_path
+            self.teacher_state_dict = self._load_teacher_state_dict()
+            self._load_teacher_model(self.teacher_state_dict)
+            self.acoustic_mapper.decoder.init_consistency_training()
+
+    ### Following are methods only for comoSVC models ###
+    def _load_teacher_state_dict(self):
+        self.checkpoint_file = self.teacher_model_path
+        print("Load teacher acoustic model from {}".format(self.checkpoint_file))
+        raw_state_dict = torch.load(self.checkpoint_file)  # , map_location=self.device)
+        return raw_state_dict
+
+    def _load_teacher_model(self, state_dict):
+        raw_dict = state_dict
+        clean_dict = OrderedDict()
+        for k, v in raw_dict.items():
+            if k.startswith("module."):
+                clean_dict[k[7:]] = v
+            else:
+                clean_dict[k] = v
+        self.model.load_state_dict(clean_dict)
+
+    def _build_model(self):
+        r"""Build the model for training. This function is called in ``__init__`` function."""
+
+        # TODO: sort out the config
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        self.acoustic_mapper = ComoSVC(self.cfg)
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+        return model
+
+    def _forward_step(self, batch):
+        r"""Forward step for training and inference. This function is called
+        in ``_train_step`` & ``_test_step`` function.
+        """
+        loss = {}
+        mask = batch["mask"]
+        mel_input = batch["mel"]
+        cond = self.condition_encoder(batch)
+        if self.distill:
+            cond = cond.detach()
+        self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
+        ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
+            mask, cond, mel_input, skip_diff=self.skip_diff
+        )
+        if self.distill:
+            loss["distil_loss"] = diff_loss
+        else:
+            loss["ssim_loss_encoder"] = ssim_loss
+            loss["prior_loss_encoder"] = prior_loss
+            loss["diffusion_loss_decoder"] = diff_loss
+
+        return loss
+
+    def _train_epoch(self):
+        r"""Training epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.model.train()
+        epoch_sum_loss: float = 0.0
+        epoch_step: int = 0
+        for batch in tqdm(
+            self.train_dataloader,
+            desc=f"Training Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            # Do training step and BP
+            with self.accelerator.accumulate(self.model):
+                loss = self._train_step(batch)
+                total_loss = 0
+                for k, v in loss.items():
+                    total_loss += v
+                self.accelerator.backward(total_loss)
+                enc_grad_norm = torch.nn.utils.clip_grad_norm_(
+                    self.acoustic_mapper.encoder.parameters(), max_norm=1
+                )
+                dec_grad_norm = torch.nn.utils.clip_grad_norm_(
+                    self.acoustic_mapper.decoder.parameters(), max_norm=1
+                )
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+            self.batch_count += 1
+
+            # Update info for each step
+            # TODO: step means BP counts or batch counts?
+            if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+                epoch_sum_loss += total_loss
+                log_info = {}
+                for k, v in loss.items():
+                    key = "Step/Train Loss/{}".format(k)
+                    log_info[key] = v
+                log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
+                self.accelerator.log(
+                    log_info,
+                    step=self.step,
+                )
+                self.step += 1
+                epoch_step += 1
+
+        self.accelerator.wait_for_everyone()
+        return (
+            epoch_sum_loss
+            / len(self.train_dataloader)
+            * self.cfg.train.gradient_accumulation_step,
+            loss,
+        )
+
+    def train_loop(self):
+        r"""Training loop. The public entry of training process."""
+        # Wait everyone to prepare before we move on
+        self.accelerator.wait_for_everyone()
+        # dump config file
+        if self.accelerator.is_main_process:
+            self.__dump_cfg(self.config_save_path)
+        self.model.train()
+        self.optimizer.zero_grad()
+        # Wait to ensure good to go
+        self.accelerator.wait_for_everyone()
+        while self.epoch < self.max_epoch:
+            self.logger.info("\n")
+            self.logger.info("-" * 32)
+            self.logger.info("Epoch {}: ".format(self.epoch))
+
+            ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
+            ### It's inconvenient for the model with multiple losses
+            # Do training & validating epoch
+            train_loss, loss = self._train_epoch()
+            self.logger.info("  |- Train/Loss: {:.6f}".format(train_loss))
+            for k, v in loss.items():
+                self.logger.info("  |- Train/Loss/{}: {:.6f}".format(k, v))
+            valid_loss = self._valid_epoch()
+            self.logger.info("  |- Valid/Loss: {:.6f}".format(valid_loss))
+            self.accelerator.log(
+                {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
+                step=self.epoch,
+            )
+
+            self.accelerator.wait_for_everyone()
+            # TODO: what is scheduler?
+            self.scheduler.step(valid_loss)  # FIXME: use epoch track correct?
+
+            # Check if hit save_checkpoint_stride and run_eval
+            run_eval = False
+            if self.accelerator.is_main_process:
+                save_checkpoint = False
+                hit_dix = []
+                for i, num in enumerate(self.save_checkpoint_stride):
+                    if self.epoch % num == 0:
+                        save_checkpoint = True
+                        hit_dix.append(i)
+                        run_eval |= self.run_eval[i]
+
+            self.accelerator.wait_for_everyone()
+            if (
+                self.accelerator.is_main_process
+                and save_checkpoint
+                and (self.distill or not self.skip_diff)
+            ):
+                path = os.path.join(
+                    self.checkpoint_dir,
+                    "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                        self.epoch, self.step, train_loss
+                    ),
+                )
+                self.accelerator.save_state(path)
+                json.dump(
+                    self.checkpoints_path,
+                    open(os.path.join(path, "ckpts.json"), "w"),
+                    ensure_ascii=False,
+                    indent=4,
+                )
+
+                # Remove old checkpoints
+                to_remove = []
+                for idx in hit_dix:
+                    self.checkpoints_path[idx].append(path)
+                    while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
+                        to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
+
+                # Search conflicts
+                total = set()
+                for i in self.checkpoints_path:
+                    total |= set(i)
+                do_remove = set()
+                for idx, path in to_remove[::-1]:
+                    if path in total:
+                        self.checkpoints_path[idx].insert(0, path)
+                    else:
+                        do_remove.add(path)
+
+                # Remove old checkpoints
+                for path in do_remove:
+                    shutil.rmtree(path, ignore_errors=True)
+                    self.logger.debug(f"Remove old checkpoint: {path}")
+
+            self.accelerator.wait_for_everyone()
+            if run_eval:
+                # TODO: run evaluation
+                pass
+
+            # Update info for each epoch
+            self.epoch += 1
+
+        # Finish training and save final checkpoint
+        self.accelerator.wait_for_everyone()
+        if self.accelerator.is_main_process:
+            self.accelerator.save_state(
+                os.path.join(
+                    self.checkpoint_dir,
+                    "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                        self.epoch, self.step, valid_loss
+                    ),
+                )
+            )
+        self.accelerator.end_training()
+
+    @torch.inference_mode()
+    def _valid_epoch(self):
+        r"""Testing epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.model.eval()
+        epoch_sum_loss = 0.0
+        for batch in tqdm(
+            self.valid_dataloader,
+            desc=f"Validating Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            batch_loss = self._valid_step(batch)
+            for k, v in batch_loss.items():
+                epoch_sum_loss += v
+
+        self.accelerator.wait_for_everyone()
+        return epoch_sum_loss / len(self.valid_dataloader)
+
+    @staticmethod
+    def __count_parameters(model):
+        model_param = 0.0
+        if isinstance(model, dict):
+            for key, value in model.items():
+                model_param += sum(p.numel() for p in model[key].parameters())
+        else:
+            model_param = sum(p.numel() for p in model.parameters())
+        return model_param
+
+    def __dump_cfg(self, path):
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        json5.dump(
+            self.cfg,
+            open(path, "w"),
+            indent=4,
+            sort_keys=True,
+            ensure_ascii=False,
+            quote_keys=True,
+        )
diff --git a/models/svc/comosvc/utils.py b/models/svc/comosvc/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f576f9a237d0a22ddfdb160122b906da9bcf889
--- /dev/null
+++ b/models/svc/comosvc/utils.py
@@ -0,0 +1,31 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+def slice_segments(x, ids_str, segment_size=200):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def rand_ids_segments(lengths, segment_size=200):
+    b = lengths.shape[0]
+    ids_str_max = lengths - segment_size
+    ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
+        dtype=torch.long
+    )
+    return ids_str
+
+
+def fix_len_compatibility(length, num_downsamplings_in_unet=2):
+    while True:
+        if length % (2**num_downsamplings_in_unet) == 0:
+            return length
+        length += 1
diff --git a/models/svc/diffusion/__init__.py b/models/svc/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/diffusion/diffusion_inference.py b/models/svc/diffusion/diffusion_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..a752ef8f195b59d0ac0ad402dc35ce5840626ab9
--- /dev/null
+++ b/models/svc/diffusion/diffusion_inference.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
+
+from models.svc.base import SVCInference
+from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
+from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
+from modules.encoder.condition_encoder import ConditionEncoder
+
+
+class DiffusionInference(SVCInference):
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        SVCInference.__init__(self, args, cfg, infer_type)
+
+        settings = {
+            **cfg.model.diffusion.scheduler_settings,
+            **cfg.inference.diffusion.scheduler_settings,
+        }
+        settings.pop("num_inference_timesteps")
+
+        if cfg.inference.diffusion.scheduler.lower() == "ddpm":
+            self.scheduler = DDPMScheduler(**settings)
+            self.logger.info("Using DDPM scheduler.")
+        elif cfg.inference.diffusion.scheduler.lower() == "ddim":
+            self.scheduler = DDIMScheduler(**settings)
+            self.logger.info("Using DDIM scheduler.")
+        elif cfg.inference.diffusion.scheduler.lower() == "pndm":
+            self.scheduler = PNDMScheduler(**settings)
+            self.logger.info("Using PNDM scheduler.")
+        else:
+            raise NotImplementedError(
+                "Unsupported scheduler type: {}".format(
+                    cfg.inference.diffusion.scheduler.lower()
+                )
+            )
+
+        self.pipeline = DiffusionInferencePipeline(
+            self.model[1],
+            self.scheduler,
+            args.diffusion_inference_steps,
+        )
+
+    def _build_model(self):
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        self.acoustic_mapper = DiffusionWrapper(self.cfg)
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+        return model
+
+    def _inference_each_batch(self, batch_data):
+        device = self.accelerator.device
+        for k, v in batch_data.items():
+            batch_data[k] = v.to(device)
+
+        conditioner = self.model[0](batch_data)
+        noise = torch.randn_like(batch_data["mel"], device=device)
+        y_pred = self.pipeline(noise, conditioner)
+        return y_pred
diff --git a/models/svc/diffusion/diffusion_inference_pipeline.py b/models/svc/diffusion/diffusion_inference_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2461aada99179ac17a2aaffebdb24864af1f5ee
--- /dev/null
+++ b/models/svc/diffusion/diffusion_inference_pipeline.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DiffusionPipeline
+
+
+class DiffusionInferencePipeline(DiffusionPipeline):
+    def __init__(self, network, scheduler, num_inference_timesteps=1000):
+        super().__init__()
+
+        self.register_modules(network=network, scheduler=scheduler)
+        self.num_inference_timesteps = num_inference_timesteps
+
+    @torch.inference_mode()
+    def __call__(
+        self,
+        initial_noise: torch.Tensor,
+        conditioner: torch.Tensor = None,
+    ):
+        r"""
+        Args:
+            initial_noise: The initial noise to be denoised.
+            conditioner:The conditioner.
+            n_inference_steps: The number of denoising steps. More denoising steps
+                usually lead to a higher quality at the expense of slower inference.
+        """
+
+        mel = initial_noise
+        batch_size = mel.size(0)
+        self.scheduler.set_timesteps(self.num_inference_timesteps)
+
+        for t in self.progress_bar(self.scheduler.timesteps):
+            timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
+
+            # 1. predict noise model_output
+            model_output = self.network(mel, timestep, conditioner)
+
+            # 2. denoise, compute previous step: x_t -> x_t-1
+            mel = self.scheduler.step(model_output, t, mel).prev_sample
+
+            # 3. clamp
+            mel = mel.clamp(-1.0, 1.0)
+
+        return mel
diff --git a/models/svc/diffusion/diffusion_trainer.py b/models/svc/diffusion/diffusion_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f5aeb56a825f84c57bb1d2ba9a5ff5a32d5f486
--- /dev/null
+++ b/models/svc/diffusion/diffusion_trainer.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from diffusers import DDPMScheduler
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from .diffusion_wrapper import DiffusionWrapper
+
+
+class DiffusionTrainer(SVCTrainer):
+    r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
+    implements ``_build_model`` and ``_forward_step`` methods.
+    """
+
+    def __init__(self, args=None, cfg=None):
+        SVCTrainer.__init__(self, args, cfg)
+
+        # Only for SVC tasks using diffusion
+        self.noise_scheduler = DDPMScheduler(
+            **self.cfg.model.diffusion.scheduler_settings,
+        )
+        self.diffusion_timesteps = (
+            self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
+        )
+
+    ### Following are methods only for diffusion models ###
+    def _build_model(self):
+        r"""Build the model for training. This function is called in ``__init__`` function."""
+
+        # TODO: sort out the config
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        self.acoustic_mapper = DiffusionWrapper(self.cfg)
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+
+        num_of_params_encoder = self.count_parameters(self.condition_encoder)
+        num_of_params_am = self.count_parameters(self.acoustic_mapper)
+        num_of_params = num_of_params_encoder + num_of_params_am
+        log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
+            num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
+        )
+        self.logger.info(log)
+
+        return model
+
+    def count_parameters(self, model):
+        model_param = 0.0
+        if isinstance(model, dict):
+            for key, value in model.items():
+                model_param += sum(p.numel() for p in model[key].parameters())
+        else:
+            model_param = sum(p.numel() for p in model.parameters())
+        return model_param
+
+    def _forward_step(self, batch):
+        r"""Forward step for training and inference. This function is called
+        in ``_train_step`` & ``_test_step`` function.
+        """
+
+        device = self.accelerator.device
+
+        mel_input = batch["mel"]
+        noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
+        batch_size = mel_input.size(0)
+        timesteps = torch.randint(
+            0,
+            self.diffusion_timesteps,
+            (batch_size,),
+            device=device,
+            dtype=torch.long,
+        )
+
+        noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
+        conditioner = self.condition_encoder(batch)
+
+        y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
+
+        # TODO: Predict noise or gt should be configurable
+        loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
+        self._check_nan(loss, y_pred, noise)
+
+        # FIXME: Clarify that we should not divide it with batch size here
+        return loss
diff --git a/models/svc/diffusion/diffusion_wrapper.py b/models/svc/diffusion/diffusion_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef66c2b6b85ceb8fe7a2cf9b53c62edc6b3ef6bc
--- /dev/null
+++ b/models/svc/diffusion/diffusion_wrapper.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch.nn as nn
+
+from modules.diffusion import BiDilConv
+from modules.encoder.position_encoder import PositionEncoder
+
+
+class DiffusionWrapper(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+
+        self.cfg = cfg
+        self.diff_cfg = cfg.model.diffusion
+
+        self.diff_encoder = PositionEncoder(
+            d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
+            d_out=self.diff_cfg.bidilconv.base_channel,
+            d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
+            activation_function=self.diff_cfg.step_encoder.activation,
+            n_layer=self.diff_cfg.step_encoder.num_layer,
+            max_period=self.diff_cfg.step_encoder.max_period,
+        )
+
+        # FIXME: Only support BiDilConv now for debug
+        if self.diff_cfg.model_type.lower() == "bidilconv":
+            self.neural_network = BiDilConv(
+                input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
+            )
+        else:
+            raise ValueError(
+                f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
+            )
+
+    def forward(self, x, t, c):
+        """
+        Args:
+            x: [N, T, mel_band] of mel spectrogram
+            t: Diffusion time step with shape of [N]
+            c: [N, T, conditioner_size] of conditioner
+
+        Returns:
+            [N, T, mel_band] of mel spectrogram
+        """
+
+        assert (
+            x.size()[:-1] == c.size()[:-1]
+        ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
+        assert x.size(0) == t.size(
+            0
+        ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
+        assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
+
+        N, T, mel_band = x.size()
+
+        x = x.transpose(1, 2).contiguous()  # [N, mel_band, T]
+        c = c.transpose(1, 2).contiguous()  # [N, conditioner_size, T]
+        t = self.diff_encoder(t).contiguous()  # [N, base_channel]
+
+        h = self.neural_network(x, t, c)
+        h = h.transpose(1, 2).contiguous()  # [N, T, mel_band]
+
+        assert h.size() == (
+            N,
+            T,
+            mel_band,
+        ), "h mismatch with input x, got \n h: {} \n x: {}".format(
+            h.size(), (N, T, mel_band)
+        )
+        return h
diff --git a/models/svc/transformer/__init__.py b/models/svc/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/svc/transformer/conformer.py b/models/svc/transformer/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e48019cfc17d5f3825ce989f4852cec55fe1daa
--- /dev/null
+++ b/models/svc/transformer/conformer.py
@@ -0,0 +1,405 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+from utils.util import convert_pad_shape
+
+
+class BaseModule(torch.nn.Module):
+    def __init__(self):
+        super(BaseModule, self).__init__()
+
+    @property
+    def nparams(self):
+        """
+        Returns number of trainable parameters of the module.
+        """
+        num_params = 0
+        for name, param in self.named_parameters():
+            if param.requires_grad:
+                num_params += np.prod(param.detach().cpu().numpy().shape)
+        return num_params
+
+    def relocate_input(self, x: list):
+        """
+        Relocates provided tensors to the same device set for the module.
+        """
+        device = next(self.parameters()).device
+        for i in range(len(x)):
+            if isinstance(x[i], torch.Tensor) and x[i].device != device:
+                x[i] = x[i].to(device)
+        return x
+
+
+class LayerNorm(BaseModule):
+    def __init__(self, channels, eps=1e-4):
+        super(LayerNorm, self).__init__()
+        self.channels = channels
+        self.eps = eps
+
+        self.gamma = torch.nn.Parameter(torch.ones(channels))
+        self.beta = torch.nn.Parameter(torch.zeros(channels))
+
+    def forward(self, x):
+        n_dims = len(x.shape)
+        mean = torch.mean(x, 1, keepdim=True)
+        variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+        x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+        shape = [1, -1] + [1] * (n_dims - 2)
+        x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+        return x
+
+
+class ConvReluNorm(BaseModule):
+    def __init__(
+        self,
+        in_channels,
+        hidden_channels,
+        out_channels,
+        kernel_size,
+        n_layers,
+        p_dropout,
+        eps=1e-5,
+    ):
+        super(ConvReluNorm, self).__init__()
+        self.in_channels = in_channels
+        self.hidden_channels = hidden_channels
+        self.out_channels = out_channels
+        self.kernel_size = kernel_size
+        self.n_layers = n_layers
+        self.p_dropout = p_dropout
+        self.eps = eps
+
+        self.conv_layers = torch.nn.ModuleList()
+        self.conv_layers.append(
+            torch.nn.Conv1d(
+                in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
+            )
+        )
+        self.relu_drop = torch.nn.Sequential(
+            torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
+        )
+        for _ in range(n_layers - 1):
+            self.conv_layers.append(
+                torch.nn.Conv1d(
+                    hidden_channels,
+                    hidden_channels,
+                    kernel_size,
+                    padding=kernel_size // 2,
+                )
+            )
+        self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
+        self.proj.weight.data.zero_()
+        self.proj.bias.data.zero_()
+
+    def forward(self, x, x_mask):
+        for i in range(self.n_layers):
+            x = self.conv_layers[i](x * x_mask)
+            x = self.instance_norm(x, x_mask)
+            x = self.relu_drop(x)
+        x = self.proj(x)
+        return x * x_mask
+
+    def instance_norm(self, x, mask, return_mean_std=False):
+        mean, std = self.calc_mean_std(x, mask)
+        x = (x - mean) / std
+        if return_mean_std:
+            return x, mean, std
+        else:
+            return x
+
+    def calc_mean_std(self, x, mask=None):
+        x = x * mask
+        B, C = x.shape[:2]
+        mn = x.view(B, C, -1).mean(-1)
+        sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
+        mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
+        sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
+        return mn, sd
+
+
+class MultiHeadAttention(BaseModule):
+    def __init__(
+        self,
+        channels,
+        out_channels,
+        n_heads,
+        window_size=None,
+        heads_share=True,
+        p_dropout=0.0,
+        proximal_bias=False,
+        proximal_init=False,
+    ):
+        super(MultiHeadAttention, self).__init__()
+        assert channels % n_heads == 0
+
+        self.channels = channels
+        self.out_channels = out_channels
+        self.n_heads = n_heads
+        self.window_size = window_size
+        self.heads_share = heads_share
+        self.proximal_bias = proximal_bias
+        self.p_dropout = p_dropout
+        self.attn = None
+
+        self.k_channels = channels // n_heads
+        self.conv_q = torch.nn.Conv1d(channels, channels, 1)
+        self.conv_k = torch.nn.Conv1d(channels, channels, 1)
+        self.conv_v = torch.nn.Conv1d(channels, channels, 1)
+        if window_size is not None:
+            n_heads_rel = 1 if heads_share else n_heads
+            rel_stddev = self.k_channels**-0.5
+            self.emb_rel_k = torch.nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+            self.emb_rel_v = torch.nn.Parameter(
+                torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
+                * rel_stddev
+            )
+        self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
+        self.drop = torch.nn.Dropout(p_dropout)
+
+        torch.nn.init.xavier_uniform_(self.conv_q.weight)
+        torch.nn.init.xavier_uniform_(self.conv_k.weight)
+        if proximal_init:
+            self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+            self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+        torch.nn.init.xavier_uniform_(self.conv_v.weight)
+
+    def forward(self, x, c, attn_mask=None):
+        q = self.conv_q(x)
+        k = self.conv_k(c)
+        v = self.conv_v(c)
+
+        x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+        x = self.conv_o(x)
+        return x
+
+    def attention(self, query, key, value, mask=None):
+        b, d, t_s, t_t = (*key.size(), query.size(2))
+        query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+        key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+        value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+        if self.window_size is not None:
+            assert (
+                t_s == t_t
+            ), "Relative attention is only available for self-attention."
+            key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+            rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
+            rel_logits = self._relative_position_to_absolute_position(rel_logits)
+            scores_local = rel_logits / math.sqrt(self.k_channels)
+            scores = scores + scores_local
+        if self.proximal_bias:
+            assert t_s == t_t, "Proximal bias is only available for self-attention."
+            scores = scores + self._attention_bias_proximal(t_s).to(
+                device=scores.device, dtype=scores.dtype
+            )
+        if mask is not None:
+            scores = scores.masked_fill(mask == 0, -1e4)
+        p_attn = torch.nn.functional.softmax(scores, dim=-1)
+        p_attn = self.drop(p_attn)
+        output = torch.matmul(p_attn, value)
+        if self.window_size is not None:
+            relative_weights = self._absolute_position_to_relative_position(p_attn)
+            value_relative_embeddings = self._get_relative_embeddings(
+                self.emb_rel_v, t_s
+            )
+            output = output + self._matmul_with_relative_values(
+                relative_weights, value_relative_embeddings
+            )
+        output = output.transpose(2, 3).contiguous().view(b, d, t_t)
+        return output, p_attn
+
+    def _matmul_with_relative_values(self, x, y):
+        ret = torch.matmul(x, y.unsqueeze(0))
+        return ret
+
+    def _matmul_with_relative_keys(self, x, y):
+        ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+        return ret
+
+    def _get_relative_embeddings(self, relative_embeddings, length):
+        pad_length = max(length - (self.window_size + 1), 0)
+        slice_start_position = max((self.window_size + 1) - length, 0)
+        slice_end_position = slice_start_position + 2 * length - 1
+        if pad_length > 0:
+            padded_relative_embeddings = torch.nn.functional.pad(
+                relative_embeddings,
+                convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
+            )
+        else:
+            padded_relative_embeddings = relative_embeddings
+        used_relative_embeddings = padded_relative_embeddings[
+            :, slice_start_position:slice_end_position
+        ]
+        return used_relative_embeddings
+
+    def _relative_position_to_absolute_position(self, x):
+        batch, heads, length, _ = x.size()
+        x = torch.nn.functional.pad(
+            x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
+        )
+        x_flat = x.view([batch, heads, length * 2 * length])
+        x_flat = torch.nn.functional.pad(
+            x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
+        )
+        x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
+            :, :, :length, length - 1 :
+        ]
+        return x_final
+
+    def _absolute_position_to_relative_position(self, x):
+        batch, heads, length, _ = x.size()
+        x = torch.nn.functional.pad(
+            x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
+        )
+        x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
+        x_flat = torch.nn.functional.pad(
+            x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
+        )
+        x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+        return x_final
+
+    def _attention_bias_proximal(self, length):
+        r = torch.arange(length, dtype=torch.float32)
+        diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+        return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(BaseModule):
+    def __init__(
+        self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
+    ):
+        super(FFN, self).__init__()
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.filter_channels = filter_channels
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+
+        self.conv_1 = torch.nn.Conv1d(
+            in_channels, filter_channels, kernel_size, padding=kernel_size // 2
+        )
+        self.conv_2 = torch.nn.Conv1d(
+            filter_channels, out_channels, kernel_size, padding=kernel_size // 2
+        )
+        self.drop = torch.nn.Dropout(p_dropout)
+
+    def forward(self, x, x_mask):
+        x = self.conv_1(x * x_mask)
+        x = torch.relu(x)
+        x = self.drop(x)
+        x = self.conv_2(x * x_mask)
+        return x * x_mask
+
+
+class Encoder(BaseModule):
+    def __init__(
+        self,
+        hidden_channels,
+        filter_channels,
+        n_heads=2,
+        n_layers=6,
+        kernel_size=3,
+        p_dropout=0.1,
+        window_size=4,
+        **kwargs
+    ):
+        super(Encoder, self).__init__()
+        self.hidden_channels = hidden_channels
+        self.filter_channels = filter_channels
+        self.n_heads = n_heads
+        self.n_layers = n_layers
+        self.kernel_size = kernel_size
+        self.p_dropout = p_dropout
+        self.window_size = window_size
+
+        self.drop = torch.nn.Dropout(p_dropout)
+        self.attn_layers = torch.nn.ModuleList()
+        self.norm_layers_1 = torch.nn.ModuleList()
+        self.ffn_layers = torch.nn.ModuleList()
+        self.norm_layers_2 = torch.nn.ModuleList()
+        for _ in range(self.n_layers):
+            self.attn_layers.append(
+                MultiHeadAttention(
+                    hidden_channels,
+                    hidden_channels,
+                    n_heads,
+                    window_size=window_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_1.append(LayerNorm(hidden_channels))
+            self.ffn_layers.append(
+                FFN(
+                    hidden_channels,
+                    hidden_channels,
+                    filter_channels,
+                    kernel_size,
+                    p_dropout=p_dropout,
+                )
+            )
+            self.norm_layers_2.append(LayerNorm(hidden_channels))
+
+    def forward(self, x, x_mask):
+        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
+        for i in range(self.n_layers):
+            x = x * x_mask
+            y = self.attn_layers[i](x, x, attn_mask)
+            y = self.drop(y)
+            x = self.norm_layers_1[i](x + y)
+            y = self.ffn_layers[i](x, x_mask)
+            y = self.drop(y)
+            x = self.norm_layers_2[i](x + y)
+        x = x * x_mask
+        return x
+
+
+class Conformer(BaseModule):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+        self.n_heads = self.cfg.n_heads
+        self.n_layers = self.cfg.n_layers
+        self.hidden_channels = self.cfg.input_dim
+        self.filter_channels = self.cfg.filter_channels
+        self.output_dim = self.cfg.output_dim
+        self.dropout = self.cfg.dropout
+
+        self.conformer_encoder = Encoder(
+            self.hidden_channels,
+            self.filter_channels,
+            n_heads=self.n_heads,
+            n_layers=self.n_layers,
+            kernel_size=3,
+            p_dropout=self.dropout,
+            window_size=4,
+        )
+        self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
+
+    def forward(self, x, x_mask):
+        """
+        Args:
+            x: (N, seq_len, input_dim)
+        Returns:
+            output: (N, seq_len, output_dim)
+        """
+        # (N, seq_len, d_model)
+        x = x.transpose(1, 2)
+        x_mask = x_mask.transpose(1, 2)
+        output = self.conformer_encoder(x, x_mask)
+        # (N, seq_len, output_dim)
+        output = self.projection(output)
+        output = output.transpose(1, 2)
+        return output
diff --git a/models/svc/transformer/transformer.py b/models/svc/transformer/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd3cdb6c2d0fc93534d005b9f67a3058c9185c60
--- /dev/null
+++ b/models/svc/transformer/transformer.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+import torch.nn as nn
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+
+
+class Transformer(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+
+        dropout = self.cfg.dropout
+        nhead = self.cfg.n_heads
+        nlayers = self.cfg.n_layers
+        input_dim = self.cfg.input_dim
+        output_dim = self.cfg.output_dim
+
+        d_model = input_dim
+        self.pos_encoder = PositionalEncoding(d_model, dropout)
+        encoder_layers = TransformerEncoderLayer(
+            d_model, nhead, dropout=dropout, batch_first=True
+        )
+        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
+
+        self.output_mlp = nn.Linear(d_model, output_dim)
+
+    def forward(self, x, mask=None):
+        """
+        Args:
+            x: (N, seq_len, input_dim)
+        Returns:
+            output: (N, seq_len, output_dim)
+        """
+        # (N, seq_len, d_model)
+        src = self.pos_encoder(x)
+        # model_stats["pos_embedding"] = x
+        # (N, seq_len, d_model)
+        output = self.transformer_encoder(src)
+        # (N, seq_len, output_dim)
+        output = self.output_mlp(output)
+        return output
+
+
+class PositionalEncoding(nn.Module):
+    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout)
+
+        position = torch.arange(max_len).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
+        )
+
+        # Assume that x is (seq_len, N, d)
+        # pe = torch.zeros(max_len, 1, d_model)
+        # pe[:, 0, 0::2] = torch.sin(position * div_term)
+        # pe[:, 0, 1::2] = torch.cos(position * div_term)
+
+        # Assume that x in (N, seq_len, d)
+        pe = torch.zeros(1, max_len, d_model)
+        pe[0, :, 0::2] = torch.sin(position * div_term)
+        pe[0, :, 1::2] = torch.cos(position * div_term)
+
+        self.register_buffer("pe", pe)
+
+    def forward(self, x):
+        """
+        Args:
+            x: Tensor, shape [N, seq_len, d]
+        """
+        # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
+        # x = x + self.pe[: x.size(0)]
+
+        # Now: self.pe is (1, max_len, d)
+        x = x + self.pe[:, : x.size(1), :]
+
+        return self.dropout(x)
diff --git a/models/svc/transformer/transformer_inference.py b/models/svc/transformer/transformer_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6299c532aec6cb9283ee87ee9f0142f0b5c981b
--- /dev/null
+++ b/models/svc/transformer/transformer_inference.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import time
+import numpy as np
+import torch
+from tqdm import tqdm
+import torch.nn as nn
+from collections import OrderedDict
+
+from models.svc.base import SVCInference
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.transformer.transformer import Transformer
+from models.svc.transformer.conformer import Conformer
+
+
+class TransformerInference(SVCInference):
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        SVCInference.__init__(self, args, cfg, infer_type)
+
+    def _build_model(self):
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        if self.cfg.model.transformer.type == "transformer":
+            self.acoustic_mapper = Transformer(self.cfg.model.transformer)
+        elif self.cfg.model.transformer.type == "conformer":
+            self.acoustic_mapper = Conformer(self.cfg.model.transformer)
+        else:
+            raise NotImplementedError
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+        return model
+
+    def _inference_each_batch(self, batch_data):
+        device = self.accelerator.device
+        for k, v in batch_data.items():
+            batch_data[k] = v.to(device)
+
+        condition = self.condition_encoder(batch_data)
+        y_pred = self.acoustic_mapper(condition, batch_data["mask"])
+
+        return y_pred
diff --git a/models/svc/transformer/transformer_trainer.py b/models/svc/transformer/transformer_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3633078475d26e708280bc354f091bb9ab01ae45
--- /dev/null
+++ b/models/svc/transformer/transformer_trainer.py
@@ -0,0 +1,52 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from models.svc.base import SVCTrainer
+from modules.encoder.condition_encoder import ConditionEncoder
+from models.svc.transformer.transformer import Transformer
+from models.svc.transformer.conformer import Conformer
+from utils.ssim import SSIM
+
+
+class TransformerTrainer(SVCTrainer):
+    def __init__(self, args, cfg):
+        SVCTrainer.__init__(self, args, cfg)
+        self.ssim_loss = SSIM()
+
+    def _build_model(self):
+        self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
+        self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
+        self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
+        if self.cfg.model.transformer.type == "transformer":
+            self.acoustic_mapper = Transformer(self.cfg.model.transformer)
+        elif self.cfg.model.transformer.type == "conformer":
+            self.acoustic_mapper = Conformer(self.cfg.model.transformer)
+        else:
+            raise NotImplementedError
+        model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
+        return model
+
+    def _forward_step(self, batch):
+        total_loss = 0
+        device = self.accelerator.device
+        mel = batch["mel"]
+        mask = batch["mask"]
+
+        condition = self.condition_encoder(batch)
+        mel_pred = self.acoustic_mapper(condition, mask)
+
+        l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
+            batch["mask"]
+        )
+        self._check_nan(l1_loss, mel_pred, mel)
+        total_loss += l1_loss
+        ssim_loss = self.ssim_loss(mel_pred, mel)
+        ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
+        self._check_nan(ssim_loss, mel_pred, mel)
+        total_loss += ssim_loss
+
+        return total_loss
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py b/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_inference.py b/models/vocoders/autoregressive/autoregressive_vocoder_inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py b/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/autoregressive/wavenet/conv.py b/models/vocoders/autoregressive/wavenet/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a095aad5d7203f6e5fb5a4d585b894e34dbe63c7
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/conv.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+from torch.nn import functional as F
+
+
+class Conv1d(nn.Conv1d):
+    """Extended nn.Conv1d for incremental dilated convolutions"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.clear_buffer()
+        self._linearized_weight = None
+        self.register_backward_hook(self._clear_linearized_weight)
+
+    def incremental_forward(self, input):
+        # input (B, T, C)
+        # run forward pre hooks
+        for hook in self._forward_pre_hooks.values():
+            hook(self, input)
+
+        # reshape weight
+        weight = self._get_linearized_weight()
+        kw = self.kernel_size[0]
+        dilation = self.dilation[0]
+
+        bsz = input.size(0)
+        if kw > 1:
+            input = input.data
+            if self.input_buffer is None:
+                self.input_buffer = input.new(
+                    bsz, kw + (kw - 1) * (dilation - 1), input.size(2)
+                )
+                self.input_buffer.zero_()
+            else:
+                # shift buffer
+                self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
+            # append next input
+            self.input_buffer[:, -1, :] = input[:, -1, :]
+            input = self.input_buffer
+            if dilation > 1:
+                input = input[:, 0::dilation, :].contiguous()
+        output = F.linear(input.view(bsz, -1), weight, self.bias)
+        return output.view(bsz, 1, -1)
+
+    def clear_buffer(self):
+        self.input_buffer = None
+
+    def _get_linearized_weight(self):
+        if self._linearized_weight is None:
+            kw = self.kernel_size[0]
+            # nn.Conv1d
+            if self.weight.size() == (self.out_channels, self.in_channels, kw):
+                weight = self.weight.transpose(1, 2).contiguous()
+            else:
+                # fairseq.modules.conv_tbc.ConvTBC
+                weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
+            assert weight.size() == (self.out_channels, kw, self.in_channels)
+            self._linearized_weight = weight.view(self.out_channels, -1)
+        return self._linearized_weight
+
+    def _clear_linearized_weight(self, *args):
+        self._linearized_weight = None
diff --git a/models/vocoders/autoregressive/wavenet/modules.py b/models/vocoders/autoregressive/wavenet/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d51e52a50af3bc1f7fe9627aeae8d2b1b28b7d
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/modules.py
@@ -0,0 +1,152 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .conv import Conv1d as conv_Conv1d
+
+
+def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
+    m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs)
+    nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
+    if m.bias is not None:
+        nn.init.constant_(m.bias, 0)
+    return nn.utils.weight_norm(m)
+
+
+def Conv1d1x1(in_channels, out_channels, bias=True):
+    return Conv1d(
+        in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias
+    )
+
+
+def _conv1x1_forward(conv, x, is_incremental):
+    if is_incremental:
+        x = conv.incremental_forward(x)
+    else:
+        x = conv(x)
+    return x
+
+
+class ResidualConv1dGLU(nn.Module):
+    """Residual dilated conv1d + Gated linear unit
+
+    Args:
+        residual_channels (int): Residual input / output channels
+        gate_channels (int): Gated activation channels.
+        kernel_size (int): Kernel size of convolution layers.
+        skip_out_channels (int): Skip connection channels. If None, set to same
+          as ``residual_channels``.
+        cin_channels (int): Local conditioning channels. If negative value is
+          set, local conditioning is disabled.
+        dropout (float): Dropout probability.
+        padding (int): Padding for convolution layers. If None, proper padding
+          is computed depends on dilation and kernel_size.
+        dilation (int): Dilation factor.
+    """
+
+    def __init__(
+        self,
+        residual_channels,
+        gate_channels,
+        kernel_size,
+        skip_out_channels=None,
+        cin_channels=-1,
+        dropout=1 - 0.95,
+        padding=None,
+        dilation=1,
+        causal=True,
+        bias=True,
+        *args,
+        **kwargs,
+    ):
+        super(ResidualConv1dGLU, self).__init__()
+        self.dropout = dropout
+
+        if skip_out_channels is None:
+            skip_out_channels = residual_channels
+        if padding is None:
+            # no future time stamps available
+            if causal:
+                padding = (kernel_size - 1) * dilation
+            else:
+                padding = (kernel_size - 1) // 2 * dilation
+        self.causal = causal
+
+        self.conv = Conv1d(
+            residual_channels,
+            gate_channels,
+            kernel_size,
+            padding=padding,
+            dilation=dilation,
+            bias=bias,
+            *args,
+            **kwargs,
+        )
+
+        # mel conditioning
+        self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False)
+
+        gate_out_channels = gate_channels // 2
+        self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
+        self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias)
+
+    def forward(self, x, c=None):
+        return self._forward(x, c, False)
+
+    def incremental_forward(self, x, c=None):
+        return self._forward(x, c, True)
+
+    def clear_buffer(self):
+        for c in [
+            self.conv,
+            self.conv1x1_out,
+            self.conv1x1_skip,
+            self.conv1x1c,
+        ]:
+            if c is not None:
+                c.clear_buffer()
+
+    def _forward(self, x, c, is_incremental):
+        """Forward
+
+        Args:
+            x (Tensor): B x C x T
+            c (Tensor): B x C x T, Mel conditioning features
+        Returns:
+            Tensor: output
+        """
+        residual = x
+        x = F.dropout(x, p=self.dropout, training=self.training)
+        if is_incremental:
+            splitdim = -1
+            x = self.conv.incremental_forward(x)
+        else:
+            splitdim = 1
+            x = self.conv(x)
+            # remove future time steps
+            x = x[:, :, : residual.size(-1)] if self.causal else x
+
+        a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+        assert self.conv1x1c is not None
+        c = _conv1x1_forward(self.conv1x1c, c, is_incremental)
+        ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+        a, b = a + ca, b + cb
+
+        x = torch.tanh(a) * torch.sigmoid(b)
+
+        # For skip connection
+        s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental)
+
+        # For residual connection
+        x = _conv1x1_forward(self.conv1x1_out, x, is_incremental)
+
+        x = (x + residual) * math.sqrt(0.5)
+        return x, s
diff --git a/models/vocoders/autoregressive/wavenet/upsample.py b/models/vocoders/autoregressive/wavenet/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..b664302cd56545f1709a4f1874ebadd8e9375a9c
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/upsample.py
@@ -0,0 +1,109 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import numpy as np
+
+from torch import nn
+from torch.nn import functional as F
+
+
+class Stretch2d(nn.Module):
+    def __init__(self, x_scale, y_scale, mode="nearest"):
+        super(Stretch2d, self).__init__()
+        self.x_scale = x_scale
+        self.y_scale = y_scale
+        self.mode = mode
+
+    def forward(self, x):
+        return F.interpolate(
+            x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode
+        )
+
+
+def _get_activation(upsample_activation):
+    nonlinear = getattr(nn, upsample_activation)
+    return nonlinear
+
+
+class UpsampleNetwork(nn.Module):
+    def __init__(
+        self,
+        upsample_scales,
+        upsample_activation="none",
+        upsample_activation_params={},
+        mode="nearest",
+        freq_axis_kernel_size=1,
+        cin_pad=0,
+        cin_channels=128,
+    ):
+        super(UpsampleNetwork, self).__init__()
+        self.up_layers = nn.ModuleList()
+        total_scale = np.prod(upsample_scales)
+        self.indent = cin_pad * total_scale
+        for scale in upsample_scales:
+            freq_axis_padding = (freq_axis_kernel_size - 1) // 2
+            k_size = (freq_axis_kernel_size, scale * 2 + 1)
+            padding = (freq_axis_padding, scale)
+            stretch = Stretch2d(scale, 1, mode)
+            conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False)
+            conv.weight.data.fill_(1.0 / np.prod(k_size))
+            conv = nn.utils.weight_norm(conv)
+            self.up_layers.append(stretch)
+            self.up_layers.append(conv)
+            if upsample_activation != "none":
+                nonlinear = _get_activation(upsample_activation)
+                self.up_layers.append(nonlinear(**upsample_activation_params))
+
+    def forward(self, c):
+        """
+        Args:
+            c : B x C x T
+        """
+
+        # B x 1 x C x T
+        c = c.unsqueeze(1)
+        for f in self.up_layers:
+            c = f(c)
+        # B x C x T
+        c = c.squeeze(1)
+
+        if self.indent > 0:
+            c = c[:, :, self.indent : -self.indent]
+        return c
+
+
+class ConvInUpsampleNetwork(nn.Module):
+    def __init__(
+        self,
+        upsample_scales,
+        upsample_activation="none",
+        upsample_activation_params={},
+        mode="nearest",
+        freq_axis_kernel_size=1,
+        cin_pad=0,
+        cin_channels=128,
+    ):
+        super(ConvInUpsampleNetwork, self).__init__()
+        # To capture wide-context information in conditional features
+        # meaningless if cin_pad == 0
+        ks = 2 * cin_pad + 1
+        self.conv_in = nn.Conv1d(
+            cin_channels, cin_channels, kernel_size=ks, padding=cin_pad, bias=False
+        )
+        self.upsample = UpsampleNetwork(
+            upsample_scales,
+            upsample_activation,
+            upsample_activation_params,
+            mode,
+            freq_axis_kernel_size,
+            cin_pad=cin_pad,
+            cin_channels=cin_channels,
+        )
+
+    def forward(self, c):
+        c_up = self.upsample(self.conv_in(c))
+        return c_up
diff --git a/models/vocoders/autoregressive/wavenet/wavenet.py b/models/vocoders/autoregressive/wavenet/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63f22c2600fd0f83e5bdf339ebb121b3d2f35e6
--- /dev/null
+++ b/models/vocoders/autoregressive/wavenet/wavenet.py
@@ -0,0 +1,170 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+
+from torch import nn
+from torch.nn import functional as F
+
+from .modules import Conv1d1x1, ResidualConv1dGLU
+from .upsample import ConvInUpsampleNetwork
+
+
+def receptive_field_size(
+    total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x
+):
+    """Compute receptive field size
+
+    Args:
+        total_layers (int): total layers
+        num_cycles (int): cycles
+        kernel_size (int): kernel size
+        dilation (lambda): lambda to compute dilation factor. ``lambda x : 1``
+          to disable dilated convolution.
+
+    Returns:
+        int: receptive field size in sample
+
+    """
+    assert total_layers % num_cycles == 0
+
+    layers_per_cycle = total_layers // num_cycles
+    dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)]
+    return (kernel_size - 1) * sum(dilations) + 1
+
+
+class WaveNet(nn.Module):
+    """The WaveNet model that supports local and global conditioning.
+
+    Args:
+        out_channels (int): Output channels. If input_type is mu-law quantized
+          one-hot vecror. this must equal to the quantize channels. Other wise
+          num_mixtures x 3 (pi, mu, log_scale).
+        layers (int): Number of total layers
+        stacks (int): Number of dilation cycles
+        residual_channels (int): Residual input / output channels
+        gate_channels (int): Gated activation channels.
+        skip_out_channels (int): Skip connection channels.
+        kernel_size (int): Kernel size of convolution layers.
+        dropout (float): Dropout probability.
+        input_dim (int): Number of mel-spec dimension.
+        upsample_scales (list): List of upsample scale.
+          ``np.prod(upsample_scales)`` must equal to hop size. Used only if
+          upsample_conditional_features is enabled.
+        freq_axis_kernel_size (int): Freq-axis kernel_size for transposed
+          convolution layers for upsampling. If you only care about time-axis
+          upsampling, set this to 1.
+        scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise
+          quantized one-hot vector is expected..
+    """
+
+    def __init__(self, cfg):
+        super(WaveNet, self).__init__()
+        self.cfg = cfg
+        self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT
+        self.out_channels = self.cfg.VOCODER.OUT_CHANNELS
+        self.cin_channels = self.cfg.VOCODER.INPUT_DIM
+        self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS
+        self.layers = self.cfg.VOCODER.LAYERS
+        self.stacks = self.cfg.VOCODER.STACKS
+        self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS
+        self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE
+        self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS
+        self.dropout = self.cfg.VOCODER.DROPOUT
+        self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES
+        self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD
+
+        assert self.layers % self.stacks == 0
+
+        layers_per_stack = self.layers // self.stacks
+        if self.scalar_input:
+            self.first_conv = Conv1d1x1(1, self.residual_channels)
+        else:
+            self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels)
+
+        self.conv_layers = nn.ModuleList()
+        for layer in range(self.layers):
+            dilation = 2 ** (layer % layers_per_stack)
+            conv = ResidualConv1dGLU(
+                self.residual_channels,
+                self.gate_channels,
+                kernel_size=self.kernel_size,
+                skip_out_channels=self.skip_out_channels,
+                bias=True,
+                dilation=dilation,
+                dropout=self.dropout,
+                cin_channels=self.cin_channels,
+            )
+            self.conv_layers.append(conv)
+
+        self.last_conv_layers = nn.ModuleList(
+            [
+                nn.ReLU(inplace=True),
+                Conv1d1x1(self.skip_out_channels, self.skip_out_channels),
+                nn.ReLU(inplace=True),
+                Conv1d1x1(self.skip_out_channels, self.out_channels),
+            ]
+        )
+
+        self.upsample_net = ConvInUpsampleNetwork(
+            upsample_scales=self.upsample_scales,
+            cin_pad=self.mel_frame_pad,
+            cin_channels=self.cin_channels,
+        )
+
+        self.receptive_field = receptive_field_size(
+            self.layers, self.stacks, self.kernel_size
+        )
+
+    def forward(self, x, mel, softmax=False):
+        """Forward step
+
+        Args:
+            x (Tensor): One-hot encoded audio signal, shape (B x C x T)
+            mel (Tensor): Local conditioning features,
+              shape (B x cin_channels x T)
+            softmax (bool): Whether applies softmax or not.
+
+        Returns:
+            Tensor: output, shape B x out_channels x T
+        """
+        B, _, T = x.size()
+
+        mel = self.upsample_net(mel)
+        assert mel.shape[-1] == x.shape[-1]
+
+        x = self.first_conv(x)
+        skips = 0
+        for f in self.conv_layers:
+            x, h = f(x, mel)
+            skips += h
+        skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+        x = skips
+        for f in self.last_conv_layers:
+            x = f(x)
+
+        x = F.softmax(x, dim=1) if softmax else x
+
+        return x
+
+    def clear_buffer(self):
+        self.first_conv.clear_buffer()
+        for f in self.conv_layers:
+            f.clear_buffer()
+        for f in self.last_conv_layers:
+            try:
+                f.clear_buffer()
+            except AttributeError:
+                pass
+
+    def make_generation_fast_(self):
+        def remove_weight_norm(m):
+            try:
+                nn.utils.remove_weight_norm(m)
+            except ValueError:  # this module didn't have weight norm
+                return
+
+        self.apply(remove_weight_norm)
diff --git a/models/vocoders/autoregressive/wavernn/wavernn.py b/models/vocoders/autoregressive/wavernn/wavernn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7475fa8fe8b4575bf714e615349582ff98bbc27
--- /dev/null
+++ b/models/vocoders/autoregressive/wavernn/wavernn.py
@@ -0,0 +1,188 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+
+
+class ResBlock(nn.Module):
+    def __init__(self, dims):
+        super().__init__()
+        self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+        self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False)
+        self.batch_norm1 = nn.BatchNorm1d(dims)
+        self.batch_norm2 = nn.BatchNorm1d(dims)
+
+    def forward(self, x):
+        residual = x
+        x = self.conv1(x)
+        x = self.batch_norm1(x)
+        x = F.relu(x)
+        x = self.conv2(x)
+        x = self.batch_norm2(x)
+        x = x + residual
+        return x
+
+
+class MelResNet(nn.Module):
+    def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad):
+        super().__init__()
+        kernel_size = pad * 2 + 1
+        self.conv_in = nn.Conv1d(
+            in_dims, compute_dims, kernel_size=kernel_size, bias=False
+        )
+        self.batch_norm = nn.BatchNorm1d(compute_dims)
+        self.layers = nn.ModuleList()
+        for i in range(res_blocks):
+            self.layers.append(ResBlock(compute_dims))
+        self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1)
+
+    def forward(self, x):
+        x = self.conv_in(x)
+        x = self.batch_norm(x)
+        x = F.relu(x)
+        for f in self.layers:
+            x = f(x)
+        x = self.conv_out(x)
+        return x
+
+
+class Stretch2d(nn.Module):
+    def __init__(self, x_scale, y_scale):
+        super().__init__()
+        self.x_scale = x_scale
+        self.y_scale = y_scale
+
+    def forward(self, x):
+        b, c, h, w = x.size()
+        x = x.unsqueeze(-1).unsqueeze(3)
+        x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
+        return x.view(b, c, h * self.y_scale, w * self.x_scale)
+
+
+class UpsampleNetwork(nn.Module):
+    def __init__(
+        self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad
+    ):
+        super().__init__()
+        total_scale = np.cumproduct(upsample_scales)[-1]
+        self.indent = pad * total_scale
+        self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad)
+        self.resnet_stretch = Stretch2d(total_scale, 1)
+        self.up_layers = nn.ModuleList()
+        for scale in upsample_scales:
+            kernel_size = (1, scale * 2 + 1)
+            padding = (0, scale)
+            stretch = Stretch2d(scale, 1)
+            conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+            conv.weight.data.fill_(1.0 / kernel_size[1])
+            self.up_layers.append(stretch)
+            self.up_layers.append(conv)
+
+    def forward(self, m):
+        aux = self.resnet(m).unsqueeze(1)
+        aux = self.resnet_stretch(aux)
+        aux = aux.squeeze(1)
+        m = m.unsqueeze(1)
+        for f in self.up_layers:
+            m = f(m)
+        m = m.squeeze(1)[:, :, self.indent : -self.indent]
+        return m.transpose(1, 2), aux.transpose(1, 2)
+
+
+class WaveRNN(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+
+        self.cfg = cfg
+        self.pad = self.cfg.VOCODER.MEL_FRAME_PAD
+
+        if self.cfg.VOCODER.MODE == "mu_law_quantize":
+            self.n_classes = 2**self.cfg.VOCODER.BITS
+        elif self.cfg.VOCODER.MODE == "mu_law" or self.cfg.VOCODER:
+            self.n_classes = 30
+
+        self._to_flatten = []
+
+        self.rnn_dims = self.cfg.VOCODER.RNN_DIMS
+        self.aux_dims = self.cfg.VOCODER.RES_OUT_DIMS // 4
+        self.hop_length = self.cfg.VOCODER.HOP_LENGTH
+        self.fc_dims = self.cfg.VOCODER.FC_DIMS
+        self.upsample_factors = self.cfg.VOCODER.UPSAMPLE_FACTORS
+        self.feat_dims = self.cfg.VOCODER.INPUT_DIM
+        self.compute_dims = self.cfg.VOCODER.COMPUTE_DIMS
+        self.res_out_dims = self.cfg.VOCODER.RES_OUT_DIMS
+        self.res_blocks = self.cfg.VOCODER.RES_BLOCKS
+
+        self.upsample = UpsampleNetwork(
+            self.feat_dims,
+            self.upsample_factors,
+            self.compute_dims,
+            self.res_blocks,
+            self.res_out_dims,
+            self.pad,
+        )
+        self.I = nn.Linear(self.feat_dims + self.aux_dims + 1, self.rnn_dims)
+
+        self.rnn1 = nn.GRU(self.rnn_dims, self.rnn_dims, batch_first=True)
+        self.rnn2 = nn.GRU(
+            self.rnn_dims + self.aux_dims, self.rnn_dims, batch_first=True
+        )
+        self._to_flatten += [self.rnn1, self.rnn2]
+
+        self.fc1 = nn.Linear(self.rnn_dims + self.aux_dims, self.fc_dims)
+        self.fc2 = nn.Linear(self.fc_dims + self.aux_dims, self.fc_dims)
+        self.fc3 = nn.Linear(self.fc_dims, self.n_classes)
+
+        self.num_params()
+
+        self._flatten_parameters()
+
+    def forward(self, x, mels):
+        device = next(self.parameters()).device
+
+        self._flatten_parameters()
+
+        batch_size = x.size(0)
+        h1 = torch.zeros(1, batch_size, self.rnn_dims, device=device)
+        h2 = torch.zeros(1, batch_size, self.rnn_dims, device=device)
+        mels, aux = self.upsample(mels)
+
+        aux_idx = [self.aux_dims * i for i in range(5)]
+        a1 = aux[:, :, aux_idx[0] : aux_idx[1]]
+        a2 = aux[:, :, aux_idx[1] : aux_idx[2]]
+        a3 = aux[:, :, aux_idx[2] : aux_idx[3]]
+        a4 = aux[:, :, aux_idx[3] : aux_idx[4]]
+
+        x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2)
+        x = self.I(x)
+        res = x
+        x, _ = self.rnn1(x, h1)
+
+        x = x + res
+        res = x
+        x = torch.cat([x, a2], dim=2)
+        x, _ = self.rnn2(x, h2)
+
+        x = x + res
+        x = torch.cat([x, a3], dim=2)
+        x = F.relu(self.fc1(x))
+
+        x = torch.cat([x, a4], dim=2)
+        x = F.relu(self.fc2(x))
+        return self.fc3(x)
+
+    def num_params(self, print_out=True):
+        parameters = filter(lambda p: p.requires_grad, self.parameters())
+        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+        if print_out:
+            print("Trainable Parameters: %.3fM" % parameters)
+        return parameters
+
+    def _flatten_parameters(self):
+        [m.flatten_parameters() for m in self._to_flatten]
diff --git a/models/vocoders/diffusion/diffusion_vocoder_dataset.py b/models/vocoders/diffusion/diffusion_vocoder_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/diffusion/diffusion_vocoder_inference.py b/models/vocoders/diffusion/diffusion_vocoder_inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/diffusion/diffusion_vocoder_trainer.py b/models/vocoders/diffusion/diffusion_vocoder_trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/diffusion/diffwave/diffwave.py b/models/vocoders/diffusion/diffwave/diffwave.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9379b0b622c6da8a754f2cc87fd7723eacfa995
--- /dev/null
+++ b/models/vocoders/diffusion/diffwave/diffwave.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from math import sqrt
+
+
+Linear = nn.Linear
+ConvTranspose2d = nn.ConvTranspose2d
+
+
+def Conv1d(*args, **kwargs):
+    layer = nn.Conv1d(*args, **kwargs)
+    nn.init.kaiming_normal_(layer.weight)
+    return layer
+
+
+@torch.jit.script
+def silu(x):
+    return x * torch.sigmoid(x)
+
+
+class DiffusionEmbedding(nn.Module):
+    def __init__(self, max_steps):
+        super().__init__()
+        self.register_buffer(
+            "embedding", self._build_embedding(max_steps), persistent=False
+        )
+        self.projection1 = Linear(128, 512)
+        self.projection2 = Linear(512, 512)
+
+    def forward(self, diffusion_step):
+        if diffusion_step.dtype in [torch.int32, torch.int64]:
+            x = self.embedding[diffusion_step]
+        else:
+            x = self._lerp_embedding(diffusion_step)
+        x = self.projection1(x)
+        x = silu(x)
+        x = self.projection2(x)
+        x = silu(x)
+        return x
+
+    def _lerp_embedding(self, t):
+        low_idx = torch.floor(t).long()
+        high_idx = torch.ceil(t).long()
+        low = self.embedding[low_idx]
+        high = self.embedding[high_idx]
+        return low + (high - low) * (t - low_idx)
+
+    def _build_embedding(self, max_steps):
+        steps = torch.arange(max_steps).unsqueeze(1)  # [T,1]
+        dims = torch.arange(64).unsqueeze(0)  # [1,64]
+        table = steps * 10.0 ** (dims * 4.0 / 63.0)  # [T,64]
+        table = torch.cat([torch.sin(table), torch.cos(table)], dim=1)
+        return table
+
+
+class SpectrogramUpsampler(nn.Module):
+    def __init__(self, upsample_factors):
+        super().__init__()
+        self.conv1 = ConvTranspose2d(
+            1,
+            1,
+            [3, upsample_factors[0] * 2],
+            stride=[1, upsample_factors[0]],
+            padding=[1, upsample_factors[0] // 2],
+        )
+        self.conv2 = ConvTranspose2d(
+            1,
+            1,
+            [3, upsample_factors[1] * 2],
+            stride=[1, upsample_factors[1]],
+            padding=[1, upsample_factors[1] // 2],
+        )
+
+    def forward(self, x):
+        x = torch.unsqueeze(x, 1)
+        x = self.conv1(x)
+        x = F.leaky_relu(x, 0.4)
+        x = self.conv2(x)
+        x = F.leaky_relu(x, 0.4)
+        x = torch.squeeze(x, 1)
+        return x
+
+
+class ResidualBlock(nn.Module):
+    def __init__(self, n_mels, residual_channels, dilation):
+        super().__init__()
+        self.dilated_conv = Conv1d(
+            residual_channels,
+            2 * residual_channels,
+            3,
+            padding=dilation,
+            dilation=dilation,
+        )
+        self.diffusion_projection = Linear(512, residual_channels)
+
+        self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1)
+
+        self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
+
+    def forward(self, x, diffusion_step, conditioner):
+        diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+        y = x + diffusion_step
+
+        conditioner = self.conditioner_projection(conditioner)
+        y = self.dilated_conv(y) + conditioner
+
+        gate, filter = torch.chunk(y, 2, dim=1)
+        y = torch.sigmoid(gate) * torch.tanh(filter)
+
+        y = self.output_projection(y)
+        residual, skip = torch.chunk(y, 2, dim=1)
+        return (x + residual) / sqrt(2.0), skip
+
+
+class DiffWave(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+        self.cfg.VOCODER.NOISE_SCHEDULE = np.linspace(
+            self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[0],
+            self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[1],
+            self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[2],
+        ).tolist()
+        self.input_projection = Conv1d(1, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1)
+        self.diffusion_embedding = DiffusionEmbedding(
+            len(self.cfg.VOCODER.NOISE_SCHEDULE)
+        )
+        self.spectrogram_upsampler = SpectrogramUpsampler(
+            self.cfg.VOCODER.UPSAMPLE_FACTORS
+        )
+
+        self.residual_layers = nn.ModuleList(
+            [
+                ResidualBlock(
+                    self.cfg.VOCODER.INPUT_DIM,
+                    self.cfg.VOCODER.RESIDUAL_CHANNELS,
+                    2 ** (i % self.cfg.VOCODER.DILATION_CYCLE_LENGTH),
+                )
+                for i in range(self.cfg.VOCODER.RESIDUAL_LAYERS)
+            ]
+        )
+        self.skip_projection = Conv1d(
+            self.cfg.VOCODER.RESIDUAL_CHANNELS, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1
+        )
+        self.output_projection = Conv1d(self.cfg.VOCODER.RESIDUAL_CHANNELS, 1, 1)
+        nn.init.zeros_(self.output_projection.weight)
+
+    def forward(self, audio, diffusion_step, spectrogram):
+        x = audio.unsqueeze(1)
+        x = self.input_projection(x)
+        x = F.relu(x)
+
+        diffusion_step = self.diffusion_embedding(diffusion_step)
+        spectrogram = self.spectrogram_upsampler(spectrogram)
+
+        skip = None
+        for layer in self.residual_layers:
+            x, skip_connection = layer(x, diffusion_step, spectrogram)
+            skip = skip_connection if skip is None else skip_connection + skip
+
+        x = skip / sqrt(len(self.residual_layers))
+        x = self.skip_projection(x)
+        x = F.relu(x)
+        x = self.output_projection(x)
+        return x
diff --git a/models/vocoders/dsp/world/world.py b/models/vocoders/dsp/world/world.py
new file mode 100644
index 0000000000000000000000000000000000000000..59f28e8e896f883fe6ce243dfb7f254e78fd09c6
--- /dev/null
+++ b/models/vocoders/dsp/world/world.py
@@ -0,0 +1,183 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# 1. Extract WORLD features including F0, AP, SP
+# 2. Transform between SP and MCEP
+import torchaudio
+import pyworld as pw
+import numpy as np
+import torch
+import diffsptk
+import os
+from tqdm import tqdm
+import pickle
+import json
+import re
+import torchaudio
+
+from cuhkszsvc.configs.config_parse import get_wav_path, get_wav_file_path
+from utils.io import has_existed
+
+
+def get_mcep_params(fs):
+    """Hyperparameters of transformation between SP and MCEP
+
+    Reference:
+        https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh
+
+    """
+    if fs in [44100, 48000]:
+        fft_size = 2048
+        alpha = 0.77
+    if fs in [16000]:
+        fft_size = 1024
+        alpha = 0.58
+    return fft_size, alpha
+
+
+def extract_world_features(wave_file, fs, frameshift):
+    # waveform: (1, seq)
+    waveform, sample_rate = torchaudio.load(wave_file)
+    if sample_rate != fs:
+        waveform = torchaudio.functional.resample(
+            waveform, orig_freq=sample_rate, new_freq=fs
+        )
+    # x: (seq,)
+    x = np.array(torch.clamp(waveform[0], -1.0, 1.0), dtype=np.double)
+
+    _f0, t = pw.dio(x, fs, frame_period=frameshift)  # raw pitch extractor
+    f0 = pw.stonemask(x, _f0, t, fs)  # pitch refinement
+    sp = pw.cheaptrick(x, f0, t, fs)  # extract smoothed spectrogram
+    ap = pw.d4c(x, f0, t, fs)  # extract aperiodicity
+
+    return f0, sp, ap, fs
+
+
+def sp2mcep(x, mcsize, fs):
+    fft_size, alpha = get_mcep_params(fs)
+    x = torch.as_tensor(x, dtype=torch.float)
+
+    tmp = diffsptk.ScalarOperation("SquareRoot")(x)
+    tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp)
+    mgc = diffsptk.MelCepstralAnalysis(
+        cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1
+    )(tmp)
+    return mgc.numpy()
+
+
+def mcep2sp(x, mcsize, fs):
+    fft_size, alpha = get_mcep_params(fs)
+    x = torch.as_tensor(x, dtype=torch.float)
+
+    tmp = diffsptk.MelGeneralizedCepstrumToSpectrum(
+        alpha=alpha,
+        cep_order=mcsize - 1,
+        fft_length=fft_size,
+    )(x)
+    tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp)
+    sp = diffsptk.ScalarOperation("Power", 2)(tmp)
+    return sp.double().numpy()
+
+
+def extract_mcep_features_of_dataset(
+    output_path, dataset_path, dataset, mcsize, fs, frameshift, splits=None
+):
+    output_dir = os.path.join(output_path, dataset, "mcep/{}".format(fs))
+
+    if not splits:
+        splits = ["train", "test"] if dataset != "m4singer" else ["test"]
+
+    for dataset_type in splits:
+        print("-" * 20)
+        print("Dataset: {}, {}".format(dataset, dataset_type))
+
+        output_file = os.path.join(output_dir, "{}.pkl".format(dataset_type))
+        if has_existed(output_file):
+            continue
+
+        # Extract SP features
+        print("\nExtracting SP featuers...")
+        sp_features = get_world_features_of_dataset(
+            output_path, dataset_path, dataset, dataset_type, fs, frameshift
+        )
+
+        # SP to MCEP
+        print("\nTransform SP to MCEP...")
+        mcep_features = [sp2mcep(sp, mcsize=mcsize, fs=fs) for sp in tqdm(sp_features)]
+
+        # Save
+        os.makedirs(output_dir, exist_ok=True)
+        with open(output_file, "wb") as f:
+            pickle.dump(mcep_features, f)
+
+
+def get_world_features_of_dataset(
+    output_path,
+    dataset_path,
+    dataset,
+    dataset_type,
+    fs,
+    frameshift,
+    save_sp_feature=False,
+):
+    data_dir = os.path.join(output_path, dataset)
+    wave_dir = get_wav_path(dataset_path, dataset)
+
+    # Dataset
+    dataset_file = os.path.join(data_dir, "{}.json".format(dataset_type))
+    if not os.path.exists(dataset_file):
+        print("File {} has not existed.".format(dataset_file))
+        return None
+
+    with open(dataset_file, "r") as f:
+        datasets = json.load(f)
+
+    # Save dir
+    f0_dir = os.path.join(output_path, dataset, "f0")
+    os.makedirs(f0_dir, exist_ok=True)
+
+    # Extract
+    f0_features = []
+    sp_features = []
+    for utt in tqdm(datasets):
+        wave_file = get_wav_file_path(dataset, wave_dir, utt)
+        f0, sp, _, _ = extract_world_features(wave_file, fs, frameshift)
+
+        sp_features.append(sp)
+        f0_features.append(f0)
+
+    # Save sp
+    if save_sp_feature:
+        sp_dir = os.path.join(output_path, dataset, "sp")
+        os.makedirs(sp_dir, exist_ok=True)
+        with open(os.path.join(sp_dir, "{}.pkl".format(dataset_type)), "wb") as f:
+            pickle.dump(sp_features, f)
+
+    # F0 statistics
+    f0_statistics_file = os.path.join(f0_dir, "{}_f0.pkl".format(dataset_type))
+    f0_statistics(f0_features, f0_statistics_file)
+
+    return sp_features
+
+
+def f0_statistics(f0_features, path):
+    print("\nF0 statistics...")
+
+    total_f0 = []
+    for f0 in tqdm(f0_features):
+        total_f0 += [f for f in f0 if f != 0]
+
+    mean = sum(total_f0) / len(total_f0)
+    print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean))
+
+    with open(path, "wb") as f:
+        pickle.dump([mean, total_f0], f)
+
+
+def world_synthesis(f0, sp, ap, fs, frameshift):
+    y = pw.synthesize(
+        f0, sp, ap, fs, frame_period=frameshift
+    )  # synthesize an utterance using the parameters
+    return y
diff --git a/models/vocoders/flow/flow_vocoder_dataset.py b/models/vocoders/flow/flow_vocoder_dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/flow_vocoder_inference.py b/models/vocoders/flow/flow_vocoder_inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/flow_vocoder_trainer.py b/models/vocoders/flow/flow_vocoder_trainer.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/flow/waveglow/waveglow.py b/models/vocoders/flow/waveglow/waveglow.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e2a1bf8f5e3c3d47a031ceec87e4ff111cd5fe
--- /dev/null
+++ b/models/vocoders/flow/waveglow/waveglow.py
@@ -0,0 +1,249 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+class Invertible1x1Conv(torch.nn.Module):
+    """
+    The layer outputs both the convolution, and the log determinant
+    of its weight matrix.  If reverse=True it does convolution with
+    inverse
+    """
+
+    def __init__(self, c):
+        super(Invertible1x1Conv, self).__init__()
+        self.conv = torch.nn.Conv1d(
+            c, c, kernel_size=1, stride=1, padding=0, bias=False
+        )
+
+        # Sample a random orthonormal matrix to initialize weights
+        W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
+
+        # Ensure determinant is 1.0 not -1.0
+        if torch.det(W) < 0:
+            W[:, 0] = -1 * W[:, 0]
+        W = W.view(c, c, 1)
+        self.conv.weight.data = W
+
+    def forward(self, z, reverse=False):
+        # shape
+        batch_size, group_size, n_of_groups = z.size()
+
+        W = self.conv.weight.squeeze()
+
+        if reverse:
+            if not hasattr(self, "W_inverse"):
+                # Reverse computation
+                W_inverse = W.float().inverse()
+                W_inverse = Variable(W_inverse[..., None])
+                if z.type() == "torch.cuda.HalfTensor":
+                    W_inverse = W_inverse.half()
+                self.W_inverse = W_inverse
+            z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
+            return z
+        else:
+            # Forward computation
+            log_det_W = batch_size * n_of_groups * torch.logdet(W)
+            z = self.conv(z)
+            return z, log_det_W
+
+
+class WN(torch.nn.Module):
+    """
+    This is the WaveNet like layer for the affine coupling.  The primary difference
+    from WaveNet is the convolutions need not be causal.  There is also no dilation
+    size reset.  The dilation only doubles on each layer
+    """
+
+    def __init__(
+        self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size
+    ):
+        super(WN, self).__init__()
+        assert kernel_size % 2 == 1
+        assert n_channels % 2 == 0
+        self.n_layers = n_layers
+        self.n_channels = n_channels
+        self.in_layers = torch.nn.ModuleList()
+        self.res_skip_layers = torch.nn.ModuleList()
+
+        start = torch.nn.Conv1d(n_in_channels, n_channels, 1)
+        start = torch.nn.utils.weight_norm(start, name="weight")
+        self.start = start
+
+        # Initializing last layer to 0 makes the affine coupling layers
+        # do nothing at first.  This helps with training stability
+        end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
+        end.weight.data.zero_()
+        end.bias.data.zero_()
+        self.end = end
+
+        cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels * n_layers, 1)
+        self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
+
+        for i in range(n_layers):
+            dilation = 2**i
+            padding = int((kernel_size * dilation - dilation) / 2)
+            in_layer = torch.nn.Conv1d(
+                n_channels,
+                2 * n_channels,
+                kernel_size,
+                dilation=dilation,
+                padding=padding,
+            )
+            in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
+            self.in_layers.append(in_layer)
+
+            # last one is not necessary
+            if i < n_layers - 1:
+                res_skip_channels = 2 * n_channels
+            else:
+                res_skip_channels = n_channels
+            res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1)
+            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+            self.res_skip_layers.append(res_skip_layer)
+
+    def forward(self, forward_input):
+        audio, spect = forward_input
+        audio = self.start(audio)
+        output = torch.zeros_like(audio)
+        n_channels_tensor = torch.IntTensor([self.n_channels])
+
+        spect = self.cond_layer(spect)
+
+        for i in range(self.n_layers):
+            spect_offset = i * 2 * self.n_channels
+            acts = fused_add_tanh_sigmoid_multiply(
+                self.in_layers[i](audio),
+                spect[:, spect_offset : spect_offset + 2 * self.n_channels, :],
+                n_channels_tensor,
+            )
+
+            res_skip_acts = self.res_skip_layers[i](acts)
+            if i < self.n_layers - 1:
+                audio = audio + res_skip_acts[:, : self.n_channels, :]
+                output = output + res_skip_acts[:, self.n_channels :, :]
+            else:
+                output = output + res_skip_acts
+
+        return self.end(output)
+
+
+class WaveGlow(torch.nn.Module):
+    def __init__(self, cfg):
+        super(WaveGlow, self).__init__()
+
+        self.cfg = cfg
+
+        self.upsample = torch.nn.ConvTranspose1d(
+            self.cfg.VOCODER.INPUT_DIM,
+            self.cfg.VOCODER.INPUT_DIM,
+            1024,
+            stride=256,
+        )
+        assert self.cfg.VOCODER.N_GROUP % 2 == 0
+        self.n_flows = self.cfg.VOCODER.N_FLOWS
+        self.n_group = self.cfg.VOCODER.N_GROUP
+        self.n_early_every = self.cfg.VOCODER.N_EARLY_EVERY
+        self.n_early_size = self.cfg.VOCODER.N_EARLY_SIZE
+        self.WN = torch.nn.ModuleList()
+        self.convinv = torch.nn.ModuleList()
+
+        n_half = int(self.cfg.VOCODER.N_GROUP / 2)
+
+        # Set up layers with the right sizes based on how many dimensions
+        # have been output already
+        n_remaining_channels = self.cfg.VOCODER.N_GROUP
+        for k in range(self.cfg.VOCODER.N_FLOWS):
+            if k % self.n_early_every == 0 and k > 0:
+                n_half = n_half - int(self.n_early_size / 2)
+                n_remaining_channels = n_remaining_channels - self.n_early_size
+            self.convinv.append(Invertible1x1Conv(n_remaining_channels))
+            self.WN.append(
+                WN(
+                    n_half,
+                    self.cfg.VOCODER.INPUT_DIM * self.cfg.VOCODER.N_GROUP,
+                    self.cfg.VOCODER.N_LAYERS,
+                    self.cfg.VOCODER.N_CHANNELS,
+                    self.cfg.VOCODER.KERNEL_SIZE,
+                )
+            )
+        self.n_remaining_channels = n_remaining_channels  # Useful during inference
+
+    def forward(self, forward_input):
+        """
+        forward_input[0] = mel_spectrogram:  batch x n_mel_channels x frames
+        forward_input[1] = audio: batch x time
+        """
+        spect, audio = forward_input
+
+        #  Upsample spectrogram to size of audio
+        spect = self.upsample(spect)
+        assert spect.size(2) >= audio.size(1)
+        if spect.size(2) > audio.size(1):
+            spect = spect[:, :, : audio.size(1)]
+
+        spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3)
+        spect = (
+            spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1)
+        )
+
+        audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1)
+        output_audio = []
+        log_s_list = []
+        log_det_W_list = []
+
+        for k in range(self.n_flows):
+            if k % self.n_early_every == 0 and k > 0:
+                output_audio.append(audio[:, : self.n_early_size, :])
+                audio = audio[:, self.n_early_size :, :]
+
+            audio, log_det_W = self.convinv[k](audio)
+            log_det_W_list.append(log_det_W)
+
+            n_half = int(audio.size(1) / 2)
+            audio_0 = audio[:, :n_half, :]
+            audio_1 = audio[:, n_half:, :]
+
+            output = self.WN[k]((audio_0, spect))
+            log_s = output[:, n_half:, :]
+            b = output[:, :n_half, :]
+            audio_1 = torch.exp(log_s) * audio_1 + b
+            log_s_list.append(log_s)
+
+            audio = torch.cat([audio_0, audio_1], 1)
+
+        output_audio.append(audio)
+        return torch.cat(output_audio, 1), log_s_list, log_det_W_list
+
+    @staticmethod
+    def remove_weightnorm(model):
+        waveglow = model
+        for WN in waveglow.WN:
+            WN.start = torch.nn.utils.remove_weight_norm(WN.start)
+            WN.in_layers = remove(WN.in_layers)
+            WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer)
+            WN.res_skip_layers = remove(WN.res_skip_layers)
+        return waveglow
+
+
+def remove(conv_list):
+    new_conv_list = torch.nn.ModuleList()
+    for old_conv in conv_list:
+        old_conv = torch.nn.utils.remove_weight_norm(old_conv)
+        new_conv_list.append(old_conv)
+    return new_conv_list
diff --git a/models/vocoders/gan/discriminator/mpd.py b/models/vocoders/gan/discriminator/mpd.py
new file mode 100644
index 0000000000000000000000000000000000000000..f28711d18847a106a998cab90871fe6303a4fd08
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mpd.py
@@ -0,0 +1,269 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv2d, Conv1d
+from torch.nn.utils import weight_norm, spectral_norm
+from torch import nn
+from modules.vocoder_blocks import *
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorP(torch.nn.Module):
+    def __init__(self, cfg, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP, self).__init__()
+        self.period = period
+        self.d_mult = cfg.model.mpd.discriminator_channel_mult_factor
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    Conv2d(
+                        1,
+                        int(32 * self.d_mult),
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(5, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        int(32 * self.d_mult),
+                        int(128 * self.d_mult),
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(5, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        int(128 * self.d_mult),
+                        int(512 * self.d_mult),
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(5, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        int(512 * self.d_mult),
+                        int(1024 * self.d_mult),
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(5, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        int(1024 * self.d_mult),
+                        int(1024 * self.d_mult),
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(2, 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(
+            Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
+        )
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+    def __init__(self, cfg):
+        super(MultiPeriodDiscriminator, self).__init__()
+        self.mpd_reshapes = cfg.model.mpd.mpd_reshapes
+        print("mpd_reshapes: {}".format(self.mpd_reshapes))
+        discriminators = [
+            DiscriminatorP(cfg, rs, use_spectral_norm=cfg.model.mpd.use_spectral_norm)
+            for rs in self.mpd_reshapes
+        ]
+        self.discriminators = nn.ModuleList(discriminators)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+# TODO: merge with DiscriminatorP (lmxue, yicheng)
+class DiscriminatorP_vits(torch.nn.Module):
+    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+        super(DiscriminatorP_vits, self).__init__()
+        self.period = period
+        self.use_spectral_norm = use_spectral_norm
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(
+                    Conv2d(
+                        1,
+                        32,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        32,
+                        128,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        128,
+                        512,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        512,
+                        1024,
+                        (kernel_size, 1),
+                        (stride, 1),
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+                norm_f(
+                    Conv2d(
+                        1024,
+                        1024,
+                        (kernel_size, 1),
+                        1,
+                        padding=(get_padding(kernel_size, 1), 0),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+    def forward(self, x):
+        fmap = []
+
+        # 1d to 2d
+        b, c, t = x.shape
+        if t % self.period != 0:  # pad first
+            n_pad = self.period - (t % self.period)
+            x = F.pad(x, (0, n_pad), "reflect")
+            t = t + n_pad
+        x = x.view(b, c, t // self.period, self.period)
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class DiscriminatorS(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 16, 15, 1, padding=7)),
+                norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
+                norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+# TODO: merge with MultiPeriodDiscriminator (lmxue, yicheng)
+class MultiPeriodDiscriminator_vits(torch.nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(MultiPeriodDiscriminator_vits, self).__init__()
+        periods = [2, 3, 5, 7, 11]
+
+        discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
+        discs = discs + [
+            DiscriminatorP_vits(i, use_spectral_norm=use_spectral_norm) for i in periods
+        ]
+        self.discriminators = nn.ModuleList(discs)
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            y_d_gs.append(y_d_g)
+            fmap_rs.append(fmap_r)
+            fmap_gs.append(fmap_g)
+
+        outputs = {
+            "y_d_hat_r": y_d_rs,
+            "y_d_hat_g": y_d_gs,
+            "fmap_rs": fmap_rs,
+            "fmap_gs": fmap_gs,
+        }
+
+        return outputs
diff --git a/models/vocoders/gan/discriminator/mrd.py b/models/vocoders/gan/discriminator/mrd.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ee80bfbf82b6aa63c80dbc2c6ffed8cb50a924
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mrd.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from torch import nn
+
+LRELU_SLOPE = 0.1
+
+
+# This code is a refined MRD adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+
+class DiscriminatorR(nn.Module):
+    def __init__(self, cfg, resolution):
+        super().__init__()
+
+        self.resolution = resolution
+        assert (
+            len(self.resolution) == 3
+        ), "MRD layer requires list with len=3, got {}".format(self.resolution)
+        self.lrelu_slope = LRELU_SLOPE
+
+        norm_f = (
+            weight_norm if cfg.model.mrd.use_spectral_norm == False else spectral_norm
+        )
+        if cfg.model.mrd.mrd_override:
+            print(
+                "INFO: overriding MRD use_spectral_norm as {}".format(
+                    cfg.model.mrd.mrd_use_spectral_norm
+                )
+            )
+            norm_f = (
+                weight_norm
+                if cfg.model.mrd.mrd_use_spectral_norm == False
+                else spectral_norm
+            )
+        self.d_mult = cfg.model.mrd.discriminator_channel_mult_factor
+        if cfg.model.mrd.mrd_override:
+            print(
+                "INFO: overriding mrd channel multiplier as {}".format(
+                    cfg.model.mrd.mrd_channel_mult
+                )
+            )
+            self.d_mult = cfg.model.mrd.mrd_channel_mult
+
+        self.convs = nn.ModuleList(
+            [
+                norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))),
+                norm_f(
+                    nn.Conv2d(
+                        int(32 * self.d_mult),
+                        int(32 * self.d_mult),
+                        (3, 9),
+                        stride=(1, 2),
+                        padding=(1, 4),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        int(32 * self.d_mult),
+                        int(32 * self.d_mult),
+                        (3, 9),
+                        stride=(1, 2),
+                        padding=(1, 4),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        int(32 * self.d_mult),
+                        int(32 * self.d_mult),
+                        (3, 9),
+                        stride=(1, 2),
+                        padding=(1, 4),
+                    )
+                ),
+                norm_f(
+                    nn.Conv2d(
+                        int(32 * self.d_mult),
+                        int(32 * self.d_mult),
+                        (3, 3),
+                        padding=(1, 1),
+                    )
+                ),
+            ]
+        )
+        self.conv_post = norm_f(
+            nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
+        )
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.spectrogram(x)
+        x = x.unsqueeze(1)
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, self.lrelu_slope)
+            fmap.append(x)
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+    def spectrogram(self, x):
+        n_fft, hop_length, win_length = self.resolution
+        x = F.pad(
+            x,
+            (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)),
+            mode="reflect",
+        )
+        x = x.squeeze(1)
+        x = torch.stft(
+            x,
+            n_fft=n_fft,
+            hop_length=hop_length,
+            win_length=win_length,
+            center=False,
+            return_complex=True,
+        )
+        x = torch.view_as_real(x)  # [B, F, TT, 2]
+        mag = torch.norm(x, p=2, dim=-1)  # [B, F, TT]
+
+        return mag
+
+
+class MultiResolutionDiscriminator(nn.Module):
+    def __init__(self, cfg, debug=False):
+        super().__init__()
+        self.resolutions = cfg.model.mrd.resolutions
+        assert (
+            len(self.resolutions) == 3
+        ), "MRD requires list of list with len=3, each element having a list with len=3. got {}".format(
+            self.resolutions
+        )
+        self.discriminators = nn.ModuleList(
+            [DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
+        )
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for i, d in enumerate(self.discriminators):
+            y_d_r, fmap_r = d(x=y)
+            y_d_g, fmap_g = d(x=y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/discriminator/msd.py b/models/vocoders/gan/discriminator/msd.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c1556aea581878dcbe10f7a3bdebc33a4972e2c
--- /dev/null
+++ b/models/vocoders/gan/discriminator/msd.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, AvgPool1d
+from torch.nn.utils import weight_norm, spectral_norm
+from torch import nn
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorS(nn.Module):
+    def __init__(self, use_spectral_norm=False):
+        super(DiscriminatorS, self).__init__()
+
+        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+
+        self.convs = nn.ModuleList(
+            [
+                norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+                norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+                norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+                norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+                norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+                norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+            ]
+        )
+
+        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+    def forward(self, x):
+        fmap = []
+
+        for l in self.convs:
+            x = l(x)
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            fmap.append(x)
+
+        x = self.conv_post(x)
+        fmap.append(x)
+        x = torch.flatten(x, 1, -1)
+
+        return x, fmap
+
+
+class MultiScaleDiscriminator(nn.Module):
+    def __init__(self, cfg):
+        super(MultiScaleDiscriminator, self).__init__()
+
+        self.cfg = cfg
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorS(use_spectral_norm=True),
+                DiscriminatorS(),
+                DiscriminatorS(),
+            ]
+        )
+
+        self.meanpools = nn.ModuleList(
+            [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
+        )
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for i, d in enumerate(self.discriminators):
+            if i != 0:
+                y = self.meanpools[i - 1](y)
+                y_hat = self.meanpools[i - 1](y_hat)
+            y_d_r, fmap_r = d(y)
+            y_d_g, fmap_g = d(y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/discriminator/mssbcqtd.py b/models/vocoders/gan/discriminator/mssbcqtd.py
new file mode 100644
index 0000000000000000000000000000000000000000..213de5441754944a360707e99a3734ad035d9077
--- /dev/null
+++ b/models/vocoders/gan/discriminator/mssbcqtd.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch import nn
+from modules.vocoder_blocks import *
+
+from einops import rearrange
+import torchaudio.transforms as T
+
+from nnAudio import features
+
+LRELU_SLOPE = 0.1
+
+
+class DiscriminatorCQT(nn.Module):
+    def __init__(self, cfg, hop_length, n_octaves, bins_per_octave):
+        super(DiscriminatorCQT, self).__init__()
+        self.cfg = cfg
+
+        self.filters = cfg.model.mssbcqtd.filters
+        self.max_filters = cfg.model.mssbcqtd.max_filters
+        self.filters_scale = cfg.model.mssbcqtd.filters_scale
+        self.kernel_size = (3, 9)
+        self.dilations = cfg.model.mssbcqtd.dilations
+        self.stride = (1, 2)
+
+        self.in_channels = cfg.model.mssbcqtd.in_channels
+        self.out_channels = cfg.model.mssbcqtd.out_channels
+        self.fs = cfg.preprocess.sample_rate
+        self.hop_length = hop_length
+        self.n_octaves = n_octaves
+        self.bins_per_octave = bins_per_octave
+
+        self.cqt_transform = features.cqt.CQT2010v2(
+            sr=self.fs * 2,
+            hop_length=self.hop_length,
+            n_bins=self.bins_per_octave * self.n_octaves,
+            bins_per_octave=self.bins_per_octave,
+            output_format="Complex",
+            pad_mode="constant",
+        )
+
+        self.conv_pres = nn.ModuleList()
+        for i in range(self.n_octaves):
+            self.conv_pres.append(
+                NormConv2d(
+                    self.in_channels * 2,
+                    self.in_channels * 2,
+                    kernel_size=self.kernel_size,
+                    padding=get_2d_padding(self.kernel_size),
+                )
+            )
+
+        self.convs = nn.ModuleList()
+
+        self.convs.append(
+            NormConv2d(
+                self.in_channels * 2,
+                self.filters,
+                kernel_size=self.kernel_size,
+                padding=get_2d_padding(self.kernel_size),
+            )
+        )
+
+        in_chs = min(self.filters_scale * self.filters, self.max_filters)
+        for i, dilation in enumerate(self.dilations):
+            out_chs = min(
+                (self.filters_scale ** (i + 1)) * self.filters, self.max_filters
+            )
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=self.kernel_size,
+                    stride=self.stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(self.kernel_size, (dilation, 1)),
+                    norm="weight_norm",
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (self.filters_scale ** (len(self.dilations) + 1)) * self.filters,
+            self.max_filters,
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+                padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+                norm="weight_norm",
+            )
+        )
+
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(self.kernel_size[0], self.kernel_size[0]),
+            padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
+            norm="weight_norm",
+        )
+
+        self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE)
+        self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2)
+
+    def forward(self, x):
+        fmap = []
+
+        x = self.resample(x)
+
+        z = self.cqt_transform(x)
+
+        z_amplitude = z[:, :, :, 0].unsqueeze(1)
+        z_phase = z[:, :, :, 1].unsqueeze(1)
+
+        z = torch.cat([z_amplitude, z_phase], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+
+        latent_z = []
+        for i in range(self.n_octaves):
+            latent_z.append(
+                self.conv_pres[i](
+                    z[
+                        :,
+                        :,
+                        :,
+                        i * self.bins_per_octave : (i + 1) * self.bins_per_octave,
+                    ]
+                )
+            )
+        latent_z = torch.cat(latent_z, dim=-1)
+
+        for i, l in enumerate(self.convs):
+            latent_z = l(latent_z)
+
+            latent_z = self.activation(latent_z)
+            fmap.append(latent_z)
+
+        latent_z = self.conv_post(latent_z)
+
+        return latent_z, fmap
+
+
+class MultiScaleSubbandCQTDiscriminator(nn.Module):
+    def __init__(self, cfg):
+        super(MultiScaleSubbandCQTDiscriminator, self).__init__()
+
+        self.cfg = cfg
+
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorCQT(
+                    cfg,
+                    hop_length=cfg.model.mssbcqtd.hop_lengths[i],
+                    n_octaves=cfg.model.mssbcqtd.n_octaves[i],
+                    bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i],
+                )
+                for i in range(len(cfg.model.mssbcqtd.hop_lengths))
+            ]
+        )
+
+    def forward(self, y, y_hat):
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for disc in self.discriminators:
+            y_d_r, fmap_r = disc(y)
+            y_d_g, fmap_g = disc(y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/discriminator/msstftd.py b/models/vocoders/gan/discriminator/msstftd.py
new file mode 100644
index 0000000000000000000000000000000000000000..83dedb78848d2d73ac667e7a191f05de1ed7bf21
--- /dev/null
+++ b/models/vocoders/gan/discriminator/msstftd.py
@@ -0,0 +1,226 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is adopted from META's Encodec under MIT License
+# https://github.com/facebookresearch/encodec
+
+"""MS-STFT discriminator, provided here for reference."""
+
+import typing as tp
+
+import torchaudio
+import torch
+from torch import nn
+from einops import rearrange
+
+from modules.vocoder_blocks import *
+
+
+FeatureMapType = tp.List[torch.Tensor]
+LogitsType = torch.Tensor
+DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]]
+
+
+def get_2d_padding(
+    kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1)
+):
+    return (
+        ((kernel_size[0] - 1) * dilation[0]) // 2,
+        ((kernel_size[1] - 1) * dilation[1]) // 2,
+    )
+
+
+class DiscriminatorSTFT(nn.Module):
+    """STFT sub-discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_fft (int): Size of FFT for each scale. Default: 1024
+        hop_length (int): Length of hop between STFT windows for each scale. Default: 256
+        kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)``
+        stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)``
+        dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]``
+        win_length (int): Window size for each scale. Default: 1024
+        normalized (bool): Whether to normalize by magnitude after stft. Default: True
+        norm (str): Normalization method. Default: `'weight_norm'`
+        activation (str): Activation function. Default: `'LeakyReLU'`
+        activation_params (dict): Parameters to provide to the activation function.
+        growth (int): Growth factor for the filters. Default: 1
+    """
+
+    def __init__(
+        self,
+        filters: int,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_fft: int = 1024,
+        hop_length: int = 256,
+        win_length: int = 1024,
+        max_filters: int = 1024,
+        filters_scale: int = 1,
+        kernel_size: tp.Tuple[int, int] = (3, 9),
+        dilations: tp.List = [1, 2, 4],
+        stride: tp.Tuple[int, int] = (1, 2),
+        normalized: bool = True,
+        norm: str = "weight_norm",
+        activation: str = "LeakyReLU",
+        activation_params: dict = {"negative_slope": 0.2},
+    ):
+        super().__init__()
+        assert len(kernel_size) == 2
+        assert len(stride) == 2
+        self.filters = filters
+        self.in_channels = in_channels
+        self.out_channels = out_channels
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.normalized = normalized
+        self.activation = getattr(torch.nn, activation)(**activation_params)
+        self.spec_transform = torchaudio.transforms.Spectrogram(
+            n_fft=self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            window_fn=torch.hann_window,
+            normalized=self.normalized,
+            center=False,
+            pad_mode=None,
+            power=None,
+        )
+        spec_channels = 2 * self.in_channels
+        self.convs = nn.ModuleList()
+        self.convs.append(
+            NormConv2d(
+                spec_channels,
+                self.filters,
+                kernel_size=kernel_size,
+                padding=get_2d_padding(kernel_size),
+            )
+        )
+        in_chs = min(filters_scale * self.filters, max_filters)
+        for i, dilation in enumerate(dilations):
+            out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters)
+            self.convs.append(
+                NormConv2d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=kernel_size,
+                    stride=stride,
+                    dilation=(dilation, 1),
+                    padding=get_2d_padding(kernel_size, (dilation, 1)),
+                    norm=norm,
+                )
+            )
+            in_chs = out_chs
+        out_chs = min(
+            (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters
+        )
+        self.convs.append(
+            NormConv2d(
+                in_chs,
+                out_chs,
+                kernel_size=(kernel_size[0], kernel_size[0]),
+                padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+                norm=norm,
+            )
+        )
+        self.conv_post = NormConv2d(
+            out_chs,
+            self.out_channels,
+            kernel_size=(kernel_size[0], kernel_size[0]),
+            padding=get_2d_padding((kernel_size[0], kernel_size[0])),
+            norm=norm,
+        )
+
+    def forward(self, x: torch.Tensor):
+        """Discriminator STFT Module is the sub module of MultiScaleSTFTDiscriminator.
+
+        Args:
+            x (torch.Tensor): input tensor of shape [B, 1, Time]
+
+        Returns:
+            z: z is the output of the last convolutional layer of shape
+            fmap: fmap is the list of feature maps of every convolutional layer of shape
+        """
+        fmap = []
+        z = self.spec_transform(x)  # [B, 2, Freq, Frames, 2]
+        z = torch.cat([z.real, z.imag], dim=1)
+        z = rearrange(z, "b c w t -> b c t w")
+        for i, layer in enumerate(self.convs):
+            z = layer(z)
+
+            z = self.activation(z)
+            fmap.append(z)
+        z = self.conv_post(z)
+        return z, fmap
+
+
+class MultiScaleSTFTDiscriminator(nn.Module):
+    """Multi-Scale STFT (MS-STFT) discriminator.
+    Args:
+        filters (int): Number of filters in convolutions
+        in_channels (int): Number of input channels. Default: 1
+        out_channels (int): Number of output channels. Default: 1
+        n_ffts (Sequence[int]): Size of FFT for each scale
+        hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale
+        win_lengths (Sequence[int]): Window size for each scale
+        **kwargs: additional args for STFTDiscriminator
+    """
+
+    def __init__(
+        self,
+        cfg,
+        in_channels: int = 1,
+        out_channels: int = 1,
+        n_ffts: tp.List[int] = [1024, 2048, 512],
+        hop_lengths: tp.List[int] = [256, 512, 256],
+        win_lengths: tp.List[int] = [1024, 2048, 512],
+        **kwargs,
+    ):
+        self.cfg = cfg
+        super().__init__()
+        assert len(n_ffts) == len(hop_lengths) == len(win_lengths)
+        self.discriminators = nn.ModuleList(
+            [
+                DiscriminatorSTFT(
+                    filters=self.cfg.model.msstftd.filters,
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    n_fft=n_ffts[i],
+                    win_length=win_lengths[i],
+                    hop_length=hop_lengths[i],
+                    **kwargs,
+                )
+                for i in range(len(n_ffts))
+            ]
+        )
+        self.num_discriminators = len(self.discriminators)
+
+    def forward(self, y, y_hat) -> DiscriminatorOutput:
+        """Multi-Scale STFT (MS-STFT) discriminator.
+
+        Args:
+            x (torch.Tensor): input waveform
+
+        Returns:
+            logits: list of every discriminator's output
+            fmaps: list of every discriminator's feature maps,
+                each feature maps is a list of Discriminator STFT's every layer
+        """
+        y_d_rs = []
+        y_d_gs = []
+        fmap_rs = []
+        fmap_gs = []
+
+        for disc in self.discriminators:
+            y_d_r, fmap_r = disc(y)
+            y_d_g, fmap_g = disc(y_hat)
+            y_d_rs.append(y_d_r)
+            fmap_rs.append(fmap_r)
+            y_d_gs.append(y_d_g)
+            fmap_gs.append(fmap_g)
+
+        return y_d_rs, y_d_gs, fmap_rs, fmap_gs
diff --git a/models/vocoders/gan/gan_vocoder_dataset.py b/models/vocoders/gan/gan_vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf87c371647a44fb5bcae33701eda65616e5fd7
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_dataset.py
@@ -0,0 +1,205 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import random
+
+import numpy as np
+
+from torch.nn import functional as F
+
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from models.vocoders.vocoder_dataset import VocoderDataset
+
+
+class GANVocoderDataset(VocoderDataset):
+    def __init__(self, cfg, dataset, is_valid=False):
+        """
+        Args:
+            cfg: config
+            dataset: dataset name
+            is_valid: whether to use train or valid dataset
+        """
+        super().__init__(cfg, dataset, is_valid)
+
+        eval_index = random.randint(0, len(self.metadata) - 1)
+        eval_utt_info = self.metadata[eval_index]
+        eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"])
+        self.eval_audio = np.load(self.utt2audio_path[eval_utt])
+        if cfg.preprocess.use_mel:
+            self.eval_mel = np.load(self.utt2mel_path[eval_utt])
+        if cfg.preprocess.use_frame_pitch:
+            self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt])
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        single_feature = dict()
+
+        if self.cfg.preprocess.use_mel:
+            mel = np.load(self.utt2mel_path[utt])
+            assert mel.shape[0] == self.cfg.preprocess.n_mel
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = mel.shape[1]
+
+            if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+                mel = np.pad(
+                    mel,
+                    ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+                    mode="constant",
+                )
+            else:
+                if "start" not in single_feature.keys():
+                    start = random.randint(
+                        0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame
+                    )
+                    end = start + self.cfg.preprocess.cut_mel_frame
+                    single_feature["start"] = start
+                    single_feature["end"] = end
+                mel = mel[:, single_feature["start"] : single_feature["end"]]
+            single_feature["mel"] = mel
+
+        if self.cfg.preprocess.use_frame_pitch:
+            frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_pitch)
+            aligned_frame_pitch = align_length(
+                frame_pitch, single_feature["target_len"]
+            )
+
+            if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+                aligned_frame_pitch = np.pad(
+                    aligned_frame_pitch,
+                    (
+                        (
+                            0,
+                            self.cfg.preprocess.cut_mel_frame
+                            * self.cfg.preprocess.hop_size
+                            - audio.shape[-1],
+                        )
+                    ),
+                    mode="constant",
+                )
+            else:
+                if "start" not in single_feature.keys():
+                    start = random.randint(
+                        0,
+                        aligned_frame_pitch.shape[-1]
+                        - self.cfg.preprocess.cut_mel_frame,
+                    )
+                    end = start + self.cfg.preprocess.cut_mel_frame
+                    single_feature["start"] = start
+                    single_feature["end"] = end
+                aligned_frame_pitch = aligned_frame_pitch[
+                    single_feature["start"] : single_feature["end"]
+                ]
+            single_feature["frame_pitch"] = aligned_frame_pitch
+
+        if self.cfg.preprocess.use_audio:
+            audio = np.load(self.utt2audio_path[utt])
+
+            assert "target_len" in single_feature.keys()
+
+            if (
+                audio.shape[-1]
+                <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size
+            ):
+                audio = np.pad(
+                    audio,
+                    (
+                        (
+                            0,
+                            self.cfg.preprocess.cut_mel_frame
+                            * self.cfg.preprocess.hop_size
+                            - audio.shape[-1],
+                        )
+                    ),
+                    mode="constant",
+                )
+            else:
+                if "start" not in single_feature.keys():
+                    audio = audio[
+                        0 : self.cfg.preprocess.cut_mel_frame
+                        * self.cfg.preprocess.hop_size
+                    ]
+                else:
+                    audio = audio[
+                        single_feature["start"]
+                        * self.cfg.preprocess.hop_size : single_feature["end"]
+                        * self.cfg.preprocess.hop_size,
+                    ]
+            single_feature["audio"] = audio
+
+        if self.cfg.preprocess.use_amplitude_phase:
+            logamp = np.load(self.utt2logamp_path[utt])
+            pha = np.load(self.utt2pha_path[utt])
+            rea = np.load(self.utt2rea_path[utt])
+            imag = np.load(self.utt2imag_path[utt])
+
+            assert "target_len" in single_feature.keys()
+
+            if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame:
+                logamp = np.pad(
+                    logamp,
+                    ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+                    mode="constant",
+                )
+                pha = np.pad(
+                    pha,
+                    ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+                    mode="constant",
+                )
+                rea = np.pad(
+                    rea,
+                    ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+                    mode="constant",
+                )
+                imag = np.pad(
+                    imag,
+                    ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])),
+                    mode="constant",
+                )
+            else:
+                logamp = logamp[:, single_feature["start"] : single_feature["end"]]
+                pha = pha[:, single_feature["start"] : single_feature["end"]]
+                rea = rea[:, single_feature["start"] : single_feature["end"]]
+                imag = imag[:, single_feature["start"] : single_feature["end"]]
+            single_feature["logamp"] = logamp
+            single_feature["pha"] = pha
+            single_feature["rea"] = rea
+            single_feature["imag"] = imag
+
+        return single_feature
+
+
+class GANVocoderCollator(object):
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def __call__(self, batch):
+        packed_batch_features = dict()
+
+        # mel: [b, n_mels, frame]
+        # frame_pitch: [b, frame]
+        # audios: [b, frame * hop_size]
+
+        for key in batch[0].keys():
+            if key in ["target_len", "start", "end"]:
+                continue
+            else:
+                values = [torch.from_numpy(b[key]) for b in batch]
+                packed_batch_features[key] = pad_sequence(
+                    values, batch_first=True, padding_value=0
+                )
+
+        return packed_batch_features
diff --git a/models/vocoders/gan/gan_vocoder_inference.py b/models/vocoders/gan/gan_vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb69631662dedf4fc73a29f675f0a4bc361b03ec
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_inference.py
@@ -0,0 +1,110 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from utils.util import pad_mels_to_tensors, pad_f0_to_tensors
+
+
+def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False):
+    """Inference the vocoder
+    Args:
+        mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames)
+    Returns:
+        audios: A tensor of audios with the shape (batch_size, seq_len)
+    """
+    model.eval()
+
+    with torch.no_grad():
+        mels = mels.to(device)
+        if f0s != None:
+            f0s = f0s.to(device)
+
+        if f0s == None and not cfg.preprocess.extract_amplitude_phase:
+            output = model.forward(mels)
+        elif cfg.preprocess.extract_amplitude_phase:
+            (
+                _,
+                _,
+                _,
+                _,
+                output,
+            ) = model.forward(mels)
+        else:
+            output = model.forward(mels, f0s)
+
+        return output.squeeze(1).detach().cpu()
+
+
+def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False):
+    """Inference the vocoder
+    Args:
+        mels: A list of mel-specs
+    Returns:
+        audios: A list of audios
+    """
+    # Get the device
+    device = next(model.parameters()).device
+
+    audios = []
+
+    # Pad the given list into tensors
+    mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size)
+    if f0s != None:
+        f0_batches = pad_f0_to_tensors(f0s, batch_size)
+
+    if f0s == None:
+        for mel_batch, mel_frame in zip(mel_batches, mel_frames):
+            for i in range(mel_batch.shape[0]):
+                mel = mel_batch[i]
+                frame = mel_frame[i]
+                audio = vocoder_inference(
+                    cfg,
+                    model,
+                    mel.unsqueeze(0),
+                    device=device,
+                    fast_inference=fast_inference,
+                ).squeeze(0)
+
+                # # Apply fade_out to make the sound more natural
+                # fade_out = torch.linspace(
+                #     1, 0, steps=20 * model.cfg.preprocess.hop_size
+                # ).cpu()
+
+                # calculate the audio length
+                audio_length = frame * model.cfg.preprocess.hop_size
+                audio = audio[:audio_length]
+
+                # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out
+
+                audios.append(audio)
+    else:
+        for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames):
+            for i in range(mel_batch.shape[0]):
+                mel = mel_batch[i]
+                f0 = f0_batch[i]
+                frame = mel_frame[i]
+                audio = vocoder_inference(
+                    cfg,
+                    model,
+                    mel.unsqueeze(0),
+                    f0s=f0.unsqueeze(0),
+                    device=device,
+                    fast_inference=fast_inference,
+                ).squeeze(0)
+
+                # # Apply fade_out to make the sound more natural
+                # fade_out = torch.linspace(
+                #     1, 0, steps=20 * model.cfg.preprocess.hop_size
+                # ).cpu()
+
+                # calculate the audio length
+                audio_length = frame * model.cfg.preprocess.hop_length
+                audio = audio[:audio_length]
+
+                # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out
+
+                audios.append(audio)
+    return audios
diff --git a/models/vocoders/gan/gan_vocoder_trainer.py b/models/vocoders/gan/gan_vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fb9c8f03a7de14d0162bfd671d33b76890293a5
--- /dev/null
+++ b/models/vocoders/gan/gan_vocoder_trainer.py
@@ -0,0 +1,1112 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import sys
+import time
+import torch
+import json
+import itertools
+import accelerate
+import torch.distributed as dist
+import torch.nn.functional as F
+from tqdm import tqdm
+from torch.nn.parallel import DistributedDataParallel
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+from torch.utils.tensorboard import SummaryWriter
+
+from torch.optim import AdamW
+from torch.optim.lr_scheduler import ExponentialLR
+
+from librosa.filters import mel as librosa_mel_fn
+
+from accelerate.logging import get_logger
+from pathlib import Path
+
+from utils.io import save_audio
+from utils.data_utils import *
+from utils.util import (
+    Logger,
+    ValueWindow,
+    remove_older_ckpt,
+    set_all_random_seed,
+    save_config,
+)
+from utils.mel import extract_mel_features
+from models.vocoders.vocoder_trainer import VocoderTrainer
+from models.vocoders.gan.gan_vocoder_dataset import (
+    GANVocoderDataset,
+    GANVocoderCollator,
+)
+
+from models.vocoders.gan.generator.bigvgan import BigVGAN
+from models.vocoders.gan.generator.hifigan import HiFiGAN
+from models.vocoders.gan.generator.melgan import MelGAN
+from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
+from models.vocoders.gan.generator.apnet import APNet
+
+from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator
+from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator
+from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator
+from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator
+from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator
+
+from models.vocoders.gan.gan_vocoder_inference import vocoder_inference
+
+supported_generators = {
+    "bigvgan": BigVGAN,
+    "hifigan": HiFiGAN,
+    "melgan": MelGAN,
+    "nsfhifigan": NSFHiFiGAN,
+    "apnet": APNet,
+}
+
+supported_discriminators = {
+    "mpd": MultiPeriodDiscriminator,
+    "msd": MultiScaleDiscriminator,
+    "mrd": MultiResolutionDiscriminator,
+    "msstftd": MultiScaleSTFTDiscriminator,
+    "mssbcqtd": MultiScaleSubbandCQTDiscriminator,
+}
+
+
+class GANVocoderTrainer(VocoderTrainer):
+    def __init__(self, args, cfg):
+        super().__init__()
+
+        self.args = args
+        self.cfg = cfg
+
+        cfg.exp_name = args.exp_name
+
+        # Init accelerator
+        self._init_accelerator()
+        self.accelerator.wait_for_everyone()
+
+        # Init logger
+        with self.accelerator.main_process_first():
+            self.logger = get_logger(args.exp_name, log_level=args.log_level)
+
+        self.logger.info("=" * 56)
+        self.logger.info("||\t\t" + "New training process started." + "\t\t||")
+        self.logger.info("=" * 56)
+        self.logger.info("\n")
+        self.logger.debug(f"Using {args.log_level.upper()} logging level.")
+        self.logger.info(f"Experiment name: {args.exp_name}")
+        self.logger.info(f"Experiment directory: {self.exp_dir}")
+        self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+        if self.accelerator.is_main_process:
+            os.makedirs(self.checkpoint_dir, exist_ok=True)
+        self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+        # Init training status
+        self.batch_count: int = 0
+        self.step: int = 0
+        self.epoch: int = 0
+
+        self.max_epoch = (
+            self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
+        )
+        self.logger.info(
+            "Max epoch: {}".format(
+                self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
+            )
+        )
+
+        # Check potential erorrs
+        if self.accelerator.is_main_process:
+            self._check_basic_configs()
+            self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
+            self.checkpoints_path = [
+                [] for _ in range(len(self.save_checkpoint_stride))
+            ]
+            self.run_eval = self.cfg.train.run_eval
+
+        # Set random seed
+        with self.accelerator.main_process_first():
+            start = time.monotonic_ns()
+            self._set_random_seed(self.cfg.train.random_seed)
+            end = time.monotonic_ns()
+            self.logger.debug(
+                f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+            )
+            self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+        # Build dataloader
+        with self.accelerator.main_process_first():
+            self.logger.info("Building dataset...")
+            start = time.monotonic_ns()
+            self.train_dataloader, self.valid_dataloader = self._build_dataloader()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+        # Build model
+        with self.accelerator.main_process_first():
+            self.logger.info("Building model...")
+            start = time.monotonic_ns()
+            self.generator, self.discriminators = self._build_model()
+            end = time.monotonic_ns()
+            self.logger.debug(self.generator)
+            for _, discriminator in self.discriminators.items():
+                self.logger.debug(discriminator)
+            self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
+            self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M")
+
+        # Build optimizers and schedulers
+        with self.accelerator.main_process_first():
+            self.logger.info("Building optimizer and scheduler...")
+            start = time.monotonic_ns()
+            (
+                self.generator_optimizer,
+                self.discriminator_optimizer,
+            ) = self._build_optimizer()
+            (
+                self.generator_scheduler,
+                self.discriminator_scheduler,
+            ) = self._build_scheduler()
+            end = time.monotonic_ns()
+            self.logger.info(
+                f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
+            )
+
+        # Accelerator preparing
+        self.logger.info("Initializing accelerate...")
+        start = time.monotonic_ns()
+        (
+            self.train_dataloader,
+            self.valid_dataloader,
+            self.generator,
+            self.generator_optimizer,
+            self.discriminator_optimizer,
+            self.generator_scheduler,
+            self.discriminator_scheduler,
+        ) = self.accelerator.prepare(
+            self.train_dataloader,
+            self.valid_dataloader,
+            self.generator,
+            self.generator_optimizer,
+            self.discriminator_optimizer,
+            self.generator_scheduler,
+            self.discriminator_scheduler,
+        )
+        for key, discriminator in self.discriminators.items():
+            self.discriminators[key] = self.accelerator.prepare_model(discriminator)
+        end = time.monotonic_ns()
+        self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
+
+        # Build criterions
+        with self.accelerator.main_process_first():
+            self.logger.info("Building criterion...")
+            start = time.monotonic_ns()
+            self.criterions = self._build_criterion()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
+
+        # Resume checkpoints
+        with self.accelerator.main_process_first():
+            if args.resume_type:
+                self.logger.info("Resuming from checkpoint...")
+                start = time.monotonic_ns()
+                ckpt_path = Path(args.checkpoint)
+                if self._is_valid_pattern(ckpt_path.parts[-1]):
+                    ckpt_path = self._load_model(
+                        None, args.checkpoint, args.resume_type
+                    )
+                else:
+                    ckpt_path = self._load_model(
+                        args.checkpoint, resume_type=args.resume_type
+                    )
+                end = time.monotonic_ns()
+                self.logger.info(
+                    f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
+                )
+                self.checkpoints_path = json.load(
+                    open(os.path.join(ckpt_path, "ckpts.json"), "r")
+                )
+
+            self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
+            if self.accelerator.is_main_process:
+                os.makedirs(self.checkpoint_dir, exist_ok=True)
+            self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
+
+        # Save config
+        self.config_save_path = os.path.join(self.exp_dir, "args.json")
+
+    def _build_dataset(self):
+        return GANVocoderDataset, GANVocoderCollator
+
+    def _build_criterion(self):
+        class feature_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(feature_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, fmap_r, fmap_g):
+                loss = 0
+
+                if self.cfg.model.generator in [
+                    "hifigan",
+                    "nsfhifigan",
+                    "bigvgan",
+                    "apnet",
+                ]:
+                    for dr, dg in zip(fmap_r, fmap_g):
+                        for rl, gl in zip(dr, dg):
+                            loss += torch.mean(torch.abs(rl - gl))
+
+                    loss = loss * 2
+                elif self.cfg.model.generator in ["melgan"]:
+                    for dr, dg in zip(fmap_r, fmap_g):
+                        for rl, gl in zip(dr, dg):
+                            loss += self.l1Loss(rl, gl)
+
+                    loss = loss * 10
+                elif self.cfg.model.generator in ["codec"]:
+                    for dr, dg in zip(fmap_r, fmap_g):
+                        for rl, gl in zip(dr, dg):
+                            loss = loss + self.l1Loss(rl, gl) / torch.mean(
+                                torch.abs(rl)
+                            )
+
+                    KL_scale = len(fmap_r) * len(fmap_r[0])
+
+                    loss = 3 * loss / KL_scale
+                else:
+                    raise NotImplementedError
+
+                return loss
+
+        class discriminator_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(discriminator_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, disc_real_outputs, disc_generated_outputs):
+                loss = 0
+                r_losses = []
+                g_losses = []
+
+                if self.cfg.model.generator in [
+                    "hifigan",
+                    "nsfhifigan",
+                    "bigvgan",
+                    "apnet",
+                ]:
+                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+                        r_loss = torch.mean((1 - dr) ** 2)
+                        g_loss = torch.mean(dg**2)
+                        loss += r_loss + g_loss
+                        r_losses.append(r_loss.item())
+                        g_losses.append(g_loss.item())
+                elif self.cfg.model.generator in ["melgan"]:
+                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+                        r_loss = torch.mean(self.relu(1 - dr))
+                        g_loss = torch.mean(self.relu(1 + dg))
+                        loss = loss + r_loss + g_loss
+                        r_losses.append(r_loss.item())
+                        g_losses.append(g_loss.item())
+                elif self.cfg.model.generator in ["codec"]:
+                    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+                        r_loss = torch.mean(self.relu(1 - dr))
+                        g_loss = torch.mean(self.relu(1 + dg))
+                        loss = loss + r_loss + g_loss
+                        r_losses.append(r_loss.item())
+                        g_losses.append(g_loss.item())
+
+                    loss = loss / len(disc_real_outputs)
+                else:
+                    raise NotImplementedError
+
+                return loss, r_losses, g_losses
+
+        class generator_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(generator_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, disc_outputs):
+                loss = 0
+                gen_losses = []
+
+                if self.cfg.model.generator in [
+                    "hifigan",
+                    "nsfhifigan",
+                    "bigvgan",
+                    "apnet",
+                ]:
+                    for dg in disc_outputs:
+                        l = torch.mean((1 - dg) ** 2)
+                        gen_losses.append(l)
+                        loss += l
+                elif self.cfg.model.generator in ["melgan"]:
+                    for dg in disc_outputs:
+                        l = -torch.mean(dg)
+                        gen_losses.append(l)
+                        loss += l
+                elif self.cfg.model.generator in ["codec"]:
+                    for dg in disc_outputs:
+                        l = torch.mean(self.relu(1 - dg)) / len(disc_outputs)
+                        gen_losses.append(l)
+                        loss += l
+                else:
+                    raise NotImplementedError
+
+                return loss, gen_losses
+
+        class mel_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(mel_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, y_gt, y_pred):
+                loss = 0
+
+                if self.cfg.model.generator in [
+                    "hifigan",
+                    "nsfhifigan",
+                    "bigvgan",
+                    "melgan",
+                    "codec",
+                    "apnet",
+                ]:
+                    y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess)
+                    y_pred_mel = extract_mel_features(
+                        y_pred.squeeze(1), self.cfg.preprocess
+                    )
+
+                    loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45
+                else:
+                    raise NotImplementedError
+
+                return loss
+
+        class wav_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(wav_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, y_gt, y_pred):
+                loss = 0
+
+                if self.cfg.model.generator in [
+                    "hifigan",
+                    "nsfhifigan",
+                    "bigvgan",
+                    "apnet",
+                ]:
+                    loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100
+                elif self.cfg.model.generator in ["melgan"]:
+                    loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10
+                elif self.cfg.model.generator in ["codec"]:
+                    loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss(
+                        y_gt, y_pred.squeeze(1)
+                    )
+                    loss /= 10
+                else:
+                    raise NotImplementedError
+
+                return loss
+
+        class phase_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(phase_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, phase_gt, phase_pred):
+                n_fft = self.cfg.preprocess.n_fft
+                frames = phase_gt.size()[-1]
+
+                GD_matrix = (
+                    torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1)
+                    - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2)
+                    - torch.eye(n_fft // 2 + 1)
+                )
+                GD_matrix = GD_matrix.to(phase_pred.device)
+
+                GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix)
+                GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix)
+
+                PTD_matrix = (
+                    torch.triu(torch.ones(frames, frames), diagonal=1)
+                    - torch.triu(torch.ones(frames, frames), diagonal=2)
+                    - torch.eye(frames)
+                )
+                PTD_matrix = PTD_matrix.to(phase_pred.device)
+
+                PTD_r = torch.matmul(phase_gt, PTD_matrix)
+                PTD_g = torch.matmul(phase_pred, PTD_matrix)
+
+                IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred))
+                GD_loss = torch.mean(-torch.cos(GD_r - GD_g))
+                PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g))
+
+                return 100 * (IP_loss + GD_loss + PTD_loss)
+
+        class amplitude_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(amplitude_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(self, log_amplitude_gt, log_amplitude_pred):
+                amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred)
+
+                return 45 * amplitude_loss
+
+        class consistency_criterion(torch.nn.Module):
+            def __init__(self, cfg):
+                super(consistency_criterion, self).__init__()
+                self.cfg = cfg
+                self.l1Loss = torch.nn.L1Loss(reduction="mean")
+                self.l2Loss = torch.nn.MSELoss(reduction="mean")
+                self.relu = torch.nn.ReLU()
+
+            def __call__(
+                self,
+                rea_gt,
+                rea_pred,
+                rea_pred_final,
+                imag_gt,
+                imag_pred,
+                imag_pred_final,
+            ):
+                C_loss = torch.mean(
+                    torch.mean(
+                        (rea_pred - rea_pred_final) ** 2
+                        + (imag_pred - imag_pred_final) ** 2,
+                        (1, 2),
+                    )
+                )
+
+                L_R = self.l1Loss(rea_gt, rea_pred)
+                L_I = self.l1Loss(imag_gt, imag_pred)
+
+                return 20 * (C_loss + 2.25 * (L_R + L_I))
+
+        criterions = dict()
+        for key in self.cfg.train.criterions:
+            if key == "feature":
+                criterions["feature"] = feature_criterion(self.cfg)
+            elif key == "discriminator":
+                criterions["discriminator"] = discriminator_criterion(self.cfg)
+            elif key == "generator":
+                criterions["generator"] = generator_criterion(self.cfg)
+            elif key == "mel":
+                criterions["mel"] = mel_criterion(self.cfg)
+            elif key == "wav":
+                criterions["wav"] = wav_criterion(self.cfg)
+            elif key == "phase":
+                criterions["phase"] = phase_criterion(self.cfg)
+            elif key == "amplitude":
+                criterions["amplitude"] = amplitude_criterion(self.cfg)
+            elif key == "consistency":
+                criterions["consistency"] = consistency_criterion(self.cfg)
+            else:
+                raise NotImplementedError
+
+        return criterions
+
+    def _build_model(self):
+        generator = supported_generators[self.cfg.model.generator](self.cfg)
+        discriminators = dict()
+        for key in self.cfg.model.discriminators:
+            discriminators[key] = supported_discriminators[key](self.cfg)
+
+        return generator, discriminators
+
+    def _build_optimizer(self):
+        optimizer_params_generator = [dict(params=self.generator.parameters())]
+        generator_optimizer = AdamW(
+            optimizer_params_generator,
+            lr=self.cfg.train.adamw.lr,
+            betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
+        )
+
+        optimizer_params_discriminator = []
+        for discriminator in self.discriminators.keys():
+            optimizer_params_discriminator.append(
+                dict(params=self.discriminators[discriminator].parameters())
+            )
+        discriminator_optimizer = AdamW(
+            optimizer_params_discriminator,
+            lr=self.cfg.train.adamw.lr,
+            betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2),
+        )
+
+        return generator_optimizer, discriminator_optimizer
+
+    def _build_scheduler(self):
+        discriminator_scheduler = ExponentialLR(
+            self.discriminator_optimizer,
+            gamma=self.cfg.train.exponential_lr.lr_decay,
+            last_epoch=self.epoch - 1,
+        )
+
+        generator_scheduler = ExponentialLR(
+            self.generator_optimizer,
+            gamma=self.cfg.train.exponential_lr.lr_decay,
+            last_epoch=self.epoch - 1,
+        )
+
+        return generator_scheduler, discriminator_scheduler
+
+    def train_loop(self):
+        """Training process"""
+        self.accelerator.wait_for_everyone()
+
+        # Dump config
+        if self.accelerator.is_main_process:
+            self._dump_cfg(self.config_save_path)
+        self.generator.train()
+        for key in self.discriminators.keys():
+            self.discriminators[key].train()
+        self.generator_optimizer.zero_grad()
+        self.discriminator_optimizer.zero_grad()
+
+        # Sync and start training
+        self.accelerator.wait_for_everyone()
+        while self.epoch < self.max_epoch:
+            self.logger.info("\n")
+            self.logger.info("-" * 32)
+            self.logger.info("Epoch {}: ".format(self.epoch))
+
+            # Train and Validate
+            train_total_loss, train_losses = self._train_epoch()
+            for key, loss in train_losses.items():
+                self.logger.info("  |- Train/{} Loss: {:.6f}".format(key, loss))
+                self.accelerator.log(
+                    {"Epoch/Train {} Loss".format(key): loss},
+                    step=self.epoch,
+                )
+            valid_total_loss, valid_losses = self._valid_epoch()
+            for key, loss in valid_losses.items():
+                self.logger.info("  |- Valid/{} Loss: {:.6f}".format(key, loss))
+                self.accelerator.log(
+                    {"Epoch/Valid {} Loss".format(key): loss},
+                    step=self.epoch,
+                )
+            self.accelerator.log(
+                {
+                    "Epoch/Train Total Loss": train_total_loss,
+                    "Epoch/Valid Total Loss": valid_total_loss,
+                },
+                step=self.epoch,
+            )
+
+            # Update scheduler
+            self.accelerator.wait_for_everyone()
+            self.generator_scheduler.step()
+            self.discriminator_scheduler.step()
+
+            # Check save checkpoint interval
+            run_eval = False
+            if self.accelerator.is_main_process:
+                save_checkpoint = False
+                for i, num in enumerate(self.save_checkpoint_stride):
+                    if self.epoch % num == 0:
+                        save_checkpoint = True
+                        run_eval |= self.run_eval[i]
+
+            # Save checkpoints
+            self.accelerator.wait_for_everyone()
+            if self.accelerator.is_main_process and save_checkpoint:
+                path = os.path.join(
+                    self.checkpoint_dir,
+                    "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                        self.epoch, self.step, valid_total_loss
+                    ),
+                )
+                self.accelerator.save_state(path)
+                json.dump(
+                    self.checkpoints_path,
+                    open(os.path.join(path, "ckpts.json"), "w"),
+                    ensure_ascii=False,
+                    indent=4,
+                )
+
+            # Save eval audios
+            self.accelerator.wait_for_everyone()
+            if self.accelerator.is_main_process and run_eval:
+                for i in range(len(self.valid_dataloader.dataset.eval_audios)):
+                    if self.cfg.preprocess.use_frame_pitch:
+                        eval_audio = self._inference(
+                            self.valid_dataloader.dataset.eval_mels[i],
+                            eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i],
+                            use_pitch=True,
+                        )
+                    else:
+                        eval_audio = self._inference(
+                            self.valid_dataloader.dataset.eval_mels[i]
+                        )
+                    path = os.path.join(
+                        self.checkpoint_dir,
+                        "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format(
+                            self.epoch,
+                            self.step,
+                            valid_total_loss,
+                            self.valid_dataloader.dataset.eval_dataset_names[i],
+                        ),
+                    )
+                    path_gt = os.path.join(
+                        self.checkpoint_dir,
+                        "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format(
+                            self.epoch,
+                            self.step,
+                            valid_total_loss,
+                            self.valid_dataloader.dataset.eval_dataset_names[i],
+                        ),
+                    )
+                    save_audio(path, eval_audio, self.cfg.preprocess.sample_rate)
+                    save_audio(
+                        path_gt,
+                        self.valid_dataloader.dataset.eval_audios[i],
+                        self.cfg.preprocess.sample_rate,
+                    )
+
+            self.accelerator.wait_for_everyone()
+
+            self.epoch += 1
+
+        # Finish training
+        self.accelerator.wait_for_everyone()
+        path = os.path.join(
+            self.checkpoint_dir,
+            "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
+                self.epoch, self.step, valid_total_loss
+            ),
+        )
+        self.accelerator.save_state(path)
+
+    def _train_epoch(self):
+        """Training epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.generator.train()
+        for key, _ in self.discriminators.items():
+            self.discriminators[key].train()
+
+        epoch_losses: dict = {}
+        epoch_total_loss: int = 0
+
+        for batch in tqdm(
+            self.train_dataloader,
+            desc=f"Training Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            # Get losses
+            total_loss, losses = self._train_step(batch)
+            self.batch_count += 1
+
+            # Log info
+            if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
+                self.accelerator.log(
+                    {
+                        "Step/Generator Learning Rate": self.generator_optimizer.param_groups[
+                            0
+                        ][
+                            "lr"
+                        ],
+                        "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[
+                            0
+                        ][
+                            "lr"
+                        ],
+                    },
+                    step=self.step,
+                )
+                for key, _ in losses.items():
+                    self.accelerator.log(
+                        {
+                            "Step/Train {} Loss".format(key): losses[key],
+                        },
+                        step=self.step,
+                    )
+
+                if not epoch_losses:
+                    epoch_losses = losses
+                else:
+                    for key, value in losses.items():
+                        epoch_losses[key] += value
+                epoch_total_loss += total_loss
+                self.step += 1
+
+        # Get and log total losses
+        self.accelerator.wait_for_everyone()
+        epoch_total_loss = (
+            epoch_total_loss
+            / len(self.train_dataloader)
+            * self.cfg.train.gradient_accumulation_step
+        )
+        for key in epoch_losses.keys():
+            epoch_losses[key] = (
+                epoch_losses[key]
+                / len(self.train_dataloader)
+                * self.cfg.train.gradient_accumulation_step
+            )
+        return epoch_total_loss, epoch_losses
+
+    def _train_step(self, data):
+        """Training forward step. Should return average loss of a sample over
+        one batch. Provoke ``_forward_step`` is recommended except for special case.
+        See ``_train_epoch`` for usage.
+        """
+        # Init losses
+        train_losses = {}
+        total_loss = 0
+
+        generator_losses = {}
+        generator_total_loss = 0
+        discriminator_losses = {}
+        discriminator_total_loss = 0
+
+        # Use input feature to get predictions
+        mel_input = data["mel"]
+        audio_gt = data["audio"]
+
+        if self.cfg.preprocess.extract_amplitude_phase:
+            logamp_gt = data["logamp"]
+            pha_gt = data["pha"]
+            rea_gt = data["rea"]
+            imag_gt = data["imag"]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            pitch_input = data["frame_pitch"]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            pitch_input = pitch_input.float()
+            audio_pred = self.generator.forward(mel_input, pitch_input)
+        elif self.cfg.preprocess.extract_amplitude_phase:
+            (
+                logamp_pred,
+                pha_pred,
+                rea_pred,
+                imag_pred,
+                audio_pred,
+            ) = self.generator.forward(mel_input)
+            from utils.mel import amplitude_phase_spectrum
+
+            _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
+                audio_pred.squeeze(1), self.cfg.preprocess
+            )
+        else:
+            audio_pred = self.generator.forward(mel_input)
+
+        # Calculate and BP Discriminator losses
+        self.discriminator_optimizer.zero_grad()
+        for key, _ in self.discriminators.items():
+            y_r, y_g, _, _ = self.discriminators[key].forward(
+                audio_gt.unsqueeze(1), audio_pred.detach()
+            )
+            (
+                discriminator_losses["{}_discriminator".format(key)],
+                _,
+                _,
+            ) = self.criterions["discriminator"](y_r, y_g)
+            discriminator_total_loss += discriminator_losses[
+                "{}_discriminator".format(key)
+            ]
+
+        self.accelerator.backward(discriminator_total_loss)
+        self.discriminator_optimizer.step()
+
+        # Calculate and BP Generator losses
+        self.generator_optimizer.zero_grad()
+        for key, _ in self.discriminators.items():
+            y_r, y_g, f_r, f_g = self.discriminators[key].forward(
+                audio_gt.unsqueeze(1), audio_pred
+            )
+            generator_losses["{}_feature".format(key)] = self.criterions["feature"](
+                f_r, f_g
+            )
+            generator_losses["{}_generator".format(key)], _ = self.criterions[
+                "generator"
+            ](y_g)
+            generator_total_loss += generator_losses["{}_feature".format(key)]
+            generator_total_loss += generator_losses["{}_generator".format(key)]
+
+        if "mel" in self.criterions.keys():
+            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["mel"]
+
+        if "wav" in self.criterions.keys():
+            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["wav"]
+
+        if "amplitude" in self.criterions.keys():
+            generator_losses["amplitude"] = self.criterions["amplitude"](
+                logamp_gt, logamp_pred
+            )
+            generator_total_loss += generator_losses["amplitude"]
+
+        if "phase" in self.criterions.keys():
+            generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
+            generator_total_loss += generator_losses["phase"]
+
+        if "consistency" in self.criterions.keys():
+            generator_losses["consistency"] = self.criterions["consistency"](
+                rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final
+            )
+            generator_total_loss += generator_losses["consistency"]
+
+        self.accelerator.backward(generator_total_loss)
+        self.generator_optimizer.step()
+
+        # Get the total losses
+        total_loss = discriminator_total_loss + generator_total_loss
+        train_losses.update(discriminator_losses)
+        train_losses.update(generator_losses)
+
+        for key, _ in train_losses.items():
+            train_losses[key] = train_losses[key].item()
+
+        return total_loss.item(), train_losses
+
+    def _valid_epoch(self):
+        """Testing epoch. Should return average loss of a batch (sample) over
+        one epoch. See ``train_loop`` for usage.
+        """
+        self.generator.eval()
+        for key, _ in self.discriminators.items():
+            self.discriminators[key].eval()
+
+        epoch_losses: dict = {}
+        epoch_total_loss: int = 0
+
+        for batch in tqdm(
+            self.valid_dataloader,
+            desc=f"Validating Epoch {self.epoch}",
+            unit="batch",
+            colour="GREEN",
+            leave=False,
+            dynamic_ncols=True,
+            smoothing=0.04,
+            disable=not self.accelerator.is_main_process,
+        ):
+            # Get losses
+            total_loss, losses = self._valid_step(batch)
+
+            # Log info
+            for key, _ in losses.items():
+                self.accelerator.log(
+                    {
+                        "Step/Valid {} Loss".format(key): losses[key],
+                    },
+                    step=self.step,
+                )
+
+            if not epoch_losses:
+                epoch_losses = losses
+            else:
+                for key, value in losses.items():
+                    epoch_losses[key] += value
+            epoch_total_loss += total_loss
+
+        # Get and log total losses
+        self.accelerator.wait_for_everyone()
+        epoch_total_loss = epoch_total_loss / len(self.valid_dataloader)
+        for key in epoch_losses.keys():
+            epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader)
+        return epoch_total_loss, epoch_losses
+
+    def _valid_step(self, data):
+        """Testing forward step. Should return average loss of a sample over
+        one batch. Provoke ``_forward_step`` is recommended except for special case.
+        See ``_test_epoch`` for usage.
+        """
+        # Init losses
+        valid_losses = {}
+        total_loss = 0
+
+        generator_losses = {}
+        generator_total_loss = 0
+        discriminator_losses = {}
+        discriminator_total_loss = 0
+
+        # Use feature inputs to get the predicted audio
+        mel_input = data["mel"]
+        audio_gt = data["audio"]
+
+        if self.cfg.preprocess.extract_amplitude_phase:
+            logamp_gt = data["logamp"]
+            pha_gt = data["pha"]
+            rea_gt = data["rea"]
+            imag_gt = data["imag"]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            pitch_input = data["frame_pitch"]
+
+        if self.cfg.preprocess.use_frame_pitch:
+            pitch_input = pitch_input.float()
+            audio_pred = self.generator.forward(mel_input, pitch_input)
+        elif self.cfg.preprocess.extract_amplitude_phase:
+            (
+                logamp_pred,
+                pha_pred,
+                rea_pred,
+                imag_pred,
+                audio_pred,
+            ) = self.generator.forward(mel_input)
+            from utils.mel import amplitude_phase_spectrum
+
+            _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum(
+                audio_pred.squeeze(1), self.cfg.preprocess
+            )
+        else:
+            audio_pred = self.generator.forward(mel_input)
+
+        # Get Discriminator losses
+        for key, _ in self.discriminators.items():
+            y_r, y_g, _, _ = self.discriminators[key].forward(
+                audio_gt.unsqueeze(1), audio_pred
+            )
+            (
+                discriminator_losses["{}_discriminator".format(key)],
+                _,
+                _,
+            ) = self.criterions["discriminator"](y_r, y_g)
+            discriminator_total_loss += discriminator_losses[
+                "{}_discriminator".format(key)
+            ]
+
+        for key, _ in self.discriminators.items():
+            y_r, y_g, f_r, f_g = self.discriminators[key].forward(
+                audio_gt.unsqueeze(1), audio_pred
+            )
+            generator_losses["{}_feature".format(key)] = self.criterions["feature"](
+                f_r, f_g
+            )
+            generator_losses["{}_generator".format(key)], _ = self.criterions[
+                "generator"
+            ](y_g)
+            generator_total_loss += generator_losses["{}_feature".format(key)]
+            generator_total_loss += generator_losses["{}_generator".format(key)]
+
+        if "mel" in self.criterions.keys():
+            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["mel"]
+        if "mel" in self.criterions.keys():
+            generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["mel"]
+
+        if "wav" in self.criterions.keys():
+            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["wav"]
+        if "wav" in self.criterions.keys():
+            generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred)
+            generator_total_loss += generator_losses["wav"]
+
+        if "amplitude" in self.criterions.keys():
+            generator_losses["amplitude"] = self.criterions["amplitude"](
+                logamp_gt, logamp_pred
+            )
+            generator_total_loss += generator_losses["amplitude"]
+
+        if "phase" in self.criterions.keys():
+            generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred)
+            generator_total_loss += generator_losses["phase"]
+
+        if "consistency" in self.criterions.keys():
+            generator_losses["consistency"] = self.criterions["consistency"](
+                rea_gt,
+                rea_pred,
+                rea_pred_final,
+                imag_gt,
+                imag_pred,
+                imag_pred_final,
+            )
+            generator_total_loss += generator_losses["consistency"]
+
+        total_loss = discriminator_total_loss + generator_total_loss
+        valid_losses.update(discriminator_losses)
+        valid_losses.update(generator_losses)
+
+        for item in valid_losses:
+            valid_losses[item] = valid_losses[item].item()
+        for item in valid_losses:
+            valid_losses[item] = valid_losses[item].item()
+
+        return total_loss.item(), valid_losses
+        return total_loss.item(), valid_losses
+
+    def _inference(self, eval_mel, eval_pitch=None, use_pitch=False):
+        """Inference during training for test audios."""
+        if use_pitch:
+            eval_pitch = align_length(eval_pitch, eval_mel.shape[1])
+            eval_audio = vocoder_inference(
+                self.cfg,
+                self.generator,
+                torch.from_numpy(eval_mel).unsqueeze(0),
+                f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(),
+                device=next(self.generator.parameters()).device,
+            ).squeeze(0)
+        else:
+            eval_audio = vocoder_inference(
+                self.cfg,
+                self.generator,
+                torch.from_numpy(eval_mel).unsqueeze(0),
+                device=next(self.generator.parameters()).device,
+            ).squeeze(0)
+        return eval_audio
+
+    def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+        """Load model from checkpoint. If checkpoint_path is None, it will
+        load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
+        None, it will load the checkpoint specified by checkpoint_path. **Only use this
+        method after** ``accelerator.prepare()``.
+        """
+        if checkpoint_path is None:
+            ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+            ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+            checkpoint_path = ls[0]
+        if resume_type == "resume":
+            self.accelerator.load_state(checkpoint_path)
+            self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+            self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+        elif resume_type == "finetune":
+            accelerate.load_checkpoint_and_dispatch(
+                self.accelerator.unwrap_model(self.generator),
+                os.path.join(checkpoint_path, "pytorch_model.bin"),
+            )
+            for key, _ in self.discriminators.items():
+                accelerate.load_checkpoint_and_dispatch(
+                    self.accelerator.unwrap_model(self.discriminators[key]),
+                    os.path.join(checkpoint_path, "pytorch_model.bin"),
+                )
+            self.logger.info("Load model weights for finetune SUCCESS!")
+        else:
+            raise ValueError("Unsupported resume type: {}".format(resume_type))
+        return checkpoint_path
+
+    def _count_parameters(self):
+        result = sum(p.numel() for p in self.generator.parameters())
+        for _, discriminator in self.discriminators.items():
+            result += sum(p.numel() for p in discriminator.parameters())
+        return result
diff --git a/models/vocoders/gan/generator/apnet.py b/models/vocoders/gan/generator/apnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f39dd6b01f6be5a6bdd2e04ca822a4a9c9b4c9b4
--- /dev/null
+++ b/models/vocoders/gan/generator/apnet.py
@@ -0,0 +1,395 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from modules.vocoder_blocks import *
+
+LRELU_SLOPE = 0.1
+
+
+class ISTFT(nn.Module):
+    """
+    Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
+    windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
+    See issue: https://github.com/pytorch/pytorch/issues/62323
+    Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
+    The NOLA constraint is met as we trim padded samples anyway.
+
+    Args:
+        n_fft (int): Size of Fourier transform.
+        hop_length (int): The distance between neighboring sliding window frames.
+        win_length (int): The size of window frame and STFT filter.
+        padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
+    """
+
+    def __init__(
+        self,
+        n_fft: int,
+        hop_length: int,
+        win_length: int,
+        padding: str = "same",
+    ):
+        super().__init__()
+        if padding not in ["center", "same"]:
+            raise ValueError("Padding must be 'center' or 'same'.")
+        self.padding = padding
+        self.n_fft = n_fft
+        self.hop_length = hop_length
+        self.win_length = win_length
+
+    def forward(self, spec: torch.Tensor, window) -> torch.Tensor:
+        """
+        Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
+
+        Args:
+            spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
+                            N is the number of frequency bins, and T is the number of time frames.
+
+        Returns:
+            Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
+        """
+        if self.padding == "center":
+            # Fallback to pytorch native implementation
+            return torch.istft(
+                spec,
+                self.n_fft,
+                self.hop_length,
+                self.win_length,
+                window,
+                center=True,
+            )
+        elif self.padding == "same":
+            pad = (self.win_length - self.hop_length) // 2
+        else:
+            raise ValueError("Padding must be 'center' or 'same'.")
+
+        assert spec.dim() == 3, "Expected a 3D tensor as input"
+        B, N, T = spec.shape
+
+        # Inverse FFT
+        ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
+        ifft = ifft * window[None, :, None]
+
+        # Overlap and Add
+        output_size = (T - 1) * self.hop_length + self.win_length
+        y = torch.nn.functional.fold(
+            ifft,
+            output_size=(1, output_size),
+            kernel_size=(1, self.win_length),
+            stride=(1, self.hop_length),
+        )[:, 0, 0, pad:-pad]
+
+        # Window envelope
+        window_sq = window.square().expand(1, T, -1).transpose(1, 2)
+        window_envelope = torch.nn.functional.fold(
+            window_sq,
+            output_size=(1, output_size),
+            kernel_size=(1, self.win_length),
+            stride=(1, self.hop_length),
+        ).squeeze()[pad:-pad]
+
+        # Normalize
+        assert (window_envelope > 1e-11).all()
+        y = y / window_envelope
+
+        return y
+    
+# The ASP and PSP Module are adopted from APNet under the MIT License
+# https://github.com/YangAi520/APNet/blob/main/models.py
+
+class ASPResBlock(torch.nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ASPResBlock, self).__init__()
+        self.cfg = cfg
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+
+class PSPResBlock(torch.nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(PSPResBlock, self).__init__()
+        self.cfg = cfg
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+
+class APNet(torch.nn.Module):
+    def __init__(self, cfg):
+        super(APNet, self).__init__()
+        self.cfg = cfg
+        self.ASP_num_kernels = len(cfg.model.apnet.ASP_resblock_kernel_sizes)
+        self.PSP_num_kernels = len(cfg.model.apnet.PSP_resblock_kernel_sizes)
+
+        self.ASP_input_conv = weight_norm(
+            Conv1d(
+                cfg.preprocess.n_mel,
+                cfg.model.apnet.ASP_channel,
+                cfg.model.apnet.ASP_input_conv_kernel_size,
+                1,
+                padding=get_padding(cfg.model.apnet.ASP_input_conv_kernel_size, 1),
+            )
+        )
+        self.PSP_input_conv = weight_norm(
+            Conv1d(
+                cfg.preprocess.n_mel,
+                cfg.model.apnet.PSP_channel,
+                cfg.model.apnet.PSP_input_conv_kernel_size,
+                1,
+                padding=get_padding(cfg.model.apnet.PSP_input_conv_kernel_size, 1),
+            )
+        )
+
+        self.ASP_ResNet = nn.ModuleList()
+        for j, (k, d) in enumerate(
+            zip(
+                cfg.model.apnet.ASP_resblock_kernel_sizes,
+                cfg.model.apnet.ASP_resblock_dilation_sizes,
+            )
+        ):
+            self.ASP_ResNet.append(ASPResBlock(cfg, cfg.model.apnet.ASP_channel, k, d))
+
+        self.PSP_ResNet = nn.ModuleList()
+        for j, (k, d) in enumerate(
+            zip(
+                cfg.model.apnet.PSP_resblock_kernel_sizes,
+                cfg.model.apnet.PSP_resblock_dilation_sizes,
+            )
+        ):
+            self.PSP_ResNet.append(PSPResBlock(cfg, cfg.model.apnet.PSP_channel, k, d))
+
+        self.ASP_output_conv = weight_norm(
+            Conv1d(
+                cfg.model.apnet.ASP_channel,
+                cfg.preprocess.n_fft // 2 + 1,
+                cfg.model.apnet.ASP_output_conv_kernel_size,
+                1,
+                padding=get_padding(cfg.model.apnet.ASP_output_conv_kernel_size, 1),
+            )
+        )
+        self.PSP_output_R_conv = weight_norm(
+            Conv1d(
+                cfg.model.apnet.PSP_channel,
+                cfg.preprocess.n_fft // 2 + 1,
+                cfg.model.apnet.PSP_output_R_conv_kernel_size,
+                1,
+                padding=get_padding(cfg.model.apnet.PSP_output_R_conv_kernel_size, 1),
+            )
+        )
+        self.PSP_output_I_conv = weight_norm(
+            Conv1d(
+                cfg.model.apnet.PSP_channel,
+                cfg.preprocess.n_fft // 2 + 1,
+                cfg.model.apnet.PSP_output_I_conv_kernel_size,
+                1,
+                padding=get_padding(cfg.model.apnet.PSP_output_I_conv_kernel_size, 1),
+            )
+        )
+
+        self.iSTFT = ISTFT(
+            self.cfg.preprocess.n_fft,
+            hop_length=self.cfg.preprocess.hop_size,
+            win_length=self.cfg.preprocess.win_size,
+        )
+
+        self.ASP_output_conv.apply(init_weights)
+        self.PSP_output_R_conv.apply(init_weights)
+        self.PSP_output_I_conv.apply(init_weights)
+
+    def forward(self, mel):
+        logamp = self.ASP_input_conv(mel)
+        logamps = None
+        for j in range(self.ASP_num_kernels):
+            if logamps is None:
+                logamps = self.ASP_ResNet[j](logamp)
+            else:
+                logamps += self.ASP_ResNet[j](logamp)
+        logamp = logamps / self.ASP_num_kernels
+        logamp = F.leaky_relu(logamp)
+        logamp = self.ASP_output_conv(logamp)
+
+        pha = self.PSP_input_conv(mel)
+        phas = None
+        for j in range(self.PSP_num_kernels):
+            if phas is None:
+                phas = self.PSP_ResNet[j](pha)
+            else:
+                phas += self.PSP_ResNet[j](pha)
+        pha = phas / self.PSP_num_kernels
+        pha = F.leaky_relu(pha)
+        R = self.PSP_output_R_conv(pha)
+        I = self.PSP_output_I_conv(pha)
+
+        pha = torch.atan2(I, R)
+
+        rea = torch.exp(logamp) * torch.cos(pha)
+        imag = torch.exp(logamp) * torch.sin(pha)
+
+        spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1)
+
+        spec = torch.view_as_complex(spec)
+
+        audio = self.iSTFT.forward(
+            spec, torch.hann_window(self.cfg.preprocess.win_size).to(mel.device)
+        )
+
+        return logamp, pha, rea, imag, audio.unsqueeze(1)
diff --git a/models/vocoders/gan/generator/bigvgan.py b/models/vocoders/gan/generator/bigvgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..205f8d697b9a49aa4ac3f4c46015fef861d4fe1b
--- /dev/null
+++ b/models/vocoders/gan/generator/bigvgan.py
@@ -0,0 +1,344 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import torch.nn as nn
+
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+
+from modules.vocoder_blocks import *
+from modules.activation_functions import *
+from modules.anti_aliasing import *
+
+LRELU_SLOPE = 0.1
+
+# The AMPBlock Module is adopted from BigVGAN under the MIT License
+# https://github.com/NVIDIA/BigVGAN
+
+class AMPBlock1(torch.nn.Module):
+    def __init__(
+        self, cfg, channels, kernel_size=3, dilation=(1, 3, 5), activation=None
+    ):
+        super(AMPBlock1, self).__init__()
+        self.cfg = cfg
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+        self.num_layers = len(self.convs1) + len(
+            self.convs2
+        )  # total number of conv layers
+
+        if (
+            activation == "snake"
+        ):  # periodic nonlinearity with snake function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=Snake(
+                            channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+                        )
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        elif (
+            activation == "snakebeta"
+        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=SnakeBeta(
+                            channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+                        )
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+    def forward(self, x):
+        acts1, acts2 = self.activations[::2], self.activations[1::2]
+        for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
+            xt = a1(x)
+            xt = c1(xt)
+            xt = a2(xt)
+            xt = c2(xt)
+            x = xt + x
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class AMPBlock2(torch.nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3), activation=None):
+        super(AMPBlock2, self).__init__()
+        self.cfg = cfg
+
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+        self.num_layers = len(self.convs)  # total number of conv layers
+
+        if (
+            activation == "snake"
+        ):  # periodic nonlinearity with snake function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=Snake(
+                            channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+                        )
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        elif (
+            activation == "snakebeta"
+        ):  # periodic nonlinearity with snakebeta function and anti-aliasing
+            self.activations = nn.ModuleList(
+                [
+                    Activation1d(
+                        activation=SnakeBeta(
+                            channels, alpha_logscale=cfg.model.bigvgan.snake_logscale
+                        )
+                    )
+                    for _ in range(self.num_layers)
+                ]
+            )
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+    def forward(self, x):
+        for c, a in zip(self.convs, self.activations):
+            xt = a(x)
+            xt = c(xt)
+            x = xt + x
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class BigVGAN(torch.nn.Module):
+    def __init__(self, cfg):
+        super(BigVGAN, self).__init__()
+        self.cfg = cfg
+
+        self.num_kernels = len(cfg.model.bigvgan.resblock_kernel_sizes)
+        self.num_upsamples = len(cfg.model.bigvgan.upsample_rates)
+
+        # Conv pre to boost channels
+        self.conv_pre = weight_norm(
+            Conv1d(
+                cfg.preprocess.n_mel,
+                cfg.model.bigvgan.upsample_initial_channel,
+                7,
+                1,
+                padding=3,
+            )
+        )
+
+        resblock = AMPBlock1 if cfg.model.bigvgan.resblock == "1" else AMPBlock2
+
+        # Upsamplers
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(
+            zip(
+                cfg.model.bigvgan.upsample_rates,
+                cfg.model.bigvgan.upsample_kernel_sizes,
+            )
+        ):
+            self.ups.append(
+                nn.ModuleList(
+                    [
+                        weight_norm(
+                            ConvTranspose1d(
+                                cfg.model.bigvgan.upsample_initial_channel // (2**i),
+                                cfg.model.bigvgan.upsample_initial_channel
+                                // (2 ** (i + 1)),
+                                k,
+                                u,
+                                padding=(k - u) // 2,
+                            )
+                        )
+                    ]
+                )
+            )
+
+        # Res Blocks with AMP and Anti-aliasing
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = cfg.model.bigvgan.upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(
+                    cfg.model.bigvgan.resblock_kernel_sizes,
+                    cfg.model.bigvgan.resblock_dilation_sizes,
+                )
+            ):
+                self.resblocks.append(
+                    resblock(cfg, ch, k, d, activation=cfg.model.bigvgan.activation)
+                )
+
+        # Conv post for result
+        if (
+            cfg.model.bigvgan.activation == "snake"
+        ):  
+            activation_post = Snake(ch, alpha_logscale=cfg.model.bigvgan.snake_logscale)
+            self.activation_post = Activation1d(activation=activation_post)
+        elif (
+            cfg.model.bigvgan.activation == "snakebeta"
+        ):  
+            activation_post = SnakeBeta(
+                ch, alpha_logscale=cfg.model.bigvgan.snake_logscale
+            )
+            self.activation_post = Activation1d(activation=activation_post)
+        else:
+            raise NotImplementedError(
+                "activation incorrectly specified. check the config file and look for 'activation'."
+            )
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+        # Weight Norm
+        for i in range(len(self.ups)):
+            self.ups[i].apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        x = self.conv_pre(x)
+
+        for i in range(self.num_upsamples):
+            for i_up in range(len(self.ups[i])):
+                x = self.ups[i][i_up](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+
+        x = self.activation_post(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            for l_i in l:
+                remove_weight_norm(l_i)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
diff --git a/models/vocoders/gan/generator/hifigan.py b/models/vocoders/gan/generator/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f5f32498f5eb6441db787b0ae204a1eeff36aa3
--- /dev/null
+++ b/models/vocoders/gan/generator/hifigan.py
@@ -0,0 +1,449 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class ResBlock1(torch.nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.cfg = cfg
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2, self).__init__()
+        self.cfg = cfg
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+class HiFiGAN(torch.nn.Module):
+    def __init__(self, cfg):
+        super(HiFiGAN, self).__init__()
+        self.cfg = cfg
+        self.num_kernels = len(self.cfg.model.hifigan.resblock_kernel_sizes)
+        self.num_upsamples = len(self.cfg.model.hifigan.upsample_rates)
+        self.conv_pre = weight_norm(
+            Conv1d(
+                cfg.preprocess.n_mel,
+                self.cfg.model.hifigan.upsample_initial_channel,
+                7,
+                1,
+                padding=3,
+            )
+        )
+        resblock = ResBlock1 if self.cfg.model.hifigan.resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(
+            zip(
+                self.cfg.model.hifigan.upsample_rates,
+                self.cfg.model.hifigan.upsample_kernel_sizes,
+            )
+        ):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        self.cfg.model.hifigan.upsample_initial_channel // (2**i),
+                        self.cfg.model.hifigan.upsample_initial_channel
+                        // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = self.cfg.model.hifigan.upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(
+                    self.cfg.model.hifigan.resblock_kernel_sizes,
+                    self.cfg.model.hifigan.resblock_dilation_sizes,
+                )
+            ):
+                self.resblocks.append(resblock(self.cfg, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+
+    def forward(self, x):
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        print("Removing weight norm...")
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
+        remove_weight_norm(self.conv_pre)
+        remove_weight_norm(self.conv_post)
+
+
+# todo: merge with ResBlock1 (lmxue, yicheng)
+class ResBlock1_vits(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1_vits, self).__init__()
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c2(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+# todo: merge with ResBlock2 (lmxue, yicheng)
+class ResBlock2_vits(torch.nn.Module):
+    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock2_vits, self).__init__()
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x, x_mask=None):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            if x_mask is not None:
+                xt = xt * x_mask
+            xt = c(xt)
+            x = xt + x
+        if x_mask is not None:
+            x = x * x_mask
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+
+# todo: merge with HiFiGAN (lmxue, yicheng)
+class HiFiGAN_vits(torch.nn.Module):
+    def __init__(
+        self,
+        initial_channel,
+        resblock,
+        resblock_kernel_sizes,
+        resblock_dilation_sizes,
+        upsample_rates,
+        upsample_initial_channel,
+        upsample_kernel_sizes,
+        gin_channels=0,
+    ):
+        super(HiFiGAN_vits, self).__init__()
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        self.conv_pre = Conv1d(
+            initial_channel, upsample_initial_channel, 7, 1, padding=3
+        )
+        resblock = ResBlock1_vits if resblock == "1" else ResBlock2_vits
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+
+        self.resblocks = nn.ModuleList()
+        for i in range(len(self.ups)):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for j, (k, d) in enumerate(
+                zip(resblock_kernel_sizes, resblock_dilation_sizes)
+            ):
+                self.resblocks.append(resblock(ch, k, d))
+
+        self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
+        self.ups.apply(init_weights)
+
+        if gin_channels != 0:
+            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+    def forward(self, x, g=None):
+        x = self.conv_pre(x)
+        if g is not None:
+            x = x + self.cond(g)
+
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.ups:
+            remove_weight_norm(l)
+        for l in self.resblocks:
+            l.remove_weight_norm()
diff --git a/models/vocoders/gan/generator/melgan.py b/models/vocoders/gan/generator/melgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca90d684ca1f5a0a813db4192540adac0cee2558
--- /dev/null
+++ b/models/vocoders/gan/generator/melgan.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from torch.nn.utils import weight_norm
+
+# This code is adopted from MelGAN under the MIT License
+# https://github.com/descriptinc/melgan-neurips
+
+def weights_init(m):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(0.0, 0.02)
+    elif classname.find("BatchNorm2d") != -1:
+        m.weight.data.normal_(1.0, 0.02)
+        m.bias.data.fill_(0)
+
+
+def WNConv1d(*args, **kwargs):
+    return weight_norm(nn.Conv1d(*args, **kwargs))
+
+
+def WNConvTranspose1d(*args, **kwargs):
+    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
+
+
+class ResnetBlock(nn.Module):
+    def __init__(self, dim, dilation=1):
+        super().__init__()
+        self.block = nn.Sequential(
+            nn.LeakyReLU(0.2),
+            nn.ReflectionPad1d(dilation),
+            WNConv1d(dim, dim, kernel_size=3, dilation=dilation),
+            nn.LeakyReLU(0.2),
+            WNConv1d(dim, dim, kernel_size=1),
+        )
+        self.shortcut = WNConv1d(dim, dim, kernel_size=1)
+
+    def forward(self, x):
+        return self.shortcut(x) + self.block(x)
+
+
+class MelGAN(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+
+        self.cfg = cfg
+
+        self.hop_length = np.prod(self.cfg.model.melgan.ratios)
+        mult = int(2 ** len(self.cfg.model.melgan.ratios))
+
+        model = [
+            nn.ReflectionPad1d(3),
+            WNConv1d(
+                self.cfg.preprocess.n_mel,
+                mult * self.cfg.model.melgan.ngf,
+                kernel_size=7,
+                padding=0,
+            ),
+        ]
+
+        # Upsample to raw audio scale
+        for i, r in enumerate(self.cfg.model.melgan.ratios):
+            model += [
+                nn.LeakyReLU(0.2),
+                WNConvTranspose1d(
+                    mult * self.cfg.model.melgan.ngf,
+                    mult * self.cfg.model.melgan.ngf // 2,
+                    kernel_size=r * 2,
+                    stride=r,
+                    padding=r // 2 + r % 2,
+                    output_padding=r % 2,
+                ),
+            ]
+
+            for j in range(self.cfg.model.melgan.n_residual_layers):
+                model += [
+                    ResnetBlock(mult * self.cfg.model.melgan.ngf // 2, dilation=3**j)
+                ]
+
+            mult //= 2
+
+        model += [
+            nn.LeakyReLU(0.2),
+            nn.ReflectionPad1d(3),
+            WNConv1d(self.cfg.model.melgan.ngf, 1, kernel_size=7, padding=0),
+            nn.Tanh(),
+        ]
+
+        self.model = nn.Sequential(*model)
+        self.apply(weights_init)
+
+    def forward(self, x):
+        return self.model(x)
diff --git a/models/vocoders/gan/generator/nsfhifigan.py b/models/vocoders/gan/generator/nsfhifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..8deb4ce7a9348f4a03bfef55fd21c634b3f25a78
--- /dev/null
+++ b/models/vocoders/gan/generator/nsfhifigan.py
@@ -0,0 +1,281 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch.nn import Conv1d, ConvTranspose1d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from modules.neural_source_filter import *
+from modules.vocoder_blocks import *
+
+
+LRELU_SLOPE = 0.1
+
+
+class ResBlock1(nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)):
+        super(ResBlock1, self).__init__()
+        self.cfg = cfg
+
+        self.convs1 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[2],
+                        padding=get_padding(kernel_size, dilation[2]),
+                    )
+                ),
+            ]
+        )
+        self.convs1.apply(init_weights)
+
+        self.convs2 = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=1,
+                        padding=get_padding(kernel_size, 1),
+                    )
+                ),
+            ]
+        )
+        self.convs2.apply(init_weights)
+
+    def forward(self, x):
+        for c1, c2 in zip(self.convs1, self.convs2):
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c1(xt)
+            xt = F.leaky_relu(xt, LRELU_SLOPE)
+            xt = c2(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs1:
+            remove_weight_norm(l)
+        for l in self.convs2:
+            remove_weight_norm(l)
+
+
+class ResBlock2(nn.Module):
+    def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)):
+        super(ResBlock1, self).__init__()
+        self.cfg = cfg
+
+        self.convs = nn.ModuleList(
+            [
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[0],
+                        padding=get_padding(kernel_size, dilation[0]),
+                    )
+                ),
+                weight_norm(
+                    Conv1d(
+                        channels,
+                        channels,
+                        kernel_size,
+                        1,
+                        dilation=dilation[1],
+                        padding=get_padding(kernel_size, dilation[1]),
+                    )
+                ),
+            ]
+        )
+        self.convs.apply(init_weights)
+
+    def forward(self, x):
+        for c in self.convs:
+            xt = F.leaky_relu(x, LRELU_SLOPE)
+            xt = c(xt)
+            x = xt + x
+        return x
+
+    def remove_weight_norm(self):
+        for l in self.convs:
+            remove_weight_norm(l)
+
+# This NSF Module is adopted from Xin Wang's NSF under the MIT License
+# https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts
+
+class SourceModuleHnNSF(nn.Module):
+    def __init__(
+        self, fs, harmonic_num=0, amp=0.1, noise_std=0.003, voiced_threshold=0
+    ):
+        super(SourceModuleHnNSF, self).__init__()
+
+        self.amp = amp
+        self.noise_std = noise_std
+        self.l_sin_gen = SineGen(fs, harmonic_num, amp, noise_std, voiced_threshold)
+
+        self.l_linear = nn.Linear(harmonic_num + 1, 1)
+        self.l_tanh = nn.Tanh()
+
+    def forward(self, x, upp):
+        sine_wavs, uv, _ = self.l_sin_gen(x, upp)
+        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+        return sine_merge
+
+
+class NSFHiFiGAN(nn.Module):
+    def __init__(self, cfg):
+        super(NSFHiFiGAN, self).__init__()
+
+        self.cfg = cfg
+        self.num_kernels = len(self.cfg.model.nsfhifigan.resblock_kernel_sizes)
+        self.num_upsamples = len(self.cfg.model.nsfhifigan.upsample_rates)
+        self.m_source = SourceModuleHnNSF(
+            fs=self.cfg.preprocess.sample_rate,
+            harmonic_num=self.cfg.model.nsfhifigan.harmonic_num,
+        )
+        self.noise_convs = nn.ModuleList()
+        self.conv_pre = weight_norm(
+            Conv1d(
+                self.cfg.preprocess.n_mel,
+                self.cfg.model.nsfhifigan.upsample_initial_channel,
+                7,
+                1,
+                padding=3,
+            )
+        )
+
+        resblock = ResBlock1 if self.cfg.model.nsfhifigan.resblock == "1" else ResBlock2
+
+        self.ups = nn.ModuleList()
+        for i, (u, k) in enumerate(
+            zip(
+                self.cfg.model.nsfhifigan.upsample_rates,
+                self.cfg.model.nsfhifigan.upsample_kernel_sizes,
+            )
+        ):
+            c_cur = self.cfg.model.nsfhifigan.upsample_initial_channel // (2 ** (i + 1))
+            self.ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        self.cfg.model.nsfhifigan.upsample_initial_channel // (2**i),
+                        self.cfg.model.nsfhifigan.upsample_initial_channel
+                        // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2,
+                    )
+                )
+            )
+            if i + 1 < len(self.cfg.model.nsfhifigan.upsample_rates):
+                stride_f0 = int(
+                    np.prod(self.cfg.model.nsfhifigan.upsample_rates[i + 1 :])
+                )
+                self.noise_convs.append(
+                    Conv1d(
+                        1,
+                        c_cur,
+                        kernel_size=stride_f0 * 2,
+                        stride=stride_f0,
+                        padding=stride_f0 // 2,
+                    )
+                )
+            else:
+                self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+        self.resblocks = nn.ModuleList()
+        ch = self.cfg.model.nsfhifigan.upsample_initial_channel
+        for i in range(len(self.ups)):
+            ch //= 2
+            for j, (k, d) in enumerate(
+                zip(
+                    self.cfg.model.nsfhifigan.resblock_kernel_sizes,
+                    self.cfg.model.nsfhifigan.resblock_dilation_sizes,
+                )
+            ):
+                self.resblocks.append(resblock(cfg, ch, k, d))
+
+        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+
+        self.ups.apply(init_weights)
+        self.conv_post.apply(init_weights)
+        self.upp = int(np.prod(self.cfg.model.nsfhifigan.upsample_rates))
+
+    def forward(self, x, f0):
+        har_source = self.m_source(f0, self.upp).transpose(1, 2)
+        x = self.conv_pre(x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.ups[i](x)
+            x_source = self.noise_convs[i](har_source)
+
+            length = min(x.shape[-1], x_source.shape[-1])
+            x = x[:, :, :length]
+            x_source = x[:, :, :length]
+
+            x = x + x_source
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.resblocks[i * self.num_kernels + j](x)
+                else:
+                    xs += self.resblocks[i * self.num_kernels + j](x)
+            x = xs / self.num_kernels
+        x = F.leaky_relu(x)
+        x = self.conv_post(x)
+        x = torch.tanh(x)
+
+        return x
diff --git a/models/vocoders/gan/generator/sifigan.py b/models/vocoders/gan/generator/sifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/vocoders/vocoder_dataset.py b/models/vocoders/vocoder_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df17b97ba7a4f770f01971324126eca4a2db272
--- /dev/null
+++ b/models/vocoders/vocoder_dataset.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Iterable
+import torch
+import numpy as np
+import torch.utils.data
+from torch.nn.utils.rnn import pad_sequence
+from utils.data_utils import *
+from torch.utils.data import ConcatDataset, Dataset
+
+
+class VocoderDataset(torch.utils.data.Dataset):
+    def __init__(self, cfg, dataset, is_valid=False):
+        """
+        Args:
+            cfg: config
+            dataset: dataset name
+            is_valid: whether to use train or valid dataset
+        """
+        assert isinstance(dataset, str)
+
+        processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
+
+        meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
+        self.metafile_path = os.path.join(processed_data_dir, meta_file)
+        self.metadata = self.get_metadata()
+
+        self.data_root = processed_data_dir
+        self.cfg = cfg
+
+        if cfg.preprocess.use_audio:
+            self.utt2audio_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2audio_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.audio_dir,
+                    uid + ".npy",
+                )
+        elif cfg.preprocess.use_label:
+            self.utt2label_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2label_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.label_dir,
+                    uid + ".npy",
+                )
+        elif cfg.preprocess.use_one_hot:
+            self.utt2one_hot_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2one_hot_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.one_hot_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_mel:
+            self.utt2mel_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2mel_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.mel_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_frame_pitch:
+            self.utt2frame_pitch_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+
+                self.utt2frame_pitch_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.pitch_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_uv:
+            self.utt2uv_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+                self.utt2uv_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.uv_dir,
+                    uid + ".npy",
+                )
+
+        if cfg.preprocess.use_amplitude_phase:
+            self.utt2logamp_path = {}
+            self.utt2pha_path = {}
+            self.utt2rea_path = {}
+            self.utt2imag_path = {}
+            for utt_info in self.metadata:
+                dataset = utt_info["Dataset"]
+                uid = utt_info["Uid"]
+                utt = "{}_{}".format(dataset, uid)
+                self.utt2logamp_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.log_amplitude_dir,
+                    uid + ".npy",
+                )
+                self.utt2pha_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.phase_dir,
+                    uid + ".npy",
+                )
+                self.utt2rea_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.real_dir,
+                    uid + ".npy",
+                )
+                self.utt2imag_path[utt] = os.path.join(
+                    cfg.preprocess.processed_dir,
+                    dataset,
+                    cfg.preprocess.imaginary_dir,
+                    uid + ".npy",
+                )
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        single_feature = dict()
+
+        if self.cfg.preprocess.use_mel:
+            mel = np.load(self.utt2mel_path[utt])
+            assert mel.shape[0] == self.cfg.preprocess.n_mel  # [n_mels, T]
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = mel.shape[1]
+
+            single_feature["mel"] = mel
+
+        if self.cfg.preprocess.use_frame_pitch:
+            frame_pitch = np.load(self.utt2frame_pitch_path[utt])
+
+            if "target_len" not in single_feature.keys():
+                single_feature["target_len"] = len(frame_pitch)
+
+            aligned_frame_pitch = align_length(
+                frame_pitch, single_feature["target_len"]
+            )
+
+            single_feature["frame_pitch"] = aligned_frame_pitch
+
+        if self.cfg.preprocess.use_audio:
+            audio = np.load(self.utt2audio_path[utt])
+
+            single_feature["audio"] = audio
+
+        return single_feature
+
+    def get_metadata(self):
+        with open(self.metafile_path, "r", encoding="utf-8") as f:
+            metadata = json.load(f)
+
+        return metadata
+
+    def get_dataset_name(self):
+        return self.metadata[0]["Dataset"]
+
+    def __len__(self):
+        return len(self.metadata)
+
+
+class VocoderConcatDataset(ConcatDataset):
+    def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False):
+        """Concatenate a series of datasets with their random inference audio merged."""
+        super().__init__(datasets)
+
+        self.cfg = self.datasets[0].cfg
+
+        self.metadata = []
+
+        # Merge metadata
+        for dataset in self.datasets:
+            self.metadata += dataset.metadata
+
+        # Merge random inference features
+        if full_audio_inference:
+            self.eval_audios = []
+            self.eval_dataset_names = []
+            if self.cfg.preprocess.use_mel:
+                self.eval_mels = []
+            if self.cfg.preprocess.use_frame_pitch:
+                self.eval_pitchs = []
+            for dataset in self.datasets:
+                self.eval_audios.append(dataset.eval_audio)
+                self.eval_dataset_names.append(dataset.get_dataset_name())
+                if self.cfg.preprocess.use_mel:
+                    self.eval_mels.append(dataset.eval_mel)
+                if self.cfg.preprocess.use_frame_pitch:
+                    self.eval_pitchs.append(dataset.eval_pitch)
+
+
+class VocoderCollator(object):
+    """Zero-pads model inputs and targets based on number of frames per step"""
+
+    def __init__(self, cfg):
+        self.cfg = cfg
+
+    def __call__(self, batch):
+        packed_batch_features = dict()
+
+        # mel: [b, n_mels, frame]
+        # frame_pitch: [b, frame]
+        # audios: [b, frame * hop_size]
+
+        for key in batch[0].keys():
+            if key == "target_len":
+                packed_batch_features["target_len"] = torch.LongTensor(
+                    [b["target_len"] for b in batch]
+                )
+                masks = [
+                    torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
+                ]
+                packed_batch_features["mask"] = pad_sequence(
+                    masks, batch_first=True, padding_value=0
+                )
+            elif key == "mel":
+                values = [torch.from_numpy(b[key]).T for b in batch]
+                packed_batch_features[key] = pad_sequence(
+                    values, batch_first=True, padding_value=0
+                )
+            else:
+                values = [torch.from_numpy(b[key]) for b in batch]
+                packed_batch_features[key] = pad_sequence(
+                    values, batch_first=True, padding_value=0
+                )
+
+        return packed_batch_features
diff --git a/models/vocoders/vocoder_inference.py b/models/vocoders/vocoder_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfd09ee6aa44c544c51a62c0e014cca4260cc6a8
--- /dev/null
+++ b/models/vocoders/vocoder_inference.py
@@ -0,0 +1,488 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import json
+import json5
+import time
+import accelerate
+import random
+import numpy as np
+import shutil
+
+from pathlib import Path
+from tqdm import tqdm
+from glob import glob
+from accelerate.logging import get_logger
+from torch.utils.data import DataLoader
+
+from models.vocoders.vocoder_dataset import (
+    VocoderDataset,
+    VocoderCollator,
+    VocoderConcatDataset,
+)
+
+from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet
+from models.vocoders.flow.waveglow import waveglow
+from models.vocoders.diffusion.diffwave import diffwave
+from models.vocoders.autoregressive.wavenet import wavenet
+from models.vocoders.autoregressive.wavernn import wavernn
+from models.vocoders.gan import gan_vocoder_inference
+from utils.io import save_audio
+
+_vocoders = {
+    "diffwave": diffwave.DiffWave,
+    "wavernn": wavernn.WaveRNN,
+    "wavenet": wavenet.WaveNet,
+    "waveglow": waveglow.WaveGlow,
+    "nsfhifigan": nsfhifigan.NSFHiFiGAN,
+    "bigvgan": bigvgan.BigVGAN,
+    "hifigan": hifigan.HiFiGAN,
+    "melgan": melgan.MelGAN,
+    "apnet": apnet.APNet,
+}
+
+_vocoder_infer_funcs = {
+    # "world": world_inference.synthesis_audios,
+    # "wavernn": wavernn_inference.synthesis_audios,
+    # "wavenet": wavenet_inference.synthesis_audios,
+    # "diffwave": diffwave_inference.synthesis_audios,
+    "nsfhifigan": gan_vocoder_inference.synthesis_audios,
+    "bigvgan": gan_vocoder_inference.synthesis_audios,
+    "melgan": gan_vocoder_inference.synthesis_audios,
+    "hifigan": gan_vocoder_inference.synthesis_audios,
+    "apnet": gan_vocoder_inference.synthesis_audios,
+}
+
+
+class VocoderInference(object):
+    def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
+        super().__init__()
+
+        start = time.monotonic_ns()
+        self.args = args
+        self.cfg = cfg
+        self.infer_type = infer_type
+
+        # Init accelerator
+        self.accelerator = accelerate.Accelerator()
+        self.accelerator.wait_for_everyone()
+
+        # Get logger
+        with self.accelerator.main_process_first():
+            self.logger = get_logger("inference", log_level=args.log_level)
+
+        # Log some info
+        self.logger.info("=" * 56)
+        self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
+        self.logger.info("=" * 56)
+        self.logger.info("\n")
+
+        self.vocoder_dir = args.vocoder_dir
+        self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
+
+        os.makedirs(args.output_dir, exist_ok=True)
+        if os.path.exists(os.path.join(args.output_dir, "pred")):
+            shutil.rmtree(os.path.join(args.output_dir, "pred"))
+        if os.path.exists(os.path.join(args.output_dir, "gt")):
+            shutil.rmtree(os.path.join(args.output_dir, "gt"))
+        os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True)
+        os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True)
+
+        # Set random seed
+        with self.accelerator.main_process_first():
+            start = time.monotonic_ns()
+            self._set_random_seed(self.cfg.train.random_seed)
+            end = time.monotonic_ns()
+            self.logger.debug(
+                f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
+            )
+            self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
+
+        # Setup inference mode
+        if self.infer_type == "infer_from_dataset":
+            self.cfg.dataset = self.args.infer_datasets
+        elif self.infer_type == "infer_from_feature":
+            self._build_tmp_dataset_from_feature()
+            self.cfg.dataset = ["tmp"]
+        elif self.infer_type == "infer_from_audio":
+            self._build_tmp_dataset_from_audio()
+            self.cfg.dataset = ["tmp"]
+
+        # Setup data loader
+        with self.accelerator.main_process_first():
+            self.logger.info("Building dataset...")
+            start = time.monotonic_ns()
+            self.test_dataloader = self._build_dataloader()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
+
+        # Build model
+        with self.accelerator.main_process_first():
+            self.logger.info("Building model...")
+            start = time.monotonic_ns()
+            self.model = self._build_model()
+            end = time.monotonic_ns()
+            self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
+
+        # Init with accelerate
+        self.logger.info("Initializing accelerate...")
+        start = time.monotonic_ns()
+        self.accelerator = accelerate.Accelerator()
+        (self.model, self.test_dataloader) = self.accelerator.prepare(
+            self.model, self.test_dataloader
+        )
+        end = time.monotonic_ns()
+        self.accelerator.wait_for_everyone()
+        self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
+
+        with self.accelerator.main_process_first():
+            self.logger.info("Loading checkpoint...")
+            start = time.monotonic_ns()
+            if os.path.isdir(args.vocoder_dir):
+                if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")):
+                    self._load_model(os.path.join(args.vocoder_dir, "checkpoint"))
+                else:
+                    self._load_model(os.path.join(args.vocoder_dir))
+            else:
+                self._load_model(os.path.join(args.vocoder_dir))
+            end = time.monotonic_ns()
+            self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
+
+        self.model.eval()
+        self.accelerator.wait_for_everyone()
+
+    def _build_tmp_dataset_from_feature(self):
+        if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+            shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+        utts = []
+        mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy"))
+        for i, mel in enumerate(mels):
+            uid = mel.split("/")[-1].split(".")[0]
+            utt = {"Dataset": "tmp", "Uid": uid, "index": i}
+            utts.append(utt)
+
+        os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+        with open(
+            os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+        ) as f:
+            json.dump(utts, f)
+
+        meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+        with open(
+            os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+            "w",
+        ) as f:
+            json.dump(meta_info, f)
+
+        features = glob(os.path.join(self.args.feature_folder, "*"))
+        for feature in features:
+            feature_name = feature.split("/")[-1]
+            if os.path.isfile(feature):
+                continue
+            shutil.copytree(
+                os.path.join(self.args.feature_folder, feature_name),
+                os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name),
+            )
+
+    def _build_tmp_dataset_from_audio(self):
+        if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+            shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+        utts = []
+        audios = glob(os.path.join(self.args.audio_folder, "*"))
+        for i, audio in enumerate(audios):
+            uid = audio.split("/")[-1].split(".")[0]
+            utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio}
+            utts.append(utt)
+
+        os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+        with open(
+            os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w"
+        ) as f:
+            json.dump(utts, f)
+
+        meta_info = {"dataset": "tmp", "test": {"size": len(utts)}}
+
+        with open(
+            os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"),
+            "w",
+        ) as f:
+            json.dump(meta_info, f)
+
+        from processors import acoustic_extractor
+
+        acoustic_extractor.extract_utt_acoustic_features_serial(
+            utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg
+        )
+
+    def _build_test_dataset(self):
+        return VocoderDataset, VocoderCollator
+
+    def _build_model(self):
+        model = _vocoders[self.cfg.model.generator](self.cfg)
+        return model
+
+    def _build_dataloader(self):
+        """Build dataloader which merges a series of datasets."""
+        Dataset, Collator = self._build_test_dataset()
+
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=True)
+            datasets_list.append(subdataset)
+        test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False)
+        test_collate = Collator(self.cfg)
+        test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset))
+        test_dataloader = DataLoader(
+            test_dataset,
+            collate_fn=test_collate,
+            num_workers=1,
+            batch_size=test_batch_size,
+            shuffle=False,
+        )
+        self.test_batch_size = test_batch_size
+        self.test_dataset = test_dataset
+        return test_dataloader
+
+    def _load_model(self, checkpoint_dir, from_multi_gpu=False):
+        """Load model from checkpoint. If a folder is given, it will
+        load the latest checkpoint in checkpoint_dir. If a path is given
+        it will load the checkpoint specified by checkpoint_path.
+        **Only use this method after** ``accelerator.prepare()``.
+        """
+        if os.path.isdir(checkpoint_dir):
+            if "epoch" in checkpoint_dir and "step" in checkpoint_dir:
+                checkpoint_path = checkpoint_dir
+            else:
+                # Load the latest accelerator state dicts
+                ls = [
+                    str(i)
+                    for i in Path(checkpoint_dir).glob("*")
+                    if not "audio" in str(i)
+                ]
+                ls.sort(
+                    key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True
+                )
+                checkpoint_path = ls[0]
+            accelerate.load_checkpoint_and_dispatch(
+                self.accelerator.unwrap_model(self.model),
+                os.path.join(checkpoint_path, "pytorch_model.bin"),
+            )
+            return str(checkpoint_path)
+        else:
+            # Load old .pt checkpoints
+            if self.cfg.model.generator in [
+                "bigvgan",
+                "hifigan",
+                "melgan",
+                "nsfhifigan",
+            ]:
+                ckpt = torch.load(
+                    checkpoint_dir,
+                    map_location=torch.device("cuda")
+                    if torch.cuda.is_available()
+                    else torch.device("cpu"),
+                )
+                if from_multi_gpu:
+                    pretrained_generator_dict = ckpt["generator_state_dict"]
+                    generator_dict = self.model.state_dict()
+
+                    new_generator_dict = {
+                        k.split("module.")[-1]: v
+                        for k, v in pretrained_generator_dict.items()
+                        if (
+                            k.split("module.")[-1] in generator_dict
+                            and v.shape == generator_dict[k.split("module.")[-1]].shape
+                        )
+                    }
+
+                    generator_dict.update(new_generator_dict)
+
+                    self.model.load_state_dict(generator_dict)
+                else:
+                    self.model.load_state_dict(ckpt["generator_state_dict"])
+            else:
+                self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"])
+            return str(checkpoint_dir)
+
+    def inference(self):
+        """Inference via batches"""
+        for i, batch in tqdm(enumerate(self.test_dataloader)):
+            if self.cfg.preprocess.use_frame_pitch:
+                audio_pred = self.model.forward(
+                    batch["mel"].transpose(-1, -2), batch["frame_pitch"].float()
+                ).cpu()
+            elif self.cfg.preprocess.extract_amplitude_phase:
+                audio_pred = self.model.forward(batch["mel"].transpose(-1, -2))[-1]
+            else:
+                audio_pred = (
+                    self.model.forward(batch["mel"].transpose(-1, -2)).detach().cpu()
+                )
+            audio_ls = audio_pred.chunk(self.test_batch_size)
+            audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size)
+            length_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
+            j = 0
+            for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls):
+                l = l.item()
+                it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size]
+                it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size]
+                uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
+                save_audio(
+                    os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid),
+                    it,
+                    self.cfg.preprocess.sample_rate,
+                )
+                save_audio(
+                    os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid),
+                    it_gt,
+                    self.cfg.preprocess.sample_rate,
+                )
+                j += 1
+
+        if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")):
+            shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp"))
+
+    def _set_random_seed(self, seed):
+        """Set random seed for all possible random modules."""
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    def _count_parameters(self, model):
+        return sum(p.numel() for p in model.parameters())
+
+    def _dump_cfg(self, path):
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        json5.dump(
+            self.cfg,
+            open(path, "w"),
+            indent=4,
+            sort_keys=True,
+            ensure_ascii=False,
+            quote_keys=True,
+        )
+
+
+def load_nnvocoder(
+    cfg,
+    vocoder_name,
+    weights_file,
+    from_multi_gpu=False,
+):
+    """Load the specified vocoder.
+    cfg: the vocoder config filer.
+    weights_file: a folder or a .pt path.
+    from_multi_gpu: automatically remove the "module" string in state dicts if "True".
+    """
+    print("Loading Vocoder from Weights file: {}".format(weights_file))
+
+    # Build model
+    model = _vocoders[vocoder_name](cfg)
+    if not os.path.isdir(weights_file):
+        # Load from .pt file
+        if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]:
+            ckpt = torch.load(
+                weights_file,
+                map_location=torch.device("cuda")
+                if torch.cuda.is_available()
+                else torch.device("cpu"),
+            )
+            if from_multi_gpu:
+                pretrained_generator_dict = ckpt["generator_state_dict"]
+                generator_dict = model.state_dict()
+
+                new_generator_dict = {
+                    k.split("module.")[-1]: v
+                    for k, v in pretrained_generator_dict.items()
+                    if (
+                        k.split("module.")[-1] in generator_dict
+                        and v.shape == generator_dict[k.split("module.")[-1]].shape
+                    )
+                }
+
+                generator_dict.update(new_generator_dict)
+
+                model.load_state_dict(generator_dict)
+            else:
+                model.load_state_dict(ckpt["generator_state_dict"])
+        else:
+            model.load_state_dict(torch.load(weights_file)["state_dict"])
+    else:
+        # Load from accelerator state dict
+        weights_file = os.path.join(weights_file, "checkpoint")
+        ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)]
+        ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+        checkpoint_path = ls[0]
+        accelerator = accelerate.Accelerator()
+        model = accelerator.prepare(model)
+        accelerator.load_state(checkpoint_path)
+
+    if torch.cuda.is_available():
+        model = model.cuda()
+
+    model = model.eval()
+    return model
+
+
+def tensorize(data, device, n_samples):
+    """
+    data: a list of numpy array
+    """
+    assert type(data) == list
+    if n_samples:
+        data = data[:n_samples]
+    data = [torch.as_tensor(x, device=device) for x in data]
+    return data
+
+
+def synthesis(
+    cfg,
+    vocoder_weight_file,
+    n_samples,
+    pred,
+    f0s=None,
+    batch_size=64,
+    fast_inference=False,
+):
+    """Synthesis audios from a given vocoder and series of given features.
+    cfg: vocoder config.
+    vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file.
+    pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...]
+    """
+
+    vocoder_name = cfg.model.generator
+
+    print("Synthesis audios using {} vocoder...".format(vocoder_name))
+
+    ###### TODO: World Vocoder Refactor ######
+    # if vocoder_name == "world":
+    #     world_inference.synthesis_audios(
+    #         cfg, dataset_name, split, n_samples, pred, save_dir, tag
+    #     )
+    #     return
+
+    # ====== Loading neural vocoder model ======
+    vocoder = load_nnvocoder(
+        cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True
+    )
+    device = next(vocoder.parameters()).device
+
+    # ====== Inference for predicted acoustic features ======
+    # pred: (frame_len, n_mels) -> (n_mels, frame_len)
+    mels_pred = tensorize([p.T for p in pred], device, n_samples)
+    print("For predicted mels, #sample = {}...".format(len(mels_pred)))
+    audios_pred = _vocoder_infer_funcs[vocoder_name](
+        cfg,
+        vocoder,
+        mels_pred,
+        f0s=f0s,
+        batch_size=batch_size,
+        fast_inference=fast_inference,
+    )
+    return audios_pred
diff --git a/models/vocoders/vocoder_sampler.py b/models/vocoders/vocoder_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98
--- /dev/null
+++ b/models/vocoders/vocoder_sampler.py
@@ -0,0 +1,126 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+
+from torch.utils.data import ConcatDataset, Dataset
+from torch.utils.data.sampler import (
+    BatchSampler,
+    RandomSampler,
+    Sampler,
+    SequentialSampler,
+)
+
+
+class ScheduledSampler(Sampler):
+    """A sampler that samples data from a given concat-dataset.
+
+    Args:
+        concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
+        batch_size (int): batch size
+        holistic_shuffle (bool): whether to shuffle the whole dataset or not
+        logger (logging.Logger): logger to print warning message
+
+    Usage:
+        For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
+        >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
+        [3, 4, 5, 0, 1, 2, 6, 7, 8]
+    """
+
+    def __init__(
+        self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train"
+    ):
+        if not isinstance(concat_dataset, ConcatDataset):
+            raise ValueError(
+                "concat_dataset must be an instance of ConcatDataset, but got {}".format(
+                    type(concat_dataset)
+                )
+            )
+        if not isinstance(batch_size, int):
+            raise ValueError(
+                "batch_size must be an integer, but got {}".format(type(batch_size))
+            )
+        if not isinstance(holistic_shuffle, bool):
+            raise ValueError(
+                "holistic_shuffle must be a boolean, but got {}".format(
+                    type(holistic_shuffle)
+                )
+            )
+
+        self.concat_dataset = concat_dataset
+        self.batch_size = batch_size
+        self.holistic_shuffle = holistic_shuffle
+
+        affected_dataset_name = []
+        affected_dataset_len = []
+        for dataset in concat_dataset.datasets:
+            dataset_len = len(dataset)
+            dataset_name = dataset.get_dataset_name()
+            if dataset_len < batch_size:
+                affected_dataset_name.append(dataset_name)
+                affected_dataset_len.append(dataset_len)
+
+        self.type = type
+        for dataset_name, dataset_len in zip(
+            affected_dataset_name, affected_dataset_len
+        ):
+            if not type == "valid":
+                logger.warning(
+                    "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
+                        type, dataset_name, dataset_len, batch_size
+                    )
+                )
+
+    def __len__(self):
+        # the number of batches with drop last
+        num_of_batches = sum(
+            [
+                math.floor(len(dataset) / self.batch_size)
+                for dataset in self.concat_dataset.datasets
+            ]
+        )
+        return num_of_batches * self.batch_size
+
+    def __iter__(self):
+        iters = []
+        for dataset in self.concat_dataset.datasets:
+            iters.append(
+                SequentialSampler(dataset).__iter__()
+                if self.holistic_shuffle
+                else RandomSampler(dataset).__iter__()
+            )
+        init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
+        output_batches = []
+        for dataset_idx in range(len(self.concat_dataset.datasets)):
+            cur_batch = []
+            for idx in iters[dataset_idx]:
+                cur_batch.append(idx + init_indices[dataset_idx])
+                if len(cur_batch) == self.batch_size:
+                    output_batches.append(cur_batch)
+                    cur_batch = []
+                if self.type == "valid" and len(cur_batch) > 0:
+                    output_batches.append(cur_batch)
+                    cur_batch = []
+        # force drop last in training
+        random.shuffle(output_batches)
+        output_indices = [item for sublist in output_batches for item in sublist]
+        return iter(output_indices)
+
+
+def build_samplers(concat_dataset: Dataset, cfg, logger, type):
+    sampler = ScheduledSampler(
+        concat_dataset,
+        cfg.train.batch_size,
+        cfg.train.sampler.holistic_shuffle,
+        logger,
+        type,
+    )
+    batch_sampler = BatchSampler(
+        sampler,
+        cfg.train.batch_size,
+        cfg.train.sampler.drop_last if not type == "valid" else False,
+    )
+    return sampler, batch_sampler
diff --git a/models/vocoders/vocoder_trainer.py b/models/vocoders/vocoder_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5821e735a64f07fcf9c782712670e24ce6a91c04
--- /dev/null
+++ b/models/vocoders/vocoder_trainer.py
@@ -0,0 +1,180 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import random
+from pathlib import Path
+import re
+
+import accelerate
+import json5
+import numpy as np
+import torch
+from accelerate.utils import ProjectConfiguration
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from models.vocoders.vocoder_dataset import VocoderConcatDataset
+from models.vocoders.vocoder_sampler import build_samplers
+
+
+class VocoderTrainer:
+    def __init__(self):
+        super().__init__()
+
+    def _init_accelerator(self):
+        """Initialize the accelerator components."""
+        self.exp_dir = os.path.join(
+            os.path.abspath(self.cfg.log_dir), self.args.exp_name
+        )
+        project_config = ProjectConfiguration(
+            project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log")
+        )
+        self.accelerator = accelerate.Accelerator(
+            gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
+            log_with=self.cfg.train.tracker,
+            project_config=project_config,
+        )
+        if self.accelerator.is_main_process:
+            os.makedirs(project_config.project_dir, exist_ok=True)
+            os.makedirs(project_config.logging_dir, exist_ok=True)
+        with self.accelerator.main_process_first():
+            self.accelerator.init_trackers(self.args.exp_name)
+
+    def _build_dataset(self):
+        pass
+
+    def _build_criterion(self):
+        pass
+
+    def _build_model(self):
+        pass
+
+    def _build_dataloader(self):
+        """Build dataloader which merges a series of datasets."""
+        # Build dataset instance for each dataset and combine them by ConcatDataset
+        Dataset, Collator = self._build_dataset()
+
+        # Build train set
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=False)
+            datasets_list.append(subdataset)
+        train_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True)
+        train_collate = Collator(self.cfg)
+        _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
+        train_loader = DataLoader(
+            train_dataset,
+            collate_fn=train_collate,
+            batch_sampler=batch_sampler,
+            num_workers=self.cfg.train.dataloader.num_worker,
+            pin_memory=self.cfg.train.dataloader.pin_memory,
+        )
+
+        # Build test set
+        datasets_list = []
+        for dataset in self.cfg.dataset:
+            subdataset = Dataset(self.cfg, dataset, is_valid=True)
+            datasets_list.append(subdataset)
+        valid_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True)
+        valid_collate = Collator(self.cfg)
+        _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "train")
+        valid_loader = DataLoader(
+            valid_dataset,
+            collate_fn=valid_collate,
+            batch_sampler=batch_sampler,
+            num_workers=self.cfg.train.dataloader.num_worker,
+            pin_memory=self.cfg.train.dataloader.pin_memory,
+        )
+        return train_loader, valid_loader
+
+    def _build_optimizer(self):
+        pass
+
+    def _build_scheduler(self):
+        pass
+
+    def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"):
+        """Load model from checkpoint. If a folder is given, it will
+        load the latest checkpoint in checkpoint_dir. If a path is given
+        it will load the checkpoint specified by checkpoint_path.
+        **Only use this method after** ``accelerator.prepare()``.
+        """
+        if checkpoint_path is None:
+            ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
+            ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
+            checkpoint_path = ls[0]
+        if resume_type == "resume":
+            self.accelerator.load_state(checkpoint_path)
+        elif resume_type == "finetune":
+            accelerate.load_checkpoint_and_dispatch(
+                self.accelerator.unwrap_model(self.model),
+                os.path.join(checkpoint_path, "pytorch_model.bin"),
+            )
+            self.logger.info("Load model weights for finetune SUCCESS!")
+        else:
+            raise ValueError("Unsupported resume type: {}".format(resume_type))
+        self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
+        self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
+        return checkpoint_path
+
+    def train_loop(self):
+        pass
+
+    def _train_epoch(self):
+        pass
+
+    def _valid_epoch(self):
+        pass
+
+    def _train_step(self):
+        pass
+
+    def _valid_step(self):
+        pass
+
+    def _inference(self):
+        pass
+
+    def _set_random_seed(self, seed):
+        """Set random seed for all possible random modules."""
+        random.seed(seed)
+        np.random.seed(seed)
+        torch.random.manual_seed(seed)
+
+    def _check_nan(self, loss):
+        if torch.any(torch.isnan(loss)):
+            self.logger.fatal("Fatal Error: NaN!")
+            self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
+
+    def _check_basic_configs(self):
+        if self.cfg.train.gradient_accumulation_step <= 0:
+            self.logger.fatal("Invalid gradient_accumulation_step value!")
+            self.logger.error(
+                f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+            )
+            self.accelerator.end_training()
+            raise ValueError(
+                f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
+            )
+
+    def _count_parameters(self):
+        pass
+
+    def _dump_cfg(self, path):
+        os.makedirs(os.path.dirname(path), exist_ok=True)
+        json5.dump(
+            self.cfg,
+            open(path, "w"),
+            indent=4,
+            sort_keys=True,
+            ensure_ascii=False,
+            quote_keys=True,
+        )
+
+    def _is_valid_pattern(self, directory_name):
+        directory_name = str(directory_name)
+        pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}"
+        return re.match(pattern, directory_name) is not None
diff --git a/pretrained/bigvgan/400000.pt b/pretrained/bigvgan/400000.pt
new file mode 100755
index 0000000000000000000000000000000000000000..a36c956e753aef0862753c496ade62d23ab5c906
--- /dev/null
+++ b/pretrained/bigvgan/400000.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:989df2350b502e1175cdb1d204d9f81c27ddf97fe1919db4fa2605631e4cab1d
+size 1846939571
diff --git a/pretrained/bigvgan/args.json b/pretrained/bigvgan/args.json
new file mode 100644
index 0000000000000000000000000000000000000000..06878c8cc6e7c667b107836a51c1d577c15fa7b1
--- /dev/null
+++ b/pretrained/bigvgan/args.json
@@ -0,0 +1,235 @@
+{
+    "base_config": "egs/vocoder/gan/exp_config_base.json",
+    "exp_name": "bigvgan_large",
+    "inference": {
+        "batch_size": 1,
+    },
+    "model": {
+        "bigvgan": {
+            "activation": "snakebeta",
+            "resblock": "1",
+            "resblock_dilation_sizes": [
+                [
+                    1,
+                    3,
+                    5,
+                ],
+                [
+                    1,
+                    3,
+                    5,
+                ],
+                [
+                    1,
+                    3,
+                    5,
+                ],
+            ],
+            "resblock_kernel_sizes": [
+                3,
+                7,
+                11,
+            ],
+            "snake_logscale": true,
+            "upsample_initial_channel": 1536,
+            "upsample_kernel_sizes": [
+                8,
+                8,
+                4,
+                4,
+                4,
+                4,
+            ],
+            "upsample_rates": [
+                4,
+                4,
+                2,
+                2,
+                2,
+                2,
+            ],
+        },
+        "discriminators": [
+            "mpd",
+            "msstftd",
+        ],
+        "generator": "bigvgan",
+        "mpd": {
+            "discriminator_channel_multi": 1,
+            "mpd_reshapes": [
+                2,
+                3,
+                5,
+                7,
+                11,
+            ],
+            "use_spectral_norm": false,
+        },
+        "mrd": {
+            "discriminator_channel_multi": 1,
+            "mrd_override": false,
+            "resolutions": [
+                [
+                    1024,
+                    120,
+                    600,
+                ],
+                [
+                    2048,
+                    240,
+                    1200,
+                ],
+                [
+                    512,
+                    50,
+                    240,
+                ],
+            ],
+            "use_spectral_norm": false,
+        },
+        "msstftd": {
+            "filters": 32,
+        },
+    },
+    "model_type": "GANVocoder",
+    "preprocess": {
+        "audio_dir": "audios",
+        "bits": 8,
+        "contentvec_dir": "contentvec",
+        "cut_mel_frame": 32,
+        "data_augment": false,
+        "dur_dir": "durs",
+        "duration_dir": "duration",
+        "emo2id": "emo2id.json",
+        "energy_dir": "energys",
+        "energy_extract_mode": "from_mel",
+        "energy_norm": false,
+        "extract_audio": true,
+        "extract_contentvec_feature": false,
+        "extract_duration": false,
+        "extract_energy": false,
+        "extract_label": false,
+        "extract_mcep": false,
+        "extract_mel": true,
+        "extract_mert_feature": false,
+        "extract_one_hot": false,
+        "extract_pitch": false,
+        "extract_uv": false,
+        "extract_wenet_feature": false,
+        "extract_whisper_feature": false,
+        "f0_max": 1100,
+        "f0_min": 50,
+        "file_lst": "file.lst",
+        "fmax": 12000,
+        "fmin": 0,
+        "hop_size": 256,
+        "is_mu_law": false,
+        "lab_dir": "labs",
+        "label_dir": "labels",
+        "mcep_dir": "mcep",
+        "mel_dir": "mels",
+        "mel_min_max_norm": false,
+        "min_level_db": -115,
+        "n_fft": 1024,
+        "n_mel": 100,
+        "num_silent_frames": 8,
+        "phone_seq_file": "phone_seq_file",
+        "pitch_bin": 256,
+        "pitch_dir": "pitches",
+        "pitch_extractor": "parselmouth",
+        "pitch_max": 1100.0,
+        "pitch_min": 50.0,
+        "pitch_norm": false,
+        "processed_dir": "processed_data",
+        "ref_level_db": 20,
+        "sample_rate": 24000,
+        "spk2id": "singers.json",
+        "train_file": "train.json",
+        "trim_fft_size": 512,
+        "trim_hop_size": 128,
+        "trim_silence": false,
+        "trim_top_db": 30,
+        "trimmed_wav_dir": "trimmed_wavs",
+        "use_audio": true,
+        "use_dur": false,
+        "use_emoid": false,
+        "use_frame_duration": false,
+        "use_frame_energy": false,
+        "use_frame_pitch": false,
+        "use_lab": false,
+        "use_label": false,
+        "use_log_scale_energy": false,
+        "use_log_scale_pitch": false,
+        "use_mel": true,
+        "use_one_hot": false,
+        "use_phn_seq": false,
+        "use_phone_duration": false,
+        "use_phone_energy": false,
+        "use_phone_pitch": false,
+        "use_spkid": false,
+        "use_uv": false,
+        "use_wav": false,
+        "use_wenet": false,
+        "utt2emo": "utt2emo",
+        "utt2spk": "utt2spk",
+        "uv_dir": "uvs",
+        "valid_file": "test.json",
+        "wav_dir": "wavs",
+        "wenet_dir": "wenet",
+        "win_size": 1024,
+    },
+    "supported_model_type": [
+        "GANVocoder",
+        "Fastspeech2",
+        "DiffSVC",
+        "Transformer",
+        "EDM",
+        "CD",
+    ],
+    "train": {
+        "adamw": {
+            "adam_b1": 0.8,
+            "adam_b2": 0.99,
+            "lr": 0.0002,
+        },
+        "batch_size": 4,
+        "criterions": [
+            "feature",
+            "discriminator",
+            "generator",
+            "mel",
+        ],
+        "dataloader": {
+            "num_worker": 4,
+            "pin_memory": true,
+        },
+        "ddp": true,
+        "epochs": 50000,
+        "exponential_lr": {
+            "lr_decay": 0.999,
+        },
+        "gradient_accumulation_step": 1,
+        "keep_checkpoint_max": 5,
+        "max_epoch": 1000000,
+        "max_steps": 1000000,
+        "multi_speaker_training": false,
+        "random_seed": 114514,
+        "run_eval": [
+            true,
+        ],
+        "sampler": {
+            "drop_last": true,
+            "holistic_shuffle": true,
+        },
+        "save_checkpoint_stride": [
+            200,
+        ],
+        "save_checkpoints_steps": 10000,
+        "save_summary_steps": 500,
+        "total_training_steps": 50000,
+        "tracker": [
+            "tensorboard",
+        ],
+        "valid_interval": 10000,
+    },
+}
\ No newline at end of file
diff --git a/pretrained/contentvec/README.md b/pretrained/contentvec/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6ea10938244c7282355be035c9489efa5bf08bdd
--- /dev/null
+++ b/pretrained/contentvec/README.md
@@ -0,0 +1,5 @@
+# Download
+
+- [Link](https://github.com/auspicious3000/contentvec)
+- Model: `ContentVec_legacy`
+- Classes: 500
diff --git a/pretrained/contentvec/checkpoint_best_legacy_500.pt b/pretrained/contentvec/checkpoint_best_legacy_500.pt
new file mode 100755
index 0000000000000000000000000000000000000000..9a2f13fb9c7047dff746e2d5d88c0d0a5aecf643
--- /dev/null
+++ b/pretrained/contentvec/checkpoint_best_legacy_500.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:60d936ec5a566776fc392e69ad8b630d14eb588111233fe313436e200a7b187b
+size 1330114945
diff --git a/pretrained/whisper/medium.pt b/pretrained/whisper/medium.pt
new file mode 100644
index 0000000000000000000000000000000000000000..8aca41c710014a3d39774cd7592fa086177c672f
--- /dev/null
+++ b/pretrained/whisper/medium.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1
+size 1528008539
diff --git a/processors/__init__.py b/processors/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/processors/acoustic_extractor.py b/processors/acoustic_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c56486336b28a8ae4560d927cf8eb633b7c4513
--- /dev/null
+++ b/processors/acoustic_extractor.py
@@ -0,0 +1,864 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+
+import json
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+from utils.io import save_feature, save_txt
+from utils.util import has_existed
+from utils.tokenizer import extract_encodec_token
+from utils.stft import TacotronSTFT
+from utils.dsp import compress, audio_to_label
+from utils.data_utils import remove_outlier
+from preprocessors.metadata import replace_augment_name
+from scipy.interpolate import interp1d
+
+ZERO = 1e-12
+
+            
+def extract_utt_acoustic_features_parallel(metadata, dataset_output, cfg, n_workers=1):
+    """Extract acoustic features from utterances using muliprocess
+
+    Args:
+        metadata (dict): dictionary that stores data in train.json and test.json files
+        dataset_output (str): directory to store acoustic features
+        cfg (dict): dictionary that stores configurations
+        n_workers (int, optional): num of processes to extract features in parallel. Defaults to 1.
+
+    Returns:
+        list: acoustic features
+    """
+    for utt in tqdm(metadata):
+        if cfg.task_type == "tts":
+            extract_utt_acoustic_features_tts(dataset_output, cfg, utt)
+        if cfg.task_type == "svc":
+            extract_utt_acoustic_features_svc(dataset_output, cfg, utt)
+        if cfg.task_type == "vocoder":
+            extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt)
+        if cfg.task_type == "tta":
+            extract_utt_acoustic_features_tta(dataset_output, cfg, utt)
+
+
+def avg_phone_feature(feature, duration, interpolation=False):
+    feature = feature[: sum(duration)]
+    if interpolation:
+        nonzero_ids = np.where(feature != 0)[0]
+        interp_fn = interp1d(
+            nonzero_ids,
+            feature[nonzero_ids],
+            fill_value=(feature[nonzero_ids[0]], feature[nonzero_ids[-1]]),
+            bounds_error=False,
+        )
+        feature = interp_fn(np.arange(0, len(feature)))
+
+    # Phoneme-level average
+    pos = 0
+    for i, d in enumerate(duration):
+        if d > 0:
+            feature[i] = np.mean(feature[pos : pos + d])
+        else:
+            feature[i] = 0
+        pos += d
+    feature = feature[: len(duration)]
+    return feature
+
+
+def extract_utt_acoustic_features_serial(metadata, dataset_output, cfg):
+    """Extract acoustic features from utterances (in single process)
+
+    Args:
+        metadata (dict): dictionary that stores data in train.json and test.json files
+        dataset_output (str): directory to store acoustic features
+        cfg (dict): dictionary that stores configurations
+
+    """
+    for utt in tqdm(metadata):
+        if cfg.task_type == "tts":
+            extract_utt_acoustic_features_tts(dataset_output, cfg, utt)
+        if cfg.task_type == "svc":
+            extract_utt_acoustic_features_svc(dataset_output, cfg, utt)
+        if cfg.task_type == "vocoder":
+            extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt)
+        if cfg.task_type == "tta":
+            extract_utt_acoustic_features_tta(dataset_output, cfg, utt)
+
+
+def __extract_utt_acoustic_features(dataset_output, cfg, utt):
+    """Extract acoustic features from utterances (in single process)
+
+    Args:
+        dataset_output (str): directory to store acoustic features
+        cfg (dict): dictionary that stores configurations
+        utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index},
+                    path to utternace, duration, utternace index
+
+    """
+    from utils import audio, f0, world, duration
+
+    uid = utt["Uid"]
+    wav_path = utt["Path"]
+    if os.path.exists(os.path.join(dataset_output, cfg.preprocess.raw_data)):
+        wav_path = os.path.join(
+            dataset_output, cfg.preprocess.raw_data, utt["Singer"], uid + ".wav"
+        )
+
+    with torch.no_grad():
+        # Load audio data into tensor with sample rate of the config file
+        wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate)
+        wav = wav_torch.cpu().numpy()
+
+        # extract features
+        if cfg.preprocess.extract_duration:
+            durations, phones, start, end = duration.get_duration(
+                utt, wav, cfg.preprocess
+            )
+            save_feature(dataset_output, cfg.preprocess.duration_dir, uid, durations)
+            save_txt(dataset_output, cfg.preprocess.lab_dir, uid, phones)
+            wav = wav[start:end].astype(np.float32)
+            wav_torch = torch.from_numpy(wav).to(wav_torch.device)
+
+        if cfg.preprocess.extract_linear_spec:
+            from utils.mel import extract_linear_features
+
+            linear = extract_linear_features(wav_torch.unsqueeze(0), cfg.preprocess)
+            save_feature(
+                dataset_output, cfg.preprocess.linear_dir, uid, linear.cpu().numpy()
+            )
+
+        if cfg.preprocess.extract_mel:
+            from utils.mel import extract_mel_features
+
+            if cfg.preprocess.mel_extract_mode == "taco":
+                _stft = TacotronSTFT(
+                    sampling_rate=cfg.preprocess.sample_rate,
+                    win_length=cfg.preprocess.win_size,
+                    hop_length=cfg.preprocess.hop_size,
+                    filter_length=cfg.preprocess.n_fft,
+                    n_mel_channels=cfg.preprocess.n_mel,
+                    mel_fmin=cfg.preprocess.fmin,
+                    mel_fmax=cfg.preprocess.fmax,
+                )
+                mel = extract_mel_features(
+                    wav_torch.unsqueeze(0), cfg.preprocess, taco=True, _stft=_stft
+                )
+                if cfg.preprocess.extract_duration:
+                    mel = mel[:, : sum(durations)]
+            else:
+                mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
+            save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())
+
+        if cfg.preprocess.extract_energy:
+            if (
+                cfg.preprocess.energy_extract_mode == "from_mel"
+                and cfg.preprocess.extract_mel
+            ):
+                energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy()
+            elif cfg.preprocess.energy_extract_mode == "from_waveform":
+                energy = audio.energy(wav, cfg.preprocess)
+            elif cfg.preprocess.energy_extract_mode == "from_tacotron_stft":
+                _stft = TacotronSTFT(
+                    sampling_rate=cfg.preprocess.sample_rate,
+                    win_length=cfg.preprocess.win_size,
+                    hop_length=cfg.preprocess.hop_size,
+                    filter_length=cfg.preprocess.n_fft,
+                    n_mel_channels=cfg.preprocess.n_mel,
+                    mel_fmin=cfg.preprocess.fmin,
+                    mel_fmax=cfg.preprocess.fmax,
+                )
+                _, energy = audio.get_energy_from_tacotron(wav, _stft)
+            else:
+                assert cfg.preprocess.energy_extract_mode in [
+                    "from_mel",
+                    "from_waveform",
+                    "from_tacotron_stft",
+                ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]"
+            if cfg.preprocess.extract_duration:
+                energy = energy[: sum(durations)]
+                phone_energy = avg_phone_feature(energy, durations)
+                save_feature(
+                    dataset_output, cfg.preprocess.phone_energy_dir, uid, phone_energy
+                )
+
+            save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy)
+
+        if cfg.preprocess.extract_pitch:
+            pitch = f0.get_f0(wav, cfg.preprocess)
+            if cfg.preprocess.extract_duration:
+                pitch = pitch[: sum(durations)]
+                phone_pitch = avg_phone_feature(pitch, durations, interpolation=True)
+                save_feature(
+                    dataset_output, cfg.preprocess.phone_pitch_dir, uid, phone_pitch
+                )
+            save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch)
+
+            if cfg.preprocess.extract_uv:
+                assert isinstance(pitch, np.ndarray)
+                uv = pitch != 0
+                save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv)
+
+        if cfg.preprocess.extract_audio:
+            save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav)
+
+        if cfg.preprocess.extract_label:
+            if cfg.preprocess.is_mu_law:
+                # compress audio
+                wav = compress(wav, cfg.preprocess.bits)
+            label = audio_to_label(wav, cfg.preprocess.bits)
+            save_feature(dataset_output, cfg.preprocess.label_dir, uid, label)
+
+        if cfg.preprocess.extract_acoustic_token:
+            if cfg.preprocess.acoustic_token_extractor == "Encodec":
+                codes = extract_encodec_token(wav_path)
+                save_feature(dataset_output, cfg.preprocess.acoustic_token_dir, uid, codes)
+            
+
+def extract_utt_acoustic_features_tts(dataset_output, cfg, utt):
+    __extract_utt_acoustic_features(dataset_output, cfg, utt)
+
+
+def extract_utt_acoustic_features_svc(dataset_output, cfg, utt):
+    __extract_utt_acoustic_features(dataset_output, cfg, utt)
+
+
+def extract_utt_acoustic_features_tta(dataset_output, cfg, utt):
+    __extract_utt_acoustic_features(dataset_output, cfg, utt)
+
+
+def extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt):
+    """Extract acoustic features from utterances (in single process)
+
+    Args:
+        dataset_output (str): directory to store acoustic features
+        cfg (dict): dictionary that stores configurations
+        utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index},
+                    path to utternace, duration, utternace index
+
+    """
+    from utils import audio, f0, world, duration
+
+    uid = utt["Uid"]
+    wav_path = utt["Path"]
+
+    with torch.no_grad():
+        # Load audio data into tensor with sample rate of the config file
+        wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate)
+        wav = wav_torch.cpu().numpy()
+
+        # extract features
+        if cfg.preprocess.extract_mel:
+            from utils.mel import extract_mel_features
+
+            mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess)
+            save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy())
+
+        if cfg.preprocess.extract_energy:
+            if (
+                cfg.preprocess.energy_extract_mode == "from_mel"
+                and cfg.preprocess.extract_mel
+            ):
+                energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy()
+            elif cfg.preprocess.energy_extract_mode == "from_waveform":
+                energy = audio.energy(wav, cfg.preprocess)
+            else:
+                assert cfg.preprocess.energy_extract_mode in [
+                    "from_mel",
+                    "from_waveform",
+                ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]"
+
+            save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy)
+
+        if cfg.preprocess.extract_pitch:
+            pitch = f0.get_f0(wav, cfg.preprocess)
+            save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch)
+
+            if cfg.preprocess.extract_uv:
+                assert isinstance(pitch, np.ndarray)
+                uv = pitch != 0
+                save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv)
+
+        if cfg.preprocess.extract_audio:
+            save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav)
+
+        if cfg.preprocess.extract_label:
+            if cfg.preprocess.is_mu_law:
+                # compress audio
+                wav = compress(wav, cfg.preprocess.bits)
+            label = audio_to_label(wav, cfg.preprocess.bits)
+            save_feature(dataset_output, cfg.preprocess.label_dir, uid, label)
+
+
+def cal_normalized_mel(mel, dataset_name, cfg):
+    mel_min, mel_max = load_mel_extrema(cfg, dataset_name)
+    mel_norm = normalize_mel_channel(mel, mel_min, mel_max)
+    return mel_norm
+
+
+def cal_mel_min_max(dataset, output_path, cfg, metadata=None):
+    dataset_output = os.path.join(output_path, dataset)
+
+    if metadata is None:
+        metadata = []
+        for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+            dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type))
+            with open(dataset_file, "r") as f:
+                metadata.extend(json.load(f))
+
+    tmp_mel_min = []
+    tmp_mel_max = []
+    for item in metadata:
+        mel_path = os.path.join(
+            dataset_output, cfg.preprocess.mel_dir, item["Uid"] + ".npy"
+        )
+        if not os.path.exists(mel_path):
+            continue
+        mel = np.load(mel_path)
+        if mel.shape[0] != cfg.preprocess.n_mel:
+            mel = mel.T
+        # mel: (n_mels, T)
+        assert mel.shape[0] == cfg.preprocess.n_mel
+
+        tmp_mel_min.append(np.min(mel, axis=-1))
+        tmp_mel_max.append(np.max(mel, axis=-1))
+
+    mel_min = np.min(tmp_mel_min, axis=0)
+    mel_max = np.max(tmp_mel_max, axis=0)
+
+    ## save mel min max data
+    mel_min_max_dir = os.path.join(dataset_output, cfg.preprocess.mel_min_max_stats_dir)
+    os.makedirs(mel_min_max_dir, exist_ok=True)
+
+    mel_min_path = os.path.join(mel_min_max_dir, "mel_min.npy")
+    mel_max_path = os.path.join(mel_min_max_dir, "mel_max.npy")
+    np.save(mel_min_path, mel_min)
+    np.save(mel_max_path, mel_max)
+
+
+def denorm_for_pred_mels(cfg, dataset_name, split, pred):
+    """
+    Args:
+        pred: a list whose every element is (frame_len, n_mels)
+    Return:
+        similar like pred
+    """
+    mel_min, mel_max = load_mel_extrema(cfg.preprocess, dataset_name)
+    recovered_mels = [
+        denormalize_mel_channel(mel.T, mel_min, mel_max).T for mel in pred
+    ]
+
+    return recovered_mels
+
+
+def load_mel_extrema(cfg, dataset_name):
+    data_dir = os.path.join(cfg.processed_dir, dataset_name, cfg.mel_min_max_stats_dir)
+
+    min_file = os.path.join(data_dir, "mel_min.npy")
+    max_file = os.path.join(data_dir, "mel_max.npy")
+
+    mel_min = np.load(min_file)
+    mel_max = np.load(max_file)
+
+    return mel_min, mel_max
+
+
+def denormalize_mel_channel(mel, mel_min, mel_max):
+    mel_min = np.expand_dims(mel_min, -1)
+    mel_max = np.expand_dims(mel_max, -1)
+    return (mel + 1) / 2 * (mel_max - mel_min + ZERO) + mel_min
+
+
+def normalize_mel_channel(mel, mel_min, mel_max):
+    mel_min = np.expand_dims(mel_min, -1)
+    mel_max = np.expand_dims(mel_max, -1)
+    return (mel - mel_min) / (mel_max - mel_min + ZERO) * 2 - 1
+
+
+def normalize(dataset, feat_dir, cfg):
+    dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset)
+    print(f"normalize {feat_dir}")
+
+    max_value = np.finfo(np.float64).min
+    min_value = np.finfo(np.float64).max
+
+    scaler = StandardScaler()
+    feat_files = os.listdir(os.path.join(dataset_output, feat_dir))
+
+    for feat_file in tqdm(feat_files):
+        feat_file = os.path.join(dataset_output, feat_dir, feat_file)
+        if not feat_file.endswith(".npy"):
+            continue
+        feat = np.load(feat_file)
+        max_value = max(max_value, max(feat))
+        min_value = min(min_value, min(feat))
+        scaler.partial_fit(feat.reshape((-1, 1)))
+    mean = scaler.mean_[0]
+    std = scaler.scale_[0]
+    stat = np.array([min_value, max_value, mean, std])
+    stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy")
+    np.save(stat_npy, stat)
+    return mean, std, min_value, max_value
+
+
+def load_normalized(feat_dir, dataset_name, cfg):
+    dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset_name)
+    stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy")
+    min_value, max_value, mean, std = np.load(stat_npy)
+    return mean, std, min_value, max_value
+
+
+def cal_pitch_statistics_svc(dataset, output_path, cfg, metadata=None):
+    # path of dataset
+    dataset_dir = os.path.join(output_path, dataset)
+    save_dir = os.path.join(dataset_dir, cfg.preprocess.pitch_dir)
+    os.makedirs(save_dir, exist_ok=True)
+    if has_existed(os.path.join(save_dir, "statistics.json")):
+        return
+
+    if metadata is None:
+        # load singers and ids
+        singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+        # combine train and test metadata
+        metadata = []
+        for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+            dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+            with open(dataset_file, "r") as f:
+                metadata.extend(json.load(f))
+    else:
+        singers = list(set([item["Singer"] for item in metadata]))
+        singers = {
+            "{}_{}".format(dataset, name): idx for idx, name in enumerate(singers)
+        }
+
+    # use different scalers for each singer
+    pitch_scalers = [[] for _ in range(len(singers))]
+    total_pitch_scalers = [[] for _ in range(len(singers))]
+
+    for utt_info in tqdm(metadata, desc="Loading F0..."):
+        # utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+        singer = utt_info["Singer"]
+        pitch_path = os.path.join(
+            dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+        )
+        # total_pitch contains all pitch including unvoiced frames
+        if not os.path.exists(pitch_path):
+            continue
+        total_pitch = np.load(pitch_path)
+        assert len(total_pitch) > 0
+        # pitch contains only voiced frames
+        pitch = total_pitch[total_pitch != 0]
+        spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+        # update pitch scalers
+        pitch_scalers[spkid].extend(pitch.tolist())
+        # update total pitch scalers
+        total_pitch_scalers[spkid].extend(total_pitch.tolist())
+
+    # save pitch statistics for each singer in dict
+    sta_dict = {}
+    for singer in tqdm(singers, desc="Singers statistics"):
+        spkid = singers[singer]
+        # voiced pitch statistics
+        mean, std, min, max, median = (
+            np.mean(pitch_scalers[spkid]),
+            np.std(pitch_scalers[spkid]),
+            np.min(pitch_scalers[spkid]),
+            np.max(pitch_scalers[spkid]),
+            np.median(pitch_scalers[spkid]),
+        )
+
+        # total pitch statistics
+        mean_t, std_t, min_t, max_t, median_t = (
+            np.mean(total_pitch_scalers[spkid]),
+            np.std(total_pitch_scalers[spkid]),
+            np.min(total_pitch_scalers[spkid]),
+            np.max(total_pitch_scalers[spkid]),
+            np.median(total_pitch_scalers[spkid]),
+        )
+        sta_dict[singer] = {
+            "voiced_positions": {
+                "mean": mean,
+                "std": std,
+                "median": median,
+                "min": min,
+                "max": max,
+            },
+            "total_positions": {
+                "mean": mean_t,
+                "std": std_t,
+                "median": median_t,
+                "min": min_t,
+                "max": max_t,
+            },
+        }
+
+    # save statistics
+    with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+        json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def cal_pitch_statistics(dataset, output_path, cfg):
+    # path of dataset
+    dataset_dir = os.path.join(output_path, dataset)
+    if cfg.preprocess.use_phone_pitch:
+        pitch_dir = cfg.preprocess.phone_pitch_dir
+    else:
+        pitch_dir = cfg.preprocess.pitch_dir
+    save_dir = os.path.join(dataset_dir, pitch_dir)
+
+    os.makedirs(save_dir, exist_ok=True)
+    if has_existed(os.path.join(save_dir, "statistics.json")):
+        return
+    # load singers and ids
+    singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+    # combine train and test metadata
+    metadata = []
+    for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+        dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+        with open(dataset_file, "r") as f:
+            metadata.extend(json.load(f))
+
+    # use different scalers for each singer
+    pitch_scalers = [[] for _ in range(len(singers))]
+    total_pitch_scalers = [[] for _ in range(len(singers))]
+
+    for utt_info in metadata:
+        utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+        singer = utt_info["Singer"]
+        pitch_path = os.path.join(dataset_dir, pitch_dir, utt_info["Uid"] + ".npy")
+        # total_pitch contains all pitch including unvoiced frames
+        if not os.path.exists(pitch_path):
+            continue
+        total_pitch = np.load(pitch_path)
+        assert len(total_pitch) > 0
+        # pitch contains only voiced frames
+        # pitch = total_pitch[total_pitch != 0]
+        if cfg.preprocess.pitch_remove_outlier:
+            pitch = remove_outlier(total_pitch)
+        spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+        # update pitch scalers
+        pitch_scalers[spkid].extend(pitch.tolist())
+        # update total pitch scalers
+        total_pitch_scalers[spkid].extend(total_pitch.tolist())
+
+    # save pitch statistics for each singer in dict
+    sta_dict = {}
+    for singer in singers:
+        spkid = singers[singer]
+        # voiced pitch statistics
+        mean, std, min, max, median = (
+            np.mean(pitch_scalers[spkid]),
+            np.std(pitch_scalers[spkid]),
+            np.min(pitch_scalers[spkid]),
+            np.max(pitch_scalers[spkid]),
+            np.median(pitch_scalers[spkid]),
+        )
+
+        # total pitch statistics
+        mean_t, std_t, min_t, max_t, median_t = (
+            np.mean(total_pitch_scalers[spkid]),
+            np.std(total_pitch_scalers[spkid]),
+            np.min(total_pitch_scalers[spkid]),
+            np.max(total_pitch_scalers[spkid]),
+            np.median(total_pitch_scalers[spkid]),
+        )
+        sta_dict[singer] = {
+            "voiced_positions": {
+                "mean": mean,
+                "std": std,
+                "median": median,
+                "min": min,
+                "max": max,
+            },
+            "total_positions": {
+                "mean": mean_t,
+                "std": std_t,
+                "median": median_t,
+                "min": min_t,
+                "max": max_t,
+            },
+        }
+
+    # save statistics
+    with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+        json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def cal_energy_statistics(dataset, output_path, cfg):
+    # path of dataset
+    dataset_dir = os.path.join(output_path, dataset)
+    if cfg.preprocess.use_phone_energy:
+        energy_dir = cfg.preprocess.phone_energy_dir
+    else:
+        energy_dir = cfg.preprocess.energy_dir
+    save_dir = os.path.join(dataset_dir, energy_dir)
+    os.makedirs(save_dir, exist_ok=True)
+    print(os.path.join(save_dir, "statistics.json"))
+    if has_existed(os.path.join(save_dir, "statistics.json")):
+        return
+    # load singers and ids
+    singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r"))
+
+    # combine train and test metadata
+    metadata = []
+    for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+        dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+        with open(dataset_file, "r") as f:
+            metadata.extend(json.load(f))
+
+    # use different scalers for each singer
+    energy_scalers = [[] for _ in range(len(singers))]
+    total_energy_scalers = [[] for _ in range(len(singers))]
+
+    for utt_info in metadata:
+        utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}'
+        singer = utt_info["Singer"]
+        energy_path = os.path.join(dataset_dir, energy_dir, utt_info["Uid"] + ".npy")
+        # total_energy contains all energy including unvoiced frames
+        if not os.path.exists(energy_path):
+            continue
+        total_energy = np.load(energy_path)
+        assert len(total_energy) > 0
+        # energy contains only voiced frames
+        # energy = total_energy[total_energy != 0]
+        if cfg.preprocess.energy_remove_outlier:
+            energy = remove_outlier(total_energy)
+        spkid = singers[f"{replace_augment_name(dataset)}_{singer}"]
+
+        # update energy scalers
+        energy_scalers[spkid].extend(energy.tolist())
+        # update total energyscalers
+        total_energy_scalers[spkid].extend(total_energy.tolist())
+
+    # save energy statistics for each singer in dict
+    sta_dict = {}
+    for singer in singers:
+        spkid = singers[singer]
+        # voiced energy statistics
+        mean, std, min, max, median = (
+            np.mean(energy_scalers[spkid]),
+            np.std(energy_scalers[spkid]),
+            np.min(energy_scalers[spkid]),
+            np.max(energy_scalers[spkid]),
+            np.median(energy_scalers[spkid]),
+        )
+
+        # total energy statistics
+        mean_t, std_t, min_t, max_t, median_t = (
+            np.mean(total_energy_scalers[spkid]),
+            np.std(total_energy_scalers[spkid]),
+            np.min(total_energy_scalers[spkid]),
+            np.max(total_energy_scalers[spkid]),
+            np.median(total_energy_scalers[spkid]),
+        )
+        sta_dict[singer] = {
+            "voiced_positions": {
+                "mean": mean,
+                "std": std,
+                "median": median,
+                "min": min,
+                "max": max,
+            },
+            "total_positions": {
+                "mean": mean_t,
+                "std": std_t,
+                "median": median_t,
+                "min": min_t,
+                "max": max_t,
+            },
+        }
+
+    # save statistics
+    with open(os.path.join(save_dir, "statistics.json"), "w") as f:
+        json.dump(sta_dict, f, indent=4, ensure_ascii=False)
+
+
+def copy_acoustic_features(metadata, dataset_dir, src_dataset_dir, cfg):
+    """Copy acoustic features from src_dataset_dir to dataset_dir
+
+    Args:
+        metadata (dict): dictionary that stores data in train.json and test.json files
+        dataset_dir (str): directory to store acoustic features
+        src_dataset_dir (str): directory to store acoustic features
+        cfg (dict): dictionary that stores configurations
+
+    """
+
+    if cfg.preprocess.extract_mel:
+        if not has_existed(os.path.join(dataset_dir, cfg.preprocess.mel_dir)):
+            os.makedirs(
+                os.path.join(dataset_dir, cfg.preprocess.mel_dir), exist_ok=True
+            )
+            print(
+                "Copying mel features from {} to {}...".format(
+                    src_dataset_dir, dataset_dir
+                )
+            )
+            for utt_info in tqdm(metadata):
+                src_mel_path = os.path.join(
+                    src_dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy"
+                )
+                dst_mel_path = os.path.join(
+                    dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy"
+                )
+                # create soft-links
+                if not os.path.exists(dst_mel_path):
+                    os.symlink(src_mel_path, dst_mel_path)
+    if cfg.preprocess.extract_energy:
+        if not has_existed(os.path.join(dataset_dir, cfg.preprocess.energy_dir)):
+            os.makedirs(
+                os.path.join(dataset_dir, cfg.preprocess.energy_dir), exist_ok=True
+            )
+            print(
+                "Copying energy features from {} to {}...".format(
+                    src_dataset_dir, dataset_dir
+                )
+            )
+            for utt_info in tqdm(metadata):
+                src_energy_path = os.path.join(
+                    src_dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy"
+                )
+                dst_energy_path = os.path.join(
+                    dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy"
+                )
+                # create soft-links
+                if not os.path.exists(dst_energy_path):
+                    os.symlink(src_energy_path, dst_energy_path)
+    if cfg.preprocess.extract_pitch:
+        if not has_existed(os.path.join(dataset_dir, cfg.preprocess.pitch_dir)):
+            os.makedirs(
+                os.path.join(dataset_dir, cfg.preprocess.pitch_dir), exist_ok=True
+            )
+            print(
+                "Copying pitch features from {} to {}...".format(
+                    src_dataset_dir, dataset_dir
+                )
+            )
+            for utt_info in tqdm(metadata):
+                src_pitch_path = os.path.join(
+                    src_dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+                )
+                dst_pitch_path = os.path.join(
+                    dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy"
+                )
+                # create soft-links
+                if not os.path.exists(dst_pitch_path):
+                    os.symlink(src_pitch_path, dst_pitch_path)
+        if cfg.preprocess.extract_uv:
+            if not has_existed(os.path.join(dataset_dir, cfg.preprocess.uv_dir)):
+                os.makedirs(
+                    os.path.join(dataset_dir, cfg.preprocess.uv_dir), exist_ok=True
+                )
+                print(
+                    "Copying uv features from {} to {}...".format(
+                        src_dataset_dir, dataset_dir
+                    )
+                )
+                for utt_info in tqdm(metadata):
+                    src_uv_path = os.path.join(
+                        src_dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy"
+                    )
+                    dst_uv_path = os.path.join(
+                        dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy"
+                    )
+                    # create soft-links
+                    if not os.path.exists(dst_uv_path):
+                        os.symlink(src_uv_path, dst_uv_path)
+    if cfg.preprocess.extract_audio:
+        if not has_existed(os.path.join(dataset_dir, cfg.preprocess.audio_dir)):
+            os.makedirs(
+                os.path.join(dataset_dir, cfg.preprocess.audio_dir), exist_ok=True
+            )
+            print(
+                "Copying audio features from {} to {}...".format(
+                    src_dataset_dir, dataset_dir
+                )
+            )
+            for utt_info in tqdm(metadata):
+                src_audio_path = os.path.join(
+                    src_dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".npy"
+                )
+                dst_audio_path = os.path.join(
+                    dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".npy"
+                )
+                # create soft-links
+                if not os.path.exists(dst_audio_path):
+                    os.symlink(src_audio_path, dst_audio_path)
+    if cfg.preprocess.extract_label:
+        if not has_existed(os.path.join(dataset_dir, cfg.preprocess.label_dir)):
+            os.makedirs(
+                os.path.join(dataset_dir, cfg.preprocess.label_dir), exist_ok=True
+            )
+            print(
+                "Copying label features from {} to {}...".format(
+                    src_dataset_dir, dataset_dir
+                )
+            )
+            for utt_info in tqdm(metadata):
+                src_label_path = os.path.join(
+                    src_dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy"
+                )
+                dst_label_path = os.path.join(
+                    dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy"
+                )
+                # create soft-links
+                if not os.path.exists(dst_label_path):
+                    os.symlink(src_label_path, dst_label_path)
+
+
+def align_duration_mel(dataset, output_path, cfg):
+    print("align the duration and mel")
+
+    dataset_dir = os.path.join(output_path, dataset)
+    metadata = []
+    for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]:
+        dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type))
+        with open(dataset_file, "r") as f:
+            metadata.extend(json.load(f))
+
+    utt2dur = {}
+    for index in tqdm(range(len(metadata))):
+        utt_info = metadata[index]
+        dataset = utt_info["Dataset"]
+        uid = utt_info["Uid"]
+        utt = "{}_{}".format(dataset, uid)
+
+        mel_path = os.path.join(dataset_dir, cfg.preprocess.mel_dir, uid + ".npy")
+        mel = np.load(mel_path).transpose(1, 0)
+        duration_path = os.path.join(
+            dataset_dir, cfg.preprocess.duration_dir, uid + ".npy"
+        )
+        duration = np.load(duration_path)
+        if sum(duration) != mel.shape[0]:
+            duration_sum = sum(duration)
+            mel_len = mel.shape[0]
+            mismatch = abs(duration_sum - mel_len)
+            assert mismatch <= 5, "duration and mel length mismatch!"
+            cloned = np.array(duration, copy=True)
+            if duration_sum > mel_len:
+                for j in range(1, len(duration) - 1):
+                    if mismatch == 0:
+                        break
+                    dur_val = cloned[-j]
+                    if dur_val >= mismatch:
+                        cloned[-j] -= mismatch
+                        mismatch -= dur_val
+                        break
+                    else:
+                        cloned[-j] = 0
+                        mismatch -= dur_val
+
+            elif duration_sum < mel_len:
+                cloned[-1] += mismatch
+            duration = cloned
+        utt2dur[utt] = duration
+        np.save(duration_path, duration)
+
+    return utt2dur
diff --git a/processors/content_extractor.py b/processors/content_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c034b6bf6aac19eb3ab912661110c8066aa3119b
--- /dev/null
+++ b/processors/content_extractor.py
@@ -0,0 +1,540 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import numpy as np
+import yaml
+import copy
+from tqdm import tqdm
+from torchaudio.compliance import kaldi
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import DataLoader
+from fairseq import checkpoint_utils
+from transformers import AutoModel, Wav2Vec2FeatureExtractor
+
+from utils.io_optim import (
+    TorchaudioDataset,
+    LibrosaDataset,
+    FFmpegDataset,
+    collate_batch,
+)
+from modules import whisper_extractor as whisper
+from modules.wenet_extractor.utils.init_model import init_model
+from modules.wenet_extractor.utils.checkpoint import load_checkpoint
+
+"""
+    Extractor for content features
+    1. whisper
+    2. contentvec
+    3. wenet
+    4. mert
+
+    Pipeline:
+        in preprocess.py:
+            call extract_utt_content_features() to extract content features for each utterance
+            extract_utt_content_features() envelopes the following steps:
+                1. load the model (whisper, contentvec, wenet)
+                2. extract the content features
+                3. save the content features into files
+        in svc_dataset.py:
+            call offline_align() to align the content features to the given target length
+
+"""
+
+"""
+    Extractor Usage:
+        1. initialize an instance of extractor
+            extractor = WhisperExtractor(cfg)
+        2. load the specified model
+            extractor.load_model()
+        3. extract the content features
+            extractor.extract_content(utt) for single utterance
+            extractor.extract_content_batch(utts) for batch utterances
+        4. save the content features
+            extractor.save_feature(utt, content_feature) for single utterance
+"""
+
+
+class BaseExtractor:
+    def __init__(self, cfg):
+        self.cfg = cfg
+        self.extractor_type = None
+        self.model = None
+
+    def offline_align(self, content, target_len):
+        """
+        args:
+            content: (source_len, dim)
+            target_len: target length
+        return:
+            mapped_feature: (target_len, dim)
+        """
+        target_hop = self.cfg.preprocess.hop_size
+
+        assert self.extractor_type in ["whisper", "contentvec", "wenet"]
+        if self.extractor_type == "whisper":
+            source_hop = (
+                self.cfg.preprocess.whisper_frameshift
+                * self.cfg.preprocess.whisper_downsample_rate
+                * self.cfg.preprocess.sample_rate
+            )
+        elif self.extractor_type == "contentvec":
+            source_hop = (
+                self.cfg.preprocess.contentvec_frameshift
+                * self.cfg.preprocess.sample_rate
+            )
+        elif self.extractor_type == "wenet":
+            source_hop = (
+                self.cfg.preprocess.wenet_frameshift
+                * self.cfg.preprocess.wenet_downsample_rate
+                * self.cfg.preprocess.sample_rate
+            )
+        source_hop = int(source_hop)
+        factor = np.gcd(source_hop, target_hop)
+        source_hop //= factor
+        target_hop //= factor
+
+        # (source_len, 256)
+        _, width = content.shape
+        # slice the content from padded feature
+        source_len = min(target_len * target_hop // source_hop + 1, len(content))
+
+        # const ~= target_len * target_hop
+        const = source_len * source_hop // target_hop * target_hop
+
+        # (source_len * source_hop, dim)
+        up_sampling_feats = np.repeat(content, source_hop, axis=0)
+        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+        down_sampling_feats = np.average(
+            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+        )
+
+        err = abs(target_len - len(down_sampling_feats))
+        if err > 8:
+            # err_log_dir is indeterminate
+            err_log_dir = os.path.join(
+                self.cfg.preprocess.processed_dir, "align_max_err.log"
+            )
+            try:
+                with open(err_log_dir, "r") as f:
+                    err_num = int(f.read())
+            except:
+                with open(err_log_dir, "w") as f:
+                    f.write("0")
+                err_num = 0
+            if err > err_num:
+                with open(err_log_dir, "w") as f:
+                    f.write(str(err))
+
+        if len(down_sampling_feats) < target_len:
+            # (1, dim) -> (err, dim)
+            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
+            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
+
+        # (target_len, dim)
+        mapped_feature = down_sampling_feats[:target_len]
+
+        return mapped_feature
+
+    def save_feature(self, utt, content_feature):
+        """Save a single utternace to path {cfg.preprocess.processed_dir}
+
+        Args:
+            utt (dict): one item in metadata, containing information for one utterance
+            content_feature (tensor): content feature of one utterance
+        """
+        uid = utt["Uid"]
+        assert self.extractor_type != None
+        out_dir = os.path.join(
+            self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
+        )
+        os.makedirs(out_dir, exist_ok=True)
+        save_path = os.path.join(out_dir, uid + ".npy")
+        # only keep effective parts
+        duration = utt["Duration"]
+        if self.extractor_type == "whisper":
+            frameshift = (
+                self.cfg.preprocess.whisper_frameshift
+                * self.cfg.preprocess.whisper_downsample_rate
+            )  # 20ms
+        elif self.extractor_type == "contentvec":
+            frameshift = self.cfg.preprocess.contentvec_frameshift  # 20ms
+        elif self.extractor_type == "wenet":
+            frameshift = (
+                self.cfg.preprocess.wenet_frameshift
+                * self.cfg.preprocess.wenet_downsample_rate
+            )  # 40ms
+        elif self.extractor_type == "mert":
+            frameshift = self.cfg.preprocess.mert_frameshift
+        else:
+            raise NotImplementedError
+        # calculate the number of valid frames
+        num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
+        # (num_frames, dim) -> (valid_frames, dim)
+        assert (
+            len(content_feature.shape) == 2
+        ), "content feature shape error, it should be (num_frames, dim)"
+        content_feature = content_feature[:num_frames, :]
+        np.save(save_path, content_feature.cpu().detach().numpy())
+
+
+class WhisperExtractor(BaseExtractor):
+    def __init__(self, config):
+        super(WhisperExtractor, self).__init__(config)
+        self.extractor_type = "whisper"
+
+    def load_model(self):
+        # load whisper checkpoint
+        print("Loading Whisper Model...")
+
+        checkpoint_file = (
+            self.cfg.preprocess.whisper_model_path
+            if "whisper_model_path" in self.cfg.preprocess
+            else None
+        )
+        model = whisper.load_model(
+            self.cfg.preprocess.whisper_model, checkpoint_file=checkpoint_file
+        )
+        if torch.cuda.is_available():
+            print("Using GPU...\n")
+            model = model.cuda()
+        else:
+            print("Using CPU...\n")
+
+        self.model = model.eval()
+
+    def extract_content_features(self, wavs, lens):
+        """extract content features from a batch of dataloader
+        Args:
+            wavs: tensor (batch_size, T)
+            lens: list
+        """
+        # wavs: (batch, max_len)
+        wavs = whisper.pad_or_trim(wavs)
+        # batch_mel: (batch, 80, 3000)
+        batch_mel = whisper.log_mel_spectrogram(wavs).to(self.model.device)
+        with torch.no_grad():
+            # (batch, 1500, 1024)
+            features = self.model.embed_audio(batch_mel)
+        return features
+
+
+class ContentvecExtractor(BaseExtractor):
+    def __init__(self, cfg):
+        super(ContentvecExtractor, self).__init__(cfg)
+        self.extractor_type = "contentvec"
+
+    def load_model(self):
+        assert self.model == None
+        # Load model
+        ckpt_path = self.cfg.preprocess.contentvec_file
+        print("Load Contentvec Model...")
+
+        models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+            [ckpt_path],
+            suffix="",
+        )
+        model = models[0]
+        model.eval()
+
+        if torch.cuda.is_available():
+            # print("Using GPU...\n")
+            model = model.cuda()
+
+        self.model = model
+
+    def extract_content_features(self, wavs, lens):
+        """extract content features from a batch of dataloader
+        Args:
+            wavs: tensor (batch, T)
+            lens: list
+        """
+        device = next(self.model.parameters()).device
+        wavs = wavs.to(device)  # (batch, max_len)
+        padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
+        with torch.no_grad():
+            logits = self.model.extract_features(
+                source=wavs, padding_mask=padding_mask, output_layer=12
+            )
+            # feats: (batch, T, 256)
+            feats = self.model.final_proj(logits[0])
+        return feats
+
+
+class WenetExtractor(BaseExtractor):
+    def __init__(self, config):
+        super(WenetExtractor, self).__init__(config)
+        self.extractor_type = "wenet"
+
+    def load_model(self):
+        wenet_cfg = self.cfg.preprocess.wenet_config
+        wenet_model_path = self.cfg.preprocess.wenet_model_path
+        # load Wenet config
+        with open(wenet_cfg, "r") as w:
+            wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
+        self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
+        print("Loading Wenet Model...")
+        self.model = init_model(wenet_configs)
+        load_checkpoint(self.model, wenet_model_path)
+
+        if torch.cuda.is_available():
+            print("Using GPU...\n")
+            self.model = self.model.cuda()
+        else:
+            print("Using CPU...\n")
+
+        self.model = self.model.eval()
+
+    def extract_content_features(self, wavs, lens):
+        """extract content features from a batch of dataloader
+        Args:
+            wavs: tensor
+            lens: list
+        """
+        feats_list = []
+        lengths_list = []
+
+        device = next(self.model.parameters()).device
+        # Extract fbank/mfcc features by kaldi
+        assert self.extract_conf is not None, "load model first!"
+        feats_type = self.extract_conf.get("feats_type", "fbank")
+        assert feats_type in ["fbank", "mfcc"]
+
+        for idx, wav in enumerate(wavs):
+            # wav: (T)
+            wav = wav[: lens[idx]].to(device)
+
+            # pad one frame to compensate for the frame cut off after feature extraction
+            pad_tensor = torch.zeros(160, device=wav.device)
+            wav = torch.cat((wav, pad_tensor), dim=-1)
+            wav *= 1 << 15
+
+            wav = wav.unsqueeze(0)  # (T) -> (1, T)
+            if feats_type == "fbank":
+                fbank_conf = self.extract_conf.get("fbank_conf", {})
+                feat = kaldi.fbank(
+                    wav,
+                    sample_frequency=16000,
+                    num_mel_bins=fbank_conf["num_mel_bins"],
+                    frame_length=fbank_conf["frame_length"],
+                    frame_shift=fbank_conf["frame_shift"],
+                    dither=fbank_conf["dither"],
+                )
+            elif feats_type == "mfcc":
+                mfcc_conf = self.extract_conf.get("mfcc", {})
+                feat = kaldi.mfcc(
+                    wav,
+                    sample_frequency=16000,
+                    num_mel_bins=mfcc_conf["num_mel_bins"],
+                    frame_length=mfcc_conf["frame_length"],
+                    frame_shift=mfcc_conf["frame_shift"],
+                    dither=mfcc_conf["dither"],
+                    num_ceps=mfcc_conf.get("num_ceps", 40),
+                    high_freq=mfcc_conf.get("high_freq", 0.0),
+                    low_freq=mfcc_conf.get("low_freq", 20.0),
+                )
+            feats_list.append(feat)
+            lengths_list.append(feat.shape[0])
+
+        feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
+        feats_tensor = pad_sequence(feats_list, batch_first=True).to(
+            device
+        )  # (batch, len, 80)
+
+        features = self.model.encoder_extractor(
+            feats_tensor,
+            feats_lengths,
+            decoding_chunk_size=-1,
+            num_decoding_left_chunks=-1,
+            simulate_streaming=False,
+        )
+        return features
+
+
+class MertExtractor(BaseExtractor):
+    def __init__(self, cfg):
+        super(MertExtractor, self).__init__(cfg)
+        self.extractor_type = "mert"
+        self.preprocessor = None
+
+    def load_model(self):
+        assert self.model == None
+        assert self.preprocessor == None
+
+        print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)
+
+        local_mert_path = "/mnt/workspace/fangzihao/acce/Amphion/pretrained/MERT"
+
+        model_name = self.cfg.preprocess.mert_model
+        model = AutoModel.from_pretrained(local_mert_path, trust_remote_code=True)
+
+        if torch.cuda.is_available():
+            model = model.cuda()
+        preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
+            local_mert_path, trust_remote_code=True
+        )
+
+        self.model = model
+        self.preprocessor = preprocessor
+
+    def extract_content_features(self, wavs, lens):
+        """extract content features from a batch of dataloader
+        Args:
+            wavs: tensor (batch, T)
+            lens: list
+        """
+        with torch.no_grad():
+            sample_rate = self.preprocessor.sampling_rate
+            device = next(self.model.parameters()).device
+            assert (
+                sample_rate == self.cfg.preprocess.mert_sample_rate
+            ), "mert sample rate mismatch, expected {}, got {}".format(
+                self.cfg.preprocess.mert_sample_rate, sample_rate
+            )
+            mert_features = []
+            # wav: (len)
+            for wav in wavs:
+                # {input_values: tensor, attention_mask: tensor}
+                inputs = self.preprocessor(
+                    wavs, sampling_rate=sample_rate, return_tensors="pt"
+                ).to(device)
+
+                outputs = self.model(**inputs, output_hidden_states=True)
+                # (25 layers, time steps, 1024 feature_dim)
+                all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
+                # (1, frame_len, 1024) -> (frame_len, 1024)
+                feature = outputs.hidden_states[
+                    self.cfg.preprocess.mert_feature_layer
+                ].squeeze(0)
+                mert_features.append(feature)
+
+        return mert_features
+
+
+def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
+    dataset_name = metadata[0]["Dataset"]
+
+    if cfg.preprocess.extract_whisper_feature:
+        feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "whisper")
+        os.makedirs(feat_dir, exist_ok=True)
+        feat_files_num = len(os.listdir(feat_dir))
+
+        if feat_files_num != len(metadata):
+            whisper_waveforms = FFmpegDataset(
+                cfg, dataset_name, cfg.preprocess.whisper_sample_rate, metadata=metadata
+            )
+            data_loader = DataLoader(
+                whisper_waveforms,
+                num_workers=num_workers,
+                shuffle=False,
+                pin_memory=cfg.preprocess.pin_memory,
+                batch_size=cfg.preprocess.content_feature_batch_size,
+                collate_fn=collate_batch,
+                drop_last=False,
+            )
+            extractor = WhisperExtractor(cfg)
+            extractor.load_model()
+            for batch_idx, items in enumerate(tqdm(data_loader)):
+                _metadata, wavs, lens = items
+
+                batch_content_features = extractor.extract_content_features(
+                    wavs,
+                    lens,
+                )
+                for index, utt in enumerate(_metadata):
+                    extractor.save_feature(utt, batch_content_features[index])
+
+    if cfg.preprocess.extract_contentvec_feature:
+        feat_dir = os.path.join(
+            cfg.preprocess.processed_dir, dataset_name, "contentvec"
+        )
+        os.makedirs(feat_dir, exist_ok=True)
+        feat_files_num = len(os.listdir(feat_dir))
+
+        if feat_files_num != len(metadata):
+            contentvec_waveforms = LibrosaDataset(
+                cfg,
+                dataset_name,
+                cfg.preprocess.contentvec_sample_rate,
+                metadata=metadata,
+            )
+            data_loader = DataLoader(
+                contentvec_waveforms,
+                num_workers=num_workers,
+                shuffle=False,
+                pin_memory=cfg.preprocess.pin_memory,
+                batch_size=cfg.preprocess.content_feature_batch_size,
+                collate_fn=collate_batch,
+                drop_last=False,
+            )
+            extractor = ContentvecExtractor(cfg)
+            extractor.load_model()
+            for batch_idx, items in enumerate(tqdm(data_loader)):
+                _metadata, wavs, lens = items
+
+                batch_content_features = extractor.extract_content_features(wavs, lens)
+                for index, utt in enumerate(_metadata):
+                    extractor.save_feature(utt, batch_content_features[index])
+
+    if cfg.preprocess.extract_wenet_feature:
+        feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
+        os.makedirs(feat_dir, exist_ok=True)
+        feat_files_num = len(os.listdir(feat_dir))
+
+        if feat_files_num != len(metadata):
+            wenet_waveforms = TorchaudioDataset(
+                cfg, dataset_name, cfg.preprocess.wenet_sample_rate, metadata=metadata
+            )
+            data_loader = DataLoader(
+                wenet_waveforms,
+                num_workers=num_workers,
+                shuffle=False,
+                pin_memory=cfg.preprocess.pin_memory,
+                batch_size=cfg.preprocess.content_feature_batch_size,
+                collate_fn=collate_batch,
+                drop_last=False,
+            )
+            extractor = WenetExtractor(cfg)
+            extractor.load_model()
+            for batch_idx, items in enumerate(tqdm(data_loader)):
+                _metadata, wavs, lens = items
+
+                batch_content_features = extractor.extract_content_features(
+                    wavs,
+                    lens,
+                )
+                for index, utt in enumerate(_metadata):
+                    extractor.save_feature(utt, batch_content_features[index])
+
+    if cfg.preprocess.extract_mert_feature:
+        feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
+        os.makedirs(feat_dir, exist_ok=True)
+        feat_files_num = len(os.listdir(feat_dir))
+
+        if feat_files_num != len(metadata):
+            mert_waveforms = TorchaudioDataset(
+                cfg, dataset_name, cfg.preprocess.mert_sample_rate, metadata=metadata
+            )
+            data_loader = DataLoader(
+                mert_waveforms,
+                num_workers=num_workers,
+                shuffle=False,
+                pin_memory=cfg.preprocess.pin_memory,
+                batch_size=cfg.preprocess.content_feature_batch_size,
+                collate_fn=collate_batch,
+                drop_last=False,
+            )
+            extractor = MertExtractor(cfg)
+            extractor.load_model()
+            for batch_idx, items in enumerate(tqdm(data_loader)):
+                _metadata, wavs, lens = items
+
+                batch_content_features = extractor.extract_content_features(
+                    wavs,
+                    lens,
+                )
+                for index, utt in enumerate(_metadata):
+                    extractor.save_feature(utt, batch_content_features[index])
diff --git a/processors/data_augment.py b/processors/data_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fc183361d4bcfd454693ee0b7ffdd9758c09312
--- /dev/null
+++ b/processors/data_augment.py
@@ -0,0 +1,378 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import random
+import os
+import json
+
+import numpy as np
+import parselmouth
+import torch
+import torchaudio
+from tqdm import tqdm
+
+from audiomentations import TimeStretch
+
+from pedalboard import (
+    Pedalboard,
+    HighShelfFilter,
+    LowShelfFilter,
+    PeakFilter,
+    PitchShift,
+)
+
+from utils.util import has_existed
+
+PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT = 0.0
+PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT = 1.0
+PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT = 1.0
+
+
+def wav_to_Sound(wav, sr: int) -> parselmouth.Sound:
+    """Convert a waveform to a parselmouth.Sound object
+
+    Args:
+        wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples)
+        sr (int, optional): sampling rate.
+
+    Returns:
+        parselmouth.Sound: a parselmouth.Sound object
+    """
+    assert wav.shape == (1, len(wav[0])), "wav must be of shape (1, n_samples)"
+    sound = None
+    if isinstance(wav, np.ndarray):
+        sound = parselmouth.Sound(wav[0], sampling_frequency=sr)
+    elif isinstance(wav, torch.Tensor):
+        sound = parselmouth.Sound(wav[0].numpy(), sampling_frequency=sr)
+    assert sound is not None, "wav must be either np.ndarray or torch.Tensor"
+    return sound
+
+
+def get_pitch_median(wav, sr: int):
+    """Get the median pitch of a waveform
+
+    Args:
+        wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples)
+        sr (int, optional): sampling rate.
+
+    Returns:
+        parselmouth.Pitch, float: a parselmouth.Pitch object and the median pitch
+    """
+    if not isinstance(wav, parselmouth.Sound):
+        sound = wav_to_Sound(wav, sr)
+    else:
+        sound = wav
+    pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+
+    # To Pitch: Time step(s)(standard value: 0.0), Pitch floor (Hz)(standard value: 75), Pitch ceiling (Hz)(standard value: 600.0)
+    pitch = parselmouth.praat.call(sound, "To Pitch", 0.8 / 75, 75, 600)
+    # Get quantile: From time (s), To time (s), Quantile(0.5 is then the 50% quantile, i.e., the median), Units (Hertz or Bark)
+    pitch_median = parselmouth.praat.call(pitch, "Get quantile", 0.0, 0.0, 0.5, "Hertz")
+
+    return pitch, pitch_median
+
+
+def change_gender(
+    sound,
+    pitch=None,
+    formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
+    new_pitch_median: float = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT,
+    pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
+    duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT,
+) -> parselmouth.Sound:
+    """Invoke change gender function in praat
+
+    Args:
+        sound (parselmouth.Sound): a parselmouth.Sound object
+        pitch (parselmouth.Pitch, optional): a parselmouth.Pitch object. Defaults to None.
+        formant_shift_ratio (float, optional): formant shift ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch. Less than 1.0 means lower pitch.
+        new_pitch_median (float, optional): new pitch median.
+        pitch_range_ratio (float, optional): pitch range ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch range. Less than 1.0 means lower pitch range.
+        duration_factor (float, optional): duration factor. A value of 1.0 means no change. Greater than 1.0 means longer duration. Less than 1.0 means shorter duration.
+
+    Returns:
+        parselmouth.Sound: a parselmouth.Sound object
+    """
+    if pitch is None:
+        new_sound = parselmouth.praat.call(
+            sound,
+            "Change gender",
+            75,
+            600,
+            formant_shift_ratio,
+            new_pitch_median,
+            pitch_range_ratio,
+            duration_factor,
+        )
+    else:
+        new_sound = parselmouth.praat.call(
+            (sound, pitch),
+            "Change gender",
+            formant_shift_ratio,
+            new_pitch_median,
+            pitch_range_ratio,
+            duration_factor,
+        )
+    return new_sound
+
+
+def apply_formant_and_pitch_shift(
+    sound: parselmouth.Sound,
+    formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT,
+    pitch_shift_ratio: float = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT,
+    pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT,
+    duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT,
+) -> parselmouth.Sound:
+    """use Praat "Changer gender" command to manipulate pitch and formant
+    "Change gender": Praat -> Sound Object -> Convert -> Change gender
+    refer to Help of Praat for more details
+    # https://github.com/YannickJadoul/Parselmouth/issues/25#issuecomment-608632887 might help
+    """
+    pitch = None
+    new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+    if pitch_shift_ratio != 1.0:
+        pitch, pitch_median = get_pitch_median(sound, sound.sampling_frequency)
+        new_pitch_median = pitch_median * pitch_shift_ratio
+
+        # refer to https://github.com/praat/praat/issues/1926#issuecomment-974909408
+        pitch_minimum = parselmouth.praat.call(
+            pitch, "Get minimum", 0.0, 0.0, "Hertz", "Parabolic"
+        )
+        new_median = pitch_median * pitch_shift_ratio
+        scaled_minimum = pitch_minimum * pitch_shift_ratio
+        result_minimum = new_median + (scaled_minimum - new_median) * pitch_range_ratio
+        if result_minimum < 0:
+            new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+            pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+        if math.isnan(new_pitch_median):
+            new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT
+            pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+    new_sound = change_gender(
+        sound,
+        pitch,
+        formant_shift_ratio,
+        new_pitch_median,
+        pitch_range_ratio,
+        duration_factor,
+    )
+    return new_sound
+
+
+# Function used in EQ
+def pedalboard_equalizer(wav: np.ndarray, sr: int) -> np.ndarray:
+    """Use pedalboard to do equalizer"""
+    board = Pedalboard()
+
+    cutoff_low_freq = 60
+    cutoff_high_freq = 10000
+
+    q_min = 2
+    q_max = 5
+
+    random_all_freq = True
+    num_filters = 10
+    if random_all_freq:
+        key_freqs = [random.uniform(1, 12000) for _ in range(num_filters)]
+    else:
+        key_freqs = [
+            power_ratio(float(z) / (num_filters - 1), cutoff_low_freq, cutoff_high_freq)
+            for z in range(num_filters)
+        ]
+    q_values = [
+        power_ratio(random.uniform(0, 1), q_min, q_max) for _ in range(num_filters)
+    ]
+    gains = [random.uniform(-12, 12) for _ in range(num_filters)]
+    # low-shelving filter
+    board.append(
+        LowShelfFilter(
+            cutoff_frequency_hz=key_freqs[0], gain_db=gains[0], q=q_values[0]
+        )
+    )
+    # peaking filters
+    for i in range(1, 9):
+        board.append(
+            PeakFilter(
+                cutoff_frequency_hz=key_freqs[i], gain_db=gains[i], q=q_values[i]
+            )
+        )
+    # high-shelving filter
+    board.append(
+        HighShelfFilter(
+            cutoff_frequency_hz=key_freqs[9], gain_db=gains[9], q=q_values[9]
+        )
+    )
+
+    # Apply the pedalboard to the audio
+    processed_audio = board(wav, sr)
+    return processed_audio
+
+
+def power_ratio(r: float, a: float, b: float):
+    return a * math.pow((b / a), r)
+
+
+def audiomentations_time_stretch(wav: np.ndarray, sr: int) -> np.ndarray:
+    """Use audiomentations to do time stretch"""
+    transform = TimeStretch(
+        min_rate=0.8, max_rate=1.25, leave_length_unchanged=False, p=1.0
+    )
+    augmented_wav = transform(wav, sample_rate=sr)
+    return augmented_wav
+
+
+def formant_and_pitch_shift(
+    sound: parselmouth.Sound, fs: bool, ps: bool
+) -> parselmouth.Sound:
+    """ """
+    formant_shift_ratio = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT
+    pitch_shift_ratio = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT
+    pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT
+
+    assert fs != ps, "fs, ps are mutually exclusive"
+
+    if fs:
+        formant_shift_ratio = random.uniform(1.0, 1.4)
+        use_reciprocal = random.uniform(-1, 1) > 0
+        if use_reciprocal:
+            formant_shift_ratio = 1.0 / formant_shift_ratio
+        # only use praat to change formant
+        new_sound = apply_formant_and_pitch_shift(
+            sound,
+            formant_shift_ratio=formant_shift_ratio,
+        )
+        return new_sound
+
+    if ps:
+        board = Pedalboard()
+        board.append(PitchShift(random.uniform(-12, 12)))
+        wav_numpy = sound.values
+        wav_numpy = board(wav_numpy, sound.sampling_frequency)
+        # use pedalboard to change pitch
+        new_sound = parselmouth.Sound(
+            wav_numpy, sampling_frequency=sound.sampling_frequency
+        )
+        return new_sound
+
+
+def wav_manipulation(
+    wav: torch.Tensor,
+    sr: int,
+    aug_type: str = "None",
+    formant_shift: bool = False,
+    pitch_shift: bool = False,
+    time_stretch: bool = False,
+    equalizer: bool = False,
+) -> torch.Tensor:
+    assert aug_type == "None" or aug_type in [
+        "formant_shift",
+        "pitch_shift",
+        "time_stretch",
+        "equalizer",
+    ], "aug_type must be one of formant_shift, pitch_shift, time_stretch, equalizer"
+
+    assert aug_type == "None" or (
+        formant_shift == False
+        and pitch_shift == False
+        and time_stretch == False
+        and equalizer == False
+    ), "if aug_type is specified, other argument must be False"
+
+    if aug_type != "None":
+        if aug_type == "formant_shift":
+            formant_shift = True
+        if aug_type == "pitch_shift":
+            pitch_shift = True
+        if aug_type == "equalizer":
+            equalizer = True
+        if aug_type == "time_stretch":
+            time_stretch = True
+
+    wav_numpy = wav.numpy()
+
+    if equalizer:
+        wav_numpy = pedalboard_equalizer(wav_numpy, sr)
+
+    if time_stretch:
+        wav_numpy = audiomentations_time_stretch(wav_numpy, sr)
+
+    sound = wav_to_Sound(wav_numpy, sr)
+
+    if formant_shift or pitch_shift:
+        sound = formant_and_pitch_shift(sound, formant_shift, pitch_shift)
+
+    wav = torch.from_numpy(sound.values).float()
+    # shape (1, n_samples)
+    return wav
+
+
+def augment_dataset(cfg, dataset) -> list:
+    """Augment dataset with formant_shift, pitch_shift, time_stretch, equalizer
+
+    Args:
+        cfg (dict): configuration
+        dataset (str): dataset name
+
+    Returns:
+        list: augmented dataset names
+    """
+    # load metadata
+    dataset_path = os.path.join(cfg.preprocess.processed_dir, dataset)
+    split = ["train", "test"] if "eval" not in dataset else ["test"]
+    augment_datasets = []
+    aug_types = [
+        "formant_shift" if cfg.preprocess.use_formant_shift else None,
+        "pitch_shift" if cfg.preprocess.use_pitch_shift else None,
+        "time_stretch" if cfg.preprocess.use_time_stretch else None,
+        "equalizer" if cfg.preprocess.use_equalizer else None,
+    ]
+    aug_types = filter(None, aug_types)
+    for aug_type in aug_types:
+        print("Augmenting {} with {}...".format(dataset, aug_type))
+        new_dataset = dataset + "_" + aug_type
+        augment_datasets.append(new_dataset)
+        new_dataset_path = os.path.join(cfg.preprocess.processed_dir, new_dataset)
+
+        for dataset_type in split:
+            metadata_path = os.path.join(dataset_path, "{}.json".format(dataset_type))
+            augmented_metadata = []
+            new_metadata_path = os.path.join(
+                new_dataset_path, "{}.json".format(dataset_type)
+            )
+            os.makedirs(new_dataset_path, exist_ok=True)
+            new_dataset_wav_dir = os.path.join(new_dataset_path, "wav")
+            os.makedirs(new_dataset_wav_dir, exist_ok=True)
+
+            if has_existed(new_metadata_path):
+                continue
+
+            with open(metadata_path, "r") as f:
+                metadata = json.load(f)
+
+            for utt in tqdm(metadata):
+                original_wav_path = utt["Path"]
+                original_wav, sr = torchaudio.load(original_wav_path)
+                new_wav = wav_manipulation(original_wav, sr, aug_type=aug_type)
+                new_wav_path = os.path.join(new_dataset_wav_dir, utt["Uid"] + ".wav")
+                torchaudio.save(new_wav_path, new_wav, sr)
+                new_utt = {
+                    "Dataset": utt["Dataset"] + "_" + aug_type,
+                    "index": utt["index"],
+                    "Singer": utt["Singer"],
+                    "Uid": utt["Uid"],
+                    "Path": new_wav_path,
+                    "Duration": utt["Duration"],
+                }
+                augmented_metadata.append(new_utt)
+            new_metadata_path = os.path.join(
+                new_dataset_path, "{}.json".format(dataset_type)
+            )
+            with open(new_metadata_path, "w") as f:
+                json.dump(augmented_metadata, f, indent=4, ensure_ascii=False)
+    return augment_datasets
diff --git a/processors/phone_extractor.py b/processors/phone_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0c53a79decf6c8c8e5ee68c5d8e05d878564f6a
--- /dev/null
+++ b/processors/phone_extractor.py
@@ -0,0 +1,142 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from tqdm import tqdm
+from text.g2p_module import G2PModule, LexiconModule
+from text.symbol_table import SymbolTable
+
+'''
+    phoneExtractor: extract phone from text
+'''
+class phoneExtractor:
+    def __init__(self, cfg, dataset_name=None, phone_symbol_file=None):
+        '''
+            Args:
+                cfg: config
+                dataset_name: name of dataset
+        '''
+        self.cfg = cfg
+
+        #  phone symbols dict
+        self.phone_symbols = set()
+        
+        # phone symbols dict file
+        if phone_symbol_file is not None:
+            self.phone_symbols_file = phone_symbol_file
+        elif dataset_name is not None:
+            self.dataset_name = dataset_name
+            self.phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
+                                            dataset_name, 
+                                            cfg.preprocess.symbols_dict)
+
+        
+        # initialize g2p module
+        if cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
+            self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor)
+        elif cfg.preprocess.phone_extractor == 'lexicon':
+            assert cfg.preprocess.lexicon_path != ""
+            self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path)
+        else:
+            print('No suppert to', cfg.preprocess.phone_extractor)
+            raise
+
+    
+    def extract_phone(self, text):
+        '''
+            Extract phone from text
+            Args:
+
+                text:  text of utterance
+                
+            Returns:    
+                phone_symbols: set of phone symbols
+                phone_seq: list of phone sequence of each utterance
+        '''
+        
+        if self.cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]:
+            text = text.replace("”", '"').replace("“", '"')
+            phone = self.g2p_module.g2p_conversion(text=text)  
+            self.phone_symbols.update(phone)    
+            phone_seq = [phn for phn in phone]
+            
+        elif self.cfg.preprocess.phone_extractor == 'lexicon':
+            phone_seq = self.g2p_module.g2p_conversion(text)
+            phone = phone_seq
+            if not isinstance(phone_seq, list):
+                phone_seq = phone_seq.split()
+           
+        return phone_seq
+
+    def save_dataset_phone_symbols_to_table(self):
+        # load and merge saved phone symbols                
+        if os.path.exists(self.phone_symbols_file):
+            phone_symbol_dict_saved = SymbolTable.from_file(self.phone_symbols_file)._sym2id.keys()
+            self.phone_symbols.update(set(phone_symbol_dict_saved))
+
+        # save phone symbols
+        phone_symbol_dict = SymbolTable()
+        for s in sorted(list(self.phone_symbols)):
+            phone_symbol_dict.add(s)
+        phone_symbol_dict.to_file(self.phone_symbols_file)    
+
+                
+def extract_utt_phone_sequence(cfg, metadata):
+    '''
+        Extract phone sequence from text
+        Args:
+            cfg: config
+            metadata: list of dict, each dict contains "Uid", "Text"
+            
+    '''
+    
+    dataset_name = cfg.dataset[0]
+    
+    # output path
+    out_path = os.path.join(cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir)
+    os.makedirs(out_path, exist_ok=True)
+        
+    phone_extractor = phoneExtractor(cfg, dataset_name)
+
+    for utt in tqdm(metadata):  
+        uid = utt["Uid"]
+        text = utt["Text"]    
+                 
+        phone_seq = phone_extractor.extract_phone(text)
+                 
+        phone_path = os.path.join(out_path, uid+'.phone')
+        with open(phone_path, 'w') as fin:
+            fin.write(' '.join(phone_seq))
+    
+    if cfg.preprocess.phone_extractor != 'lexicon':
+        phone_extractor.save_dataset_phone_symbols_to_table()
+    
+    
+        
+def save_all_dataset_phone_symbols_to_table(self, cfg, dataset):
+    #  phone symbols dict
+    phone_symbols = set()
+    
+    for dataset_name in dataset:
+        phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
+                                          dataset_name, 
+                                          cfg.preprocess.symbols_dict)
+        
+        # load and merge saved phone symbols                
+        assert os.path.exists(phone_symbols_file)
+        phone_symbol_dict_saved = SymbolTable.from_file(phone_symbols_file)._sym2id.keys()
+        phone_symbols.update(set(phone_symbol_dict_saved))
+        
+    # save all phone symbols to each dataset
+    phone_symbol_dict = SymbolTable()
+    for s in sorted(list(phone_symbols)):
+        phone_symbol_dict.add(s)
+    for dataset_name in dataset:
+        phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, 
+                                          dataset_name, 
+                                          cfg.preprocess.symbols_dict)
+        phone_symbol_dict.to_file(phone_symbols_file)    
+        
+          
\ No newline at end of file
diff --git a/utils/HyperParams/__init__.py b/utils/HyperParams/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..706e31b1c499f11548d08d38e1c3091aeb2dadaa
--- /dev/null
+++ b/utils/HyperParams/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .hps import HyperParams
diff --git a/utils/HyperParams/hps.py b/utils/HyperParams/hps.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6f474c4e28d092ea8ba2cdcf7233322c5ff281
--- /dev/null
+++ b/utils/HyperParams/hps.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+class HyperParams:
+    """The class to store hyperparameters. The key is case-insensitive.
+
+    Args:
+        *args: a list of dict or HyperParams.
+        **kwargs: a list of key-value pairs.
+    """
+
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            if type(v) == dict:
+                v = HyperParams(**v)
+            self[k] = v
+
+    def keys(self):
+        return self.__dict__.keys()
+
+    def items(self):
+        return self.__dict__.items()
+
+    def values(self):
+        return self.__dict__.values()
+
+    def __len__(self):
+        return len(self.__dict__)
+
+    def __getitem__(self, key):
+        return getattr(self, key)
+
+    def __setitem__(self, key, value):
+        return setattr(self, key, value)
+
+    def __contains__(self, key):
+        return key in self.__dict__
+
+    def __repr__(self):
+        return self.__dict__.__repr__()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/utils/audio.py b/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..374d50915cafa106a3035e77b99adc96e8484f0b
--- /dev/null
+++ b/utils/audio.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+from numpy import linalg as LA
+import librosa
+import soundfile as sf
+import librosa.filters
+
+
+def load_audio_torch(wave_file, fs):
+    """Load audio data into torch tensor
+
+    Args:
+        wave_file (str): path to wave file
+        fs (int): sample rate
+
+    Returns:
+        audio (tensor): audio data in tensor
+        fs (int): sample rate
+    """
+
+    audio, sample_rate = librosa.load(wave_file, sr=fs, mono=True)
+    # audio: (T,)
+    assert len(audio) > 2
+
+    # Check the audio type (for soundfile loading backbone) - float, 8bit or 16bit
+    if np.issubdtype(audio.dtype, np.integer):
+        max_mag = -np.iinfo(audio.dtype).min
+    else:
+        max_mag = max(np.amax(audio), -np.amin(audio))
+        max_mag = (
+            (2**31) + 1
+            if max_mag > (2**15)
+            else ((2**15) + 1 if max_mag > 1.01 else 1.0)
+        )
+
+    # Normalize the audio
+    audio = torch.FloatTensor(audio.astype(np.float32)) / max_mag
+
+    if (torch.isnan(audio) | torch.isinf(audio)).any():
+        return [], sample_rate or fs or 48000
+
+    # Resample the audio to our target samplerate
+    if fs is not None and fs != sample_rate:
+        audio = torch.from_numpy(
+            librosa.core.resample(audio.numpy(), orig_sr=sample_rate, target_sr=fs)
+        )
+        sample_rate = fs
+
+    return audio, fs
+
+
+def _stft(y, cfg):
+    return librosa.stft(
+        y=y, n_fft=cfg.n_fft, hop_length=cfg.hop_size, win_length=cfg.win_size
+    )
+
+
+def energy(wav, cfg):
+    D = _stft(wav, cfg)
+    magnitudes = np.abs(D).T  # [F, T]
+    return LA.norm(magnitudes, axis=1)
+
+
+def get_energy_from_tacotron(audio, _stft):
+    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
+    audio = torch.autograd.Variable(audio, requires_grad=False)
+    mel, energy = _stft.mel_spectrogram(audio)
+    energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
+    return mel, energy
diff --git a/utils/audio_slicer.py b/utils/audio_slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..28474596b42c8f8215b878a80967112960d0c9e0
--- /dev/null
+++ b/utils/audio_slicer.py
@@ -0,0 +1,476 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import json
+import numpy as np
+from tqdm import tqdm
+import torch
+import torchaudio
+
+from utils.io import save_audio
+from utils.audio import load_audio_torch
+
+
+# This function is obtained from librosa.
+def get_rms(
+    y,
+    *,
+    frame_length=2048,
+    hop_length=512,
+    pad_mode="constant",
+):
+    padding = (int(frame_length // 2), int(frame_length // 2))
+    y = np.pad(y, padding, mode=pad_mode)
+
+    axis = -1
+    # put our new within-frame axis at the end for now
+    out_strides = y.strides + tuple([y.strides[axis]])
+    # Reduce the shape on the framing axis
+    x_shape_trimmed = list(y.shape)
+    x_shape_trimmed[axis] -= frame_length - 1
+    out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
+    xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
+    if axis < 0:
+        target_axis = axis - 1
+    else:
+        target_axis = axis + 1
+    xw = np.moveaxis(xw, -1, target_axis)
+    # Downsample along the target axis
+    slices = [slice(None)] * xw.ndim
+    slices[axis] = slice(0, None, hop_length)
+    x = xw[tuple(slices)]
+
+    # Calculate power
+    power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
+
+    return np.sqrt(power)
+
+
+class Slicer:
+    """
+    Copy from: https://github.com/openvpi/audio-slicer/blob/main/slicer2.py
+    """
+
+    def __init__(
+        self,
+        sr: int,
+        threshold: float = -40.0,
+        min_length: int = 5000,
+        min_interval: int = 300,
+        hop_size: int = 10,
+        max_sil_kept: int = 5000,
+    ):
+        if not min_length >= min_interval >= hop_size:
+            raise ValueError(
+                "The following condition must be satisfied: min_length >= min_interval >= hop_size"
+            )
+        if not max_sil_kept >= hop_size:
+            raise ValueError(
+                "The following condition must be satisfied: max_sil_kept >= hop_size"
+            )
+        min_interval = sr * min_interval / 1000
+        self.threshold = 10 ** (threshold / 20.0)
+        self.hop_size = round(sr * hop_size / 1000)
+        self.win_size = min(round(min_interval), 4 * self.hop_size)
+        self.min_length = round(sr * min_length / 1000 / self.hop_size)
+        self.min_interval = round(min_interval / self.hop_size)
+        self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
+
+    def _apply_slice(self, waveform, begin, end):
+        begin = begin * self.hop_size
+        if len(waveform.shape) > 1:
+            end = min(waveform.shape[1], end * self.hop_size)
+            return waveform[:, begin:end], begin, end
+        else:
+            end = min(waveform.shape[0], end * self.hop_size)
+            return waveform[begin:end], begin, end
+
+    # @timeit
+    def slice(self, waveform, return_chunks_positions=False):
+        if len(waveform.shape) > 1:
+            # (#channle, wave_len) -> (wave_len)
+            samples = waveform.mean(axis=0)
+        else:
+            samples = waveform
+        if samples.shape[0] <= self.min_length:
+            return [waveform]
+        rms_list = get_rms(
+            y=samples, frame_length=self.win_size, hop_length=self.hop_size
+        ).squeeze(0)
+        sil_tags = []
+        silence_start = None
+        clip_start = 0
+        for i, rms in enumerate(rms_list):
+            # Keep looping while frame is silent.
+            if rms < self.threshold:
+                # Record start of silent frames.
+                if silence_start is None:
+                    silence_start = i
+                continue
+            # Keep looping while frame is not silent and silence start has not been recorded.
+            if silence_start is None:
+                continue
+            # Clear recorded silence start if interval is not enough or clip is too short
+            is_leading_silence = silence_start == 0 and i > self.max_sil_kept
+            need_slice_middle = (
+                i - silence_start >= self.min_interval
+                and i - clip_start >= self.min_length
+            )
+            if not is_leading_silence and not need_slice_middle:
+                silence_start = None
+                continue
+            # Need slicing. Record the range of silent frames to be removed.
+            if i - silence_start <= self.max_sil_kept:
+                pos = rms_list[silence_start : i + 1].argmin() + silence_start
+                if silence_start == 0:
+                    sil_tags.append((0, pos))
+                else:
+                    sil_tags.append((pos, pos))
+                clip_start = pos
+            elif i - silence_start <= self.max_sil_kept * 2:
+                pos = rms_list[
+                    i - self.max_sil_kept : silence_start + self.max_sil_kept + 1
+                ].argmin()
+                pos += i - self.max_sil_kept
+                pos_l = (
+                    rms_list[
+                        silence_start : silence_start + self.max_sil_kept + 1
+                    ].argmin()
+                    + silence_start
+                )
+                pos_r = (
+                    rms_list[i - self.max_sil_kept : i + 1].argmin()
+                    + i
+                    - self.max_sil_kept
+                )
+                if silence_start == 0:
+                    sil_tags.append((0, pos_r))
+                    clip_start = pos_r
+                else:
+                    sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
+                    clip_start = max(pos_r, pos)
+            else:
+                pos_l = (
+                    rms_list[
+                        silence_start : silence_start + self.max_sil_kept + 1
+                    ].argmin()
+                    + silence_start
+                )
+                pos_r = (
+                    rms_list[i - self.max_sil_kept : i + 1].argmin()
+                    + i
+                    - self.max_sil_kept
+                )
+                if silence_start == 0:
+                    sil_tags.append((0, pos_r))
+                else:
+                    sil_tags.append((pos_l, pos_r))
+                clip_start = pos_r
+            silence_start = None
+        # Deal with trailing silence.
+        total_frames = rms_list.shape[0]
+        if (
+            silence_start is not None
+            and total_frames - silence_start >= self.min_interval
+        ):
+            silence_end = min(total_frames, silence_start + self.max_sil_kept)
+            pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
+            sil_tags.append((pos, total_frames + 1))
+        # Apply and return slices.
+        if len(sil_tags) == 0:
+            return [waveform]
+        else:
+            chunks = []
+            chunks_pos_of_waveform = []
+
+            if sil_tags[0][0] > 0:
+                chunk, begin, end = self._apply_slice(waveform, 0, sil_tags[0][0])
+                chunks.append(chunk)
+                chunks_pos_of_waveform.append((begin, end))
+
+            for i in range(len(sil_tags) - 1):
+                chunk, begin, end = self._apply_slice(
+                    waveform, sil_tags[i][1], sil_tags[i + 1][0]
+                )
+                chunks.append(chunk)
+                chunks_pos_of_waveform.append((begin, end))
+
+            if sil_tags[-1][1] < total_frames:
+                chunk, begin, end = self._apply_slice(
+                    waveform, sil_tags[-1][1], total_frames
+                )
+                chunks.append(chunk)
+                chunks_pos_of_waveform.append((begin, end))
+
+            return (
+                chunks
+                if not return_chunks_positions
+                else (
+                    chunks,
+                    chunks_pos_of_waveform,
+                )
+            )
+
+
+def split_utterances_from_audio(
+    wav_file,
+    output_dir,
+    max_duration_of_utterance=10.0,
+    min_interval=300,
+    db_threshold=-40,
+):
+    """
+    Split a long audio into utterances accoring to the silence (VAD).
+
+    max_duration_of_utterance (second):
+        The maximum duration of every utterance (seconds)
+    min_interval (millisecond):
+        The smaller min_interval is, the more sliced audio clips this script is likely to generate.
+    """
+    print("File:", wav_file.split("/")[-1])
+    waveform, fs = torchaudio.load(wav_file)
+
+    slicer = Slicer(sr=fs, min_interval=min_interval, threshold=db_threshold)
+    chunks, positions = slicer.slice(waveform, return_chunks_positions=True)
+
+    durations = [(end - begin) / fs for begin, end in positions]
+    print(
+        "Slicer's min silence part is {}ms, min and max duration of sliced utterances is {}s and {}s".format(
+            min_interval, min(durations), max(durations)
+        )
+    )
+
+    res_chunks, res_positions = [], []
+    for i, chunk in enumerate(chunks):
+        if len(chunk.shape) == 1:
+            chunk = chunk[None, :]
+
+        begin, end = positions[i]
+        assert end - begin == chunk.shape[-1]
+
+        max_wav_len = max_duration_of_utterance * fs
+        if chunk.shape[-1] <= max_wav_len:
+            res_chunks.append(chunk)
+            res_positions.append(positions[i])
+        else:
+            # TODO: to reserve overlapping and conduct fade-in, fade-out
+
+            # Get segments number
+            number = 2
+            while chunk.shape[-1] // number >= max_wav_len:
+                number += 1
+            seg_len = chunk.shape[-1] // number
+
+            # Split
+            for num in range(number):
+                s = seg_len * num
+                t = min(s + seg_len, chunk.shape[-1])
+
+                seg_begin = begin + s
+                seg_end = begin + t
+
+                res_chunks.append(chunk[:, s:t])
+                res_positions.append((seg_begin, seg_end))
+
+    # Save utterances
+    os.makedirs(output_dir, exist_ok=True)
+    res = {"fs": int(fs)}
+    for i, chunk in enumerate(res_chunks):
+        filename = "{:04d}.wav".format(i)
+        res[filename] = [int(p) for p in res_positions[i]]
+        save_audio(os.path.join(output_dir, filename), chunk, fs)
+
+    # Save positions
+    with open(os.path.join(output_dir, "positions.json"), "w") as f:
+        json.dump(res, f, indent=4, ensure_ascii=False)
+    return res
+
+
+def is_silence(
+    wavform,
+    fs,
+    threshold=-40.0,
+    min_interval=300,
+    hop_size=10,
+    min_length=5000,
+):
+    """
+    Detect whether the given wavform is a silence
+
+    wavform: (T, )
+    """
+    threshold = 10 ** (threshold / 20.0)
+
+    hop_size = round(fs * hop_size / 1000)
+    win_size = min(round(min_interval), 4 * hop_size)
+    min_length = round(fs * min_length / 1000 / hop_size)
+
+    if wavform.shape[0] <= min_length:
+        return True
+
+    # (#Frame,)
+    rms_array = get_rms(y=wavform, frame_length=win_size, hop_length=hop_size).squeeze(
+        0
+    )
+    return (rms_array < threshold).all()
+
+
+def split_audio(
+    wav_file, target_sr, output_dir, max_duration_of_segment=10.0, overlap_duration=1.0
+):
+    """
+    Split a long audio into segments.
+
+    target_sr:
+        The target sampling rate to save the segments.
+    max_duration_of_utterance (second):
+        The maximum duration of every utterance (second)
+    overlap_duraion:
+        Each segment has "overlap duration" (second) overlap with its previous and next segment
+    """
+    # (#channel, T) -> (T,)
+    waveform, fs = torchaudio.load(wav_file)
+    waveform = torchaudio.functional.resample(
+        waveform, orig_freq=fs, new_freq=target_sr
+    )
+    waveform = torch.mean(waveform, dim=0)
+
+    # waveform, _ = load_audio_torch(wav_file, target_sr)
+    assert len(waveform.shape) == 1
+
+    assert overlap_duration < max_duration_of_segment
+    length = int(max_duration_of_segment * target_sr)
+    stride = int((max_duration_of_segment - overlap_duration) * target_sr)
+    chunks = []
+    for i in range(0, len(waveform), stride):
+        # (length,)
+        chunks.append(waveform[i : i + length])
+        if i + length >= len(waveform):
+            break
+
+    # Save segments
+    os.makedirs(output_dir, exist_ok=True)
+    results = []
+    for i, chunk in enumerate(chunks):
+        uid = "{:04d}".format(i)
+        filename = os.path.join(output_dir, "{}.wav".format(uid))
+        results.append(
+            {"Uid": uid, "Path": filename, "Duration": len(chunk) / target_sr}
+        )
+        save_audio(
+            filename,
+            chunk,
+            target_sr,
+            turn_up=not is_silence(chunk, target_sr),
+            add_silence=False,
+        )
+
+    return results
+
+
+def merge_segments_torchaudio(wav_files, fs, output_path, overlap_duration=1.0):
+    """Merge the given wav_files (may have overlaps) into a long audio
+
+    fs:
+        The sampling rate of the wav files.
+    output_path:
+        The output path to save the merged audio.
+    overlap_duration (float, optional):
+        Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0.
+    """
+
+    waveforms = []
+    for file in wav_files:
+        # (T,)
+        waveform, _ = load_audio_torch(file, fs)
+        waveforms.append(waveform)
+
+    if len(waveforms) == 1:
+        save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False)
+        return
+
+    overlap_len = int(overlap_duration * fs)
+    fade_out = torchaudio.transforms.Fade(fade_out_len=overlap_len)
+    fade_in = torchaudio.transforms.Fade(fade_in_len=overlap_len)
+    fade_in_and_out = torchaudio.transforms.Fade(fade_out_len=overlap_len)
+
+    segments_lens = [len(wav) for wav in waveforms]
+    merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1)
+    merged_waveform = torch.zeros(merged_waveform_len)
+
+    start = 0
+    for index, wav in enumerate(
+        tqdm(waveforms, desc="Merge for {}".format(output_path))
+    ):
+        wav_len = len(wav)
+
+        if index == 0:
+            wav = fade_out(wav)
+        elif index == len(waveforms) - 1:
+            wav = fade_in(wav)
+        else:
+            wav = fade_in_and_out(wav)
+
+        merged_waveform[start : start + wav_len] = wav
+        start += wav_len - overlap_len
+
+    save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True)
+
+
+def merge_segments_encodec(wav_files, fs, output_path, overlap_duration=1.0):
+    """Merge the given wav_files (may have overlaps) into a long audio
+
+    fs:
+        The sampling rate of the wav files.
+    output_path:
+        The output path to save the merged audio.
+    overlap_duration (float, optional):
+        Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0.
+    """
+
+    waveforms = []
+    for file in wav_files:
+        # (T,)
+        waveform, _ = load_audio_torch(file, fs)
+        waveforms.append(waveform)
+
+    if len(waveforms) == 1:
+        save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False)
+        return
+
+    device = waveforms[0].device
+    dtype = waveforms[0].dtype
+    shape = waveforms[0].shape[:-1]
+
+    overlap_len = int(overlap_duration * fs)
+    segments_lens = [len(wav) for wav in waveforms]
+    merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1)
+
+    sum_weight = torch.zeros(merged_waveform_len, device=device, dtype=dtype)
+    out = torch.zeros(*shape, merged_waveform_len, device=device, dtype=dtype)
+    offset = 0
+
+    for frame in waveforms:
+        frame_length = frame.size(-1)
+        t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=torch.float32)[
+            1:-1
+        ]
+        weight = 0.5 - (t - 0.5).abs()
+        weighted_frame = frame * weight
+
+        cur = out[..., offset : offset + frame_length]
+        cur += weighted_frame[..., : cur.size(-1)]
+        out[..., offset : offset + frame_length] = cur
+
+        cur = sum_weight[offset : offset + frame_length]
+        cur += weight[..., : cur.size(-1)]
+        sum_weight[offset : offset + frame_length] = cur
+
+        offset += frame_length - overlap_len
+
+    assert sum_weight.min() > 0
+    merged_waveform = out / sum_weight
+    save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True)
diff --git a/utils/data_utils.py b/utils/data_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7976d050f01990c8a98a37d48be67bc68695a6c3
--- /dev/null
+++ b/utils/data_utils.py
@@ -0,0 +1,575 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+
+import numpy as np
+from scipy.interpolate import interp1d
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+
+
+def load_content_feature_path(meta_data, processed_dir, feat_dir):
+    utt2feat_path = {}
+    for utt_info in meta_data:
+        utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+        feat_path = os.path.join(
+            processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy'
+        )
+        utt2feat_path[utt] = feat_path
+
+    return utt2feat_path
+
+
+def load_source_content_feature_path(meta_data, feat_dir):
+    utt2feat_path = {}
+    for utt in meta_data:
+        feat_path = os.path.join(feat_dir, f"{utt}.npy")
+        utt2feat_path[utt] = feat_path
+
+    return utt2feat_path
+
+
+def get_spk_map(spk2id_path, utt2spk_path):
+    utt2spk = {}
+    with open(spk2id_path, "r") as spk2id_file:
+        spk2id = json.load(spk2id_file)
+    with open(utt2spk_path, encoding="utf-8") as f:
+        for line in f.readlines():
+            utt, spk = line.strip().split("\t")
+            utt2spk[utt] = spk
+    return spk2id, utt2spk
+
+
+def get_target_f0_median(f0_dir):
+    total_f0 = []
+    for utt in os.listdir(f0_dir):
+        if not utt.endswith(".npy"):
+            continue
+        f0_feat_path = os.path.join(f0_dir, utt)
+        f0 = np.load(f0_feat_path)
+        total_f0 += f0.tolist()
+
+    total_f0 = np.array(total_f0)
+    voiced_position = np.where(total_f0 != 0)
+    return np.median(total_f0[voiced_position])
+
+
+def get_conversion_f0_factor(source_f0, target_median, source_median=None):
+    """Align the median between source f0 and target f0
+
+    Note: Here we use multiplication, whose factor is target_median/source_median
+
+    Reference: Frequency and pitch interval
+    http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/
+    """
+    if source_median is None:
+        voiced_position = np.where(source_f0 != 0)
+        source_median = np.median(source_f0[voiced_position])
+    factor = target_median / source_median
+    return source_median, factor
+
+
+def transpose_key(frame_pitch, trans_key):
+    # Transpose by user's argument
+    print("Transpose key = {} ...\n".format(trans_key))
+
+    transed_pitch = frame_pitch * 2 ** (trans_key / 12)
+    return transed_pitch
+
+
+def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None):
+    # Loading F0 Base (median) and shift
+    source_pitch_median, factor = get_conversion_f0_factor(
+        frame_pitch, target_pitch_median, source_pitch_median
+    )
+    print(
+        "Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format(
+            source_pitch_median, target_pitch_median, factor
+        )
+    )
+    transed_pitch = frame_pitch * factor
+    return transed_pitch
+
+
+def load_frame_pitch(
+    meta_data,
+    processed_dir,
+    pitch_dir,
+    use_log_scale=False,
+    return_norm=False,
+    interoperate=False,
+    utt2spk=None,
+):
+    utt2pitch = {}
+    utt2uv = {}
+    if utt2spk is None:
+        pitch_scaler = StandardScaler()
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            pitch_path = os.path.join(
+                processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
+            )
+            pitch = np.load(pitch_path)
+            assert len(pitch) > 0
+            uv = pitch != 0
+            utt2uv[utt] = uv
+            if use_log_scale:
+                nonzero_idxes = np.where(pitch != 0)[0]
+                pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
+            utt2pitch[utt] = pitch
+            pitch_scaler.partial_fit(pitch.reshape(-1, 1))
+
+        mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
+        if return_norm:
+            for utt_info in meta_data:
+                utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+                pitch = utt2pitch[utt]
+                normalized_pitch = (pitch - mean) / std
+                utt2pitch[utt] = normalized_pitch
+        pitch_statistic = {"mean": mean, "std": std}
+    else:
+        spk2utt = {}
+        pitch_statistic = []
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            if not utt2spk[utt] in spk2utt:
+                spk2utt[utt2spk[utt]] = []
+            spk2utt[utt2spk[utt]].append(utt)
+
+        for spk in spk2utt:
+            pitch_scaler = StandardScaler()
+            for utt in spk2utt[spk]:
+                dataset = utt.split("_")[0]
+                uid = "_".join(utt.split("_")[1:])
+                pitch_path = os.path.join(
+                    processed_dir, dataset, pitch_dir, f"{uid}.npy"
+                )
+                pitch = np.load(pitch_path)
+                assert len(pitch) > 0
+                uv = pitch != 0
+                utt2uv[utt] = uv
+                if use_log_scale:
+                    nonzero_idxes = np.where(pitch != 0)[0]
+                    pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes])
+                utt2pitch[utt] = pitch
+                pitch_scaler.partial_fit(pitch.reshape(-1, 1))
+
+            mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
+            if return_norm:
+                for utt in spk2utt[spk]:
+                    pitch = utt2pitch[utt]
+                    normalized_pitch = (pitch - mean) / std
+                    utt2pitch[utt] = normalized_pitch
+            pitch_statistic.append({"spk": spk, "mean": mean, "std": std})
+
+    return utt2pitch, utt2uv, pitch_statistic
+
+
+# discard
+def load_phone_pitch(
+    meta_data,
+    processed_dir,
+    pitch_dir,
+    utt2dur,
+    use_log_scale=False,
+    return_norm=False,
+    interoperate=True,
+    utt2spk=None,
+):
+    print("Load Phone Pitch")
+    utt2pitch = {}
+    utt2uv = {}
+    if utt2spk is None:
+        pitch_scaler = StandardScaler()
+        for utt_info in tqdm(meta_data):
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            pitch_path = os.path.join(
+                processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy'
+            )
+            frame_pitch = np.load(pitch_path)
+            assert len(frame_pitch) > 0
+            uv = frame_pitch != 0
+            utt2uv[utt] = uv
+            phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate)
+            if use_log_scale:
+                nonzero_idxes = np.where(phone_pitch != 0)[0]
+                phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
+            utt2pitch[utt] = phone_pitch
+            pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
+
+        mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
+        max_value = np.finfo(np.float64).min
+        min_value = np.finfo(np.float64).max
+        if return_norm:
+            for utt_info in meta_data:
+                utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+                pitch = utt2pitch[utt]
+                normalized_pitch = (pitch - mean) / std
+                max_value = max(max_value, max(normalized_pitch))
+                min_value = min(min_value, min(normalized_pitch))
+                utt2pitch[utt] = normalized_pitch
+                phone_normalized_pitch_path = os.path.join(
+                    processed_dir,
+                    utt_info["Dataset"],
+                    "phone_level_" + pitch_dir,
+                    f'{utt_info["Uid"]}.npy',
+                )
+        pitch_statistic = {
+            "mean": mean,
+            "std": std,
+            "min_value": min_value,
+            "max_value": max_value,
+        }
+    else:
+        spk2utt = {}
+        pitch_statistic = []
+        for utt_info in tqdm(meta_data):
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            if not utt2spk[utt] in spk2utt:
+                spk2utt[utt2spk[utt]] = []
+            spk2utt[utt2spk[utt]].append(utt)
+
+        for spk in spk2utt:
+            pitch_scaler = StandardScaler()
+            for utt in spk2utt[spk]:
+                dataset = utt.split("_")[0]
+                uid = "_".join(utt.split("_")[1:])
+                pitch_path = os.path.join(
+                    processed_dir, dataset, pitch_dir, f"{uid}.npy"
+                )
+                frame_pitch = np.load(pitch_path)
+                assert len(frame_pitch) > 0
+                uv = frame_pitch != 0
+                utt2uv[utt] = uv
+                phone_pitch = phone_average_pitch(
+                    frame_pitch, utt2dur[utt], interoperate
+                )
+                if use_log_scale:
+                    nonzero_idxes = np.where(phone_pitch != 0)[0]
+                    phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes])
+                utt2pitch[utt] = phone_pitch
+                pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1))
+
+            mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0]
+            max_value = np.finfo(np.float64).min
+            min_value = np.finfo(np.float64).max
+
+            if return_norm:
+                for utt in spk2utt[spk]:
+                    pitch = utt2pitch[utt]
+                    normalized_pitch = (pitch - mean) / std
+                    max_value = max(max_value, max(normalized_pitch))
+                    min_value = min(min_value, min(normalized_pitch))
+                    utt2pitch[utt] = normalized_pitch
+            pitch_statistic.append(
+                {
+                    "spk": spk,
+                    "mean": mean,
+                    "std": std,
+                    "min_value": min_value,
+                    "max_value": max_value,
+                }
+            )
+
+    return utt2pitch, utt2uv, pitch_statistic
+
+
+def phone_average_pitch(pitch, dur, interoperate=False):
+    pos = 0
+
+    if interoperate:
+        nonzero_ids = np.where(pitch != 0)[0]
+        interp_fn = interp1d(
+            nonzero_ids,
+            pitch[nonzero_ids],
+            fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]),
+            bounds_error=False,
+        )
+        pitch = interp_fn(np.arange(0, len(pitch)))
+    phone_pitch = np.zeros(len(dur))
+
+    for i, d in enumerate(dur):
+        d = int(d)
+        if d > 0 and pos < len(pitch):
+            phone_pitch[i] = np.mean(pitch[pos : pos + d])
+        else:
+            phone_pitch[i] = 0
+        pos += d
+    return phone_pitch
+
+
+def load_energy(
+    meta_data,
+    processed_dir,
+    energy_dir,
+    use_log_scale=False,
+    return_norm=False,
+    utt2spk=None,
+):
+    utt2energy = {}
+    if utt2spk is None:
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            energy_path = os.path.join(
+                processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
+            )
+            if not os.path.exists(energy_path):
+                continue
+            energy = np.load(energy_path)
+            assert len(energy) > 0
+
+            if use_log_scale:
+                nonzero_idxes = np.where(energy != 0)[0]
+                energy[nonzero_idxes] = np.log(energy[nonzero_idxes])
+            utt2energy[utt] = energy
+
+        if return_norm:
+            with open(
+                os.path.join(
+                    processed_dir, utt_info["Dataset"], energy_dir, "statistics.json"
+                )
+            ) as f:
+                stats = json.load(f)
+                mean, std = (
+                    stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][
+                        "voiced_positions"
+                    ]["mean"],
+                    stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"],
+                )
+            for utt in utt2energy.keys():
+                energy = utt2energy[utt]
+                normalized_energy = (energy - mean) / std
+                utt2energy[utt] = normalized_energy
+
+        energy_statistic = {"mean": mean, "std": std}
+    else:
+        spk2utt = {}
+        energy_statistic = []
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            if not utt2spk[utt] in spk2utt:
+                spk2utt[utt2spk[utt]] = []
+            spk2utt[utt2spk[utt]].append(utt)
+
+        for spk in spk2utt:
+            energy_scaler = StandardScaler()
+            for utt in spk2utt[spk]:
+                dataset = utt.split("_")[0]
+                uid = "_".join(utt.split("_")[1:])
+                energy_path = os.path.join(
+                    processed_dir, dataset, energy_dir, f"{uid}.npy"
+                )
+                if not os.path.exists(energy_path):
+                    continue
+                frame_energy = np.load(energy_path)
+                assert len(frame_energy) > 0
+
+                if use_log_scale:
+                    nonzero_idxes = np.where(frame_energy != 0)[0]
+                    frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
+                utt2energy[utt] = frame_energy
+                energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
+
+            mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
+            if return_norm:
+                for utt in spk2utt[spk]:
+                    energy = utt2energy[utt]
+                    normalized_energy = (energy - mean) / std
+                    utt2energy[utt] = normalized_energy
+            energy_statistic.append({"spk": spk, "mean": mean, "std": std})
+
+    return utt2energy, energy_statistic
+
+
+def load_frame_energy(
+    meta_data,
+    processed_dir,
+    energy_dir,
+    use_log_scale=False,
+    return_norm=False,
+    interoperate=False,
+    utt2spk=None,
+):
+    utt2energy = {}
+    if utt2spk is None:
+        energy_scaler = StandardScaler()
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            energy_path = os.path.join(
+                processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy'
+            )
+            frame_energy = np.load(energy_path)
+            assert len(frame_energy) > 0
+
+            if use_log_scale:
+                nonzero_idxes = np.where(frame_energy != 0)[0]
+                frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
+            utt2energy[utt] = frame_energy
+            energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
+
+        mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
+        if return_norm:
+            for utt_info in meta_data:
+                utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+                energy = utt2energy[utt]
+                normalized_energy = (energy - mean) / std
+                utt2energy[utt] = normalized_energy
+        energy_statistic = {"mean": mean, "std": std}
+
+    else:
+        spk2utt = {}
+        energy_statistic = []
+        for utt_info in meta_data:
+            utt = utt_info["Dataset"] + "_" + utt_info["Uid"]
+            if not utt2spk[utt] in spk2utt:
+                spk2utt[utt2spk[utt]] = []
+            spk2utt[utt2spk[utt]].append(utt)
+
+        for spk in spk2utt:
+            energy_scaler = StandardScaler()
+            for utt in spk2utt[spk]:
+                dataset = utt.split("_")[0]
+                uid = "_".join(utt.split("_")[1:])
+                energy_path = os.path.join(
+                    processed_dir, dataset, energy_dir, f"{uid}.npy"
+                )
+                frame_energy = np.load(energy_path)
+                assert len(frame_energy) > 0
+
+                if use_log_scale:
+                    nonzero_idxes = np.where(frame_energy != 0)[0]
+                    frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes])
+                utt2energy[utt] = frame_energy
+                energy_scaler.partial_fit(frame_energy.reshape(-1, 1))
+
+            mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0]
+            if return_norm:
+                for utt in spk2utt[spk]:
+                    energy = utt2energy[utt]
+                    normalized_energy = (energy - mean) / std
+                    utt2energy[utt] = normalized_energy
+            energy_statistic.append({"spk": spk, "mean": mean, "std": std})
+
+    return utt2energy, energy_statistic
+
+
+def align_length(feature, target_len, pad_value=0.0):
+    feature_len = feature.shape[-1]
+    dim = len(feature.shape)
+    # align 1-D data
+    if dim == 2:
+        if target_len > feature_len:
+            feature = np.pad(
+                feature,
+                ((0, 0), (0, target_len - feature_len)),
+                constant_values=pad_value,
+            )
+        else:
+            feature = feature[:, :target_len]
+    # align 2-D data
+    elif dim == 1:
+        if target_len > feature_len:
+            feature = np.pad(
+                feature, (0, target_len - feature_len), constant_values=pad_value
+            )
+        else:
+            feature = feature[:target_len]
+    else:
+        raise NotImplementedError
+    return feature
+
+
+def align_whisper_feauture_length(
+    feature, target_len, fast_mapping=True, source_hop=320, target_hop=256
+):
+    factor = np.gcd(source_hop, target_hop)
+    source_hop //= factor
+    target_hop //= factor
+    # print(
+    #     "Mapping source's {} frames => target's {} frames".format(
+    #         target_hop, source_hop
+    #     )
+    # )
+
+    max_source_len = 1500
+    target_len = min(target_len, max_source_len * source_hop // target_hop)
+
+    width = feature.shape[-1]
+
+    if fast_mapping:
+        source_len = target_len * target_hop // source_hop + 1
+        feature = feature[:source_len]
+
+    else:
+        source_len = max_source_len
+
+    # const ~= target_len * target_hop
+    const = source_len * source_hop // target_hop * target_hop
+
+    # (source_len * source_hop, dim)
+    up_sampling_feats = np.repeat(feature, source_hop, axis=0)
+    # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+    down_sampling_feats = np.average(
+        up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+    )
+    assert len(down_sampling_feats) >= target_len
+
+    # (target_len, dim)
+    feat = down_sampling_feats[:target_len]
+
+    return feat
+
+
+def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256):
+    factor = np.gcd(source_hop, target_hop)
+    source_hop //= factor
+    target_hop //= factor
+    # print(
+    #     "Mapping source's {} frames => target's {} frames".format(
+    #         target_hop, source_hop
+    #     )
+    # )
+
+    # (source_len, 256)
+    source_len, width = feature.shape
+
+    # const ~= target_len * target_hop
+    const = source_len * source_hop // target_hop * target_hop
+
+    # (source_len * source_hop, dim)
+    up_sampling_feats = np.repeat(feature, source_hop, axis=0)
+    # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+    down_sampling_feats = np.average(
+        up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+    )
+
+    err = abs(target_len - len(down_sampling_feats))
+    if err > 4:  ## why 4 not 3?
+        print("target_len:", target_len)
+        print("raw feature:", feature.shape)
+        print("up_sampling:", up_sampling_feats.shape)
+        print("down_sampling_feats:", down_sampling_feats.shape)
+        exit()
+    if len(down_sampling_feats) < target_len:
+        # (1, dim) -> (err, dim)
+        end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
+        down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
+
+    # (target_len, dim)
+    feat = down_sampling_feats[:target_len]
+
+    return feat
+
+
+def remove_outlier(values):
+    values = np.array(values)
+    p25 = np.percentile(values, 25)
+    p75 = np.percentile(values, 75)
+    lower = p25 - 1.5 * (p75 - p25)
+    upper = p75 + 1.5 * (p75 - p25)
+    normal_indices = np.logical_and(values > lower, values < upper)
+    return values[normal_indices]
diff --git a/utils/distribution.py b/utils/distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..de3000e99194f7e848712d8b4cb77c988f098fd2
--- /dev/null
+++ b/utils/distribution.py
@@ -0,0 +1,270 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from torch.distributions import Normal
+
+
+def log_sum_exp(x):
+    """numerically stable log_sum_exp implementation that prevents overflow"""
+    # TF ordering
+    axis = len(x.size()) - 1
+    m, _ = torch.max(x, dim=axis)
+    m2, _ = torch.max(x, dim=axis, keepdim=True)
+    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
+
+
+def discretized_mix_logistic_loss(
+    y_hat, y, num_classes=256, log_scale_min=-7.0, reduce=True
+):
+    """Discretized mixture of logistic distributions loss
+
+    Note that it is assumed that input is scaled to [-1, 1].
+
+    Args:
+        y_hat (Tensor): Predicted output (B x C x T)
+        y (Tensor): Target (B x T x 1).
+        num_classes (int): Number of classes
+        log_scale_min (float): Log scale minimum value
+        reduce (bool): If True, the losses are averaged or summed for each
+          minibatch.
+
+    Returns
+        Tensor: loss
+    """
+    assert y_hat.dim() == 3
+    assert y_hat.size(1) % 3 == 0
+    nr_mix = y_hat.size(1) // 3
+
+    # (B x T x C)
+    y_hat = y_hat.transpose(1, 2)
+
+    # unpack parameters. (B, T, num_mixtures) x 3
+    logit_probs = y_hat[:, :, :nr_mix]
+    means = y_hat[:, :, nr_mix : 2 * nr_mix]
+    log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min)
+
+    # B x T x 1 -> B x T x num_mixtures
+    y = y.expand_as(means)
+
+    centered_y = y - means
+    inv_stdv = torch.exp(-log_scales)
+    plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1))
+    cdf_plus = torch.sigmoid(plus_in)
+    min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1))
+    cdf_min = torch.sigmoid(min_in)
+
+    # log probability for edge case of 0 (before scaling)
+    # equivalent: torch.log(torch.sigmoid(plus_in))
+    log_cdf_plus = plus_in - F.softplus(plus_in)
+
+    # log probability for edge case of 255 (before scaling)
+    # equivalent: (1 - torch.sigmoid(min_in)).log()
+    log_one_minus_cdf_min = -F.softplus(min_in)
+
+    # probability for all other cases
+    cdf_delta = cdf_plus - cdf_min
+
+    mid_in = inv_stdv * centered_y
+    # log probability in the center of the bin, to be used in extreme cases
+    # (not actually used in our code)
+    log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)
+
+    # tf equivalent
+    """
+    log_probs = tf.where(x < -0.999, log_cdf_plus,
+                         tf.where(x > 0.999, log_one_minus_cdf_min,
+                                  tf.where(cdf_delta > 1e-5,
+                                           tf.log(tf.maximum(cdf_delta, 1e-12)),
+                                           log_pdf_mid - np.log(127.5))))
+    """
+    # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
+    # for num_classes=65536 case? 1e-7? not sure..
+    inner_inner_cond = (cdf_delta > 1e-5).float()
+
+    inner_inner_out = inner_inner_cond * torch.log(
+        torch.clamp(cdf_delta, min=1e-12)
+    ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
+    inner_cond = (y > 0.999).float()
+    inner_out = (
+        inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
+    )
+    cond = (y < -0.999).float()
+    log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out
+
+    log_probs = log_probs + F.log_softmax(logit_probs, -1)
+
+    if reduce:
+        return -torch.sum(log_sum_exp(log_probs))
+    else:
+        return -log_sum_exp(log_probs).unsqueeze(-1)
+
+
+def to_one_hot(tensor, n, fill_with=1.0):
+    # we perform one hot encore with respect to the last axis
+    one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
+    if tensor.is_cuda:
+        one_hot = one_hot.cuda()
+    one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
+    return one_hot
+
+
+def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, clamp_log_scale=False):
+    """
+    Sample from discretized mixture of logistic distributions
+
+    Args:
+        y (Tensor): B x C x T
+        log_scale_min (float): Log scale minimum value
+
+    Returns:
+        Tensor: sample in range of [-1, 1].
+    """
+    assert y.size(1) % 3 == 0
+    nr_mix = y.size(1) // 3
+
+    # B x T x C
+    y = y.transpose(1, 2)
+    logit_probs = y[:, :, :nr_mix]
+
+    # sample mixture indicator from softmax
+    temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
+    temp = logit_probs.data - torch.log(-torch.log(temp))
+    _, argmax = temp.max(dim=-1)
+
+    # (B, T) -> (B, T, nr_mix)
+    one_hot = to_one_hot(argmax, nr_mix)
+    # select logistic parameters
+    means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
+    log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
+    if clamp_log_scale:
+        log_scales = torch.clamp(log_scales, min=log_scale_min)
+    # sample from logistic & clip to interval
+    # we don't actually round to the nearest 8bit value when sampling
+    u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
+    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u))
+
+    x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0)
+
+    return x
+
+
+# we can easily define discretized version of the gaussian loss, however,
+# use continuous version as same as the https://clarinet-demo.github.io/
+def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True):
+    """Mixture of continuous gaussian distributions loss
+
+    Note that it is assumed that input is scaled to [-1, 1].
+
+    Args:
+        y_hat (Tensor): Predicted output (B x C x T)
+        y (Tensor): Target (B x T x 1).
+        log_scale_min (float): Log scale minimum value
+        reduce (bool): If True, the losses are averaged or summed for each
+          minibatch.
+    Returns
+        Tensor: loss
+    """
+    assert y_hat.dim() == 3
+    C = y_hat.size(1)
+    if C == 2:
+        nr_mix = 1
+    else:
+        assert y_hat.size(1) % 3 == 0
+        nr_mix = y_hat.size(1) // 3
+
+    # (B x T x C)
+    y_hat = y_hat.transpose(1, 2)
+
+    # unpack parameters.
+    if C == 2:
+        # special case for C == 2, just for compatibility
+        logit_probs = None
+        means = y_hat[:, :, 0:1]
+        log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min)
+    else:
+        #  (B, T, num_mixtures) x 3
+        logit_probs = y_hat[:, :, :nr_mix]
+        means = y_hat[:, :, nr_mix : 2 * nr_mix]
+        log_scales = torch.clamp(
+            y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min
+        )
+
+    # B x T x 1 -> B x T x num_mixtures
+    y = y.expand_as(means)
+
+    centered_y = y - means
+    dist = Normal(loc=0.0, scale=torch.exp(log_scales))
+    # do we need to add a trick to avoid log(0)?
+    log_probs = dist.log_prob(centered_y)
+
+    if nr_mix > 1:
+        log_probs = log_probs + F.log_softmax(logit_probs, -1)
+
+    if reduce:
+        if nr_mix == 1:
+            return -torch.sum(log_probs)
+        else:
+            return -torch.sum(log_sum_exp(log_probs))
+    else:
+        if nr_mix == 1:
+            return -log_probs
+        else:
+            return -log_sum_exp(log_probs).unsqueeze(-1)
+
+
+def sample_from_mix_gaussian(y, log_scale_min=-7.0):
+    """
+    Sample from (discretized) mixture of gaussian distributions
+    Args:
+        y (Tensor): B x C x T
+        log_scale_min (float): Log scale minimum value
+    Returns:
+        Tensor: sample in range of [-1, 1].
+    """
+    C = y.size(1)
+    if C == 2:
+        nr_mix = 1
+    else:
+        assert y.size(1) % 3 == 0
+        nr_mix = y.size(1) // 3
+
+    # B x T x C
+    y = y.transpose(1, 2)
+
+    if C == 2:
+        logit_probs = None
+    else:
+        logit_probs = y[:, :, :nr_mix]
+
+    if nr_mix > 1:
+        # sample mixture indicator from softmax
+        temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
+        temp = logit_probs.data - torch.log(-torch.log(temp))
+        _, argmax = temp.max(dim=-1)
+
+        # (B, T) -> (B, T, nr_mix)
+        one_hot = to_one_hot(argmax, nr_mix)
+
+        # Select means and log scales
+        means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1)
+        log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1)
+    else:
+        if C == 2:
+            means, log_scales = y[:, :, 0], y[:, :, 1]
+        elif C == 3:
+            means, log_scales = y[:, :, 1], y[:, :, 2]
+        else:
+            assert False, "shouldn't happen"
+
+    scales = torch.exp(log_scales)
+    dist = Normal(loc=means, scale=scales)
+    x = dist.sample()
+
+    x = torch.clamp(x, min=-1.0, max=1.0)
+    return x
diff --git a/utils/dsp.py b/utils/dsp.py
new file mode 100644
index 0000000000000000000000000000000000000000..18f9466f6b12e5539ce221f86030f5114ccdb503
--- /dev/null
+++ b/utils/dsp.py
@@ -0,0 +1,97 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+
+# ZERO = 1e-12
+
+
+def gaussian_normalize_mel_channel(mel, mu, sigma):
+    """
+    Shift to Standorm Normal Distribution
+
+    Args:
+        mel: (n_mels, frame_len)
+        mu: (n_mels,), mean value
+        sigma: (n_mels,), sd value
+    Return:
+        Tensor like mel
+    """
+    mu = np.expand_dims(mu, -1)
+    sigma = np.expand_dims(sigma, -1)
+    return (mel - mu) / sigma
+
+
+def de_gaussian_normalize_mel_channel(mel, mu, sigma):
+    """
+
+    Args:
+        mel: (n_mels, frame_len)
+        mu: (n_mels,), mean value
+        sigma: (n_mels,), sd value
+    Return:
+        Tensor like mel
+    """
+    mu = np.expand_dims(mu, -1)
+    sigma = np.expand_dims(sigma, -1)
+    return sigma * mel + mu
+
+
+def decompress(audio_compressed, bits):
+    mu = 2**bits - 1
+    audio = np.sign(audio_compressed) / mu * ((1 + mu) ** np.abs(audio_compressed) - 1)
+    return audio
+
+
+def compress(audio, bits):
+    mu = 2**bits - 1
+    audio_compressed = np.sign(audio) * np.log(1 + mu * np.abs(audio)) / np.log(mu + 1)
+    return audio_compressed
+
+
+def label_to_audio(quant, bits):
+    classes = 2**bits
+    audio = 2 * quant / (classes - 1.0) - 1.0
+    return audio
+
+
+def audio_to_label(audio, bits):
+    """Normalized audio data tensor to digit array
+
+    Args:
+        audio (tensor): audio data
+        bits (int): data bits
+
+    Returns:
+        array<int>: digit array of audio data
+    """
+    classes = 2**bits
+    # initialize an increasing array with values from -1 to 1
+    bins = np.linspace(-1, 1, classes)
+    # change value in audio tensor to digits
+    quant = np.digitize(audio, bins) - 1
+    return quant
+
+
+def label_to_onehot(x, bits):
+    """Converts a class vector (integers) to binary class matrix.
+    Args:
+        x: class vector to be converted into a matrix
+            (integers from 0 to num_classes).
+        num_classes: total number of classes.
+    Returns:
+        A binary matrix representation of the input. The classes axis
+        is placed last.
+    """
+    classes = 2**bits
+
+    result = torch.zeros((x.shape[0], classes), dtype=torch.float32)
+    for i in range(x.shape[0]):
+        result[i, x[i]] = 1
+
+    output_shape = x.shape + (classes,)
+    output = torch.reshape(result, output_shape)
+    return output
diff --git a/utils/duration.py b/utils/duration.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9544b40b88c68b4e1df33ab1c81a6196a43111e
--- /dev/null
+++ b/utils/duration.py
@@ -0,0 +1,86 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import os
+import tgt
+
+
+def get_alignment(tier, cfg):
+    sample_rate = cfg["sample_rate"]
+    hop_size = cfg["hop_size"]
+
+    sil_phones = ["sil", "sp", "spn"]
+
+    phones = []
+    durations = []
+    start_time = 0
+    end_time = 0
+    end_idx = 0
+
+    for t in tier._objects:
+        s, e, p = t.start_time, t.end_time, t.text
+
+        # Trim leading silences
+        if phones == []:
+            if p in sil_phones:
+                continue
+            else:
+                start_time = s
+
+        if p not in sil_phones:
+            # For ordinary phones
+            phones.append(p)
+            end_time = e
+            end_idx = len(phones)
+        else:
+            # For silent phones
+            phones.append(p)
+
+        durations.append(
+            int(
+                np.round(e * sample_rate / hop_size)
+                - np.round(s * sample_rate / hop_size)
+            )
+        )
+
+    # Trim tailing silences
+    phones = phones[:end_idx]
+    durations = durations[:end_idx]
+
+    return phones, durations, start_time, end_time
+
+
+def get_duration(utt, wav, cfg):
+    speaker = utt["Singer"]
+    basename = utt["Uid"]
+    dataset = utt["Dataset"]
+    sample_rate = cfg["sample_rate"]
+
+    # print(cfg.processed_dir, dataset, speaker, basename)
+    wav_path = os.path.join(
+        cfg.processed_dir, dataset, "raw_data", speaker, "{}.wav".format(basename)
+    )
+    text_path = os.path.join(
+        cfg.processed_dir, dataset, "raw_data", speaker, "{}.lab".format(basename)
+    )
+    tg_path = os.path.join(
+        cfg.processed_dir, dataset, "TextGrid", speaker, "{}.TextGrid".format(basename)
+    )
+
+    # Read raw text
+    with open(text_path, "r") as f:
+        raw_text = f.readline().strip("\n")
+
+    # Get alignments
+    textgrid = tgt.io.read_textgrid(tg_path)
+    phone, duration, start, end = get_alignment(
+        textgrid.get_tier_by_name("phones"), cfg
+    )
+    text = "{" + " ".join(phone) + "}"
+    if start >= end:
+        return None
+
+    return duration, text, int(sample_rate * start), int(sample_rate * end)
diff --git a/utils/f0.py b/utils/f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcd95396a6b08d8ffa1a310c5bee6f0d8b556796
--- /dev/null
+++ b/utils/f0.py
@@ -0,0 +1,299 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import librosa
+import numpy as np
+import torch
+import parselmouth
+import torchcrepe
+import pyworld as pw
+
+
+def get_bin_index(f0, m, M, n_bins, use_log_scale):
+    """
+    WARNING: to abandon!
+
+    Args:
+        raw_f0: tensor whose shpae is (N, frame_len)
+    Returns:
+        index: tensor whose shape is same to f0
+    """
+    raw_f0 = f0.clone()
+    raw_m, raw_M = m, M
+
+    if use_log_scale:
+        f0[torch.where(f0 == 0)] = 1
+        f0 = torch.log(f0)
+        m, M = float(np.log(m)), float(np.log(M))
+
+    # Set normal index in [1, n_bins - 1]
+    width = (M + 1e-7 - m) / (n_bins - 1)
+    index = (f0 - m) // width + 1
+    # Set unvoiced frames as 0, Therefore, the vocabulary is [0, n_bins- 1], whose size is n_bins
+    index[torch.where(f0 == 0)] = 0
+
+    # TODO: Boundary check (special: to judge whether 0 for unvoiced)
+    if torch.any(raw_f0 > raw_M):
+        print("F0 Warning: too high f0: {}".format(raw_f0[torch.where(raw_f0 > raw_M)]))
+        index[torch.where(raw_f0 > raw_M)] = n_bins - 1
+    if torch.any(raw_f0 < raw_m):
+        print("F0 Warning: too low f0: {}".format(raw_f0[torch.where(f0 < m)]))
+        index[torch.where(f0 < m)] = 0
+
+    return torch.as_tensor(index, dtype=torch.long, device=f0.device)
+
+
+def f0_to_coarse(f0, pitch_bin, pitch_min, pitch_max):
+    ## TODO: Figure out the detail of this function
+
+    f0_mel_min = 1127 * np.log(1 + pitch_min / 700)
+    f0_mel_max = 1127 * np.log(1 + pitch_max / 700)
+
+    is_torch = isinstance(f0, torch.Tensor)
+    f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+    f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (pitch_bin - 2) / (
+        f0_mel_max - f0_mel_min
+    ) + 1
+
+    f0_mel[f0_mel <= 1] = 1
+    f0_mel[f0_mel > pitch_bin - 1] = pitch_bin - 1
+    f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int32)
+    assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (
+        f0_coarse.max(),
+        f0_coarse.min(),
+    )
+    return f0_coarse
+
+
+def interpolate(f0):
+    """Interpolate the unvoiced part. Thus the f0 can be passed to a subtractive synthesizer.
+    Args:
+        f0: A numpy array of shape (seq_len,)
+    Returns:
+        f0: Interpolated f0 of shape (seq_len,)
+        uv: Unvoiced part of shape (seq_len,)
+    """
+    uv = f0 == 0
+    if len(f0[~uv]) > 0:
+        # interpolate the unvoiced f0
+        f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+        uv = uv.astype("float")
+        uv = np.min(np.array([uv[:-2], uv[1:-1], uv[2:]]), axis=0)
+        uv = np.pad(uv, (1, 1))
+    return f0, uv
+
+
+def get_log_f0(f0):
+    f0[np.where(f0 == 0)] = 1
+    log_f0 = np.log(f0)
+    return log_f0
+
+
+# ========== Methods ==========
+
+
+def get_f0_features_using_pyin(audio, cfg):
+    """Using pyin to extract the f0 feature.
+    Args:
+        audio
+        fs
+        win_length
+        hop_length
+        f0_min
+        f0_max
+    Returns:
+        f0: numpy array of shape (frame_len,)
+    """
+    f0, voiced_flag, voiced_probs = librosa.pyin(
+        y=audio,
+        fmin=cfg.f0_min,
+        fmax=cfg.f0_max,
+        sr=cfg.sample_rate,
+        win_length=cfg.win_size,
+        hop_length=cfg.hop_size,
+    )
+    # Set nan to 0
+    f0[voiced_flag == False] = 0
+    return f0
+
+
+def get_f0_features_using_parselmouth(audio, cfg, speed=1):
+    """Using parselmouth to extract the f0 feature.
+    Args:
+        audio
+        mel_len
+        hop_length
+        fs
+        f0_min
+        f0_max
+        speed(default=1)
+    Returns:
+        f0: numpy array of shape (frame_len,)
+        pitch_coarse: numpy array of shape (frame_len,)
+    """
+    hop_size = int(np.round(cfg.hop_size * speed))
+
+    # Calculate the time step for pitch extraction
+    time_step = hop_size / cfg.sample_rate * 1000
+
+    f0 = (
+        parselmouth.Sound(audio, cfg.sample_rate)
+        .to_pitch_ac(
+            time_step=time_step / 1000,
+            voicing_threshold=0.6,
+            pitch_floor=cfg.f0_min,
+            pitch_ceiling=cfg.f0_max,
+        )
+        .selected_array["frequency"]
+    )
+
+    # Pad the pitch to the mel_len
+    # pad_size = (int(len(audio) // hop_size) - len(f0) + 1) // 2
+    # f0 = np.pad(f0, [[pad_size, mel_len - len(f0) - pad_size]], mode="constant")
+
+    # Get the coarse part
+    pitch_coarse = f0_to_coarse(f0, cfg.pitch_bin, cfg.f0_min, cfg.f0_max)
+    return f0, pitch_coarse
+
+
+def get_f0_features_using_dio(audio, cfg):
+    """Using dio to extract the f0 feature.
+    Args:
+        audio
+        mel_len
+        fs
+        hop_length
+        f0_min
+        f0_max
+    Returns:
+        f0: numpy array of shape (frame_len,)
+    """
+    # Get the raw f0
+    _f0, t = pw.dio(
+        audio.astype("double"),
+        cfg.sample_rate,
+        f0_floor=cfg.f0_min,
+        f0_ceil=cfg.f0_max,
+        channels_in_octave=2,
+        frame_period=(1000 * cfg.hop_size / cfg.sample_rate),
+    )
+    # Get the f0
+    f0 = pw.stonemask(audio.astype("double"), _f0, t, cfg.sample_rate)
+    return f0
+
+
+def get_f0_features_using_harvest(audio, mel_len, fs, hop_length, f0_min, f0_max):
+    """Using harvest to extract the f0 feature.
+    Args:
+        audio
+        mel_len
+        fs
+        hop_length
+        f0_min
+        f0_max
+    Returns:
+        f0: numpy array of shape (frame_len,)
+    """
+    f0, _ = pw.harvest(
+        audio.astype("double"),
+        fs,
+        f0_floor=f0_min,
+        f0_ceil=f0_max,
+        frame_period=(1000 * hop_length / fs),
+    )
+    f0 = f0.astype("float")[:mel_len]
+    return f0
+
+
+def get_f0_features_using_crepe(
+    audio, mel_len, fs, hop_length, hop_length_new, f0_min, f0_max, threshold=0.3
+):
+    """Using torchcrepe to extract the f0 feature.
+    Args:
+        audio
+        mel_len
+        fs
+        hop_length
+        hop_length_new
+        f0_min
+        f0_max
+        threshold(default=0.3)
+    Returns:
+        f0: numpy array of shape (frame_len,)
+    """
+    # Currently, crepe only supports 16khz audio
+    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    audio_16k = librosa.resample(audio, orig_sr=fs, target_sr=16000)
+    audio_16k_torch = torch.FloatTensor(audio_16k).unsqueeze(0).to(device)
+
+    # Get the raw pitch
+    f0, pd = torchcrepe.predict(
+        audio_16k_torch,
+        16000,
+        hop_length_new,
+        f0_min,
+        f0_max,
+        pad=True,
+        model="full",
+        batch_size=1024,
+        device=device,
+        return_periodicity=True,
+    )
+
+    # Filter, de-silence, set up threshold for unvoiced part
+    pd = torchcrepe.filter.median(pd, 3)
+    pd = torchcrepe.threshold.Silence(-60.0)(pd, audio_16k_torch, 16000, hop_length_new)
+    f0 = torchcrepe.threshold.At(threshold)(f0, pd)
+    f0 = torchcrepe.filter.mean(f0, 3)
+
+    # Convert unvoiced part to 0hz
+    f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)
+
+    # Interpolate f0
+    nzindex = torch.nonzero(f0[0]).squeeze()
+    f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy()
+    time_org = 0.005 * nzindex.cpu().numpy()
+    time_frame = np.arange(mel_len) * hop_length / fs
+    f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
+    return f0
+
+
+def get_f0(audio, cfg):
+    if cfg.pitch_extractor == "dio":
+        f0 = get_f0_features_using_dio(audio, cfg)
+    elif cfg.pitch_extractor == "pyin":
+        f0 = get_f0_features_using_pyin(audio, cfg)
+    elif cfg.pitch_extractor == "parselmouth":
+        f0, _ = get_f0_features_using_parselmouth(audio, cfg)
+    # elif cfg.data.f0_extractor == 'cwt': # todo
+
+    return f0
+
+
+def get_cents(f0_hz):
+    """
+    F_{cent} = 1200 * log2 (F/440)
+
+    Reference:
+        APSIPA'17, Perceptual Evaluation of Singing Quality
+    """
+    voiced_f0 = f0_hz[f0_hz != 0]
+    return 1200 * np.log2(voiced_f0 / 440)
+
+
+def get_pitch_derivatives(f0_hz):
+    """
+    f0_hz: (,T)
+    """
+    f0_cent = get_cents(f0_hz)
+    return f0_cent[1:] - f0_cent[:-1]
+
+
+def get_pitch_sub_median(f0_hz):
+    """
+    f0_hz: (,T)
+    """
+    f0_cent = get_cents(f0_hz)
+    return f0_cent - np.median(f0_cent)
diff --git a/utils/hparam.py b/utils/hparam.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5dd35c6a3158b0aaf8d936dba139a030a48bc62
--- /dev/null
+++ b/utils/hparam.py
@@ -0,0 +1,659 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py  pylint: disable=line-too-long
+"""Hyperparameter values."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import json
+import numbers
+import re
+import six
+
+# Define the regular expression for parsing a single clause of the input
+# (delimited by commas).  A legal clause looks like:
+#   <variable name>[<index>]? = <rhs>
+# where <rhs> is either a single token or [] enclosed list of tokens.
+# For example:  "var[1] = a" or "x = [1,2,3]"
+PARAM_RE = re.compile(
+    r"""
+  (?P<name>[a-zA-Z][\w\.]*)      # variable name: "var" or "x"
+  (\[\s*(?P<index>\d+)\s*\])?  # (optional) index: "1" or None
+  \s*=\s*
+  ((?P<val>[^,\[]*)            # single value: "a" or None
+   |
+   \[(?P<vals>[^\]]*)\])       # list of values: None or "1,2,3"
+  ($|,\s*)""",
+    re.VERBOSE,
+)
+
+
+def _parse_fail(name, var_type, value, values):
+    """Helper function for raising a value error for bad assignment."""
+    raise ValueError(
+        "Could not parse hparam '%s' of type '%s' with value '%s' in %s"
+        % (name, var_type.__name__, value, values)
+    )
+
+
+def _reuse_fail(name, values):
+    """Helper function for raising a value error for reuse of name."""
+    raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values))
+
+
+def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
+    """Update results_dictionary with a scalar value.
+
+    Used to update the results_dictionary to be returned by parse_values when
+    encountering a clause with a scalar RHS (e.g.  "s=5" or "arr[0]=5".)
+
+    Mutates results_dictionary.
+
+    Args:
+      name: Name of variable in assignment ("s" or "arr").
+      parse_fn: Function for parsing the actual value.
+      var_type: Type of named variable.
+      m_dict: Dictionary constructed from regex parsing.
+        m_dict['val']: RHS value (scalar)
+        m_dict['index']: List index value (or None)
+      values: Full expression being parsed
+      results_dictionary: The dictionary being updated for return by the parsing
+        function.
+
+    Raises:
+      ValueError: If the name has already been used.
+    """
+    try:
+        parsed_value = parse_fn(m_dict["val"])
+    except ValueError:
+        _parse_fail(name, var_type, m_dict["val"], values)
+
+    # If no index is provided
+    if not m_dict["index"]:
+        if name in results_dictionary:
+            _reuse_fail(name, values)
+        results_dictionary[name] = parsed_value
+    else:
+        if name in results_dictionary:
+            # The name has already been used as a scalar, then it
+            # will be in this dictionary and map to a non-dictionary.
+            if not isinstance(results_dictionary.get(name), dict):
+                _reuse_fail(name, values)
+        else:
+            results_dictionary[name] = {}
+
+        index = int(m_dict["index"])
+        # Make sure the index position hasn't already been assigned a value.
+        if index in results_dictionary[name]:
+            _reuse_fail("{}[{}]".format(name, index), values)
+        results_dictionary[name][index] = parsed_value
+
+
+def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary):
+    """Update results_dictionary from a list of values.
+
+    Used to update results_dictionary to be returned by parse_values when
+    encountering a clause with a list RHS (e.g.  "arr=[1,2,3]".)
+
+    Mutates results_dictionary.
+
+    Args:
+      name: Name of variable in assignment ("arr").
+      parse_fn: Function for parsing individual values.
+      var_type: Type of named variable.
+      m_dict: Dictionary constructed from regex parsing.
+        m_dict['val']: RHS value (scalar)
+      values: Full expression being parsed
+      results_dictionary: The dictionary being updated for return by the parsing
+        function.
+
+    Raises:
+      ValueError: If the name has an index or the values cannot be parsed.
+    """
+    if m_dict["index"] is not None:
+        raise ValueError("Assignment of a list to a list index.")
+    elements = filter(None, re.split("[ ,]", m_dict["vals"]))
+    # Make sure the name hasn't already been assigned a value
+    if name in results_dictionary:
+        raise _reuse_fail(name, values)
+    try:
+        results_dictionary[name] = [parse_fn(e) for e in elements]
+    except ValueError:
+        _parse_fail(name, var_type, m_dict["vals"], values)
+
+
+def _cast_to_type_if_compatible(name, param_type, value):
+    """Cast hparam to the provided type, if compatible.
+
+    Args:
+      name: Name of the hparam to be cast.
+      param_type: The type of the hparam.
+      value: The value to be cast, if compatible.
+
+    Returns:
+      The result of casting `value` to `param_type`.
+
+    Raises:
+      ValueError: If the type of `value` is not compatible with param_type.
+        * If `param_type` is a string type, but `value` is not.
+        * If `param_type` is a boolean, but `value` is not, or vice versa.
+        * If `param_type` is an integer type, but `value` is not.
+        * If `param_type` is a float type, but `value` is not a numeric type.
+    """
+    fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % (
+        name,
+        param_type,
+        value,
+    )
+
+    # Some callers use None, for which we can't do any casting/checking. :(
+    if issubclass(param_type, type(None)):
+        return value
+
+    # Avoid converting a non-string type to a string.
+    if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance(
+        value, (six.string_types, six.binary_type)
+    ):
+        raise ValueError(fail_msg)
+
+    # Avoid converting a number or string type to a boolean or vice versa.
+    if issubclass(param_type, bool) != isinstance(value, bool):
+        raise ValueError(fail_msg)
+
+    # Avoid converting float to an integer (the reverse is fine).
+    if issubclass(param_type, numbers.Integral) and not isinstance(
+        value, numbers.Integral
+    ):
+        raise ValueError(fail_msg)
+
+    # Avoid converting a non-numeric type to a numeric type.
+    if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number):
+        raise ValueError(fail_msg)
+
+    return param_type(value)
+
+
+def parse_values(values, type_map, ignore_unknown=False):
+    """Parses hyperparameter values from a string into a python map.
+
+    `values` is a string containing comma-separated `name=value` pairs.
+    For each pair, the value of the hyperparameter named `name` is set to
+    `value`.
+
+    If a hyperparameter name appears multiple times in `values`, a ValueError
+    is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
+
+    If a hyperparameter name in both an index assignment and scalar assignment,
+    a ValueError is raised.  (e.g. 'a=[1,2,3],a[0] = 1').
+
+    The hyperparameter name may contain '.' symbols, which will result in an
+    attribute name that is only accessible through the getattr and setattr
+    functions.  (And must be first explicit added through add_hparam.)
+
+    WARNING: Use of '.' in your variable names is allowed, but is not well
+    supported and not recommended.
+
+    The `value` in `name=value` must follows the syntax according to the
+    type of the parameter:
+
+    *  Scalar integer: A Python-parsable integer point value.  E.g.: 1,
+       100, -12.
+    *  Scalar float: A Python-parsable floating point value.  E.g.: 1.0,
+       -.54e89.
+    *  Boolean: Either true or false.
+    *  Scalar string: A non-empty sequence of characters, excluding comma,
+       spaces, and square brackets.  E.g.: foo, bar_1.
+    *  List: A comma separated list of scalar values of the parameter type
+       enclosed in square brackets.  E.g.: [1,2,3], [1.0,1e-12], [high,low].
+
+    When index assignment is used, the corresponding type_map key should be the
+    list name.  E.g. for "arr[1]=0" the type_map must have the key "arr" (not
+    "arr[1]").
+
+    Args:
+      values: String.  Comma separated list of `name=value` pairs where
+        'value' must follow the syntax described above.
+      type_map: A dictionary mapping hyperparameter names to types.  Note every
+        parameter name in values must be a key in type_map.  The values must
+        conform to the types indicated, where a value V is said to conform to a
+        type T if either V has type T, or V is a list of elements of type T.
+        Hence, for a multidimensional parameter 'x' taking float values,
+        'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
+      ignore_unknown: Bool. Whether values that are missing a type in type_map
+        should be ignored. If set to True, a ValueError will not be raised for
+        unknown hyperparameter type.
+
+    Returns:
+      A python map mapping each name to either:
+      * A scalar value.
+      * A list of scalar values.
+      * A dictionary mapping index numbers to scalar values.
+      (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
+
+    Raises:
+      ValueError: If there is a problem with input.
+      * If `values` cannot be parsed.
+      * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
+      * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
+        'a[1]=1,a[1]=2', or 'a=1,a=[1]')
+    """
+    results_dictionary = {}
+    pos = 0
+    while pos < len(values):
+        m = PARAM_RE.match(values, pos)
+        if not m:
+            raise ValueError("Malformed hyperparameter value: %s" % values[pos:])
+        # Check that there is a comma between parameters and move past it.
+        pos = m.end()
+        # Parse the values.
+        m_dict = m.groupdict()
+        name = m_dict["name"]
+        if name not in type_map:
+            if ignore_unknown:
+                continue
+            raise ValueError("Unknown hyperparameter type for %s" % name)
+        type_ = type_map[name]
+
+        # Set up correct parsing function (depending on whether type_ is a bool)
+        if type_ == bool:
+
+            def parse_bool(value):
+                if value in ["true", "True"]:
+                    return True
+                elif value in ["false", "False"]:
+                    return False
+                else:
+                    try:
+                        return bool(int(value))
+                    except ValueError:
+                        _parse_fail(name, type_, value, values)
+
+            parse = parse_bool
+        else:
+            parse = type_
+
+        # If a singe value is provided
+        if m_dict["val"] is not None:
+            _process_scalar_value(
+                name, parse, type_, m_dict, values, results_dictionary
+            )
+
+        # If the assigned value is a list:
+        elif m_dict["vals"] is not None:
+            _process_list_value(name, parse, type_, m_dict, values, results_dictionary)
+
+        else:  # Not assigned a list or value
+            _parse_fail(name, type_, "", values)
+
+    return results_dictionary
+
+
+class HParams(object):
+    """Class to hold a set of hyperparameters as name-value pairs.
+
+    A `HParams` object holds hyperparameters used to build and train a model,
+    such as the number of hidden units in a neural net layer or the learning rate
+    to use when training.
+
+    You first create a `HParams` object by specifying the names and values of the
+    hyperparameters.
+
+    To make them easily accessible the parameter names are added as direct
+    attributes of the class.  A typical usage is as follows:
+
+    ```python
+    # Create a HParams object specifying names and values of the model
+    # hyperparameters:
+    hparams = HParams(learning_rate=0.1, num_hidden_units=100)
+
+    # The hyperparameter are available as attributes of the HParams object:
+    hparams.learning_rate ==> 0.1
+    hparams.num_hidden_units ==> 100
+    ```
+
+    Hyperparameters have type, which is inferred from the type of their value
+    passed at construction type.   The currently supported types are: integer,
+    float, boolean, string, and list of integer, float, boolean, or string.
+
+    You can override hyperparameter values by calling the
+    [`parse()`](#HParams.parse) method, passing a string of comma separated
+    `name=value` pairs.  This is intended to make it possible to override
+    any hyperparameter values from a single command-line flag to which
+    the user passes 'hyper-param=value' pairs.  It avoids having to define
+    one flag for each hyperparameter.
+
+    The syntax expected for each value depends on the type of the parameter.
+    See `parse()` for a description of the syntax.
+
+    Example:
+
+    ```python
+    # Define a command line flag to pass name=value pairs.
+    # For example using argparse:
+    import argparse
+    parser = argparse.ArgumentParser(description='Train my model.')
+    parser.add_argument('--hparams', type=str,
+                        help='Comma separated list of "name=value" pairs.')
+    args = parser.parse_args()
+    ...
+    def my_program():
+      # Create a HParams object specifying the names and values of the
+      # model hyperparameters:
+      hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
+                           activations=['relu', 'tanh'])
+
+      # Override hyperparameters values by parsing the command line
+      hparams.parse(args.hparams)
+
+      # If the user passed `--hparams=learning_rate=0.3` on the command line
+      # then 'hparams' has the following attributes:
+      hparams.learning_rate ==> 0.3
+      hparams.num_hidden_units ==> 100
+      hparams.activations ==> ['relu', 'tanh']
+
+      # If the hyperparameters are in json format use parse_json:
+      hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
+    ```
+    """
+
+    _HAS_DYNAMIC_ATTRIBUTES = True  # Required for pytype checks.
+
+    def __init__(self, model_structure=None, **kwargs):
+        """Create an instance of `HParams` from keyword arguments.
+
+        The keyword arguments specify name-values pairs for the hyperparameters.
+        The parameter types are inferred from the type of the values passed.
+
+        The parameter names are added as attributes of `HParams` object, so they
+        can be accessed directly with the dot notation `hparams._name_`.
+
+        Example:
+
+        ```python
+        # Define 3 hyperparameters: 'learning_rate' is a float parameter,
+        # 'num_hidden_units' an integer parameter, and 'activation' a string
+        # parameter.
+        hparams = tf.HParams(
+            learning_rate=0.1, num_hidden_units=100, activation='relu')
+
+        hparams.activation ==> 'relu'
+        ```
+
+        Note that a few names are reserved and cannot be used as hyperparameter
+        names.  If you use one of the reserved name the constructor raises a
+        `ValueError`.
+
+        Args:
+          model_structure: An instance of ModelStructure, defining the feature
+            crosses to be used in the Trial.
+          **kwargs: Key-value pairs where the key is the hyperparameter name and
+            the value is the value for the parameter.
+
+        Raises:
+          ValueError: If both `hparam_def` and initialization values are provided,
+            or if one of the arguments is invalid.
+
+        """
+        # Register the hyperparameters and their type in _hparam_types.
+        # This simplifies the implementation of parse().
+        # _hparam_types maps the parameter name to a tuple (type, bool).
+        # The type value is the type of the parameter for scalar hyperparameters,
+        # or the type of the list elements for multidimensional hyperparameters.
+        # The bool value is True if the value is a list, False otherwise.
+        self._hparam_types = {}
+        self._model_structure = model_structure
+        for name, value in six.iteritems(kwargs):
+            self.add_hparam(name, value)
+
+    def add_hparam(self, name, value):
+        """Adds {name, value} pair to hyperparameters.
+
+        Args:
+          name: Name of the hyperparameter.
+          value: Value of the hyperparameter. Can be one of the following types:
+            int, float, string, int list, float list, or string list.
+
+        Raises:
+          ValueError: if one of the arguments is invalid.
+        """
+        # Keys in kwargs are unique, but 'name' could the name of a pre-existing
+        # attribute of this object.  In that case we refuse to use it as a
+        # hyperparameter name.
+        if getattr(self, name, None) is not None:
+            raise ValueError("Hyperparameter name is reserved: %s" % name)
+        if isinstance(value, (list, tuple)):
+            if not value:
+                raise ValueError(
+                    "Multi-valued hyperparameters cannot be empty: %s" % name
+                )
+            self._hparam_types[name] = (type(value[0]), True)
+        else:
+            self._hparam_types[name] = (type(value), False)
+        setattr(self, name, value)
+
+    def set_hparam(self, name, value):
+        """Set the value of an existing hyperparameter.
+
+        This function verifies that the type of the value matches the type of the
+        existing hyperparameter.
+
+        Args:
+          name: Name of the hyperparameter.
+          value: New value of the hyperparameter.
+
+        Raises:
+          KeyError: If the hyperparameter doesn't exist.
+          ValueError: If there is a type mismatch.
+        """
+        param_type, is_list = self._hparam_types[name]
+        if isinstance(value, list):
+            if not is_list:
+                raise ValueError(
+                    "Must not pass a list for single-valued parameter: %s" % name
+                )
+            setattr(
+                self,
+                name,
+                [_cast_to_type_if_compatible(name, param_type, v) for v in value],
+            )
+        else:
+            if is_list:
+                raise ValueError(
+                    "Must pass a list for multi-valued parameter: %s." % name
+                )
+            setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
+
+    def del_hparam(self, name):
+        """Removes the hyperparameter with key 'name'.
+
+        Does nothing if it isn't present.
+
+        Args:
+          name: Name of the hyperparameter.
+        """
+        if hasattr(self, name):
+            delattr(self, name)
+            del self._hparam_types[name]
+
+    def parse(self, values):
+        """Override existing hyperparameter values, parsing new values from a string.
+
+        See parse_values for more detail on the allowed format for values.
+
+        Args:
+          values: String.  Comma separated list of `name=value` pairs where 'value'
+            must follow the syntax described above.
+
+        Returns:
+          The `HParams` instance.
+
+        Raises:
+          ValueError: If `values` cannot be parsed or a hyperparameter in `values`
+          doesn't exist.
+        """
+        type_map = {}
+        for name, t in self._hparam_types.items():
+            param_type, _ = t
+            type_map[name] = param_type
+
+        values_map = parse_values(values, type_map)
+        return self.override_from_dict(values_map)
+
+    def override_from_dict(self, values_dict):
+        """Override existing hyperparameter values, parsing new values from a dictionary.
+
+        Args:
+          values_dict: Dictionary of name:value pairs.
+
+        Returns:
+          The `HParams` instance.
+
+        Raises:
+          KeyError: If a hyperparameter in `values_dict` doesn't exist.
+          ValueError: If `values_dict` cannot be parsed.
+        """
+        for name, value in values_dict.items():
+            self.set_hparam(name, value)
+        return self
+
+    def set_model_structure(self, model_structure):
+        self._model_structure = model_structure
+
+    def get_model_structure(self):
+        return self._model_structure
+
+    def to_json(self, indent=None, separators=None, sort_keys=False):
+        """Serializes the hyperparameters into JSON.
+
+        Args:
+          indent: If a non-negative integer, JSON array elements and object members
+            will be pretty-printed with that indent level. An indent level of 0, or
+            negative, will only insert newlines. `None` (the default) selects the
+            most compact representation.
+          separators: Optional `(item_separator, key_separator)` tuple. Default is
+            `(', ', ': ')`.
+          sort_keys: If `True`, the output dictionaries will be sorted by key.
+
+        Returns:
+          A JSON string.
+        """
+
+        def remove_callables(x):
+            """Omit callable elements from input with arbitrary nesting."""
+            if isinstance(x, dict):
+                return {
+                    k: remove_callables(v)
+                    for k, v in six.iteritems(x)
+                    if not callable(v)
+                }
+            elif isinstance(x, list):
+                return [remove_callables(i) for i in x if not callable(i)]
+            return x
+
+        return json.dumps(
+            remove_callables(self.values()),
+            indent=indent,
+            separators=separators,
+            sort_keys=sort_keys,
+        )
+
+    def parse_json(self, values_json):
+        """Override existing hyperparameter values, parsing new values from a json object.
+
+        Args:
+          values_json: String containing a json object of name:value pairs.
+
+        Returns:
+          The `HParams` instance.
+
+        Raises:
+          KeyError: If a hyperparameter in `values_json` doesn't exist.
+          ValueError: If `values_json` cannot be parsed.
+        """
+        values_map = json.loads(values_json)
+        return self.override_from_dict(values_map)
+
+    def values(self):
+        """Return the hyperparameter values as a Python dictionary.
+
+        Returns:
+          A dictionary with hyperparameter names as keys.  The values are the
+          hyperparameter values.
+        """
+        return {n: getattr(self, n) for n in self._hparam_types.keys()}
+
+    def get(self, key, default=None):
+        """Returns the value of `key` if it exists, else `default`."""
+        if key in self._hparam_types:
+            # Ensure that default is compatible with the parameter type.
+            if default is not None:
+                param_type, is_param_list = self._hparam_types[key]
+                type_str = "list<%s>" % param_type if is_param_list else str(param_type)
+                fail_msg = (
+                    "Hparam '%s' of type '%s' is incompatible with "
+                    "default=%s" % (key, type_str, default)
+                )
+
+                is_default_list = isinstance(default, list)
+                if is_param_list != is_default_list:
+                    raise ValueError(fail_msg)
+
+                try:
+                    if is_default_list:
+                        for value in default:
+                            _cast_to_type_if_compatible(key, param_type, value)
+                    else:
+                        _cast_to_type_if_compatible(key, param_type, default)
+                except ValueError as e:
+                    raise ValueError("%s. %s" % (fail_msg, e))
+
+            return getattr(self, key)
+
+        return default
+
+    def __contains__(self, key):
+        return key in self._hparam_types
+
+    def __str__(self):
+        return str(sorted(self.values().items()))
+
+    def __repr__(self):
+        return "%s(%s)" % (type(self).__name__, self.__str__())
+
+    @staticmethod
+    def _get_kind_name(param_type, is_list):
+        """Returns the field name given parameter type and is_list.
+
+        Args:
+          param_type: Data type of the hparam.
+          is_list: Whether this is a list.
+
+        Returns:
+          A string representation of the field name.
+
+        Raises:
+          ValueError: If parameter type is not recognized.
+        """
+        if issubclass(param_type, bool):
+            # This check must happen before issubclass(param_type, six.integer_types),
+            # since Python considers bool to be a subclass of int.
+            typename = "bool"
+        elif issubclass(param_type, six.integer_types):
+            # Setting 'int' and 'long' types to be 'int64' to ensure the type is
+            # compatible with both Python2 and Python3.
+            typename = "int64"
+        elif issubclass(param_type, (six.string_types, six.binary_type)):
+            # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
+            # compatible with both Python2 and Python3.
+            typename = "bytes"
+        elif issubclass(param_type, float):
+            typename = "float"
+        else:
+            raise ValueError("Unsupported parameter type: %s" % str(param_type))
+
+        suffix = "list" if is_list else "value"
+        return "_".join([typename, suffix])
diff --git a/utils/hubert.py b/utils/hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..84b509fb9fde8485cfb504a675e5d3b7d27622ff
--- /dev/null
+++ b/utils/hubert.py
@@ -0,0 +1,155 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.0/preprocess_hubert_f0.py
+
+import os
+import librosa
+import torch
+import numpy as np
+from fairseq import checkpoint_utils
+from tqdm import tqdm
+import torch
+
+
+def load_hubert_model(hps):
+    # Load model
+    ckpt_path = hps.hubert_file
+    print("Load Hubert Model...")
+
+    models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+        [ckpt_path],
+        suffix="",
+    )
+    model = models[0]
+    model.eval()
+
+    if torch.cuda.is_available():
+        model = model.cuda()
+
+    return model
+
+
+def get_hubert_content(hmodel, wav_16k_tensor):
+    feats = wav_16k_tensor
+    if feats.dim() == 2:  # double channels
+        feats = feats.mean(-1)
+    assert feats.dim() == 1, feats.dim()
+    feats = feats.view(1, -1)
+    padding_mask = torch.BoolTensor(feats.shape).fill_(False)
+    inputs = {
+        "source": feats.to(wav_16k_tensor.device),
+        "padding_mask": padding_mask.to(wav_16k_tensor.device),
+        "output_layer": 9,  # layer 9
+    }
+    with torch.no_grad():
+        logits = hmodel.extract_features(**inputs)
+        feats = hmodel.final_proj(logits[0]).squeeze(0)
+
+    return feats
+
+
+def content_vector_encoder(model, audio_path, default_sampling_rate=16000):
+    """
+    # content vector default sr: 16000
+    """
+
+    wav16k, sr = librosa.load(audio_path, sr=default_sampling_rate)
+    device = next(model.parameters()).device
+    wav16k = torch.from_numpy(wav16k).to(device)
+
+    # (1, 256, frame_len)
+    content_feature = get_hubert_content(model, wav_16k_tensor=wav16k)
+
+    return content_feature.cpu().detach().numpy()
+
+
+def repeat_expand_2d(content, target_len):
+    """
+    content : [hubert_dim(256), src_len]
+    target: [hubert_dim(256), target_len]
+    """
+    src_len = content.shape[-1]
+    target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to(
+        content.device
+    )
+    temp = torch.arange(src_len + 1) * target_len / src_len
+    current_pos = 0
+    for i in range(target_len):
+        if i < temp[current_pos + 1]:
+            target[:, i] = content[:, current_pos]
+        else:
+            current_pos += 1
+            target[:, i] = content[:, current_pos]
+
+    return target
+
+
+def get_mapped_features(raw_content_features, mapping_features):
+    """
+    Content Vector: frameshift = 20ms, hop_size = 480 in 24k
+
+    Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms)
+    """
+    source_hop = 480
+    target_hop = 256
+
+    factor = np.gcd(source_hop, target_hop)
+    source_hop //= factor
+    target_hop //= factor
+    print(
+        "Mapping source's {} frames => target's {} frames".format(
+            target_hop, source_hop
+        )
+    )
+
+    results = []
+    for index, mapping_feat in enumerate(tqdm(mapping_features)):
+        # mappping_feat: (mels_frame_len, n_mels)
+        target_len = len(mapping_feat)
+
+        # (source_len, 256)
+        raw_feats = raw_content_features[index][0].cpu().numpy().T
+        source_len, width = raw_feats.shape
+
+        # const ~= target_len * target_hop
+        const = source_len * source_hop // target_hop * target_hop
+
+        # (source_len * source_hop, dim)
+        up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0)
+        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+        down_sampling_feats = np.average(
+            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+        )
+
+        err = abs(target_len - len(down_sampling_feats))
+        if err > 3:
+            print("index:", index)
+            print("mels:", mapping_feat.shape)
+            print("raw content vector:", raw_feats.shape)
+            print("up_sampling:", up_sampling_feats.shape)
+            print("down_sampling_feats:", down_sampling_feats.shape)
+            exit()
+        if len(down_sampling_feats) < target_len:
+            # (1, dim) -> (err, dim)
+            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
+            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
+
+        # (target_len, dim)
+        feats = down_sampling_feats[:target_len]
+        results.append(feats)
+
+    return results
+
+
+def extract_hubert_features_of_dataset(datasets, model, out_dir):
+    for utt in tqdm(datasets):
+        uid = utt["Uid"]
+        audio_path = utt["Path"]
+
+        content_vector_feature = content_vector_encoder(model, audio_path)  # (T, 256)
+
+        save_path = os.path.join(out_dir, uid + ".npy")
+        np.save(save_path, content_vector_feature)
diff --git a/utils/io.py b/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6e31d2fc6ebf3ca888f58dbce96bb3d6c7c2905
--- /dev/null
+++ b/utils/io.py
@@ -0,0 +1,153 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import torch
+import torchaudio
+
+
+def save_feature(process_dir, feature_dir, item, feature, overrides=True):
+    """Save features to path
+
+    Args:
+        process_dir (str): directory to store features
+        feature_dir (_type_): directory to store one type of features (mel, energy, ...)
+        item (str): uid
+        feature (tensor): feature tensor
+        overrides (bool, optional): whether to override existing files. Defaults to True.
+    """
+    process_dir = os.path.join(process_dir, feature_dir)
+    os.makedirs(process_dir, exist_ok=True)
+    out_path = os.path.join(process_dir, item + ".npy")
+
+    if os.path.exists(out_path):
+        if overrides:
+            np.save(out_path, feature)
+    else:
+        np.save(out_path, feature)
+
+
+def save_txt(process_dir, feature_dir, item, feature, overrides=True):
+    process_dir = os.path.join(process_dir, feature_dir)
+    os.makedirs(process_dir, exist_ok=True)
+    out_path = os.path.join(process_dir, item + ".txt")
+
+    if os.path.exists(out_path):
+        if overrides:
+            f = open(out_path, "w")
+            f.writelines(feature)
+            f.close()
+    else:
+        f = open(out_path, "w")
+        f.writelines(feature)
+        f.close()
+
+
+def save_audio(path, waveform, fs, add_silence=False, turn_up=False, volume_peak=0.9):
+    if turn_up:
+        # continue to turn up to volume_peak
+        ratio = volume_peak / max(waveform.max(), abs(waveform.min()))
+        waveform = waveform * ratio
+
+    if add_silence:
+        silence_len = fs // 20
+        silence = np.zeros((silence_len,), dtype=waveform.dtype)
+        result = np.concatenate([silence, waveform, silence])
+        waveform = result
+
+    waveform = torch.as_tensor(waveform, dtype=torch.float32, device="cpu")
+    if len(waveform.size()) == 1:
+        waveform = waveform[None, :]
+    elif waveform.size(0) != 1:
+        # Stereo to mono
+        waveform = torch.mean(waveform, dim=0, keepdim=True)
+    torchaudio.save(path, waveform, fs, encoding="PCM_S", bits_per_sample=16)
+
+
+async def async_load_audio(path, sample_rate: int = 24000):
+    r"""
+    Args:
+        path: The source loading path.
+        sample_rate: The target sample rate, will automatically resample if necessary.
+
+    Returns:
+        waveform: The waveform object. Should be [1 x sequence_len].
+    """
+
+    async def use_torchaudio_load(path):
+        return torchaudio.load(path)
+
+    waveform, sr = await use_torchaudio_load(path)
+    waveform = torch.mean(waveform, dim=0, keepdim=True)
+
+    if sr != sample_rate:
+        waveform = torchaudio.functional.resample(waveform, sr, sample_rate)
+
+    if torch.any(torch.isnan(waveform) or torch.isinf(waveform)):
+        raise ValueError("NaN or Inf found in waveform.")
+    return waveform
+
+
+async def async_save_audio(
+    path,
+    waveform,
+    sample_rate: int = 24000,
+    add_silence: bool = False,
+    volume_peak: float = 0.9,
+):
+    r"""
+    Args:
+        path: The target saving path.
+        waveform: The waveform object. Should be [n_channel x sequence_len].
+        sample_rate: Sample rate.
+        add_silence: If ``true``, concat 0.05s silence to beginning and end.
+        volume_peak: Turn up volume for larger number, vice versa.
+    """
+
+    async def use_torchaudio_save(path, waveform, sample_rate):
+        torchaudio.save(
+            path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16
+        )
+
+    waveform = torch.as_tensor(waveform, device="cpu", dtype=torch.float32)
+    shape = waveform.size()[:-1]
+
+    ratio = abs(volume_peak) / max(waveform.max(), abs(waveform.min()))
+    waveform = waveform * ratio
+
+    if add_silence:
+        silence_len = sample_rate // 20
+        silence = torch.zeros((*shape, silence_len), dtype=waveform.type())
+        waveform = torch.concatenate((silence, waveform, silence), dim=-1)
+
+    if waveform.dim() == 1:
+        waveform = waveform[None]
+
+    await use_torchaudio_save(path, waveform, sample_rate)
+
+
+def load_mel_extrema(cfg, dataset_name, split):
+    dataset_dir = os.path.join(
+        cfg.OUTPUT_PATH,
+        "preprocess/{}_version".format(cfg.data.process_version),
+        dataset_name,
+    )
+
+    min_file = os.path.join(
+        dataset_dir,
+        "mel_min_max",
+        split.split("_")[-1],
+        "mel_min.npy",
+    )
+    max_file = os.path.join(
+        dataset_dir,
+        "mel_min_max",
+        split.split("_")[-1],
+        "mel_max.npy",
+    )
+    mel_min = np.load(min_file)
+    mel_max = np.load(max_file)
+    return mel_min, mel_max
diff --git a/utils/io_optim.py b/utils/io_optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..942619d625d4c8e2d00ae1255421ceaf6ab39986
--- /dev/null
+++ b/utils/io_optim.py
@@ -0,0 +1,123 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torchaudio
+import json
+import os
+import numpy as np
+import librosa
+from torch.nn.utils.rnn import pad_sequence
+from modules import whisper_extractor as whisper
+
+
+class TorchaudioDataset(torch.utils.data.Dataset):
+    def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
+        """
+        Args:
+            cfg: config
+            dataset: dataset name
+
+        """
+        assert isinstance(dataset, str)
+
+        self.sr = sr
+        self.cfg = cfg
+
+        if metadata is None:
+            self.train_metadata_path = os.path.join(
+                cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file
+            )
+            self.valid_metadata_path = os.path.join(
+                cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file
+            )
+            self.metadata = self.get_metadata()
+        else:
+            self.metadata = metadata
+
+        if accelerator is not None:
+            self.device = accelerator.device
+        elif torch.cuda.is_available():
+            self.device = torch.device("cuda")
+        else:
+            self.device = torch.device("cpu")
+
+    def get_metadata(self):
+        metadata = []
+        with open(self.train_metadata_path, "r", encoding="utf-8") as t:
+            metadata.extend(json.load(t))
+        with open(self.valid_metadata_path, "r", encoding="utf-8") as v:
+            metadata.extend(json.load(v))
+        return metadata
+
+    def __len__(self):
+        return len(self.metadata)
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+        wav_path = utt_info["Path"]
+
+        wav, sr = torchaudio.load(wav_path)
+
+        # resample
+        if sr != self.sr:
+            wav = torchaudio.functional.resample(wav, sr, self.sr)
+        # downmixing
+        if wav.shape[0] > 1:
+            wav = torch.mean(wav, dim=0, keepdim=True)
+        assert wav.shape[0] == 1
+        wav = wav.squeeze(0)
+        # record the length of wav without padding
+        length = wav.shape[0]
+        # wav: (T)
+        return utt_info, wav, length
+
+
+class LibrosaDataset(TorchaudioDataset):
+    def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
+        super().__init__(cfg, dataset, sr, accelerator, metadata)
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+        wav_path = utt_info["Path"]
+
+        wav, _ = librosa.load(wav_path, sr=self.sr)
+        # wav: (T)
+        wav = torch.from_numpy(wav)
+
+        # record the length of wav without padding
+        length = wav.shape[0]
+        return utt_info, wav, length
+
+
+class FFmpegDataset(TorchaudioDataset):
+    def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None):
+        super().__init__(cfg, dataset, sr, accelerator, metadata)
+
+    def __getitem__(self, index):
+        utt_info = self.metadata[index]
+        wav_path = utt_info["Path"]
+
+        # wav: (T,)
+        wav = whisper.load_audio(wav_path)  # sr = 16000
+        # convert to torch tensor
+        wav = torch.from_numpy(wav)
+        # record the length of wav without padding
+        length = wav.shape[0]
+
+        return utt_info, wav, length
+
+
+def collate_batch(batch_list):
+    """
+    Args:
+        batch_list: list of (metadata, wav, length)
+    """
+    metadata = [item[0] for item in batch_list]
+    # wavs: (B, T)
+    wavs = pad_sequence([item[1] for item in batch_list], batch_first=True)
+    lens = [item[2] for item in batch_list]
+
+    return metadata, wavs, lens
diff --git a/utils/mel.py b/utils/mel.py
new file mode 100644
index 0000000000000000000000000000000000000000..d32d38226dfe6c0162527b3136a392b9168c7f06
--- /dev/null
+++ b/utils/mel.py
@@ -0,0 +1,283 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+from librosa.filters import mel as librosa_mel_fn
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def spectral_normalize_torch(magnitudes):
+    output = dynamic_range_compression_torch(magnitudes)
+    return output
+
+
+def extract_linear_features(y, cfg, center=False):
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global hann_window
+    hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(
+        y,
+        cfg.n_fft,
+        hop_length=cfg.hop_size,
+        win_length=cfg.win_size,
+        window=hann_window[str(y.device)],
+        center=center,
+        pad_mode="reflect",
+        normalized=False,
+        onesided=True,
+        return_complex=True,
+    )
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+    spec = torch.squeeze(spec, 0)
+    return spec
+
+
+def mel_spectrogram_torch(y, cfg, center=False):
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global mel_basis, hann_window
+    if cfg.fmax not in mel_basis:
+        mel = librosa_mel_fn(
+            sr=cfg.sample_rate,
+            n_fft=cfg.n_fft,
+            n_mels=cfg.n_mel,
+            fmin=cfg.fmin,
+            fmax=cfg.fmax,
+        )
+        mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
+            torch.from_numpy(mel).float().to(y.device)
+        )
+        hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+
+    spec = torch.stft(
+        y,
+        cfg.n_fft,
+        hop_length=cfg.hop_size,
+        win_length=cfg.win_size,
+        window=hann_window[str(y.device)],
+        center=center,
+        pad_mode="reflect",
+        normalized=False,
+        onesided=True,
+        return_complex=True,
+    )
+
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def extract_mel_features(
+    y,
+    cfg,
+    center=False
+    # n_fft, n_mel, sampling_rate, hop_size, win_size, fmin, fmax, center=False
+):
+    """Extract mel features
+
+    Args:
+        y (tensor): audio data in tensor
+        cfg (dict): configuration in cfg.preprocess
+        center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
+
+    Returns:
+        tensor: a tensor containing the mel feature calculated based on STFT result
+    """
+    if torch.min(y) < -1.0:
+        print("min value is ", torch.min(y))
+    if torch.max(y) > 1.0:
+        print("max value is ", torch.max(y))
+
+    global mel_basis, hann_window
+    if cfg.fmax not in mel_basis:
+        mel = librosa_mel_fn(
+            sr=cfg.sample_rate,
+            n_fft=cfg.n_fft,
+            n_mels=cfg.n_mel,
+            fmin=cfg.fmin,
+            fmax=cfg.fmax,
+        )
+        mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
+            torch.from_numpy(mel).float().to(y.device)
+        )
+        hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+
+    # complex tensor as default, then use view_as_real for future pytorch compatibility
+    spec = torch.stft(
+        y,
+        cfg.n_fft,
+        hop_length=cfg.hop_size,
+        win_length=cfg.win_size,
+        window=hann_window[str(y.device)],
+        center=center,
+        pad_mode="reflect",
+        normalized=False,
+        onesided=True,
+        return_complex=True,
+    )
+    spec = torch.view_as_real(spec)
+    spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec.squeeze(0)
+
+
+def extract_mel_features_tts(
+    y,
+    cfg,
+    center=False,
+    taco=False,
+    _stft=None,
+):
+    """Extract mel features
+
+    Args:
+        y (tensor): audio data in tensor
+        cfg (dict): configuration in cfg.preprocess
+        center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False.
+        taco: use tacotron mel
+
+    Returns:
+        tensor: a tensor containing the mel feature calculated based on STFT result
+    """
+    if not taco:
+        if torch.min(y) < -1.0:
+            print("min value is ", torch.min(y))
+        if torch.max(y) > 1.0:
+            print("max value is ", torch.max(y))
+
+        global mel_basis, hann_window
+        if cfg.fmax not in mel_basis:
+            mel = librosa_mel_fn(
+                sr=cfg.sample_rate,
+                n_fft=cfg.n_fft,
+                n_mels=cfg.n_mel,
+                fmin=cfg.fmin,
+                fmax=cfg.fmax,
+            )
+            mel_basis[str(cfg.fmax) + "_" + str(y.device)] = (
+                torch.from_numpy(mel).float().to(y.device)
+            )
+            hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device)
+
+        y = torch.nn.functional.pad(
+            y.unsqueeze(1),
+            (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
+            mode="reflect",
+        )
+        y = y.squeeze(1)
+
+        # complex tensor as default, then use view_as_real for future pytorch compatibility
+        spec = torch.stft(
+            y,
+            cfg.n_fft,
+            hop_length=cfg.hop_size,
+            win_length=cfg.win_size,
+            window=hann_window[str(y.device)],
+            center=center,
+            pad_mode="reflect",
+            normalized=False,
+            onesided=True,
+            return_complex=True,
+        )
+        spec = torch.view_as_real(spec)
+        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+
+        spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
+        spec = spectral_normalize_torch(spec)
+        spec = spec.squeeze(0)
+    else:
+        audio = torch.clip(y, -1, 1)
+        audio = torch.autograd.Variable(audio, requires_grad=False)
+        spec, energy = _stft.mel_spectrogram(audio)
+        spec = torch.squeeze(spec, 0)
+
+    spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec)
+    spec = spectral_normalize_torch(spec)
+
+    return spec.squeeze(0)
+
+
+def amplitude_phase_spectrum(y, cfg):
+    hann_window = torch.hann_window(cfg.win_size).to(y.device)
+
+    y = torch.nn.functional.pad(
+        y.unsqueeze(1),
+        (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)),
+        mode="reflect",
+    )
+    y = y.squeeze(1)
+
+    stft_spec = torch.stft(
+        y,
+        cfg.n_fft,
+        hop_length=cfg.hop_size,
+        win_length=cfg.win_size,
+        window=hann_window,
+        center=False,
+        return_complex=True,
+    )
+
+    stft_spec = torch.view_as_real(stft_spec)
+    if stft_spec.size()[0] == 1:
+        stft_spec = stft_spec.squeeze(0)
+
+    if len(list(stft_spec.size())) == 4:
+        rea = stft_spec[:, :, :, 0]  # [batch_size, n_fft//2+1, frames]
+        imag = stft_spec[:, :, :, 1]  # [batch_size, n_fft//2+1, frames]
+    else:
+        rea = stft_spec[:, :, 0]  # [n_fft//2+1, frames]
+        imag = stft_spec[:, :, 1]  # [n_fft//2+1, frames]
+
+    log_amplitude = torch.log(
+        torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
+    )  # [n_fft//2+1, frames]
+    phase = torch.atan2(imag, rea)  # [n_fft//2+1, frames]
+
+    return log_amplitude, phase, rea, imag
diff --git a/utils/mert.py b/utils/mert.py
new file mode 100644
index 0000000000000000000000000000000000000000..4181429feb36f5013bafafff9505b0b9571485b4
--- /dev/null
+++ b/utils/mert.py
@@ -0,0 +1,139 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://huggingface.co/m-a-p/MERT-v1-330M
+
+import torch
+from tqdm import tqdm
+import numpy as np
+
+from transformers import Wav2Vec2FeatureExtractor
+from transformers import AutoModel
+import torchaudio
+import torchaudio.transforms as T
+from sklearn.preprocessing import StandardScaler
+
+
+def mert_encoder(model, processor, audio_path, hps):
+    """
+    # mert default sr: 24000
+    """
+    with torch.no_grad():
+        resample_rate = processor.sampling_rate
+        device = next(model.parameters()).device
+
+        input_audio, sampling_rate = torchaudio.load(audio_path)
+        input_audio = input_audio.squeeze()
+
+        if sampling_rate != resample_rate:
+            resampler = T.Resample(sampling_rate, resample_rate)
+            input_audio = resampler(input_audio)
+
+        inputs = processor(
+            input_audio, sampling_rate=resample_rate, return_tensors="pt"
+        ).to(
+            device
+        )  # {input_values: tensor, attention_mask: tensor}
+
+        outputs = model(**inputs, output_hidden_states=True)  # list: len is 25
+
+    # [25 layer, Time steps, 1024 feature_dim]
+    # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
+    # mert_features.append(all_layer_hidden_states)
+
+    feature = outputs.hidden_states[
+        hps.mert_feature_layer
+    ].squeeze()  # [1, frame len, 1024] ->  [frame len, 1024]
+
+    return feature.cpu().detach().numpy()
+
+
+def mert_features_normalization(raw_mert_features):
+    normalized_mert_features = list()
+
+    mert_features = np.array(raw_mert_features)
+    scaler = StandardScaler().fit(mert_features)
+    for raw_mert_feature in raw_mert_feature:
+        normalized_mert_feature = scaler.transform(raw_mert_feature)
+        normalized_mert_features.append(normalized_mert_feature)
+    return normalized_mert_features
+
+
+def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True):
+    source_hop = 320
+    target_hop = 256
+
+    factor = np.gcd(source_hop, target_hop)
+    source_hop //= factor
+    target_hop //= factor
+    print(
+        "Mapping source's {} frames => target's {} frames".format(
+            target_hop, source_hop
+        )
+    )
+
+    mert_features = []
+    for index, mapping_feat in enumerate(tqdm(mapping_features)):
+        # mapping_feat: (mels_frame_len, n_mels)
+        target_len = mapping_feat.shape[0]
+
+        # (frame_len, 1024)
+        raw_feats = raw_mert_features[index].cpu().numpy()
+        source_len, width = raw_feats.shape
+
+        # const ~= target_len * target_hop
+        const = source_len * source_hop // target_hop * target_hop
+
+        # (source_len * source_hop, dim)
+        up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0)
+        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+        down_sampling_feats = np.average(
+            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+        )
+
+        err = abs(target_len - len(down_sampling_feats))
+        if err > 3:
+            print("index:", index)
+            print("mels:", mapping_feat.shape)
+            print("raw mert vector:", raw_feats.shape)
+            print("up_sampling:", up_sampling_feats.shape)
+            print("const:", const)
+            print("down_sampling_feats:", down_sampling_feats.shape)
+            exit()
+        if len(down_sampling_feats) < target_len:
+            # (1, dim) -> (err, dim)
+            end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
+            down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
+
+        # (target_len, dim)
+        feats = down_sampling_feats[:target_len]
+        mert_features.append(feats)
+
+    return mert_features
+
+
+def load_mert_model(hps):
+    print("Loading MERT Model: ", hps.mert_model)
+
+    # Load model
+    model_name = hps.mert_model
+    model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
+
+    if torch.cuda.is_available():
+        model = model.cuda()
+
+    # model = model.eval()
+
+    preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
+        model_name, trust_remote_code=True
+    )
+    return model, preprocessor
+
+
+# loading the corresponding preprocessor config
+# def load_preprocessor (model_name="m-a-p/MERT-v1-330M"):
+#     print('load_preprocessor...')
+#     preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True)
+#     return preprocessor
diff --git a/utils/model_summary.py b/utils/model_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec72b0d17869dc0886a9eb665efedb70bb307bbf
--- /dev/null
+++ b/utils/model_summary.py
@@ -0,0 +1,74 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import humanfriendly
+import numpy as np
+import torch
+
+
+def get_human_readable_count(number: int) -> str:
+    """Return human_readable_count
+
+    Originated from:
+    https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
+
+    Abbreviates an integer number with K, M, B, T for thousands, millions,
+    billions and trillions, respectively.
+    Examples:
+        >>> get_human_readable_count(123)
+        '123  '
+        >>> get_human_readable_count(1234)  # (one thousand)
+        '1 K'
+        >>> get_human_readable_count(2e6)   # (two million)
+        '2 M'
+        >>> get_human_readable_count(3e9)   # (three billion)
+        '3 B'
+        >>> get_human_readable_count(4e12)  # (four trillion)
+        '4 T'
+        >>> get_human_readable_count(5e15)  # (more than trillion)
+        '5,000 T'
+    Args:
+        number: a positive integer number
+    Return:
+        A string formatted according to the pattern described above.
+    """
+    assert number >= 0
+    labels = [" ", "K", "M", "B", "T"]
+    num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
+    num_groups = int(np.ceil(num_digits / 3))
+    num_groups = min(num_groups, len(labels))
+    shift = -3 * (num_groups - 1)
+    number = number * (10**shift)
+    index = num_groups - 1
+    return f"{number:.2f} {labels[index]}"
+
+
+def to_bytes(dtype) -> int:
+    return int(str(dtype)[-2:]) // 8
+
+
+def model_summary(model: torch.nn.Module) -> str:
+    message = "Model structure:\n"
+    message += str(model)
+    tot_params = sum(p.numel() for p in model.parameters())
+    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
+    tot_params = get_human_readable_count(tot_params)
+    num_params = get_human_readable_count(num_params)
+    message += "\n\nModel summary:\n"
+    message += f"    Class Name: {model.__class__.__name__}\n"
+    message += f"    Total Number of model parameters: {tot_params}\n"
+    message += (
+        f"    Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
+    )
+    num_bytes = humanfriendly.format_size(
+        sum(
+            p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
+        )
+    )
+    message += f"    Size: {num_bytes}\n"
+    dtype = next(iter(model.parameters())).dtype
+    message += f"    Type: {dtype}"
+    return message
diff --git a/utils/prompt_preparer.py b/utils/prompt_preparer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dba833ee72cd661bbc178ffc4735e05a4e4949c9
--- /dev/null
+++ b/utils/prompt_preparer.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+class PromptPreparer: 
+    def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes):
+        if self.prefix_mode == 0:
+            y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage)
+        elif self.prefix_mode == 1:
+            y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage)
+        elif self.prefix_mode in [2, 4]:
+            y_emb, prefix_len = self._handle_prefix_mode_2_4(y, y_lens, codes, nar_stage, y_prompts_codes)
+        else:
+            raise ValueError("Invalid prefix mode")
+
+        return y_emb, prefix_len
+
+    def _handle_prefix_mode_0(self, y, codes, nar_stage):
+        prefix_len = 0
+        y_emb = self.nar_audio_embeddings[0](y)
+        for j in range(1, nar_stage):
+            y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
+        return y_emb, 0  
+
+    def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage):
+        int_low = (0.25 * y_lens.min()).type(torch.int64).item()
+        prefix_len = torch.randint(int_low, int_low * 2, size=()).item()
+        prefix_len = min(prefix_len, 225) 
+
+        y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
+        y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
+        for j in range(1, self.num_quantizers):
+            y_prompts += self.nar_audio_embeddings[j](
+                codes[:, :prefix_len, j]
+            )
+            if j < nar_stage:
+                y_emb += self.nar_audio_embeddings[j](
+                    codes[:, prefix_len:, j]
+                )
+        y_emb = torch.concat([y_prompts, y_emb], axis=1)
+        return y_emb, prefix_len
+
+    def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes):
+        if self.prefix_mode == 2:
+            prefix_len = min(225, int(0.25 * y_lens.min().item()))
+
+            y_prompts_codes = []
+            for b in range(codes.shape[0]):
+                start = self.rng.randint(0, y_lens[b].item() - prefix_len)
+                y_prompts_codes.append(
+                    torch.clone(codes[b, start : start + prefix_len])
+                )
+                codes[
+                    b, start : start + prefix_len, nar_stage
+                ] = self.audio_token_num
+            y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
+        else:
+            prefix_len = y_prompts_codes.shape[1]
+
+        y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
+        y_emb = self.nar_audio_embeddings[0](y)
+        for j in range(1, self.num_quantizers):
+            y_prompts += self.nar_audio_embeddings[j](
+                y_prompts_codes[..., j]
+            )
+            if j < nar_stage:
+                y_emb += self.nar_audio_embeddings[j](codes[..., j])
+        y_emb = torch.concat([y_prompts, y_emb], axis=1)
+        
+        return y_emb, prefix_len
diff --git a/utils/ssim.py b/utils/ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0b95007b2c225f3cae869f0653a733d1d92043a
--- /dev/null
+++ b/utils/ssim.py
@@ -0,0 +1,80 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from https://github.com/Po-Hsun-Su/pytorch-ssim
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+from math import exp
+
+
+def gaussian(window_size, sigma):
+    gauss = torch.Tensor(
+        [
+            exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
+            for x in range(window_size)
+        ]
+    )
+    return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+    window = Variable(
+        _2D_window.expand(channel, 1, window_size, window_size).contiguous()
+    )
+    return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+    mu1_sq = mu1.pow(2)
+    mu2_sq = mu2.pow(2)
+    mu1_mu2 = mu1 * mu2
+
+    sigma1_sq = (
+        F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+    )
+    sigma2_sq = (
+        F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+    )
+    sigma12 = (
+        F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
+        - mu1_mu2
+    )
+
+    C1 = 0.01**2
+    C2 = 0.03**2
+
+    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
+        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
+    )
+
+    if size_average:
+        return ssim_map.mean()
+    else:
+        return ssim_map.mean(1)
+
+
+class SSIM(torch.nn.Module):
+    def __init__(self, window_size=11, size_average=True):
+        super(SSIM, self).__init__()
+        self.window_size = window_size
+        self.size_average = size_average
+        self.channel = 1
+        self.window = create_window(window_size, self.channel)
+
+    def forward(self, fake, real, bias=6.0):
+        fake = fake[:, None, :, :] + bias  # [B, 1, T, n_mels]
+        real = real[:, None, :, :] + bias  # [B, 1, T, n_mels]
+        self.window = self.window.to(dtype=fake.dtype, device=fake.device)
+        loss = 1 - _ssim(
+            fake, real, self.window, self.window_size, self.channel, self.size_average
+        )
+        return loss
diff --git a/utils/stft.py b/utils/stft.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcec4c84ace0cc40d65361316222b090428cc391
--- /dev/null
+++ b/utils/stft.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy.signal import get_window
+from librosa.util import pad_center, tiny
+from librosa.filters import mel as librosa_mel_fn
+
+import torch
+import numpy as np
+import librosa.util as librosa_util
+from scipy.signal import get_window
+
+
+def window_sumsquare(
+    window,
+    n_frames,
+    hop_length,
+    win_length,
+    n_fft,
+    dtype=np.float32,
+    norm=None,
+):
+    """
+    # from librosa 0.6
+    Compute the sum-square envelope of a window function at a given hop length.
+
+    This is used to estimate modulation effects induced by windowing
+    observations in short-time fourier transforms.
+
+    Parameters
+    ----------
+    window : string, tuple, number, callable, or list-like
+        Window specification, as in `get_window`
+
+    n_frames : int > 0
+        The number of analysis frames
+
+    hop_length : int > 0
+        The number of samples to advance between frames
+
+    win_length : [optional]
+        The length of the window function.  By default, this matches `n_fft`.
+
+    n_fft : int > 0
+        The length of each analysis frame.
+
+    dtype : np.dtype
+        The data type of the output
+
+    Returns
+    -------
+    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+        The sum-squared envelope of the window function
+    """
+    if win_length is None:
+        win_length = n_fft
+
+    n = n_fft + hop_length * (n_frames - 1)
+    x = np.zeros(n, dtype=dtype)
+
+    # Compute the squared window at the desired length
+    win_sq = get_window(window, win_length, fftbins=True)
+    win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+    win_sq = librosa_util.pad_center(win_sq, n_fft)
+
+    # Fill the envelope
+    for i in range(n_frames):
+        sample = i * hop_length
+        x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+    return x
+
+
+def griffin_lim(magnitudes, stft_fn, n_iters=30):
+    """
+    PARAMS
+    ------
+    magnitudes: spectrogram magnitudes
+    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
+    """
+
+    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
+    angles = angles.astype(np.float32)
+    angles = torch.autograd.Variable(torch.from_numpy(angles))
+    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+
+    for i in range(n_iters):
+        _, angles = stft_fn.transform(signal)
+        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
+    return signal
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+    """
+    PARAMS
+    ------
+    C: compression factor
+    """
+    return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+    """
+    PARAMS
+    ------
+    C: compression factor used to compress
+    """
+    return torch.exp(x) / C
+
+
+class STFT(torch.nn.Module):
+    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+
+    def __init__(self, filter_length, hop_length, win_length, window="hann"):
+        super(STFT, self).__init__()
+        self.filter_length = filter_length
+        self.hop_length = hop_length
+        self.win_length = win_length
+        self.window = window
+        self.forward_transform = None
+        scale = self.filter_length / self.hop_length
+        fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+        cutoff = int((self.filter_length / 2 + 1))
+        fourier_basis = np.vstack(
+            [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
+        )
+
+        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+        inverse_basis = torch.FloatTensor(
+            np.linalg.pinv(scale * fourier_basis).T[:, None, :]
+        )
+
+        if window is not None:
+            assert filter_length >= win_length
+            # get window and zero center pad it to filter_length
+            fft_window = get_window(window, win_length, fftbins=True)
+            fft_window = pad_center(fft_window, filter_length)
+            fft_window = torch.from_numpy(fft_window).float()
+
+            # window the bases
+            forward_basis *= fft_window
+            inverse_basis *= fft_window
+
+        self.register_buffer("forward_basis", forward_basis.float())
+        self.register_buffer("inverse_basis", inverse_basis.float())
+
+    def transform(self, input_data):
+        num_batches = input_data.size(0)
+        num_samples = input_data.size(1)
+
+        self.num_samples = num_samples
+
+        # similar to librosa, reflect-pad the input
+        input_data = input_data.view(num_batches, 1, num_samples)
+        input_data = F.pad(
+            input_data.unsqueeze(1),
+            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+            mode="reflect",
+        )
+        input_data = input_data.squeeze(1)
+
+        forward_transform = F.conv1d(
+            input_data.cuda(),
+            torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(),
+            stride=self.hop_length,
+            padding=0,
+        ).cpu()
+
+        cutoff = int((self.filter_length / 2) + 1)
+        real_part = forward_transform[:, :cutoff, :]
+        imag_part = forward_transform[:, cutoff:, :]
+
+        magnitude = torch.sqrt(real_part**2 + imag_part**2)
+        phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
+
+        return magnitude, phase
+
+    def inverse(self, magnitude, phase):
+        recombine_magnitude_phase = torch.cat(
+            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
+        )
+
+        inverse_transform = F.conv_transpose1d(
+            recombine_magnitude_phase,
+            torch.autograd.Variable(self.inverse_basis, requires_grad=False),
+            stride=self.hop_length,
+            padding=0,
+        )
+
+        if self.window is not None:
+            window_sum = window_sumsquare(
+                self.window,
+                magnitude.size(-1),
+                hop_length=self.hop_length,
+                win_length=self.win_length,
+                n_fft=self.filter_length,
+                dtype=np.float32,
+            )
+            # remove modulation effects
+            approx_nonzero_indices = torch.from_numpy(
+                np.where(window_sum > tiny(window_sum))[0]
+            )
+            window_sum = torch.autograd.Variable(
+                torch.from_numpy(window_sum), requires_grad=False
+            )
+            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
+            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
+                approx_nonzero_indices
+            ]
+
+            # scale by hop ratio
+            inverse_transform *= float(self.filter_length) / self.hop_length
+
+        inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
+        inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
+
+        return inverse_transform
+
+    def forward(self, input_data):
+        self.magnitude, self.phase = self.transform(input_data)
+        reconstruction = self.inverse(self.magnitude, self.phase)
+        return reconstruction
+
+
+class TacotronSTFT(torch.nn.Module):
+    def __init__(
+        self,
+        filter_length,
+        hop_length,
+        win_length,
+        n_mel_channels,
+        sampling_rate,
+        mel_fmin,
+        mel_fmax,
+    ):
+        super(TacotronSTFT, self).__init__()
+        self.n_mel_channels = n_mel_channels
+        self.sampling_rate = sampling_rate
+        self.stft_fn = STFT(filter_length, hop_length, win_length)
+        mel_basis = librosa_mel_fn(
+            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
+        )
+        mel_basis = torch.from_numpy(mel_basis).float()
+        self.register_buffer("mel_basis", mel_basis)
+
+    def spectral_normalize(self, magnitudes):
+        output = dynamic_range_compression(magnitudes)
+        return output
+
+    def spectral_de_normalize(self, magnitudes):
+        output = dynamic_range_decompression(magnitudes)
+        return output
+
+    def mel_spectrogram(self, y):
+        """Computes mel-spectrograms from a batch of waves
+        PARAMS
+        ------
+        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
+
+        RETURNS
+        -------
+        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
+        """
+        assert torch.min(y.data) >= -1
+        assert torch.max(y.data) <= 1
+
+        magnitudes, phases = self.stft_fn.transform(y)
+        magnitudes = magnitudes.data
+        mel_output = torch.matmul(self.mel_basis, magnitudes)
+        mel_output = self.spectral_normalize(mel_output)
+        energy = torch.norm(magnitudes, dim=1)
+
+        return mel_output, energy
diff --git a/utils/symbol_table.py b/utils/symbol_table.py
new file mode 100644
index 0000000000000000000000000000000000000000..730ffe7a8018f80f662e260542a859a6f6f74a47
--- /dev/null
+++ b/utils/symbol_table.py
@@ -0,0 +1,313 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from 
+# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/utils/symbol_table.py 
+
+from dataclasses import dataclass
+from dataclasses import field
+from typing import Dict
+from typing import Generic
+from typing import List
+from typing import Optional
+from typing import TypeVar
+from typing import Union
+
+Symbol = TypeVar('Symbol')
+
+
+@dataclass(repr=False)
+class SymbolTable(Generic[Symbol]):
+    '''SymbolTable that maps symbol IDs, found on the FSA arcs to
+    actual objects. These objects can be arbitrary Python objects
+    that can serve as keys in a dictionary (i.e. they need to be
+    hashable and immutable).
+
+    The SymbolTable can only be read to/written from disk if the
+    symbols are strings.
+    '''
+    _id2sym: Dict[int, Symbol] = field(default_factory=dict)
+    '''Map an integer to a symbol.
+    '''
+
+    _sym2id: Dict[Symbol, int] = field(default_factory=dict)
+    '''Map a symbol to an integer.
+    '''
+
+    _next_available_id: int = 1
+    '''A helper internal field that helps adding new symbols
+    to the table efficiently.
+    '''
+
+    eps: Symbol = '<eps>'
+    '''Null symbol, always mapped to index 0.
+    '''
+
+    def __post_init__(self):
+        assert all(self._sym2id[sym] == idx for idx, sym in self._id2sym.items())
+        assert all(self._id2sym[idx] == sym for sym, idx in self._sym2id.items())
+        assert 0 not in self._id2sym or self._id2sym[0] == self.eps
+
+        self._next_available_id = max(self._id2sym, default=0) + 1
+        self._id2sym.setdefault(0, self.eps)
+        self._sym2id.setdefault(self.eps, 0)
+
+
+    @staticmethod
+    def from_str(s: str) -> 'SymbolTable':
+        '''Build a symbol table from a string.
+
+        The string consists of lines. Every line has two fields separated
+        by space(s), tab(s) or both. The first field is the symbol and the
+        second the integer id of the symbol.
+
+        Args:
+          s:
+            The input string with the format described above.
+        Returns:
+          An instance of :class:`SymbolTable`.
+        '''
+        id2sym: Dict[int, str] = dict()
+        sym2id: Dict[str, int] = dict()
+
+        for line in s.split('\n'):
+            fields = line.split()
+            if len(fields) == 0:
+                continue  # skip empty lines
+            assert len(fields) == 2, \
+                    f'Expect a line with 2 fields. Given: {len(fields)}'
+            sym, idx = fields[0], int(fields[1])
+            assert sym not in sym2id, f'Duplicated symbol {sym}'
+            assert idx not in id2sym, f'Duplicated id {idx}'
+            id2sym[idx] = sym
+            sym2id[sym] = idx
+
+        eps = id2sym.get(0, '<eps>')
+
+        return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps)
+
+    @staticmethod
+    def from_file(filename: str) -> 'SymbolTable':
+        '''Build a symbol table from file.
+
+        Every line in the symbol table file has two fields separated by
+        space(s), tab(s) or both. The following is an example file:
+
+        .. code-block::
+
+            <eps> 0
+            a 1
+            b 2
+            c 3
+
+        Args:
+          filename:
+            Name of the symbol table file. Its format is documented above.
+
+        Returns:
+          An instance of :class:`SymbolTable`.
+
+        '''
+        with open(filename, 'r', encoding='utf-8') as f:
+            return SymbolTable.from_str(f.read().strip())
+
+    def to_str(self) -> str:
+        '''
+        Returns:
+          Return a string representation of this object. You can pass
+          it to the method ``from_str`` to recreate an identical object.
+        '''
+        s = ''
+        for idx, symbol in sorted(self._id2sym.items()):
+            s += f'{symbol} {idx}\n'
+        return s
+
+    def to_file(self, filename: str):
+        '''Serialize the SymbolTable to a file.
+
+        Every line in the symbol table file has two fields separated by
+        space(s), tab(s) or both. The following is an example file:
+
+        .. code-block::
+
+            <eps> 0
+            a 1
+            b 2
+            c 3
+
+        Args:
+          filename:
+            Name of the symbol table file. Its format is documented above.
+        '''
+        with open(filename, 'w') as f:
+            for idx, symbol in sorted(self._id2sym.items()):
+                print(symbol, idx, file=f)                
+
+    def add(self, symbol: Symbol, index: Optional[int] = None) -> int:
+        '''Add a new symbol to the SymbolTable.
+
+        Args:
+            symbol:
+                The symbol to be added.
+            index:
+                Optional int id to which the symbol should be assigned.
+                If it is not available, a ValueError will be raised.
+
+        Returns:
+            The int id to which the symbol has been assigned.
+        '''
+        # Already in the table? Return its ID.
+        if symbol in self._sym2id:
+            return self._sym2id[symbol]
+        # Specific ID not provided - use next available.
+        if index is None:
+            index = self._next_available_id
+        # Specific ID provided but not available.
+        if index in self._id2sym:
+            raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - "
+                             f"already occupied by {self._id2sym[index]}")
+        self._sym2id[symbol] = index
+        self._id2sym[index] = symbol
+
+        # Update next available ID if needed
+        if self._next_available_id <= index:
+            self._next_available_id = index + 1
+
+        return index
+
+    def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]:
+        '''Get a symbol for an id or get an id for a symbol
+
+        Args:
+          k:
+            If it is an id, it tries to find the symbol corresponding
+            to the id; if it is a symbol, it tries to find the id
+            corresponding to the symbol.
+
+        Returns:
+          An id or a symbol depending on the given `k`.
+        '''
+        if isinstance(k, int):
+            return self._id2sym[k]
+        else:
+            return self._sym2id[k]
+
+    def merge(self, other: 'SymbolTable') -> 'SymbolTable':
+        '''Create a union of two SymbolTables.
+        Raises an AssertionError if the same IDs are occupied by
+        different symbols.
+
+        Args:
+            other:
+                A symbol table to merge with ``self``.
+
+        Returns:
+            A new symbol table.
+        '''
+        self._check_compatible(other)
+        return SymbolTable(
+            _id2sym={**self._id2sym, **other._id2sym},
+            _sym2id={**self._sym2id, **other._sym2id},
+            eps=self.eps
+        )
+        
+    def _check_compatible(self, other: 'SymbolTable') -> None:
+        # Epsilon compatibility
+        assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \
+                                      f'{self.eps} != {other.eps}'
+        # IDs compatibility
+        common_ids = set(self._id2sym).intersection(other._id2sym)
+        for idx in common_ids:
+            assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \
+                                            f'self[idx] = "{self[idx]}", ' \
+                                            f'other[idx] = "{other[idx]}"'
+        # Symbols compatibility
+        common_symbols = set(self._sym2id).intersection(other._sym2id)
+        for sym in common_symbols:
+            assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \
+                                            f'self[sym] = "{self[sym]}", ' \
+                                            f'other[sym] = "{other[sym]}"'
+
+    def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]:
+        return self.get(item)
+
+    def __contains__(self, item: Union[int, Symbol]) -> bool:
+        if isinstance(item, int):
+            return item in self._id2sym
+        else:
+            return item in self._sym2id
+
+    def __len__(self) -> int:
+        return len(self._id2sym)
+
+    def __eq__(self, other: 'SymbolTable') -> bool:
+        if len(self) != len(other):
+            return False
+
+        for s in self.symbols:
+            if self[s] != other[s]:
+                return False
+
+        return True
+
+    @property
+    def ids(self) -> List[int]:
+        '''Returns a list of integer IDs corresponding to the symbols.
+        '''
+        ans = list(self._id2sym.keys())
+        ans.sort()
+        return ans
+
+    @property
+    def symbols(self) -> List[Symbol]:
+        '''Returns a list of symbols (e.g., strings) corresponding to
+        the integer IDs.
+        '''
+        ans = list(self._sym2id.keys())
+        ans.sort()
+        return ans
+
+
+class TextToken:
+    def __init__(
+        self,
+        text_tokens: List[str],
+        add_eos: bool = True,
+        add_bos: bool = True,
+        pad_symbol: str = "<pad>",
+        bos_symbol: str = "<bos>",
+        eos_symbol: str = "<eos>",
+    ):
+        self.pad_symbol = pad_symbol
+        self.add_eos = add_eos
+        self.add_bos = add_bos
+        self.bos_symbol = bos_symbol
+        self.eos_symbol = eos_symbol
+
+        unique_tokens = [pad_symbol]
+        if add_bos:
+            unique_tokens.append(bos_symbol)
+        if add_eos:
+            unique_tokens.append(eos_symbol)
+        unique_tokens.extend(sorted(text_tokens))
+
+        self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)}
+        self.idx2token = unique_tokens
+
+            
+    def get_token_id_seq(self, text):
+        tokens_seq = [p for p in text]
+        seq = (
+            ([self.bos_symbol] if self.add_bos else [])
+            + tokens_seq
+            + ([self.eos_symbol] if self.add_eos else [])
+        )
+
+        token_ids = [self.token2idx[token] for token in seq]
+        token_lens = len(tokens_seq) + self.add_eos + self.add_bos
+
+        return token_ids, token_lens  
+    
+    
\ No newline at end of file
diff --git a/utils/tokenizer.py b/utils/tokenizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8971432bdcc2f3f2920775bb3397c90f2a91f8b8
--- /dev/null
+++ b/utils/tokenizer.py
@@ -0,0 +1,151 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# This code is modified from 
+# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/tokenizer.py
+
+import re
+from typing import Any, Dict, List, Optional, Pattern, Union
+
+import torch
+import torchaudio
+from encodec import EncodecModel
+from encodec.utils import convert_audio
+
+
+
+class AudioTokenizer:
+    """EnCodec audio tokenizer for encoding and decoding audio.
+
+    Attributes:
+        device: The device on which the codec model is loaded.
+        codec: The pretrained EnCodec model.
+        sample_rate: Sample rate of the model.
+        channels: Number of audio channels in the model.
+    """
+
+    def __init__(self, device: Any = None) -> None:
+        model = EncodecModel.encodec_model_24khz()
+        model.set_target_bandwidth(6.0)
+        remove_encodec_weight_norm(model)
+
+        if not device:
+            device = torch.device("cpu")
+            if torch.cuda.is_available():
+                device = torch.device("cuda:0")
+
+        self._device = device
+
+        self.codec = model.to(device)
+        self.sample_rate = model.sample_rate
+        self.channels = model.channels
+
+    @property
+    def device(self):
+        return self._device
+
+    def encode(self, wav: torch.Tensor) -> torch.Tensor:
+        """Encode the audio waveform.
+
+        Args:
+            wav: A tensor representing the audio waveform.
+
+        Returns:
+            A tensor representing the encoded audio.
+        """
+        return self.codec.encode(wav.to(self.device))
+
+    def decode(self, frames: torch.Tensor) -> torch.Tensor:
+        """Decode the encoded audio frames.
+
+        Args:
+            frames: A tensor representing the encoded audio frames.
+
+        Returns:
+            A tensor representing the decoded audio waveform.
+        """
+        return self.codec.decode(frames)
+
+
+
+def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str):
+    """
+    Tokenize the audio waveform using the given AudioTokenizer.
+
+    Args:
+        tokenizer: An instance of AudioTokenizer.
+        audio_path: Path to the audio file.
+
+    Returns:
+        A tensor of encoded frames from the audio.
+
+    Raises:
+        FileNotFoundError: If the audio file is not found.
+        RuntimeError: If there's an error processing the audio data.
+    """
+    # try:
+        # Load and preprocess the audio waveform
+    wav, sr = torchaudio.load(audio_path)
+    wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels)
+    wav = wav.unsqueeze(0)
+
+    # Extract discrete codes from EnCodec
+    with torch.no_grad():
+        encoded_frames = tokenizer.encode(wav)
+    return encoded_frames
+
+    # except FileNotFoundError:
+    #     raise FileNotFoundError(f"Audio file not found at {audio_path}")
+    # except Exception as e:
+    #     raise RuntimeError(f"Error processing audio data: {e}")
+
+
+
+def remove_encodec_weight_norm(model):
+    from encodec.modules import SConv1d
+    from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock
+    from torch.nn.utils import remove_weight_norm
+
+    encoder = model.encoder.model
+    for key in encoder._modules:
+        if isinstance(encoder._modules[key], SEANetResnetBlock):
+            remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
+            block_modules = encoder._modules[key].block._modules
+            for skey in block_modules:
+                if isinstance(block_modules[skey], SConv1d):
+                    remove_weight_norm(block_modules[skey].conv.conv)
+        elif isinstance(encoder._modules[key], SConv1d):
+            remove_weight_norm(encoder._modules[key].conv.conv)
+
+    decoder = model.decoder.model
+    for key in decoder._modules:
+        if isinstance(decoder._modules[key], SEANetResnetBlock):
+            remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
+            block_modules = decoder._modules[key].block._modules
+            for skey in block_modules:
+                if isinstance(block_modules[skey], SConv1d):
+                    remove_weight_norm(block_modules[skey].conv.conv)
+        elif isinstance(decoder._modules[key], SConvTranspose1d):
+            remove_weight_norm(decoder._modules[key].convtr.convtr)
+        elif isinstance(decoder._modules[key], SConv1d):
+            remove_weight_norm(decoder._modules[key].conv.conv)
+
+
+def extract_encodec_token(wav_path):
+    model = EncodecModel.encodec_model_24khz()
+    model.set_target_bandwidth(6.0)
+
+    wav, sr = torchaudio.load(wav_path)
+    wav = convert_audio(wav, sr, model.sample_rate, model.channels)
+    wav = wav.unsqueeze(0)
+    if torch.cuda.is_available():
+        model = model.cuda()
+        wav = wav.cuda()
+    with torch.no_grad():
+        encoded_frames = model.encode(wav)
+        codes_ = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [B, n_q, T]
+        codes = codes_.cpu().numpy()[0,:,:].T # [T, 8]
+        
+        return codes
\ No newline at end of file
diff --git a/utils/topk_sampling.py b/utils/topk_sampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..03f405c15506609bbd571818123cf7f0db7f1f4b
--- /dev/null
+++ b/utils/topk_sampling.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import torch.nn.functional as F
+
+
+# This function is modified from https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
+def top_k_top_p_filtering(
+    logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
+):
+    """
+    Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
+
+    Args:
+        logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size).
+        top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering).
+                               Set to 0 to disable. Defaults to 0.
+        top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering).
+                                 Must be between 0 and 1, inclusive. Defaults to 1.0.
+        filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf').
+        min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example.
+                                            Defaults to 1.
+
+    Returns:
+        torch.Tensor: The filtered logits.
+    """
+    """
+        Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
+        Make sure we keep at least min_tokens_to_keep per batch example in the output
+        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
+    """
+    if top_k > 0:
+        # Apply top-k filtering
+        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))
+        indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None]
+        logits[indices_to_remove] = filter_value
+
+    if top_p < 1.0:
+        # Apply top-p filtering
+        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+
+        # Create a mask to remove tokens with cumulative probability above the top_p threshold
+        sorted_indices_to_remove = cumulative_probs > top_p
+        if min_tokens_to_keep > 1:
+            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
+        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+        sorted_indices_to_remove[..., 0] = 0
+
+        # Scatter sorted tensors back to original indexing
+        indices_to_remove = sorted_indices.scatter(1, sorted_indices, sorted_indices_to_remove)
+        logits[indices_to_remove] = filter_value
+
+    return logits
+
+
+def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0):
+    """
+    Perform top-k and top-p sampling on logits.
+
+    Args:
+        logits (torch.Tensor): The logits to sample from.
+        top_k (int, optional): The number of highest probability tokens to keep for top-k filtering.
+                               Must be a positive integer. Defaults to 50.
+        top_p (float, optional): The cumulative probability threshold for nucleus sampling.
+                                 Must be between 0 and 1. Defaults to 1.0.
+        temperature (float, optional): The scaling factor to adjust the logits distribution.
+                                       Must be strictly positive. Defaults to 1.0.
+
+    Returns:
+        torch.Tensor: The sampled token.
+    """
+
+    # Adjust logits using temperature
+    if temperature != 1.0:
+        logits = logits / temperature
+
+    # Top-p/top-k filtering
+    logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+
+    # Sample from the filtered distribution
+    token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
+    return token
diff --git a/utils/trainer_utils.py b/utils/trainer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d9ad794864aa3ee0c49a86e9be293f69442886
--- /dev/null
+++ b/utils/trainer_utils.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+def check_nan(logger, loss, y_pred, y_gt):
+    if torch.any(torch.isnan(loss)):
+        logger.info("out has nan: ", torch.any(torch.isnan(y_pred)))
+        logger.info("y_gt has nan: ", torch.any(torch.isnan(y_gt)))
+        logger.info("out: ", y_pred)
+        logger.info("y_gt: ", y_gt)
+        logger.info("loss = {:.4f}\n".format(loss.item()))
+        exit()
diff --git a/utils/util.py b/utils/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..096d9129cd0dab4798f205ab7fdc690303ef5d37
--- /dev/null
+++ b/utils/util.py
@@ -0,0 +1,688 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import collections
+import glob
+import os
+import random
+import time
+import argparse
+from collections import OrderedDict
+
+import json5
+import numpy as np
+import glob
+from torch.nn import functional as F
+
+
+try:
+    from ruamel.yaml import YAML as yaml
+except:
+    from ruamel_yaml import YAML as yaml
+
+import torch
+
+from utils.hparam import HParams
+import logging
+from logging import handlers
+
+
+def str2bool(v):
+    """Used in argparse.ArgumentParser.add_argument to indicate
+    that a type is a bool type and user can enter
+
+        - yes, true, t, y, 1, to represent True
+        - no, false, f, n, 0, to represent False
+
+    See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse  # noqa
+    """
+    if isinstance(v, bool):
+        return v
+    if v.lower() in ("yes", "true", "t", "y", "1"):
+        return True
+    elif v.lower() in ("no", "false", "f", "n", "0"):
+        return False
+    else:
+        raise argparse.ArgumentTypeError("Boolean value expected.")
+
+
+def find_checkpoint_of_mapper(mapper_ckpt_dir):
+    mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt"))
+
+    # Select the max steps
+    mapper_ckpts.sort()
+    mapper_weights_file = mapper_ckpts[-1]
+    return mapper_weights_file
+
+
+def pad_f0_to_tensors(f0s, batched=None):
+    # Initialize
+    tensors = []
+
+    if batched == None:
+        # Get the max frame for padding
+        size = -1
+        for f0 in f0s:
+            size = max(size, f0.shape[-1])
+
+        tensor = torch.zeros(len(f0s), size)
+
+        for i, f0 in enumerate(f0s):
+            tensor[i, : f0.shape[-1]] = f0[:]
+
+        tensors.append(tensor)
+    else:
+        start = 0
+        while start + batched - 1 < len(f0s):
+            end = start + batched - 1
+
+            # Get the max frame for padding
+            size = -1
+            for i in range(start, end + 1):
+                size = max(size, f0s[i].shape[-1])
+
+            tensor = torch.zeros(batched, size)
+
+            for i in range(start, end + 1):
+                tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
+
+            tensors.append(tensor)
+
+            start = start + batched
+
+        if start != len(f0s):
+            end = len(f0s)
+
+            # Get the max frame for padding
+            size = -1
+            for i in range(start, end):
+                size = max(size, f0s[i].shape[-1])
+
+            tensor = torch.zeros(len(f0s) - start, size)
+
+            for i in range(start, end):
+                tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:]
+
+            tensors.append(tensor)
+
+    return tensors
+
+
+def pad_mels_to_tensors(mels, batched=None):
+    """
+    Args:
+        mels: A list of mel-specs
+    Returns:
+        tensors: A list of tensors containing the batched mel-specs
+        mel_frames: A list of tensors containing the frames of the original mel-specs
+    """
+    # Initialize
+    tensors = []
+    mel_frames = []
+
+    # Split mel-specs into batches to avoid cuda memory exceed
+    if batched == None:
+        # Get the max frame for padding
+        size = -1
+        for mel in mels:
+            size = max(size, mel.shape[-1])
+
+        tensor = torch.zeros(len(mels), mels[0].shape[0], size)
+        mel_frame = torch.zeros(len(mels), dtype=torch.int32)
+
+        for i, mel in enumerate(mels):
+            tensor[i, :, : mel.shape[-1]] = mel[:]
+            mel_frame[i] = mel.shape[-1]
+
+        tensors.append(tensor)
+        mel_frames.append(mel_frame)
+    else:
+        start = 0
+        while start + batched - 1 < len(mels):
+            end = start + batched - 1
+
+            # Get the max frame for padding
+            size = -1
+            for i in range(start, end + 1):
+                size = max(size, mels[i].shape[-1])
+
+            tensor = torch.zeros(batched, mels[0].shape[0], size)
+            mel_frame = torch.zeros(batched, dtype=torch.int32)
+
+            for i in range(start, end + 1):
+                tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
+                mel_frame[i - start] = mels[i].shape[-1]
+
+            tensors.append(tensor)
+            mel_frames.append(mel_frame)
+
+            start = start + batched
+
+        if start != len(mels):
+            end = len(mels)
+
+            # Get the max frame for padding
+            size = -1
+            for i in range(start, end):
+                size = max(size, mels[i].shape[-1])
+
+            tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size)
+            mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32)
+
+            for i in range(start, end):
+                tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:]
+                mel_frame[i - start] = mels[i].shape[-1]
+
+            tensors.append(tensor)
+            mel_frames.append(mel_frame)
+
+    return tensors, mel_frames
+
+
+def load_model_config(args):
+    """Load model configurations (in args.json under checkpoint directory)
+
+    Args:
+        args (ArgumentParser): arguments to run bins/preprocess.py
+
+    Returns:
+        dict: dictionary that stores model configurations
+    """
+    if args.checkpoint_dir is None:
+        assert args.checkpoint_file is not None
+        checkpoint_dir = os.path.split(args.checkpoint_file)[0]
+    else:
+        checkpoint_dir = args.checkpoint_dir
+    config_path = os.path.join(checkpoint_dir, "args.json")
+    print("config_path: ", config_path)
+
+    config = load_config(config_path)
+    return config
+
+
+def remove_and_create(dir):
+    if os.path.exists(dir):
+        os.system("rm -r {}".format(dir))
+    os.makedirs(dir, exist_ok=True)
+
+
+def has_existed(path, warning=False):
+    if not warning:
+        return os.path.exists(path)
+
+    if os.path.exists(path):
+        answer = input(
+            "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format(
+                path
+            )
+        )
+        if not answer == "n":
+            return True
+
+    return False
+
+
+def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5):
+    if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")):
+        with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f:
+            ckpts = [x.strip() for x in f.readlines()]
+    else:
+        ckpts = []
+    ckpts.append(saved_model_name)
+    for item in ckpts[:-max_to_keep]:
+        if os.path.exists(os.path.join(checkpoint_dir, item)):
+            os.remove(os.path.join(checkpoint_dir, item))
+    with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f:
+        for item in ckpts[-max_to_keep:]:
+            f.write("{}\n".format(item))
+
+
+def set_all_random_seed(seed: int):
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.random.manual_seed(seed)
+
+
+def save_checkpoint(
+    args,
+    generator,
+    g_optimizer,
+    step,
+    discriminator=None,
+    d_optimizer=None,
+    max_to_keep=5,
+):
+    saved_model_name = "model.ckpt-{}.pt".format(step)
+    checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name)
+
+    if discriminator and d_optimizer:
+        torch.save(
+            {
+                "generator": generator.state_dict(),
+                "discriminator": discriminator.state_dict(),
+                "g_optimizer": g_optimizer.state_dict(),
+                "d_optimizer": d_optimizer.state_dict(),
+                "global_step": step,
+            },
+            checkpoint_path,
+        )
+    else:
+        torch.save(
+            {
+                "generator": generator.state_dict(),
+                "g_optimizer": g_optimizer.state_dict(),
+                "global_step": step,
+            },
+            checkpoint_path,
+        )
+
+    print("Saved checkpoint: {}".format(checkpoint_path))
+
+    if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")):
+        with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f:
+            ckpts = [x.strip() for x in f.readlines()]
+    else:
+        ckpts = []
+    ckpts.append(saved_model_name)
+    for item in ckpts[:-max_to_keep]:
+        if os.path.exists(os.path.join(args.checkpoint_dir, item)):
+            os.remove(os.path.join(args.checkpoint_dir, item))
+    with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f:
+        for item in ckpts[-max_to_keep:]:
+            f.write("{}\n".format(item))
+
+
+def attempt_to_restore(
+    generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None
+):
+    checkpoint_list = os.path.join(checkpoint_dir, "checkpoint")
+    if os.path.exists(checkpoint_list):
+        checkpoint_filename = open(checkpoint_list).readlines()[-1].strip()
+        checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename))
+        print("Restore from {}".format(checkpoint_path))
+        checkpoint = torch.load(checkpoint_path, map_location="cpu")
+        if generator:
+            if not list(generator.state_dict().keys())[0].startswith("module."):
+                raw_dict = checkpoint["generator"]
+                clean_dict = OrderedDict()
+                for k, v in raw_dict.items():
+                    if k.startswith("module."):
+                        clean_dict[k[7:]] = v
+                    else:
+                        clean_dict[k] = v
+                generator.load_state_dict(clean_dict)
+            else:
+                generator.load_state_dict(checkpoint["generator"])
+        if g_optimizer:
+            g_optimizer.load_state_dict(checkpoint["g_optimizer"])
+        global_step = 100000
+        if discriminator and "discriminator" in checkpoint.keys():
+            discriminator.load_state_dict(checkpoint["discriminator"])
+            global_step = checkpoint["global_step"]
+            print("restore discriminator")
+        if d_optimizer and "d_optimizer" in checkpoint.keys():
+            d_optimizer.load_state_dict(checkpoint["d_optimizer"])
+            print("restore d_optimizer...")
+    else:
+        global_step = 0
+    return global_step
+
+
+class ExponentialMovingAverage(object):
+    def __init__(self, decay):
+        self.decay = decay
+        self.shadow = {}
+
+    def register(self, name, val):
+        self.shadow[name] = val.clone()
+
+    def update(self, name, x):
+        assert name in self.shadow
+        update_delta = self.shadow[name] - x
+        self.shadow[name] -= (1.0 - self.decay) * update_delta
+
+
+def apply_moving_average(model, ema):
+    for name, param in model.named_parameters():
+        if name in ema.shadow:
+            ema.update(name, param.data)
+
+
+def register_model_to_ema(model, ema):
+    for name, param in model.named_parameters():
+        if param.requires_grad:
+            ema.register(name, param.data)
+
+
+class YParams(HParams):
+    def __init__(self, yaml_file):
+        if not os.path.exists(yaml_file):
+            raise IOError("yaml file: {} is not existed".format(yaml_file))
+        super().__init__()
+        self.d = collections.OrderedDict()
+        with open(yaml_file) as fp:
+            for _, v in yaml().load(fp).items():
+                for k1, v1 in v.items():
+                    try:
+                        if self.get(k1):
+                            self.set_hparam(k1, v1)
+                        else:
+                            self.add_hparam(k1, v1)
+                        self.d[k1] = v1
+                    except Exception:
+                        import traceback
+
+                        print(traceback.format_exc())
+
+    # @property
+    def get_elements(self):
+        return self.d.items()
+
+
+def override_config(base_config, new_config):
+    """Update new configurations in the original dict with the new dict
+
+    Args:
+        base_config (dict): original dict to be overridden
+        new_config (dict): dict with new configurations
+
+    Returns:
+        dict: updated configuration dict
+    """
+    for k, v in new_config.items():
+        if type(v) == dict:
+            if k not in base_config.keys():
+                base_config[k] = {}
+            base_config[k] = override_config(base_config[k], v)
+        else:
+            base_config[k] = v
+    return base_config
+
+
+def get_lowercase_keys_config(cfg):
+    """Change all keys in cfg to lower case
+
+    Args:
+        cfg (dict): dictionary that stores configurations
+
+    Returns:
+        dict: dictionary that stores configurations
+    """
+    updated_cfg = dict()
+    for k, v in cfg.items():
+        if type(v) == dict:
+            v = get_lowercase_keys_config(v)
+        updated_cfg[k.lower()] = v
+    return updated_cfg
+
+
+def _load_config(config_fn, lowercase=False):
+    """Load configurations into a dictionary
+
+    Args:
+        config_fn (str): path to configuration file
+        lowercase (bool, optional): whether changing keys to lower case. Defaults to False.
+
+    Returns:
+        dict: dictionary that stores configurations
+    """
+    with open(config_fn, "r") as f:
+        data = f.read()
+    config_ = json5.loads(data)
+    if "base_config" in config_:
+        # load configurations from new path
+        p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"])
+        p_config_ = _load_config(p_config_path)
+        config_ = override_config(p_config_, config_)
+    if lowercase:
+        # change keys in config_ to lower case
+        config_ = get_lowercase_keys_config(config_)
+    return config_
+
+
+def load_config(config_fn, lowercase=False):
+    """Load configurations into a dictionary
+
+    Args:
+        config_fn (str): path to configuration file
+        lowercase (bool, optional): _description_. Defaults to False.
+
+    Returns:
+        JsonHParams: an object that stores configurations
+    """
+    config_ = _load_config(config_fn, lowercase=lowercase)
+    # create an JsonHParams object with configuration dict
+    cfg = JsonHParams(**config_)
+    return cfg
+
+
+def save_config(save_path, cfg):
+    """Save configurations into a json file
+
+    Args:
+        save_path (str): path to save configurations
+        cfg (dict): dictionary that stores configurations
+    """
+    with open(save_path, "w") as f:
+        json5.dump(
+            cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True
+        )
+
+
+class JsonHParams:
+    def __init__(self, **kwargs):
+        for k, v in kwargs.items():
+            if type(v) == dict:
+                v = JsonHParams(**v)
+            self[k] = v
+
+    def keys(self):
+        return self.__dict__.keys()
+
+    def items(self):
+        return self.__dict__.items()
+
+    def values(self):
+        return self.__dict__.values()
+
+    def __len__(self):
+        return len(self.__dict__)
+
+    def __getitem__(self, key):
+        return getattr(self, key)
+
+    def __setitem__(self, key, value):
+        return setattr(self, key, value)
+
+    def __contains__(self, key):
+        return key in self.__dict__
+
+    def __repr__(self):
+        return self.__dict__.__repr__()
+
+
+class ValueWindow:
+    def __init__(self, window_size=100):
+        self._window_size = window_size
+        self._values = []
+
+    def append(self, x):
+        self._values = self._values[-(self._window_size - 1) :] + [x]
+
+    @property
+    def sum(self):
+        return sum(self._values)
+
+    @property
+    def count(self):
+        return len(self._values)
+
+    @property
+    def average(self):
+        return self.sum / max(1, self.count)
+
+    def reset(self):
+        self._values = []
+
+
+class Logger(object):
+    def __init__(
+        self,
+        filename,
+        level="info",
+        when="D",
+        backCount=10,
+        fmt="%(asctime)s : %(message)s",
+    ):
+        self.level_relations = {
+            "debug": logging.DEBUG,
+            "info": logging.INFO,
+            "warning": logging.WARNING,
+            "error": logging.ERROR,
+            "crit": logging.CRITICAL,
+        }
+        if level == "debug":
+            fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s"
+        self.logger = logging.getLogger(filename)
+        format_str = logging.Formatter(fmt)
+        self.logger.setLevel(self.level_relations.get(level))
+        sh = logging.StreamHandler()
+        sh.setFormatter(format_str)
+        th = handlers.TimedRotatingFileHandler(
+            filename=filename, when=when, backupCount=backCount, encoding="utf-8"
+        )
+        th.setFormatter(format_str)
+        self.logger.addHandler(sh)
+        self.logger.addHandler(th)
+        self.logger.info(
+            "==========================New Starting Here=============================="
+        )
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)
+
+
+def slice_segments(x, ids_str, segment_size=4):
+    ret = torch.zeros_like(x[:, :, :segment_size])
+    for i in range(x.size(0)):
+        idx_str = ids_str[i]
+        idx_end = idx_str + segment_size
+        ret[i] = x[i, :, idx_str:idx_end]
+    return ret
+
+
+def rand_slice_segments(x, x_lengths=None, segment_size=4):
+    b, d, t = x.size()
+    if x_lengths is None:
+        x_lengths = t
+    ids_str_max = x_lengths - segment_size + 1
+    ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
+    ret = slice_segments(x, ids_str, segment_size)
+    return ret, ids_str
+
+
+def subsequent_mask(length):
+    mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
+    return mask
+
+
+@torch.jit.script
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+    n_channels_int = n_channels[0]
+    in_act = input_a + input_b
+    t_act = torch.tanh(in_act[:, :n_channels_int, :])
+    s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+    acts = t_act * s_act
+    return acts
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def generate_path(duration, mask):
+    """
+    duration: [b, 1, t_x]
+    mask: [b, 1, t_y, t_x]
+    """
+    device = duration.device
+
+    b, _, t_y, t_x = mask.shape
+    cum_duration = torch.cumsum(duration, -1)
+
+    cum_duration_flat = cum_duration.view(b * t_x)
+    path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
+    path = path.view(b, t_x, t_y)
+    path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
+    path = path.unsqueeze(1).transpose(2, 3) * mask
+    return path
+
+
+def clip_grad_value_(parameters, clip_value, norm_type=2):
+    if isinstance(parameters, torch.Tensor):
+        parameters = [parameters]
+    parameters = list(filter(lambda p: p.grad is not None, parameters))
+    norm_type = float(norm_type)
+    if clip_value is not None:
+        clip_value = float(clip_value)
+
+    total_norm = 0
+    for p in parameters:
+        param_norm = p.grad.data.norm(norm_type)
+        total_norm += param_norm.item() ** norm_type
+        if clip_value is not None:
+            p.grad.data.clamp_(min=-clip_value, max=clip_value)
+    total_norm = total_norm ** (1.0 / norm_type)
+    return total_norm
+
+
+def get_current_time():
+    pass
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+    """
+    Args:
+      lengths:
+        A 1-D tensor containing sentence lengths.
+      max_len:
+        The length of masks.
+    Returns:
+      Return a 2-D bool tensor, where masked positions
+      are filled with `True` and non-masked positions are
+      filled with `False`.
+
+    >>> lengths = torch.tensor([1, 3, 2, 5])
+    >>> make_pad_mask(lengths)
+    tensor([[False,  True,  True,  True,  True],
+            [False, False, False,  True,  True],
+            [False, False,  True,  True,  True],
+            [False, False, False, False, False]])
+    """
+    assert lengths.ndim == 1, lengths.ndim
+    max_len = max(max_len, lengths.max())
+    n = lengths.size(0)
+    seq_range = torch.arange(0, max_len, device=lengths.device)
+    expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
+
+    return expaned_lengths >= lengths.unsqueeze(-1)
+
diff --git a/utils/whisper.py b/utils/whisper.py
new file mode 100644
index 0000000000000000000000000000000000000000..16462c7e84f2ce71d4fd5e57832d705fd07b95ca
--- /dev/null
+++ b/utils/whisper.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import os
+import pickle
+from tqdm import tqdm
+import numpy as np
+
+from modules import whisper_extractor as whisper
+
+
+def whisper_encoder_batch(model, audio_paths):
+    batch = len(audio_paths)
+    batch_mel = torch.zeros((batch, 80, 3000), dtype=torch.float32, device=model.device)
+
+    for i, audio_path in enumerate(audio_paths):
+        # (48000,)
+        audio = whisper.load_audio(str(audio_path))
+        audio = whisper.pad_or_trim(audio)
+
+        # (80, 3000)
+        mel = whisper.log_mel_spectrogram(audio).to(model.device)
+        batch_mel[i] = mel
+
+    with torch.no_grad():
+        # (batch, 1500, 1024)
+        features = model.embed_audio(batch_mel)
+
+    return features.cpu().detach().numpy()
+
+
+def whisper_encoder(model, audio_path):
+    audio = whisper.load_audio(str(audio_path))
+    audio = whisper.pad_or_trim(audio)
+
+    # (80, 3000)
+    mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
+
+    with torch.no_grad():
+        # (1, 1500, 1024) -> # (1500, 1024)
+        features = model.embed_audio(mel).squeeze(0)
+
+    return features.cpu().detach().numpy()
+
+
+def get_mapped_whisper_features(
+    raw_whisper_features, mapping_features, fast_mapping=True
+):
+    """
+    Whisper: frameshift = 20ms (30s audio -> 1500 frames), hop_size = 480 in 24k
+    # Ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/model.py#L136
+
+    Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms)
+    """
+    source_hop = 480
+    target_hop = 256
+
+    factor = np.gcd(source_hop, target_hop)
+    source_hop //= factor
+    target_hop //= factor
+    print(
+        "Mapping source's {} frames => target's {} frames".format(
+            target_hop, source_hop
+        )
+    )
+
+    max_source_len = 1500
+    whisper_features = []
+    for index, mapping_feat in enumerate(tqdm(mapping_features)):
+        # mapping_feat: (mels_frame_len, n_mels)
+        target_len = mapping_feat.shape[0]
+        # The max target_len is 2812
+        target_len = min(target_len, max_source_len * source_hop // target_hop)
+
+        # (1500, dim)
+        raw_feats = raw_whisper_features[index]
+        width = raw_feats.shape[-1]
+
+        if fast_mapping:
+            source_len = target_len * target_hop // source_hop + 1
+            raw_feats = raw_feats[:source_len]
+        else:
+            source_len = max_source_len
+
+        # const ~= target_len * target_hop
+        const = source_len * source_hop // target_hop * target_hop
+
+        # (source_len * source_hop, dim)
+        up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0)
+        # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
+        down_sampling_feats = np.average(
+            up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
+        )
+        assert len(down_sampling_feats) >= target_len
+
+        # (target_len, dim)
+        feats = down_sampling_feats[:target_len]
+        whisper_features.append(feats)
+
+    return whisper_features
+
+
+def load_whisper_model(hps):
+    print("Loading Whisper Model: ", hps.whisper_model)
+    model = whisper.load_model(hps.whisper_model)
+    if torch.cuda.is_available():
+        model = model.cuda()
+
+    model = model.eval()
+    return model
+
+
+def load_target_acoustic_features(
+    output_path, dataset, acoustic_features_name, acoustic_features_fs, dataset_type
+):
+    mapping_dir = os.path.join(
+        output_path,
+        dataset,
+        "{}/{}".format(acoustic_features_name, acoustic_features_fs),
+    )
+    with open(os.path.join(mapping_dir, "{}.pkl".format(dataset_type)), "rb") as f:
+        mapping_features = pickle.load(f)
+
+    # Mels: (n_mels, frame_len) -> (frame_len, n_mels)
+    if acoustic_features_name == "mels":
+        print("Transposing mel features...")
+        mapping_features = [feat.T for feat in mapping_features]
+
+    print(
+        "Mapping to the acoustic features {}, #sz = {}, feats[0] is {}".format(
+            acoustic_features_name, len(mapping_features), mapping_features[0].shape
+        )
+    )
+    return mapping_features
+
+
+def extract_whisper_features_of_dataset(
+    datasets,
+    model,
+    batch_size,
+    out_dir,
+):
+    audio_paths = [utt["Path"] for utt in datasets]
+    if len(audio_paths) < batch_size:
+        batch_size = len(audio_paths)
+
+    start, end = 0, 0
+    while end < len(audio_paths):
+        # Raw features: (batch_size, 1500, dim)
+        start = end
+        end = start + batch_size
+        tmp_raw_whisper_features = whisper_encoder_batch(model, audio_paths[start:end])
+
+        # Mapping to acoustic features' lengths
+        for index, utt in enumerate(tqdm(datasets[start:end])):
+            uid = utt["Uid"]
+            raw_whisper_feature = tmp_raw_whisper_features[index]
+
+            save_path = os.path.join(out_dir, uid + ".npy")
+            np.save(save_path, raw_whisper_feature)
+
+        print("{}/{} Done...".format(end, len(audio_paths)))
diff --git a/utils/world.py b/utils/world.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5f61bd9b571607fd83da6b22283757e67201da
--- /dev/null
+++ b/utils/world.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2023 Amphion.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+# 1. Extract WORLD features including F0, AP, SP
+# 2. Transform between SP and MCEP
+import torchaudio
+import pyworld as pw
+import numpy as np
+import torch
+import diffsptk
+import os
+from tqdm import tqdm
+import pickle
+import torchaudio
+
+
+def get_mcep_params(fs):
+    """Hyperparameters of transformation between SP and MCEP
+
+    Reference:
+        https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh
+
+    """
+    if fs in [44100, 48000]:
+        fft_size = 2048
+        alpha = 0.77
+    if fs in [16000]:
+        fft_size = 1024
+        alpha = 0.58
+    return fft_size, alpha
+
+
+def extract_world_features(waveform, frameshift=10):
+    # waveform: (1, seq)
+    # x: (seq,)
+    x = np.array(waveform, dtype=np.double)
+
+    _f0, t = pw.dio(x, fs, frame_period=frameshift)  # raw pitch extractor
+    f0 = pw.stonemask(x, _f0, t, fs)  # pitch refinement
+    sp = pw.cheaptrick(x, f0, t, fs)  # extract smoothed spectrogram
+    ap = pw.d4c(x, f0, t, fs)  # extract aperiodicity
+
+    return f0, sp, ap, fs
+
+
+def sp2mcep(x, mcsize, fs):
+    fft_size, alpha = get_mcep_params(fs)
+    x = torch.as_tensor(x, dtype=torch.float)
+
+    tmp = diffsptk.ScalarOperation("SquareRoot")(x)
+    tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp)
+    mgc = diffsptk.MelCepstralAnalysis(
+        cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1
+    )(tmp)
+    return mgc.numpy()
+
+
+def mcep2sp(x, mcsize, fs):
+    fft_size, alpha = get_mcep_params(fs)
+    x = torch.as_tensor(x, dtype=torch.float)
+
+    tmp = diffsptk.MelGeneralizedCepstrumToSpectrum(
+        alpha=alpha,
+        cep_order=mcsize - 1,
+        fft_length=fft_size,
+    )(x)
+    tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp)
+    sp = diffsptk.ScalarOperation("Power", 2)(tmp)
+    return sp.double().numpy()
+
+
+def f0_statistics(f0_features, path):
+    print("\nF0 statistics...")
+
+    total_f0 = []
+    for f0 in tqdm(f0_features):
+        total_f0 += [f for f in f0 if f != 0]
+
+    mean = sum(total_f0) / len(total_f0)
+    print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean))
+
+    with open(path, "wb") as f:
+        pickle.dump([mean, total_f0], f)
+
+
+def world_synthesis(f0, sp, ap, fs, frameshift):
+    y = pw.synthesize(
+        f0, sp, ap, fs, frame_period=frameshift
+    )  # synthesize an utterance using the parameters
+    return y